Showing
94 changed files
with
4619 additions
and
0 deletions
.gitignore
0 → 100644
1 | +# Created by https://www.gitignore.io/api/django | ||
2 | +# Edit at https://www.gitignore.io/?templates=django | ||
3 | + | ||
4 | +### Django ### | ||
5 | +*.log | ||
6 | +*.pot | ||
7 | +*.pyc | ||
8 | +__pycache__/ | ||
9 | +local_settings.py | ||
10 | +db.sqlite3 | ||
11 | +db.sqlite3-journal | ||
12 | +media | ||
13 | +static/ | ||
14 | +/static | ||
15 | +/backend/env | ||
16 | +env/ | ||
17 | +/env | ||
18 | + | ||
19 | +# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/ | ||
20 | +# in your Git repository. Update and uncomment the following line accordingly. | ||
21 | +# <django-project-name>/staticfiles/ | ||
22 | + | ||
23 | +### Django.Python Stack ### | ||
24 | +# Byte-compiled / optimized / DLL files | ||
25 | +*.py[cod] | ||
26 | +*$py.class | ||
27 | + | ||
28 | +# C extensions | ||
29 | +*.so | ||
30 | + | ||
31 | +# Distribution / packaging | ||
32 | +.Python | ||
33 | +build/ | ||
34 | +develop-eggs/ | ||
35 | +dist/ | ||
36 | +downloads/ | ||
37 | +eggs/ | ||
38 | +.eggs/ | ||
39 | +lib/ | ||
40 | +lib64/ | ||
41 | +parts/ | ||
42 | +sdist/ | ||
43 | +var/ | ||
44 | +wheels/ | ||
45 | +pip-wheel-metadata/ | ||
46 | +share/python-wheels/ | ||
47 | +*.egg-info/ | ||
48 | +.installed.cfg | ||
49 | +*.egg | ||
50 | +MANIFEST | ||
51 | + | ||
52 | +# PyInstaller | ||
53 | +# Usually these files are written by a python script from a template | ||
54 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
55 | +*.manifest | ||
56 | +*.spec | ||
57 | + | ||
58 | +# Installer logs | ||
59 | +pip-log.txt | ||
60 | +pip-delete-this-directory.txt | ||
61 | + | ||
62 | +# Unit test / coverage reports | ||
63 | +htmlcov/ | ||
64 | +.tox/ | ||
65 | +.nox/ | ||
66 | +.coverage | ||
67 | +.coverage.* | ||
68 | +.cache | ||
69 | +nosetests.xml | ||
70 | +coverage.xml | ||
71 | +*.cover | ||
72 | +.hypothesis/ | ||
73 | +.pytest_cache/ | ||
74 | + | ||
75 | +# Translations | ||
76 | +*.mo | ||
77 | + | ||
78 | +# Scrapy stuff: | ||
79 | +.scrapy | ||
80 | + | ||
81 | +# Sphinx documentation | ||
82 | +docs/_build/ | ||
83 | + | ||
84 | +# PyBuilder | ||
85 | +target/ | ||
86 | + | ||
87 | +# pyenv | ||
88 | +.python-version | ||
89 | + | ||
90 | +# pipenv | ||
91 | +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
92 | +# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
93 | +# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
94 | +# install all needed dependencies. | ||
95 | +#Pipfile.lock | ||
96 | + | ||
97 | +# celery beat schedule file | ||
98 | +celerybeat-schedule | ||
99 | + | ||
100 | +# SageMath parsed files | ||
101 | +*.sage.py | ||
102 | + | ||
103 | +# Spyder project settings | ||
104 | +.spyderproject | ||
105 | +.spyproject | ||
106 | + | ||
107 | +# Rope project settings | ||
108 | +.ropeproject | ||
109 | + | ||
110 | +# Mr Developer | ||
111 | +.mr.developer.cfg | ||
112 | +.project | ||
113 | +.pydevproject | ||
114 | + | ||
115 | +# mkdocs documentation | ||
116 | +/site | ||
117 | + | ||
118 | +# mypy | ||
119 | +.mypy_cache/ | ||
120 | +.dmypy.json | ||
121 | +dmypy.json | ||
122 | + | ||
123 | +# Pyre type checker | ||
124 | +.pyre/ | ||
125 | + | ||
126 | +# End of https://www.gitignore.io/api/django | ||
127 | + | ||
128 | +.DS_Store | ||
129 | +node_modules | ||
130 | +/dist | ||
131 | + | ||
132 | +# local env files | ||
133 | +.env.local | ||
134 | +.env.*.local | ||
135 | + | ||
136 | +# Log files | ||
137 | +npm-debug.log* | ||
138 | +yarn-debug.log* | ||
139 | +yarn-error.log* | ||
140 | + | ||
141 | +# Editor directories and files | ||
142 | +.idea | ||
143 | +.vscode | ||
144 | +*.suo | ||
145 | +*.ntvs* | ||
146 | +*.njsproj | ||
147 | +*.sln | ||
148 | +*.sw? | ||
149 | + |
No preview for this file type
img/profit_hunter.png
0 → 100644

40.3 KB
web/backend/api/__init__.py
0 → 100644
1 | +import os | ||
2 | +from django.conf.global_settings import FILE_UPLOAD_MAX_MEMORY_SIZE | ||
3 | +import datetime | ||
4 | +import uuid | ||
5 | + | ||
6 | +# 각 media 파일에 대한 URL Prefix | ||
7 | +MEDIA_URL = '/media/' # 항상 / 로 끝나도록 설정 | ||
8 | +# MEDIA_URL = 'http://static.myservice.com/media/' 다른 서버로 media 파일 복사시 | ||
9 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
10 | + | ||
11 | +MEDIA_ROOT = os.path.join(BASE_DIR, 'media') | ||
12 | + | ||
13 | +# 파일 업로드 사이즈 100M ( 100 * 1024 * 1024 ) | ||
14 | +#FILE_UPLOAD_MAX_MEMORY_SIZE = 104857600 | ||
15 | + | ||
16 | +# 실제 파일을 저장할 경로 및 파일 명 생성 | ||
17 | +# 폴더는 일별로 생성됨 | ||
18 | +def file_upload_path( filename): | ||
19 | + ext = filename.split('.')[-1] | ||
20 | + d = datetime.datetime.now() | ||
21 | + filepath = d.strftime('%Y-%m-%d') | ||
22 | + suffix = d.strftime("%Y%m%d%H%M%S") | ||
23 | + filename = "%s.%s"%(suffix, ext) | ||
24 | + return os.path.join( MEDIA_ROOT , filepath, filename) | ||
25 | + | ||
26 | +# DB 필드에서 호출 | ||
27 | +def file_upload_path_for_db( intance, filename): | ||
28 | + return file_upload_path(filename) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/api/admin.py
0 → 100644
web/backend/api/apps.py
0 → 100644
web/backend/api/migrations/0001_initial.py
0 → 100644
1 | +# Generated by Django 3.0.5 on 2020-05-02 12:19 | ||
2 | + | ||
3 | +import api | ||
4 | +from django.db import migrations, models | ||
5 | + | ||
6 | + | ||
7 | +class Migration(migrations.Migration): | ||
8 | + | ||
9 | + initial = True | ||
10 | + | ||
11 | + dependencies = [ | ||
12 | + ] | ||
13 | + | ||
14 | + operations = [ | ||
15 | + migrations.CreateModel( | ||
16 | + name='Video', | ||
17 | + fields=[ | ||
18 | + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
19 | + ('videourl', models.CharField(blank=True, max_length=1000)), | ||
20 | + ('title', models.CharField(max_length=200)), | ||
21 | + ('tags', models.CharField(max_length=500)), | ||
22 | + ], | ||
23 | + ), | ||
24 | + migrations.CreateModel( | ||
25 | + name='VideoFile', | ||
26 | + fields=[ | ||
27 | + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
28 | + ('file_save_name', models.FileField(upload_to=api.file_upload_path_for_db)), | ||
29 | + ('file_origin_name', models.CharField(max_length=100)), | ||
30 | + ('file_path', models.CharField(max_length=100)), | ||
31 | + ], | ||
32 | + ), | ||
33 | + ] |
1 | +# Generated by Django 3.0.5 on 2020-05-31 05:39 | ||
2 | + | ||
3 | +from django.db import migrations, models | ||
4 | + | ||
5 | + | ||
6 | +class Migration(migrations.Migration): | ||
7 | + | ||
8 | + dependencies = [ | ||
9 | + ('api', '0001_initial'), | ||
10 | + ] | ||
11 | + | ||
12 | + operations = [ | ||
13 | + migrations.AddField( | ||
14 | + model_name='video', | ||
15 | + name='threshold', | ||
16 | + field=models.CharField(default=20, max_length=20), | ||
17 | + preserve_default=False, | ||
18 | + ), | ||
19 | + ] |
web/backend/api/migrations/__init__.py
0 → 100644
File mode changed
web/backend/api/models.py
0 → 100644
1 | +from django.db import models | ||
2 | +from api import file_upload_path_for_db | ||
3 | +# Create your models here. | ||
4 | + | ||
5 | + | ||
6 | +class Video(models.Model): | ||
7 | + videourl = models.CharField(max_length=1000, blank=True) | ||
8 | + title = models.CharField(max_length=200) | ||
9 | + threshold = models.CharField(max_length=20) | ||
10 | + tags = models.CharField(max_length=500) | ||
11 | + | ||
12 | + | ||
13 | +class VideoFile(models.Model): | ||
14 | + file_save_name = models.FileField(upload_to=file_upload_path_for_db, blank=False, null=False) | ||
15 | + # 파일의 원래 이름 | ||
16 | + file_origin_name = models.CharField(max_length=100) | ||
17 | + # 파일 저장 경로 | ||
18 | + file_path = models.CharField(max_length=100) | ||
19 | + | ||
20 | + def __str__(self): | ||
21 | + return self.file.name |
web/backend/api/serializers.py
0 → 100644
1 | +from rest_framework import serializers | ||
2 | +from api.models import Video, VideoFile | ||
3 | + | ||
4 | + | ||
5 | +class VideoSerializer(serializers.ModelSerializer): | ||
6 | + | ||
7 | + class Meta: | ||
8 | + model = Video | ||
9 | + fields = '__all__' | ||
10 | + | ||
11 | + | ||
12 | +class VideoFileSerializer(serializers.ModelSerializer): | ||
13 | + | ||
14 | + class Meta: | ||
15 | + model = VideoFile | ||
16 | + fields = '__all__' |
web/backend/api/tests.py
0 → 100644
web/backend/api/urls.py
0 → 100644
1 | +from django.urls import path, include | ||
2 | +from django.conf.urls import url | ||
3 | +from api.views import VideoFileUploadView, VideoFileList, FileListView | ||
4 | +from . import views | ||
5 | +from rest_framework.routers import DefaultRouter | ||
6 | +from django.views.generic import TemplateView | ||
7 | + | ||
8 | +router = DefaultRouter() | ||
9 | +router.register('db/videofile', views.VideoFileViewSet) | ||
10 | +router.register('db/video', views.VideoViewSet) | ||
11 | + | ||
12 | +urlpatterns = [ | ||
13 | + # FBV | ||
14 | + path('api/upload', VideoFileUploadView.as_view(), name="file-upload"), | ||
15 | + path('api/upload/<int:pk>/', VideoFileList.as_view(), name="file-list"), | ||
16 | + path('api/file', FileListView.as_view(), name="file"), | ||
17 | + url(r'^(?P<path>.*)$', TemplateView.as_view(template_name='index.html')), | ||
18 | + # path('api/upload', views.VideoFile_Upload), | ||
19 | + path('', include(router.urls)), | ||
20 | +] |
web/backend/api/views.py
0 → 100644
1 | +from rest_framework import status | ||
2 | +from rest_framework.views import APIView | ||
3 | +from rest_framework.response import Response | ||
4 | +from rest_framework.parsers import MultiPartParser, FormParser | ||
5 | +from rest_framework.response import Response | ||
6 | +from rest_framework import viewsets | ||
7 | +import os | ||
8 | +from django.http.request import QueryDict | ||
9 | +from django.http import Http404 | ||
10 | +from django.http import JsonResponse | ||
11 | +from django.shortcuts import get_object_or_404, render | ||
12 | +from api.models import Video, VideoFile | ||
13 | +from api.serializers import VideoSerializer, VideoFileSerializer | ||
14 | +from api import file_upload_path | ||
15 | + | ||
16 | +import subprocess | ||
17 | +import shlex | ||
18 | +import json | ||
19 | +# Create your views here. | ||
20 | +import sys | ||
21 | +sys.path.insert(0, "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria") | ||
22 | +import inference_pb | ||
23 | + | ||
24 | +def with_ffprobe(filename): | ||
25 | + | ||
26 | + result = subprocess.check_output( | ||
27 | + f'ffprobe -v quiet -show_streams -select_streams v:0 -of json "{filename}"', | ||
28 | + shell=True).decode() | ||
29 | + fields = json.loads(result)['streams'][0] | ||
30 | + duration = int(float(fields['duration'])) | ||
31 | + return duration | ||
32 | + | ||
33 | +def index(request): | ||
34 | + return render(request, template_name='index.html') | ||
35 | + | ||
36 | +class VideoViewSet(viewsets.ModelViewSet): | ||
37 | + queryset = Video.objects.all() | ||
38 | + serializer_class = VideoSerializer | ||
39 | + | ||
40 | +class VideoFileViewSet(viewsets.ModelViewSet): | ||
41 | + queryset = VideoFile.objects.all() | ||
42 | + serializer_class = VideoFileSerializer | ||
43 | + | ||
44 | + | ||
45 | +class VideoFileUploadView(APIView): | ||
46 | + parser_classes = (MultiPartParser, FormParser) | ||
47 | + | ||
48 | + def get(self, request, format=None): | ||
49 | + videoFiles = VideoFile.objects.all() | ||
50 | + serializer = VideoFileSerializer(videoFiles, many=True) | ||
51 | + return Response(serializer.data) | ||
52 | + | ||
53 | + def post(self, req, *args, **kwargs): | ||
54 | + # 동영상 길이 | ||
55 | + runTime = 0 | ||
56 | + # 요청된 데이터를 꺼냄( QueryDict) | ||
57 | + new_data = req.data.dict() | ||
58 | + | ||
59 | + # 요청된 파일 객체 | ||
60 | + file_name = req.data['file'] | ||
61 | + threshold = req.data['threshold'] | ||
62 | + | ||
63 | + # 저장될 파일의 풀path를 생성 | ||
64 | + new_file_full_name = file_upload_path(file_name.name) | ||
65 | + # 새롭게 생성된 파일의 경로 | ||
66 | + file_path = '-'.join(new_file_full_name.split('-')[0:-1]) | ||
67 | + | ||
68 | + new_data['file_path'] = file_path | ||
69 | + new_data['file_origin_name'] = req.data['file'].name | ||
70 | + new_data['file_save_name'] = req.data['file'] | ||
71 | + | ||
72 | + new_query_dict = QueryDict('', mutable=True) | ||
73 | + new_query_dict.update(new_data) | ||
74 | + file_serializer = VideoFileSerializer(data = new_query_dict) | ||
75 | + | ||
76 | + if file_serializer.is_valid(): | ||
77 | + file_serializer.save() | ||
78 | + # 동영상 길이 출력 | ||
79 | + runTime = with_ffprobe('/'+file_serializer.data['file_save_name']) | ||
80 | + print(runTime) | ||
81 | + print(threshold) | ||
82 | + process = subprocess.Popen(['./runMediaPipe.sh %s %s' %(file_serializer.data['file_save_name'],runTime,)], shell = True) | ||
83 | + process.wait() | ||
84 | + | ||
85 | + | ||
86 | + result = inference_pb.inference_pb('/tmp/mediapipe/features.pb', threshold) | ||
87 | + | ||
88 | + return Response(result, status=status.HTTP_201_CREATED) | ||
89 | + else: | ||
90 | + return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) | ||
91 | + | ||
92 | + | ||
93 | +class VideoFileList(APIView): | ||
94 | + | ||
95 | + def get_object(self, pk): | ||
96 | + try: | ||
97 | + return VideoFile.objects.get(pk=pk) | ||
98 | + except VideoFile.DoesNotExist: | ||
99 | + raise Http404 | ||
100 | + | ||
101 | + def get(self, request, pk, format=None): | ||
102 | + video = self.get_object(pk) | ||
103 | + serializer = VideoFileSerializer(video) | ||
104 | + return Response(serializer.data) | ||
105 | + | ||
106 | + def put(self, request, pk, format=None): | ||
107 | + video = self.get_object(pk) | ||
108 | + serializer = VideoFileSerializer(video, data=request.data) | ||
109 | + if serializer.is_valid(): | ||
110 | + serializer.save() | ||
111 | + return Response(serializer.data) | ||
112 | + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) | ||
113 | + | ||
114 | + def delete(self, request, pk, format=None): | ||
115 | + video = self.get_object(pk) | ||
116 | + video.delete() | ||
117 | + return Response(status=status.HTTP_204_NO_CONTENT) | ||
118 | + | ||
119 | + | ||
120 | +class FileListView(APIView): | ||
121 | + def get(self, request): | ||
122 | + data = { | ||
123 | + "search": '', | ||
124 | + "limit": 10, | ||
125 | + "skip": 0, | ||
126 | + "order": "time", | ||
127 | + "fileList": [ | ||
128 | + { | ||
129 | + "name": "1.png", | ||
130 | + "created": "2020-04-30", | ||
131 | + "size": 10234, | ||
132 | + "isFolder": False, | ||
133 | + "deletedDate": "", | ||
134 | + }, | ||
135 | + { | ||
136 | + "name": "2.png", | ||
137 | + "created": "2020-04-30", | ||
138 | + "size": 3145, | ||
139 | + "isFolder": False, | ||
140 | + "deletedDate": "", | ||
141 | + }, | ||
142 | + { | ||
143 | + "name": "3.png", | ||
144 | + "created": "2020-05-01", | ||
145 | + "size": 5653, | ||
146 | + "isFolder": False, | ||
147 | + "deletedDate": "", | ||
148 | + }, | ||
149 | + ] | ||
150 | + } | ||
151 | + return Response(data) | ||
152 | + def post(self, request, format=None): | ||
153 | + data = { | ||
154 | + "isSuccess": True, | ||
155 | + "File": { | ||
156 | + "name": "test.jpg", | ||
157 | + "created": "2020-05-02", | ||
158 | + "deletedDate": "", | ||
159 | + "size": 2312, | ||
160 | + "isFolder": False | ||
161 | + } | ||
162 | + } | ||
163 | + return Response(data) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/backend/__init__.py
0 → 100644
File mode changed
web/backend/backend/asgi.py
0 → 100644
1 | +""" | ||
2 | +ASGI config for backend project. | ||
3 | + | ||
4 | +It exposes the ASGI callable as a module-level variable named ``application``. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | ||
8 | +""" | ||
9 | + | ||
10 | +import os | ||
11 | + | ||
12 | +from django.core.asgi import get_asgi_application | ||
13 | + | ||
14 | +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
15 | + | ||
16 | +application = get_asgi_application() |
web/backend/backend/settings.py
0 → 100644
1 | +""" | ||
2 | +Django settings for backend project. | ||
3 | + | ||
4 | +Generated by 'django-admin startproject' using Django 3.0.5. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/topics/settings/ | ||
8 | + | ||
9 | +For the full list of settings and their values, see | ||
10 | +https://docs.djangoproject.com/en/3.0/ref/settings/ | ||
11 | +""" | ||
12 | + | ||
13 | +import os | ||
14 | + | ||
15 | +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) | ||
16 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
17 | + | ||
18 | + | ||
19 | +# Quick-start development settings - unsuitable for production | ||
20 | +# See https://docs.djangoproject.com/en/3.0/howto/deployment/checklist/ | ||
21 | + | ||
22 | +# SECURITY WARNING: keep the secret key used in production secret! | ||
23 | +SECRET_KEY = '7e^4!u6019jww&=-!mu%r$hz6jy#=i+i9@9m_44+ga^#%7#e0l' | ||
24 | + | ||
25 | +# SECURITY WARNING: don't run with debug turned on in production! | ||
26 | +DEBUG = True | ||
27 | + | ||
28 | +ALLOWED_HOSTS = ['*'] | ||
29 | + | ||
30 | + | ||
31 | +# Static File DIR Settings | ||
32 | +STATIC_URL = '/static/' | ||
33 | +STATIC_ROOT = os.path.join(BASE_DIR, 'static') | ||
34 | + | ||
35 | + | ||
36 | +FRONTEND_DIR = os.path.join(os.path.abspath('../'), 'frontend') | ||
37 | +STATICFILES_DIRS = [ | ||
38 | + os.path.join(FRONTEND_DIR, 'dist/'), | ||
39 | +] | ||
40 | +# Application definition | ||
41 | + | ||
42 | +REST_FRAMEWORK = { | ||
43 | + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', | ||
44 | + 'PAGE_SIZE': 10 | ||
45 | +} | ||
46 | + | ||
47 | +INSTALLED_APPS = [ | ||
48 | + 'django.contrib.admin', | ||
49 | + 'django.contrib.auth', | ||
50 | + 'django.contrib.contenttypes', | ||
51 | + 'django.contrib.sessions', | ||
52 | + 'django.contrib.messages', | ||
53 | + 'django.contrib.staticfiles', | ||
54 | + 'rest_framework', | ||
55 | + 'corsheaders', | ||
56 | + 'api' | ||
57 | +] | ||
58 | + | ||
59 | +MIDDLEWARE = [ | ||
60 | + 'django.middleware.security.SecurityMiddleware', | ||
61 | + 'django.contrib.sessions.middleware.SessionMiddleware', | ||
62 | + 'django.middleware.common.CommonMiddleware', | ||
63 | + # 'django.middleware.csrf.CsrfViewMiddleware', | ||
64 | + 'django.contrib.auth.middleware.AuthenticationMiddleware', | ||
65 | + 'django.contrib.messages.middleware.MessageMiddleware', | ||
66 | + 'django.middleware.clickjacking.XFrameOptionsMiddleware', | ||
67 | + 'corsheaders.middleware.CorsMiddleware', | ||
68 | +] | ||
69 | + | ||
70 | +ROOT_URLCONF = 'backend.urls' | ||
71 | + | ||
72 | +TEMPLATES = [ | ||
73 | + { | ||
74 | + 'BACKEND': 'django.template.backends.django.DjangoTemplates', | ||
75 | + 'DIRS': [ | ||
76 | + os.path.join(BASE_DIR, 'templates'), | ||
77 | + ], | ||
78 | + 'APP_DIRS': True, | ||
79 | + 'OPTIONS': { | ||
80 | + 'context_processors': [ | ||
81 | + 'django.template.context_processors.debug', | ||
82 | + 'django.template.context_processors.request', | ||
83 | + 'django.contrib.auth.context_processors.auth', | ||
84 | + 'django.contrib.messages.context_processors.messages', | ||
85 | + ], | ||
86 | + }, | ||
87 | + }, | ||
88 | +] | ||
89 | + | ||
90 | +WSGI_APPLICATION = 'backend.wsgi.application' | ||
91 | + | ||
92 | + | ||
93 | +# Database | ||
94 | +# https://docs.djangoproject.com/en/3.0/ref/settings/#databases | ||
95 | + | ||
96 | +DATABASES = { | ||
97 | + 'default': { | ||
98 | + 'ENGINE': 'django.db.backends.sqlite3', | ||
99 | + 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), | ||
100 | + } | ||
101 | +} | ||
102 | + | ||
103 | + | ||
104 | +# Password validation | ||
105 | +# https://docs.djangoproject.com/en/3.0/ref/settings/#auth-password-validators | ||
106 | + | ||
107 | +AUTH_PASSWORD_VALIDATORS = [ | ||
108 | + { | ||
109 | + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', | ||
110 | + }, | ||
111 | + { | ||
112 | + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', | ||
113 | + }, | ||
114 | + { | ||
115 | + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', | ||
116 | + }, | ||
117 | + { | ||
118 | + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', | ||
119 | + }, | ||
120 | +] | ||
121 | + | ||
122 | + | ||
123 | +# Internationalization | ||
124 | +# https://docs.djangoproject.com/en/3.0/topics/i18n/ | ||
125 | + | ||
126 | +LANGUAGE_CODE = 'en-us' | ||
127 | + | ||
128 | +TIME_ZONE = 'UTC' | ||
129 | + | ||
130 | +USE_I18N = True | ||
131 | + | ||
132 | +USE_L10N = True | ||
133 | + | ||
134 | +USE_TZ = True | ||
135 | + | ||
136 | + | ||
137 | +# CORS settings | ||
138 | + | ||
139 | +CORS_ORIGIN_ALLOW_ALL = True | ||
140 | +CORS_ALLOW_CREDENTIALS = True | ||
141 | +CORS_ORIGIN_WHITELIST = [ | ||
142 | + "http://127.0.0.1:12233", | ||
143 | + "http://localhost:8080", | ||
144 | + "http://127.0.0.1:9000" | ||
145 | +] | ||
146 | +CORS_URLS_REGEX = r'^/api/.*$' | ||
147 | +CORS_ALLOW_METHODS = ( | ||
148 | + 'DELETE', | ||
149 | + 'GET', | ||
150 | + 'OPTIONS', | ||
151 | + 'PATCH', | ||
152 | + 'POST', | ||
153 | + 'PUT', | ||
154 | +) | ||
155 | + | ||
156 | +CORS_ALLOW_HEADERS = ( | ||
157 | + 'accept', | ||
158 | + 'accept-encoding', | ||
159 | + 'authorization', | ||
160 | + 'content-type', | ||
161 | + 'dnt', | ||
162 | + 'origin', | ||
163 | + 'user-agent', | ||
164 | + 'x-csrftoken', | ||
165 | + 'x-requested-with', | ||
166 | +) |
web/backend/backend/urls.py
0 → 100644
1 | +"""backend URL Configuration | ||
2 | + | ||
3 | +The `urlpatterns` list routes URLs to views. For more information please see: | ||
4 | + https://docs.djangoproject.com/en/3.0/topics/http/urls/ | ||
5 | +Examples: | ||
6 | +Function views | ||
7 | + 1. Add an import: from my_app import views | ||
8 | + 2. Add a URL to urlpatterns: path('', views.home, name='home') | ||
9 | +Class-based views | ||
10 | + 1. Add an import: from other_app.views import Home | ||
11 | + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') | ||
12 | +Including another URLconf | ||
13 | + 1. Import the include() function: from django.urls import include, path | ||
14 | + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) | ||
15 | +""" | ||
16 | +from django.contrib import admin | ||
17 | +from django.urls import path | ||
18 | +from django.conf.urls import url, include | ||
19 | +from django.conf import settings | ||
20 | +from django.conf.urls.static import static | ||
21 | + | ||
22 | +urlpatterns = [ | ||
23 | + path('', include('api.urls')), | ||
24 | + path('admin/', admin.site.urls), | ||
25 | +] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/backend/wsgi.py
0 → 100644
1 | +""" | ||
2 | +WSGI config for backend project. | ||
3 | + | ||
4 | +It exposes the WSGI callable as a module-level variable named ``application``. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/howto/deployment/wsgi/ | ||
8 | +""" | ||
9 | + | ||
10 | +import os | ||
11 | + | ||
12 | +from django.core.wsgi import get_wsgi_application | ||
13 | + | ||
14 | +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
15 | + | ||
16 | +application = get_wsgi_application() |
web/backend/convertPb2Tfrecord.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import glob, os | ||
3 | +import numpy | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + audio_seg = quantize(seg) | ||
68 | + byte_audio.append(audio_seg) | ||
69 | + | ||
70 | + for seg in rgb_slices: | ||
71 | + rgb_seg = quantize(seg) | ||
72 | + byte_rgb.append(rgb_seg) | ||
73 | + | ||
74 | + return byte_audio, byte_rgb | ||
75 | + | ||
76 | + | ||
77 | +def make_exmp(id, labels, audio, rgb): | ||
78 | + audio_features = [] | ||
79 | + rgb_features = [] | ||
80 | + | ||
81 | + for embedding in audio: | ||
82 | + embedding_feature = tf.train.Feature( | ||
83 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
84 | + audio_features.append(embedding_feature) | ||
85 | + | ||
86 | + for embedding in rgb: | ||
87 | + embedding_feature = tf.train.Feature( | ||
88 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
89 | + rgb_features.append(embedding_feature) | ||
90 | + | ||
91 | + # for construct yt8m data | ||
92 | + seq_exmp = tf.train.SequenceExample( | ||
93 | + context=tf.train.Features( | ||
94 | + feature={ | ||
95 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
96 | + value=[id.encode('utf-8')])), | ||
97 | + 'labels': tf.train.Feature(int64_list=tf.train.Int64List( | ||
98 | + value=[labels])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +if __name__ == '__main__': | ||
115 | + filename = '/tmp/mediapipe/features.pb' | ||
116 | + | ||
117 | + sequence_example = open(filename, 'rb').read() | ||
118 | + | ||
119 | + audio, rgb = parse_exmp(sequence_example) | ||
120 | + | ||
121 | + id = 'test_001' | ||
122 | + | ||
123 | + labels = 1 | ||
124 | + | ||
125 | + tmp_example = make_exmp(id, labels, audio, rgb) | ||
126 | + | ||
127 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
128 | + print(decoded) | ||
129 | + | ||
130 | + # then you can write tmp_example to tfrecord files | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/manage.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | +"""Django's command-line utility for administrative tasks.""" | ||
3 | +import os | ||
4 | +import sys | ||
5 | + | ||
6 | + | ||
7 | +def main(): | ||
8 | + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
9 | + try: | ||
10 | + from django.core.management import execute_from_command_line | ||
11 | + except ImportError as exc: | ||
12 | + raise ImportError( | ||
13 | + "Couldn't import Django. Are you sure it's installed and " | ||
14 | + "available on your PYTHONPATH environment variable? Did you " | ||
15 | + "forget to activate a virtual environment?" | ||
16 | + ) from exc | ||
17 | + execute_from_command_line(sys.argv) | ||
18 | + | ||
19 | + | ||
20 | +if __name__ == '__main__': | ||
21 | + main() |
web/backend/requirements.txt
0 → 100644
web/backend/run
0 → 100644
web/backend/runMediaPipe.sh
0 → 100644
1 | +#!/bin/bash | ||
2 | +cd ../../../mediapipe | ||
3 | +. venv/bin/activate | ||
4 | + | ||
5 | +/usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \ | ||
6 | +alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel' | ||
7 | + | ||
8 | +python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ | ||
9 | + --path_to_input_video=/$1 \ | ||
10 | + --clip_end_time_sec=$2 | ||
11 | + | ||
12 | +GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \ | ||
13 | + --calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \ | ||
14 | + --input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.pb \ | ||
15 | + --output_side_packets=output_sequence_example=/tmp/mediapipe/features.pb | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/templates/index.html
0 → 100644
1 | +{% load static %} | ||
2 | +<!doctype html> | ||
3 | +<html lang="en"> | ||
4 | + | ||
5 | +<head> | ||
6 | + <meta charset=utf-8> | ||
7 | + <meta http-equiv=X-UA-Compatible content="IE=edge"> | ||
8 | + <meta name=viewport content="width=device-width,initial-scale=1"> | ||
9 | + <link rel=icon href="{% static 'favicon.ico' %}"> | ||
10 | + <title>ThrowBox</title> | ||
11 | + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=preload as=style> | ||
12 | + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=preload as=style> | ||
13 | + <link href="{% static '/js/app.1e0612ff.js' %}" rel=preload as=script> | ||
14 | + <link href="{% static '/js/chunk-vendors.951af5fd.js' %}" rel=preload as=script> | ||
15 | + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=stylesheet> | ||
16 | + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=stylesheet> | ||
17 | +</head> | ||
18 | + | ||
19 | +<body> | ||
20 | + <noscript><strong>We're sorry but ThrowBox doesn't work properly without JavaScript enabled. Please enable it to | ||
21 | + continue.</strong></noscript> | ||
22 | + <div id=app> | ||
23 | + </div> | ||
24 | + <script src="{% static '/js/chunk-vendors.951af5fd.js' %}"></script> | ||
25 | + <script src="{% static '/js/app.1e0612ff.js' %}"></script> | ||
26 | +</body> | ||
27 | + | ||
28 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/vue2djangoTemplate.py
0 → 100644
web/backend/yt8m/__init__.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate or keep track of the interpolated average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating interpolated average precision for an | ||
17 | +entire list or the top-n ranked items. For the definition of the | ||
18 | +(non-)interpolated average precision: | ||
19 | +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | ||
20 | + | ||
21 | +Example usages: | ||
22 | +1) Use it as a static function call to directly calculate average precision for | ||
23 | +a short ranked list in the memory. | ||
24 | + | ||
25 | +``` | ||
26 | +import random | ||
27 | + | ||
28 | +p = np.array([random.random() for _ in xrange(10)]) | ||
29 | +a = np.array([random.choice([0, 1]) for _ in xrange(10)]) | ||
30 | + | ||
31 | +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) | ||
32 | +``` | ||
33 | + | ||
34 | +2) Use it as an object for long ranked list that cannot be stored in memory or | ||
35 | +the case where partial predictions can be observed at a time (Tensorflow | ||
36 | +predictions). In this case, we first call the function accumulate many times | ||
37 | +to process parts of the ranked list. After processing all the parts, we call | ||
38 | +peek_interpolated_ap_at_n. | ||
39 | +``` | ||
40 | +p1 = np.array([random.random() for _ in xrange(5)]) | ||
41 | +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
42 | +p2 = np.array([random.random() for _ in xrange(5)]) | ||
43 | +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
44 | + | ||
45 | +# interpolated average precision at 10 using 1000 break points | ||
46 | +calculator = average_precision_calculator.AveragePrecisionCalculator(10) | ||
47 | +calculator.accumulate(p1, a1) | ||
48 | +calculator.accumulate(p2, a2) | ||
49 | +ap3 = calculator.peek_ap_at_n() | ||
50 | +``` | ||
51 | +""" | ||
52 | + | ||
53 | +import heapq | ||
54 | +import random | ||
55 | +import numbers | ||
56 | + | ||
57 | +import numpy | ||
58 | + | ||
59 | + | ||
60 | +class AveragePrecisionCalculator(object): | ||
61 | + """Calculate the average precision and average precision at n.""" | ||
62 | + | ||
63 | + def __init__(self, top_n=None): | ||
64 | + """Construct an AveragePrecisionCalculator to calculate average precision. | ||
65 | + | ||
66 | + This class is used to calculate the average precision for a single label. | ||
67 | + | ||
68 | + Args: | ||
69 | + top_n: A positive Integer specifying the average precision at n, or None | ||
70 | + to use all provided data points. | ||
71 | + | ||
72 | + Raises: | ||
73 | + ValueError: An error occurred when the top_n is not a positive integer. | ||
74 | + """ | ||
75 | + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): | ||
76 | + raise ValueError("top_n must be a positive integer or None.") | ||
77 | + | ||
78 | + self._top_n = top_n # average precision at n | ||
79 | + self._total_positives = 0 # total number of positives have seen | ||
80 | + self._heap = [] # max heap of (prediction, actual) | ||
81 | + | ||
82 | + @property | ||
83 | + def heap_size(self): | ||
84 | + """Gets the heap size maintained in the class.""" | ||
85 | + return len(self._heap) | ||
86 | + | ||
87 | + @property | ||
88 | + def num_accumulated_positives(self): | ||
89 | + """Gets the number of positive samples that have been accumulated.""" | ||
90 | + return self._total_positives | ||
91 | + | ||
92 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
93 | + """Accumulate the predictions and their ground truth labels. | ||
94 | + | ||
95 | + After the function call, we may call peek_ap_at_n to actually calculate | ||
96 | + the average precision. | ||
97 | + Note predictions and actuals must have the same shape. | ||
98 | + | ||
99 | + Args: | ||
100 | + predictions: a list storing the prediction scores. | ||
101 | + actuals: a list storing the ground truth labels. Any value larger than 0 | ||
102 | + will be treated as positives, otherwise as negatives. num_positives = If | ||
103 | + the 'predictions' and 'actuals' inputs aren't complete, then it's | ||
104 | + possible some true positives were missed in them. In that case, you can | ||
105 | + provide 'num_positives' in order to accurately track recall. | ||
106 | + | ||
107 | + Raises: | ||
108 | + ValueError: An error occurred when the format of the input is not the | ||
109 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
110 | + """ | ||
111 | + if len(predictions) != len(actuals): | ||
112 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
113 | + | ||
114 | + if num_positives is not None: | ||
115 | + if not isinstance(num_positives, numbers.Number) or num_positives < 0: | ||
116 | + raise ValueError( | ||
117 | + "'num_positives' was provided but it was a negative number.") | ||
118 | + | ||
119 | + if num_positives is not None: | ||
120 | + self._total_positives += num_positives | ||
121 | + else: | ||
122 | + self._total_positives += numpy.size( | ||
123 | + numpy.where(numpy.array(actuals) > 1e-5)) | ||
124 | + topk = self._top_n | ||
125 | + heap = self._heap | ||
126 | + | ||
127 | + for i in range(numpy.size(predictions)): | ||
128 | + if topk is None or len(heap) < topk: | ||
129 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
130 | + else: | ||
131 | + if predictions[i] > heap[0][0]: # heap[0] is the smallest | ||
132 | + heapq.heappop(heap) | ||
133 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
134 | + | ||
135 | + def clear(self): | ||
136 | + """Clear the accumulated predictions.""" | ||
137 | + self._heap = [] | ||
138 | + self._total_positives = 0 | ||
139 | + | ||
140 | + def peek_ap_at_n(self): | ||
141 | + """Peek the non-interpolated average precision at n. | ||
142 | + | ||
143 | + Returns: | ||
144 | + The non-interpolated average precision at n (default 0). | ||
145 | + If n is larger than the length of the ranked list, | ||
146 | + the average precision will be returned. | ||
147 | + """ | ||
148 | + if self.heap_size <= 0: | ||
149 | + return 0 | ||
150 | + predlists = numpy.array(list(zip(*self._heap))) | ||
151 | + | ||
152 | + ap = self.ap_at_n(predlists[0], | ||
153 | + predlists[1], | ||
154 | + n=self._top_n, | ||
155 | + total_num_positives=self._total_positives) | ||
156 | + return ap | ||
157 | + | ||
158 | + @staticmethod | ||
159 | + def ap(predictions, actuals): | ||
160 | + """Calculate the non-interpolated average precision. | ||
161 | + | ||
162 | + Args: | ||
163 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
164 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
165 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
166 | + | ||
167 | + Returns: | ||
168 | + The non-interpolated average precision at n. | ||
169 | + If n is larger than the length of the ranked list, | ||
170 | + the average precision will be returned. | ||
171 | + | ||
172 | + Raises: | ||
173 | + ValueError: An error occurred when the format of the input is not the | ||
174 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
175 | + """ | ||
176 | + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None) | ||
177 | + | ||
178 | + @staticmethod | ||
179 | + def ap_at_n(predictions, actuals, n=20, total_num_positives=None): | ||
180 | + """Calculate the non-interpolated average precision. | ||
181 | + | ||
182 | + Args: | ||
183 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
184 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
185 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
186 | + n: the top n items to be considered in ap@n. | ||
187 | + total_num_positives : (optionally) you can specify the number of total | ||
188 | + positive in the list. If specified, it will be used in calculation. | ||
189 | + | ||
190 | + Returns: | ||
191 | + The non-interpolated average precision at n. | ||
192 | + If n is larger than the length of the ranked list, | ||
193 | + the average precision will be returned. | ||
194 | + | ||
195 | + Raises: | ||
196 | + ValueError: An error occurred when | ||
197 | + 1) the format of the input is not the numpy 1-D array; | ||
198 | + 2) the shape of predictions and actuals does not match; | ||
199 | + 3) the input n is not a positive integer. | ||
200 | + """ | ||
201 | + if len(predictions) != len(actuals): | ||
202 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
203 | + | ||
204 | + if n is not None: | ||
205 | + if not isinstance(n, int) or n <= 0: | ||
206 | + raise ValueError("n must be 'None' or a positive integer." | ||
207 | + " It was '%s'." % n) | ||
208 | + | ||
209 | + ap = 0.0 | ||
210 | + | ||
211 | + predictions = numpy.array(predictions) | ||
212 | + actuals = numpy.array(actuals) | ||
213 | + | ||
214 | + # add a shuffler to avoid overestimating the ap | ||
215 | + predictions, actuals = AveragePrecisionCalculator._shuffle( | ||
216 | + predictions, actuals) | ||
217 | + sortidx = sorted(range(len(predictions)), | ||
218 | + key=lambda k: predictions[k], | ||
219 | + reverse=True) | ||
220 | + | ||
221 | + if total_num_positives is None: | ||
222 | + numpos = numpy.size(numpy.where(actuals > 0)) | ||
223 | + else: | ||
224 | + numpos = total_num_positives | ||
225 | + | ||
226 | + if numpos == 0: | ||
227 | + return 0 | ||
228 | + | ||
229 | + if n is not None: | ||
230 | + numpos = min(numpos, n) | ||
231 | + delta_recall = 1.0 / numpos | ||
232 | + poscount = 0.0 | ||
233 | + | ||
234 | + # calculate the ap | ||
235 | + r = len(sortidx) | ||
236 | + if n is not None: | ||
237 | + r = min(r, n) | ||
238 | + for i in range(r): | ||
239 | + if actuals[sortidx[i]] > 0: | ||
240 | + poscount += 1 | ||
241 | + ap += poscount / (i + 1) * delta_recall | ||
242 | + return ap | ||
243 | + | ||
244 | + @staticmethod | ||
245 | + def _shuffle(predictions, actuals): | ||
246 | + random.seed(0) | ||
247 | + suffidx = random.sample(range(len(predictions)), len(predictions)) | ||
248 | + predictions = predictions[suffidx] | ||
249 | + actuals = actuals[suffidx] | ||
250 | + return predictions, actuals | ||
251 | + | ||
252 | + @staticmethod | ||
253 | + def _zero_one_normalize(predictions, epsilon=1e-7): | ||
254 | + """Normalize the predictions to the range between 0.0 and 1.0. | ||
255 | + | ||
256 | + For some predictions like SVM predictions, we need to normalize them before | ||
257 | + calculate the interpolated average precision. The normalization will not | ||
258 | + change the rank in the original list and thus won't change the average | ||
259 | + precision. | ||
260 | + | ||
261 | + Args: | ||
262 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
263 | + epsilon: a small constant to avoid denominator being zero. | ||
264 | + | ||
265 | + Returns: | ||
266 | + The normalized prediction. | ||
267 | + """ | ||
268 | + denominator = numpy.max(predictions) - numpy.min(predictions) | ||
269 | + ret = (predictions - numpy.min(predictions)) / numpy.max( | ||
270 | + denominator, epsilon) | ||
271 | + return ret |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utility to convert the output of batch prediction into a CSV submission. | ||
15 | + | ||
16 | +It converts the JSON files created by the command | ||
17 | +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. | ||
18 | +""" | ||
19 | + | ||
20 | +import json | ||
21 | +import tensorflow as tf | ||
22 | + | ||
23 | +from builtins import range | ||
24 | +from tensorflow import app | ||
25 | +from tensorflow import flags | ||
26 | +from tensorflow import gfile | ||
27 | +from tensorflow import logging | ||
28 | + | ||
29 | +FLAGS = flags.FLAGS | ||
30 | + | ||
31 | +if __name__ == "__main__": | ||
32 | + | ||
33 | + flags.DEFINE_string( | ||
34 | + "json_prediction_files_pattern", None, | ||
35 | + "Pattern specifying the list of JSON files that the command " | ||
36 | + "'gcloud beta ml jobs submit prediction' outputs. These files are " | ||
37 | + "located in the output path of the prediction command and are prefixed " | ||
38 | + "with 'prediction.results'.") | ||
39 | + flags.DEFINE_string( | ||
40 | + "csv_output_file", None, | ||
41 | + "The file to save the predictions converted to the CSV format.") | ||
42 | + | ||
43 | + | ||
44 | +def get_csv_header(): | ||
45 | + return "VideoId,LabelConfidencePairs\n" | ||
46 | + | ||
47 | + | ||
48 | +def to_csv_row(json_data): | ||
49 | + | ||
50 | + video_id = json_data["video_id"] | ||
51 | + | ||
52 | + class_indexes = json_data["class_indexes"] | ||
53 | + predictions = json_data["predictions"] | ||
54 | + | ||
55 | + if isinstance(video_id, list): | ||
56 | + video_id = video_id[0] | ||
57 | + class_indexes = class_indexes[0] | ||
58 | + predictions = predictions[0] | ||
59 | + | ||
60 | + if len(class_indexes) != len(predictions): | ||
61 | + raise ValueError( | ||
62 | + "The number of indexes (%s) and predictions (%s) must be equal." % | ||
63 | + (len(class_indexes), len(predictions))) | ||
64 | + | ||
65 | + return (video_id.decode("utf-8") + "," + | ||
66 | + " ".join("%i %f" % (class_indexes[i], predictions[i]) | ||
67 | + for i in range(len(class_indexes))) + "\n") | ||
68 | + | ||
69 | + | ||
70 | +def main(unused_argv): | ||
71 | + logging.set_verbosity(tf.logging.INFO) | ||
72 | + | ||
73 | + if not FLAGS.json_prediction_files_pattern: | ||
74 | + raise ValueError( | ||
75 | + "The flag --json_prediction_files_pattern must be specified.") | ||
76 | + | ||
77 | + if not FLAGS.csv_output_file: | ||
78 | + raise ValueError("The flag --csv_output_file must be specified.") | ||
79 | + | ||
80 | + logging.info("Looking for prediction files with pattern: %s", | ||
81 | + FLAGS.json_prediction_files_pattern) | ||
82 | + | ||
83 | + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) | ||
84 | + logging.info("Found files: %s", file_paths) | ||
85 | + | ||
86 | + logging.info("Writing submission file to: %s", FLAGS.csv_output_file) | ||
87 | + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: | ||
88 | + output_file.write(get_csv_header()) | ||
89 | + | ||
90 | + for file_path in file_paths: | ||
91 | + logging.info("processing file: %s", file_path) | ||
92 | + | ||
93 | + with gfile.Open(file_path) as input_file: | ||
94 | + | ||
95 | + for line in input_file: | ||
96 | + json_data = json.loads(line) | ||
97 | + output_file.write(to_csv_row(json_data)) | ||
98 | + | ||
99 | + output_file.flush() | ||
100 | + logging.info("done") | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == "__main__": | ||
104 | + app.run() |
web/backend/yt8m/esot3ria/inference_pb.py
0 → 100644
1 | +import numpy as np | ||
2 | +import tensorflow as tf | ||
3 | +from tensorflow import logging | ||
4 | +from tensorflow import gfile | ||
5 | +import operator | ||
6 | +import pb_util as pbutil | ||
7 | +import video_recommender as recommender | ||
8 | +import video_util as videoutil | ||
9 | + | ||
10 | +# Define file paths. | ||
11 | +MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model" | ||
12 | +VOCAB_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/vocabulary.csv" | ||
13 | +VIDEO_TAGS_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/kaggle_solution_40k.csv" | ||
14 | +TAG_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/tag_vectors.model" | ||
15 | +VIDEO_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/video_vectors.model" | ||
16 | +SEGMENT_LABEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/segment_label_ids.csv" | ||
17 | + | ||
18 | +# Define parameters. | ||
19 | +TAG_TOP_K = 5 | ||
20 | +VIDEO_TOP_K = 10 | ||
21 | + | ||
22 | + | ||
23 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
24 | + """Get segment-level inputs from frame-level features.""" | ||
25 | + video_batch_size = batch_video_mtx.shape[0] | ||
26 | + max_frame = batch_video_mtx.shape[1] | ||
27 | + feature_dim = batch_video_mtx.shape[-1] | ||
28 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
29 | + padded_segment_sizes *= segment_size | ||
30 | + segment_mask = ( | ||
31 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
32 | + | ||
33 | + # Segment bags. | ||
34 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
35 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
36 | + (-1, segment_size, feature_dim)) | ||
37 | + | ||
38 | + # Segment num frames. | ||
39 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
40 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
41 | + num_segment_bags = num_segments.reshape((-1)) | ||
42 | + valid_segment_mask = num_segment_bags > 0 | ||
43 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
44 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
45 | + | ||
46 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
47 | + video_idxs = np.tile( | ||
48 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
49 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
50 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
51 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
52 | + | ||
53 | + return { | ||
54 | + "video_batch": segment_frames, | ||
55 | + "num_frames_batch": segment_num_frames, | ||
56 | + "video_segment_ids": video_segment_ids | ||
57 | + } | ||
58 | + | ||
59 | + | ||
60 | +def format_predictions(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
61 | + batch_size = len(video_ids) | ||
62 | + for video_index in range(batch_size): | ||
63 | + video_prediction = predictions[video_index] | ||
64 | + if whitelisted_cls_mask is not None: | ||
65 | + # Whitelist classes. | ||
66 | + video_prediction *= whitelisted_cls_mask | ||
67 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
68 | + line = [(class_index, predictions[video_index][class_index]) | ||
69 | + for class_index in top_indices] | ||
70 | + line = sorted(line, key=lambda p: -p[1]) | ||
71 | + yield (video_ids[video_index] + "," + | ||
72 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
73 | + "\n").encode("utf8") | ||
74 | + | ||
75 | + | ||
76 | +def normalize_tag(tag): | ||
77 | + if isinstance(tag, str): | ||
78 | + new_tag = tag.lower().replace('[^a-zA-Z]', ' ') | ||
79 | + if new_tag.find(" (") != -1: | ||
80 | + new_tag = new_tag[:new_tag.find(" (")] | ||
81 | + new_tag = new_tag.replace(" ", "-") | ||
82 | + return new_tag | ||
83 | + else: | ||
84 | + return tag | ||
85 | + | ||
86 | + | ||
87 | +def inference_pb(file_path, threshold): | ||
88 | + VIDEO_TOP_K = int(threshold) | ||
89 | + inference_result = {} | ||
90 | + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
91 | + | ||
92 | + # 0. Import SequenceExample type target from pb. | ||
93 | + target_video = pbutil.convert_pb(file_path) | ||
94 | + | ||
95 | + # 1. Load video features from pb. | ||
96 | + video_id_batch_val = np.array([b'video']) | ||
97 | + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature) | ||
98 | + # Restrict frame size to 300 | ||
99 | + if n_frames > 300: | ||
100 | + n_frames = 300 | ||
101 | + video_batch_val = np.zeros((300, 1152)) | ||
102 | + for i in range(n_frames): | ||
103 | + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0] | ||
104 | + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval()) | ||
105 | + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0] | ||
106 | + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval()) | ||
107 | + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) | ||
108 | + video_batch_val = np.array([video_batch_val]) | ||
109 | + num_frames_batch_val = np.array([n_frames]) | ||
110 | + | ||
111 | + # Restore checkpoint and meta-graph file. | ||
112 | + if not gfile.Exists(MODEL_PATH + ".meta"): | ||
113 | + raise IOError("Cannot find %s. Did you run eval.py?" % MODEL_PATH) | ||
114 | + meta_graph_location = MODEL_PATH + ".meta" | ||
115 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
116 | + | ||
117 | + with tf.device("/cpu:0"): | ||
118 | + saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
119 | + logging.info("restoring variables from " + MODEL_PATH) | ||
120 | + saver.restore(sess, MODEL_PATH) | ||
121 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
122 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
123 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
124 | + | ||
125 | + # Workaround for num_epochs issue. | ||
126 | + def set_up_init_ops(variables): | ||
127 | + init_op_list = [] | ||
128 | + for variable in list(variables): | ||
129 | + if "train_input" in variable.name: | ||
130 | + init_op_list.append(tf.assign(variable, 1)) | ||
131 | + variables.remove(variable) | ||
132 | + init_op_list.append(tf.variables_initializer(variables)) | ||
133 | + return init_op_list | ||
134 | + | ||
135 | + sess.run( | ||
136 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
137 | + | ||
138 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
139 | + dtype=np.float32) | ||
140 | + with tf.io.gfile.GFile(SEGMENT_LABEL_PATH) as fobj: | ||
141 | + for line in fobj: | ||
142 | + try: | ||
143 | + cls_id = int(line) | ||
144 | + whitelisted_cls_mask[cls_id] = 1. | ||
145 | + except ValueError: | ||
146 | + # Simply skip the non-integer line. | ||
147 | + continue | ||
148 | + | ||
149 | + # 2. Make segment features. | ||
150 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
151 | + video_segment_ids = results["video_segment_ids"] | ||
152 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
153 | + video_id_batch_val = np.array([ | ||
154 | + "%s:%d" % (x.decode("utf8"), y) | ||
155 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
156 | + ]) | ||
157 | + video_batch_val = results["video_batch"] | ||
158 | + num_frames_batch_val = results["num_frames_batch"] | ||
159 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
160 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
161 | + "with correct segment_labels settings.") | ||
162 | + | ||
163 | + predictions_val, = sess.run([predictions_tensor], | ||
164 | + feed_dict={ | ||
165 | + input_tensor: video_batch_val, | ||
166 | + num_frames_tensor: num_frames_batch_val | ||
167 | + }) | ||
168 | + | ||
169 | + # 3. Make vocabularies. | ||
170 | + voca_dict = {} | ||
171 | + vocabs = open(VOCAB_PATH, 'r') | ||
172 | + while True: | ||
173 | + line = vocabs.readline() | ||
174 | + if not line: break | ||
175 | + vocab_dict_item = line.split(",") | ||
176 | + if vocab_dict_item[0] != "Index": | ||
177 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
178 | + vocabs.close() | ||
179 | + | ||
180 | + # 4. Make combined scores. | ||
181 | + combined_scores = {} | ||
182 | + for line in format_predictions(video_id_batch_val, predictions_val, TAG_TOP_K, whitelisted_cls_mask): | ||
183 | + segment_id, preds = line.decode("utf8").split(",") | ||
184 | + preds = preds.split(" ") | ||
185 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
186 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
187 | + for i in range(len(pred_cls_ids)): | ||
188 | + if pred_cls_ids[i] in combined_scores: | ||
189 | + combined_scores[pred_cls_ids[i]] += pred_cls_scores[i] | ||
190 | + else: | ||
191 | + combined_scores[pred_cls_ids[i]] = pred_cls_scores[i] | ||
192 | + | ||
193 | + combined_scores = sorted(combined_scores.items(), key=operator.itemgetter(1), reverse=True) | ||
194 | + demoninator = float(combined_scores[0][1] + combined_scores[1][1] | ||
195 | + + combined_scores[2][1] + combined_scores[3][1] + combined_scores[4][1]) | ||
196 | + | ||
197 | + tag_result = [] | ||
198 | + for itemIndex in range(TAG_TOP_K): | ||
199 | + segment_tag = str(voca_dict[str(combined_scores[itemIndex][0])]) | ||
200 | + normalized_tag = normalize_tag(segment_tag) | ||
201 | + tag_percentage = format(combined_scores[itemIndex][1] / demoninator, ".3f") | ||
202 | + tag_result.append((normalized_tag, tag_percentage)) | ||
203 | + | ||
204 | + # 5. Create recommend videos info, Combine results. | ||
205 | + recommend_video_ids = recommender.recommend_videos(tag_result, TAG_VECTOR_MODEL_PATH, | ||
206 | + VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K) | ||
207 | + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids] | ||
208 | + | ||
209 | + inference_result = { | ||
210 | + "tag_result": tag_result, | ||
211 | + "video_result": video_result | ||
212 | + } | ||
213 | + | ||
214 | + # 6. Dispose instances. | ||
215 | + sess.close() | ||
216 | + | ||
217 | + return inference_result | ||
218 | + | ||
219 | + | ||
220 | +if __name__ == '__main__': | ||
221 | + filepath = "/tmp/mediapipe/features.pb" | ||
222 | + result = inference_pb(filepath) | ||
223 | + print(result) |
This diff could not be displayed because it is too large.
web/backend/yt8m/esot3ria/model/eval/events.out.tfevents.1591170123.mlvc-nogadahalf12-instance
0 → 100644
No preview for this file type
web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model.data-00000-of-00001
0 → 100644
This file is too large to display.
No preview for this file type
No preview for this file type
1 | +{"model": "FrameLevelLogisticModel", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/yt8m/esot3ria/pb_util.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy | ||
3 | + | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + # audio_seg = quantize(seg) | ||
68 | + audio_seg = _make_bytes(seg) | ||
69 | + byte_audio.append(audio_seg) | ||
70 | + | ||
71 | + for seg in rgb_slices: | ||
72 | + # rgb_seg = quantize(seg) | ||
73 | + rgb_seg = _make_bytes(seg) | ||
74 | + byte_rgb.append(rgb_seg) | ||
75 | + | ||
76 | + return byte_audio, byte_rgb | ||
77 | + | ||
78 | + | ||
79 | +def make_exmp(id, audio, rgb): | ||
80 | + audio_features = [] | ||
81 | + rgb_features = [] | ||
82 | + | ||
83 | + for embedding in audio: | ||
84 | + embedding_feature = tf.train.Feature( | ||
85 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
86 | + audio_features.append(embedding_feature) | ||
87 | + | ||
88 | + for embedding in rgb: | ||
89 | + embedding_feature = tf.train.Feature( | ||
90 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
91 | + rgb_features.append(embedding_feature) | ||
92 | + | ||
93 | + # for construct yt8m data | ||
94 | + seq_exmp = tf.train.SequenceExample( | ||
95 | + context=tf.train.Features( | ||
96 | + feature={ | ||
97 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
98 | + value=[id.encode('utf-8')])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +def convert_pb(filename): | ||
115 | + sequence_example = open(filename, 'rb').read() | ||
116 | + | ||
117 | + audio, rgb = parse_exmp(sequence_example) | ||
118 | + tmp_example = make_exmp('video', audio, rgb) | ||
119 | + | ||
120 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
121 | + return decoded |
web/backend/yt8m/esot3ria/readpb.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +frame_lvl_record = "test0000.tfrecord" | ||
5 | + | ||
6 | +feat_rgb = [] | ||
7 | +feat_audio = [] | ||
8 | + | ||
9 | +for example in tf.python_io.tf_record_iterator(frame_lvl_record): | ||
10 | + tf_seq_example = tf.train.SequenceExample.FromString(example) | ||
11 | + test = tf_seq_example.SerializeToString() | ||
12 | + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature) | ||
13 | + sess = tf.InteractiveSession() | ||
14 | + rgb_frame = [] | ||
15 | + audio_frame = [] | ||
16 | + # iterate through frames | ||
17 | + for i in range(n_frames): | ||
18 | + rgb_frame.append(tf.cast(tf.decode_raw( | ||
19 | + tf_seq_example.feature_lists.feature_list['rgb'] | ||
20 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
21 | + , tf.float32).eval()) | ||
22 | + audio_frame.append(tf.cast(tf.decode_raw( | ||
23 | + tf_seq_example.feature_lists.feature_list['audio'] | ||
24 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
25 | + , tf.float32).eval()) | ||
26 | + | ||
27 | + sess.close() | ||
28 | + | ||
29 | + feat_audio.append(audio_frame) | ||
30 | + feat_rgb.append(rgb_frame) | ||
31 | + break | ||
32 | + | ||
33 | +print('The first video has %d frames' %len(feat_rgb[0])) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +import nltk | ||
2 | +import gensim | ||
3 | +import pandas as pd | ||
4 | + | ||
5 | +# Load files. | ||
6 | +nltk.download('stopwords') | ||
7 | +vocab = pd.read_csv('../vocabulary.csv') | ||
8 | + | ||
9 | +# Lower corpus and Remove () from name. | ||
10 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.lower().str.replace('[^a-zA-Z0-9]', ' ') | ||
11 | +for i in range(vocab['Name'].__len__()): | ||
12 | + name = vocab['Name'][i] | ||
13 | + if isinstance(name, str) and name.find(" (") != -1: | ||
14 | + vocab['Name'][i] = name[:name.find(" (")] | ||
15 | +vocab['Name'] = vocab['Name'].str.lower() | ||
16 | + | ||
17 | +# Combine separated names.(mobile phone -> mobile-phone) | ||
18 | +for name in vocab['Name']: | ||
19 | + if isinstance(name, str) and name.find(" ") != -1: | ||
20 | + combined_name = name.replace(" ", "-") | ||
21 | + for i in range(vocab['WikiDescription'].__len__()): | ||
22 | + if isinstance(vocab['WikiDescription'][i], str): | ||
23 | + vocab['WikiDescription'][i] = vocab['WikiDescription'][i].replace(name, combined_name) | ||
24 | + | ||
25 | + | ||
26 | +# Remove stopwords from corpus. | ||
27 | +stop_re = '\\b'+'\\b|\\b'.join(nltk.corpus.stopwords.words('english'))+'\\b' | ||
28 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.replace(stop_re, '') | ||
29 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.split() | ||
30 | + | ||
31 | +# Tokenize corpus. | ||
32 | +tokenlist = [x for x in vocab['WikiDescription'] if str(x) != 'nan'] | ||
33 | +phrases = gensim.models.phrases.Phrases(tokenlist) | ||
34 | +phraser = gensim.models.phrases.Phraser(phrases) | ||
35 | +vocab_phrased = phraser[tokenlist] | ||
36 | + | ||
37 | +# Vectorize tags. | ||
38 | +w2v = gensim.models.word2vec.Word2Vec(sentences=tokenlist, min_count=1) | ||
39 | +w2v.save('tag_vectors.model') | ||
40 | + | ||
41 | +# word_vectors = w2v.wv | ||
42 | +# vocabs = word_vectors.vocab.keys() | ||
43 | +# word_vectors_list = [word_vectors[v] for v in vocabs] |
web/backend/yt8m/esot3ria/tag_vectors.model
0 → 100644
This file is too large to display.
1 | +from gensim.models import Word2Vec | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +def recommend_videos(tags, tag_model_path, video_model_path, top_k): | ||
5 | + tag_vectors = Word2Vec.load(tag_model_path).wv | ||
6 | + video_vectors = Word2Vec().wv.load(video_model_path) | ||
7 | + error_tags = [] | ||
8 | + | ||
9 | + video_vector = np.zeros(100) | ||
10 | + for (tag, weight) in tags: | ||
11 | + if tag in tag_vectors.vocab: | ||
12 | + video_vector = video_vector + (tag_vectors[tag] * float(weight)) | ||
13 | + else: | ||
14 | + # Pass if tag is unknown | ||
15 | + if tag not in error_tags: | ||
16 | + error_tags.append(tag) | ||
17 | + | ||
18 | + similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)] | ||
19 | + return similar_ids |
web/backend/yt8m/esot3ria/video_util.py
0 → 100644
1 | +import requests | ||
2 | +import pandas as pd | ||
3 | + | ||
4 | +base_URL = 'https://data.yt8m.org/2/j/i/' | ||
5 | +youtube_url = 'https://www.youtube.com/watch?v=' | ||
6 | + | ||
7 | + | ||
8 | +def getURL(vid_id): | ||
9 | + URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js' | ||
10 | + response = requests.get(URL, verify = False) | ||
11 | + if response.status_code == 200: | ||
12 | + return youtube_url + response.text[10:-3] | ||
13 | + | ||
14 | + | ||
15 | +def getVideoInfo(vid_id, video_tags_path, top_k): | ||
16 | + video_url = getURL(vid_id) | ||
17 | + | ||
18 | + entire_video_tags = pd.read_csv(video_tags_path) | ||
19 | + video_tags_info = entire_video_tags.loc[entire_video_tags["vid_id"] == vid_id] | ||
20 | + video_tags = [] | ||
21 | + for i in range(1, top_k + 1): | ||
22 | + video_tag_tuple = video_tags_info["segment" + str(i)].values[0] # ex: "mobile-phone:0.361" | ||
23 | + video_tags.append(video_tag_tuple.split(":")[0]) | ||
24 | + | ||
25 | + return { | ||
26 | + "video_url": video_url, | ||
27 | + "video_tags": video_tags | ||
28 | + } |
1 | +import pandas as pd | ||
2 | +import numpy as np | ||
3 | +from gensim.models import Word2Vec | ||
4 | + | ||
5 | +BATCH_SIZE = 1000 | ||
6 | + | ||
7 | + | ||
8 | +def vectorization_video(): | ||
9 | + print('[0.1 0.2]') | ||
10 | + | ||
11 | + | ||
12 | +if __name__ == '__main__': | ||
13 | + tag_vectors = Word2Vec.load("tag_vectors.model").wv | ||
14 | + video_vectors = Word2Vec().wv # Empty model | ||
15 | + | ||
16 | + # Load video recommendation tags. | ||
17 | + video_tags = pd.read_csv('kaggle_solution_40k.csv') | ||
18 | + | ||
19 | + # Define batch variables. | ||
20 | + batch_video_ids = [] | ||
21 | + batch_video_vectors = [] | ||
22 | + error_tags = [] | ||
23 | + | ||
24 | + for i, row in video_tags.iterrows(): | ||
25 | + video_id = row[0] | ||
26 | + video_vector = np.zeros(100) | ||
27 | + for segment_index in range(1, 6): | ||
28 | + tag, weight = row[segment_index].split(":") | ||
29 | + if tag in tag_vectors.vocab: | ||
30 | + video_vector = video_vector + (tag_vectors[tag] * float(weight)) | ||
31 | + else: | ||
32 | + # Pass if tag is unknown | ||
33 | + if tag not in error_tags: | ||
34 | + error_tags.append(tag) | ||
35 | + | ||
36 | + batch_video_ids.append(video_id) | ||
37 | + batch_video_vectors.append(video_vector) | ||
38 | + # Add video vectors. | ||
39 | + if (i+1) % BATCH_SIZE == 0: | ||
40 | + video_vectors.add(batch_video_ids, batch_video_vectors) | ||
41 | + batch_video_ids = [] | ||
42 | + batch_video_vectors = [] | ||
43 | + print("Video vectors created: ", i+1) | ||
44 | + | ||
45 | + # Add rest of video vectors. | ||
46 | + video_vectors.add(batch_video_ids, batch_video_vectors) | ||
47 | + print("error tags: ") | ||
48 | + print(error_tags) | ||
49 | + | ||
50 | + video_vectors.save("video_vectors.model") | ||
51 | + | ||
52 | + # Usage | ||
53 | + # video_vectors = Word2Vec().wv.load("video_vectors.model") | ||
54 | + # video_vectors.most_similar("XwFj", topn=5) |
This file is too large to display.
web/backend/yt8m/eval.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/eval_util.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides functions to help with evaluating models.""" | ||
15 | +import average_precision_calculator as ap_calculator | ||
16 | +import mean_average_precision_calculator as map_calculator | ||
17 | +import numpy | ||
18 | +from tensorflow.python.platform import gfile | ||
19 | + | ||
20 | + | ||
21 | +def flatten(l): | ||
22 | + """Merges a list of lists into a single list. """ | ||
23 | + return [item for sublist in l for item in sublist] | ||
24 | + | ||
25 | + | ||
26 | +def calculate_hit_at_one(predictions, actuals): | ||
27 | + """Performs a local (numpy) calculation of the hit at one. | ||
28 | + | ||
29 | + Args: | ||
30 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
31 | + 'batch' x 'num_classes'. | ||
32 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
33 | + 'num_classes'. | ||
34 | + | ||
35 | + Returns: | ||
36 | + float: The average hit at one across the entire batch. | ||
37 | + """ | ||
38 | + top_prediction = numpy.argmax(predictions, 1) | ||
39 | + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] | ||
40 | + return numpy.average(hits) | ||
41 | + | ||
42 | + | ||
43 | +def calculate_precision_at_equal_recall_rate(predictions, actuals): | ||
44 | + """Performs a local (numpy) calculation of the PERR. | ||
45 | + | ||
46 | + Args: | ||
47 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
48 | + 'batch' x 'num_classes'. | ||
49 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
50 | + 'num_classes'. | ||
51 | + | ||
52 | + Returns: | ||
53 | + float: The average precision at equal recall rate across the entire batch. | ||
54 | + """ | ||
55 | + aggregated_precision = 0.0 | ||
56 | + num_videos = actuals.shape[0] | ||
57 | + for row in numpy.arange(num_videos): | ||
58 | + num_labels = int(numpy.sum(actuals[row])) | ||
59 | + top_indices = numpy.argpartition(predictions[row], | ||
60 | + -num_labels)[-num_labels:] | ||
61 | + item_precision = 0.0 | ||
62 | + for label_index in top_indices: | ||
63 | + if predictions[row][label_index] > 0: | ||
64 | + item_precision += actuals[row][label_index] | ||
65 | + item_precision /= top_indices.size | ||
66 | + aggregated_precision += item_precision | ||
67 | + aggregated_precision /= num_videos | ||
68 | + return aggregated_precision | ||
69 | + | ||
70 | + | ||
71 | +def calculate_gap(predictions, actuals, top_k=20): | ||
72 | + """Performs a local (numpy) calculation of the global average precision. | ||
73 | + | ||
74 | + Only the top_k predictions are taken for each of the videos. | ||
75 | + | ||
76 | + Args: | ||
77 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
78 | + 'batch' x 'num_classes'. | ||
79 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
80 | + 'num_classes'. | ||
81 | + top_k: How many predictions to use per video. | ||
82 | + | ||
83 | + Returns: | ||
84 | + float: The global average precision. | ||
85 | + """ | ||
86 | + gap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
87 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
88 | + predictions, actuals, top_k) | ||
89 | + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), | ||
90 | + sum(num_positives)) | ||
91 | + return gap_calculator.peek_ap_at_n() | ||
92 | + | ||
93 | + | ||
94 | +def top_k_by_class(predictions, labels, k=20): | ||
95 | + """Extracts the top k predictions for each video, sorted by class. | ||
96 | + | ||
97 | + Args: | ||
98 | + predictions: A numpy matrix containing the outputs of the model. Dimensions | ||
99 | + are 'batch' x 'num_classes'. | ||
100 | + k: the top k non-zero entries to preserve in each prediction. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A tuple (predictions,labels, true_positives). 'predictions' and 'labels' | ||
104 | + are lists of lists of floats. 'true_positives' is a list of scalars. The | ||
105 | + length of the lists are equal to the number of classes. The entries in the | ||
106 | + predictions variable are probability predictions, and | ||
107 | + the corresponding entries in the labels variable are the ground truth for | ||
108 | + those predictions. The entries in 'true_positives' are the number of true | ||
109 | + positives for each class in the ground truth. | ||
110 | + | ||
111 | + Raises: | ||
112 | + ValueError: An error occurred when the k is not a positive integer. | ||
113 | + """ | ||
114 | + if k <= 0: | ||
115 | + raise ValueError("k must be a positive integer.") | ||
116 | + k = min(k, predictions.shape[1]) | ||
117 | + num_classes = predictions.shape[1] | ||
118 | + prediction_triplets = [] | ||
119 | + for video_index in range(predictions.shape[0]): | ||
120 | + prediction_triplets.extend( | ||
121 | + top_k_triplets(predictions[video_index], labels[video_index], k)) | ||
122 | + out_predictions = [[] for _ in range(num_classes)] | ||
123 | + out_labels = [[] for _ in range(num_classes)] | ||
124 | + for triplet in prediction_triplets: | ||
125 | + out_predictions[triplet[0]].append(triplet[1]) | ||
126 | + out_labels[triplet[0]].append(triplet[2]) | ||
127 | + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)] | ||
128 | + | ||
129 | + return out_predictions, out_labels, out_true_positives | ||
130 | + | ||
131 | + | ||
132 | +def top_k_triplets(predictions, labels, k=20): | ||
133 | + """Get the top_k for a 1-d numpy array. | ||
134 | + | ||
135 | + Returns a sparse list of tuples in | ||
136 | + (prediction, class) format | ||
137 | + """ | ||
138 | + m = len(predictions) | ||
139 | + k = min(k, m) | ||
140 | + indices = numpy.argpartition(predictions, -k)[-k:] | ||
141 | + return [(index, predictions[index], labels[index]) for index in indices] | ||
142 | + | ||
143 | + | ||
144 | +class EvaluationMetrics(object): | ||
145 | + """A class to store the evaluation metrics.""" | ||
146 | + | ||
147 | + def __init__(self, num_class, top_k, top_n): | ||
148 | + """Construct an EvaluationMetrics object to store the evaluation metrics. | ||
149 | + | ||
150 | + Args: | ||
151 | + num_class: A positive integer specifying the number of classes. | ||
152 | + top_k: A positive integer specifying how many predictions are considered | ||
153 | + per video. | ||
154 | + top_n: A positive Integer specifying the average precision at n, or None | ||
155 | + to use all provided data points. | ||
156 | + | ||
157 | + Raises: | ||
158 | + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot | ||
159 | + not be constructed. | ||
160 | + """ | ||
161 | + self.sum_hit_at_one = 0.0 | ||
162 | + self.sum_perr = 0.0 | ||
163 | + self.sum_loss = 0.0 | ||
164 | + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( | ||
165 | + num_class, top_n=top_n) | ||
166 | + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
167 | + self.top_k = top_k | ||
168 | + self.num_examples = 0 | ||
169 | + | ||
170 | + def accumulate(self, predictions, labels, loss): | ||
171 | + """Accumulate the metrics calculated locally for this mini-batch. | ||
172 | + | ||
173 | + Args: | ||
174 | + predictions: A numpy matrix containing the outputs of the model. | ||
175 | + Dimensions are 'batch' x 'num_classes'. | ||
176 | + labels: A numpy matrix containing the ground truth labels. Dimensions are | ||
177 | + 'batch' x 'num_classes'. | ||
178 | + loss: A numpy array containing the loss for each sample. | ||
179 | + | ||
180 | + Returns: | ||
181 | + dictionary: A dictionary storing the metrics for the mini-batch. | ||
182 | + | ||
183 | + Raises: | ||
184 | + ValueError: An error occurred when the shape of predictions and actuals | ||
185 | + does not match. | ||
186 | + """ | ||
187 | + batch_size = labels.shape[0] | ||
188 | + mean_hit_at_one = calculate_hit_at_one(predictions, labels) | ||
189 | + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) | ||
190 | + mean_loss = numpy.mean(loss) | ||
191 | + | ||
192 | + # Take the top 20 predictions. | ||
193 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
194 | + predictions, labels, self.top_k) | ||
195 | + self.map_calculator.accumulate(sparse_predictions, sparse_labels, | ||
196 | + num_positives) | ||
197 | + self.global_ap_calculator.accumulate(flatten(sparse_predictions), | ||
198 | + flatten(sparse_labels), | ||
199 | + sum(num_positives)) | ||
200 | + | ||
201 | + self.num_examples += batch_size | ||
202 | + self.sum_hit_at_one += mean_hit_at_one * batch_size | ||
203 | + self.sum_perr += mean_perr * batch_size | ||
204 | + self.sum_loss += mean_loss * batch_size | ||
205 | + | ||
206 | + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} | ||
207 | + | ||
208 | + def get(self): | ||
209 | + """Calculate the evaluation metrics for the whole epoch. | ||
210 | + | ||
211 | + Raises: | ||
212 | + ValueError: If no examples were accumulated. | ||
213 | + | ||
214 | + Returns: | ||
215 | + dictionary: a dictionary storing the evaluation metrics for the epoch. The | ||
216 | + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and | ||
217 | + aps (default nan). | ||
218 | + """ | ||
219 | + if self.num_examples <= 0: | ||
220 | + raise ValueError("total_sample must be positive.") | ||
221 | + avg_hit_at_one = self.sum_hit_at_one / self.num_examples | ||
222 | + avg_perr = self.sum_perr / self.num_examples | ||
223 | + avg_loss = self.sum_loss / self.num_examples | ||
224 | + | ||
225 | + aps = self.map_calculator.peek_map_at_n() | ||
226 | + gap = self.global_ap_calculator.peek_ap_at_n() | ||
227 | + | ||
228 | + epoch_info_dict = { | ||
229 | + "avg_hit_at_one": avg_hit_at_one, | ||
230 | + "avg_perr": avg_perr, | ||
231 | + "avg_loss": avg_loss, | ||
232 | + "aps": aps, | ||
233 | + "gap": gap | ||
234 | + } | ||
235 | + return epoch_info_dict | ||
236 | + | ||
237 | + def clear(self): | ||
238 | + """Clear the evaluation metrics and reset the EvaluationMetrics object.""" | ||
239 | + self.sum_hit_at_one = 0.0 | ||
240 | + self.sum_perr = 0.0 | ||
241 | + self.sum_loss = 0.0 | ||
242 | + self.map_calculator.clear() | ||
243 | + self.global_ap_calculator.clear() | ||
244 | + self.num_examples = 0 |
web/backend/yt8m/export_model.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utilities to export a model for batch prediction.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import tensorflow.contrib.slim as slim | ||
18 | + | ||
19 | +from tensorflow.python.saved_model import builder as saved_model_builder | ||
20 | +from tensorflow.python.saved_model import signature_constants | ||
21 | +from tensorflow.python.saved_model import signature_def_utils | ||
22 | +from tensorflow.python.saved_model import tag_constants | ||
23 | +from tensorflow.python.saved_model import utils as saved_model_utils | ||
24 | + | ||
25 | +_TOP_PREDICTIONS_IN_OUTPUT = 20 | ||
26 | + | ||
27 | + | ||
28 | +class ModelExporter(object): | ||
29 | + | ||
30 | + def __init__(self, frame_features, model, reader): | ||
31 | + self.frame_features = frame_features | ||
32 | + self.model = model | ||
33 | + self.reader = reader | ||
34 | + | ||
35 | + with tf.Graph().as_default() as graph: | ||
36 | + self.inputs, self.outputs = self.build_inputs_and_outputs() | ||
37 | + self.graph = graph | ||
38 | + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) | ||
39 | + | ||
40 | + def export_model(self, model_dir, global_step_val, last_checkpoint): | ||
41 | + """Exports the model so that it can used for batch predictions.""" | ||
42 | + | ||
43 | + with self.graph.as_default(): | ||
44 | + with tf.Session() as session: | ||
45 | + session.run(tf.global_variables_initializer()) | ||
46 | + self.saver.restore(session, last_checkpoint) | ||
47 | + | ||
48 | + signature = signature_def_utils.build_signature_def( | ||
49 | + inputs=self.inputs, | ||
50 | + outputs=self.outputs, | ||
51 | + method_name=signature_constants.PREDICT_METHOD_NAME) | ||
52 | + | ||
53 | + signature_map = { | ||
54 | + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature | ||
55 | + } | ||
56 | + | ||
57 | + model_builder = saved_model_builder.SavedModelBuilder(model_dir) | ||
58 | + model_builder.add_meta_graph_and_variables( | ||
59 | + session, | ||
60 | + tags=[tag_constants.SERVING], | ||
61 | + signature_def_map=signature_map, | ||
62 | + clear_devices=True) | ||
63 | + model_builder.save() | ||
64 | + | ||
65 | + def build_inputs_and_outputs(self): | ||
66 | + if self.frame_features: | ||
67 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
68 | + | ||
69 | + fn = lambda x: self.build_prediction_graph(x) | ||
70 | + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn( | ||
71 | + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32))) | ||
72 | + | ||
73 | + else: | ||
74 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
75 | + | ||
76 | + video_id_output, top_indices_output, top_predictions_output = ( | ||
77 | + self.build_prediction_graph(serialized_examples)) | ||
78 | + | ||
79 | + inputs = { | ||
80 | + "example_bytes": | ||
81 | + saved_model_utils.build_tensor_info(serialized_examples) | ||
82 | + } | ||
83 | + | ||
84 | + outputs = { | ||
85 | + "video_id": | ||
86 | + saved_model_utils.build_tensor_info(video_id_output), | ||
87 | + "class_indexes": | ||
88 | + saved_model_utils.build_tensor_info(top_indices_output), | ||
89 | + "predictions": | ||
90 | + saved_model_utils.build_tensor_info(top_predictions_output) | ||
91 | + } | ||
92 | + | ||
93 | + return inputs, outputs | ||
94 | + | ||
95 | + def build_prediction_graph(self, serialized_examples): | ||
96 | + input_data_dict = ( | ||
97 | + self.reader.prepare_serialized_examples(serialized_examples)) | ||
98 | + video_id = input_data_dict["video_ids"] | ||
99 | + model_input_raw = input_data_dict["video_matrix"] | ||
100 | + labels_batch = input_data_dict["labels"] | ||
101 | + num_frames = input_data_dict["num_frames"] | ||
102 | + | ||
103 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
104 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
105 | + | ||
106 | + with tf.variable_scope("tower"): | ||
107 | + result = self.model.create_model(model_input, | ||
108 | + num_frames=num_frames, | ||
109 | + vocab_size=self.reader.num_classes, | ||
110 | + labels=labels_batch, | ||
111 | + is_training=False) | ||
112 | + | ||
113 | + for variable in slim.get_model_variables(): | ||
114 | + tf.summary.histogram(variable.op.name, variable) | ||
115 | + | ||
116 | + predictions = result["predictions"] | ||
117 | + | ||
118 | + top_predictions, top_indices = tf.nn.top_k(predictions, | ||
119 | + _TOP_PREDICTIONS_IN_OUTPUT) | ||
120 | + return video_id, top_indices, top_predictions |
web/backend/yt8m/export_model_mediapipe.py
0 → 100644
1 | +# Lint as: python3 | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from tensorflow import app | ||
5 | +from tensorflow import flags | ||
6 | + | ||
7 | +FLAGS = flags.FLAGS | ||
8 | + | ||
9 | + | ||
10 | +def main(unused_argv): | ||
11 | + # Get the input tensor names to be replaced. | ||
12 | + tf.reset_default_graph() | ||
13 | + meta_graph_location = FLAGS.checkpoint_file + ".meta" | ||
14 | + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
15 | + | ||
16 | + input_tensor_name = tf.get_collection("input_batch_raw")[0].name | ||
17 | + num_frames_tensor_name = tf.get_collection("num_frames")[0].name | ||
18 | + | ||
19 | + # Create output graph. | ||
20 | + saver = tf.train.Saver() | ||
21 | + tf.reset_default_graph() | ||
22 | + | ||
23 | + input_feature_placeholder = tf.placeholder( | ||
24 | + tf.float32, shape=(None, None, 1152)) | ||
25 | + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) | ||
26 | + | ||
27 | + saver = tf.train.import_meta_graph( | ||
28 | + meta_graph_location, | ||
29 | + input_map={ | ||
30 | + input_tensor_name: input_feature_placeholder, | ||
31 | + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) | ||
32 | + }, | ||
33 | + clear_devices=True) | ||
34 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
35 | + | ||
36 | + with tf.Session() as sess: | ||
37 | + print("restoring variables from " + FLAGS.checkpoint_file) | ||
38 | + saver.restore(sess, FLAGS.checkpoint_file) | ||
39 | + tf.saved_model.simple_save( | ||
40 | + sess, | ||
41 | + FLAGS.output_dir, | ||
42 | + inputs={'rgb_and_audio': input_feature_placeholder, | ||
43 | + 'num_frames': num_frames_placeholder}, | ||
44 | + outputs={'predictions': predictions_tensor}) | ||
45 | + | ||
46 | + # Try running inference. | ||
47 | + predictions = sess.run( | ||
48 | + [predictions_tensor], | ||
49 | + feed_dict={ | ||
50 | + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), | ||
51 | + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) | ||
52 | + print('Test inference:', predictions) | ||
53 | + | ||
54 | + print('Model saved to ', FLAGS.output_dir) | ||
55 | + | ||
56 | + | ||
57 | +if __name__ == '__main__': | ||
58 | + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') | ||
59 | + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') | ||
60 | + app.run(main) |
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
web/backend/yt8m/frame_level_models.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/inference.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/inference_per_segment.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/losses.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides definitions for non-regularized training or test losses.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | + | ||
18 | + | ||
19 | +class BaseLoss(object): | ||
20 | + """Inherit from this class when implementing new losses.""" | ||
21 | + | ||
22 | + def calculate_loss(self, unused_predictions, unused_labels, **unused_params): | ||
23 | + """Calculates the average loss of the examples in a mini-batch. | ||
24 | + | ||
25 | + Args: | ||
26 | + unused_predictions: a 2-d tensor storing the prediction scores, in which | ||
27 | + each row represents a sample in the mini-batch and each column | ||
28 | + represents a class. | ||
29 | + unused_labels: a 2-d tensor storing the labels, which has the same shape | ||
30 | + as the unused_predictions. The labels must be in the range of 0 and 1. | ||
31 | + unused_params: loss specific parameters. | ||
32 | + | ||
33 | + Returns: | ||
34 | + A scalar loss tensor. | ||
35 | + """ | ||
36 | + raise NotImplementedError() | ||
37 | + | ||
38 | + | ||
39 | +class CrossEntropyLoss(BaseLoss): | ||
40 | + """Calculate the cross entropy loss between the predictions and labels.""" | ||
41 | + | ||
42 | + def calculate_loss(self, | ||
43 | + predictions, | ||
44 | + labels, | ||
45 | + label_weights=None, | ||
46 | + **unused_params): | ||
47 | + with tf.name_scope("loss_xent"): | ||
48 | + epsilon = 1e-5 | ||
49 | + float_labels = tf.cast(labels, tf.float32) | ||
50 | + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + ( | ||
51 | + 1 - float_labels) * tf.math.log(1 - predictions + epsilon) | ||
52 | + cross_entropy_loss = tf.negative(cross_entropy_loss) | ||
53 | + if label_weights is not None: | ||
54 | + cross_entropy_loss *= label_weights | ||
55 | + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) | ||
56 | + | ||
57 | + | ||
58 | +class HingeLoss(BaseLoss): | ||
59 | + """Calculate the hinge loss between the predictions and labels. | ||
60 | + | ||
61 | + Note the subgradient is used in the backpropagation, and thus the optimization | ||
62 | + may converge slower. The predictions trained by the hinge loss are between -1 | ||
63 | + and +1. | ||
64 | + """ | ||
65 | + | ||
66 | + def calculate_loss(self, predictions, labels, b=1.0, **unused_params): | ||
67 | + with tf.name_scope("loss_hinge"): | ||
68 | + float_labels = tf.cast(labels, tf.float32) | ||
69 | + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) | ||
70 | + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) | ||
71 | + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) | ||
72 | + hinge_loss = tf.maximum( | ||
73 | + all_zeros, | ||
74 | + tf.scalar_mul(b, all_ones) - sign_labels * predictions) | ||
75 | + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) | ||
76 | + | ||
77 | + | ||
78 | +class SoftmaxLoss(BaseLoss): | ||
79 | + """Calculate the softmax loss between the predictions and labels. | ||
80 | + | ||
81 | + The function calculates the loss in the following way: first we feed the | ||
82 | + predictions to the softmax activation function and then we calculate | ||
83 | + the minus linear dot product between the logged softmax activations and the | ||
84 | + normalized ground truth label. | ||
85 | + | ||
86 | + It is an extension to the one-hot label. It allows for more than one positive | ||
87 | + labels for each sample. | ||
88 | + """ | ||
89 | + | ||
90 | + def calculate_loss(self, predictions, labels, **unused_params): | ||
91 | + with tf.name_scope("loss_softmax"): | ||
92 | + epsilon = 10e-8 | ||
93 | + float_labels = tf.cast(labels, tf.float32) | ||
94 | + # l1 normalization (labels are no less than 0) | ||
95 | + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True), | ||
96 | + epsilon) | ||
97 | + norm_float_labels = tf.div(float_labels, label_rowsum) | ||
98 | + softmax_outputs = tf.nn.softmax(predictions) | ||
99 | + softmax_loss = tf.negative( | ||
100 | + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)), | ||
101 | + 1)) | ||
102 | + return tf.reduce_mean(softmax_loss) |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate the mean average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating mean average precision | ||
17 | +for an entire list or the top-n ranked items. | ||
18 | + | ||
19 | +Example usages: | ||
20 | +We first call the function accumulate many times to process parts of the ranked | ||
21 | +list. After processing all the parts, we call peek_map_at_n | ||
22 | +to calculate the mean average precision. | ||
23 | + | ||
24 | +``` | ||
25 | +import random | ||
26 | + | ||
27 | +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) | ||
28 | +a = np.array([[random.choice([0, 1]) for _ in xrange(50)] | ||
29 | + for _ in xrange(1000)]) | ||
30 | + | ||
31 | +# mean average precision for 50 classes. | ||
32 | +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( | ||
33 | + num_class=50) | ||
34 | +calculator.accumulate(p, a) | ||
35 | +aps = calculator.peek_map_at_n() | ||
36 | +``` | ||
37 | +""" | ||
38 | + | ||
39 | +import average_precision_calculator | ||
40 | + | ||
41 | + | ||
42 | +class MeanAveragePrecisionCalculator(object): | ||
43 | + """This class is to calculate mean average precision.""" | ||
44 | + | ||
45 | + def __init__(self, num_class, filter_empty_classes=True, top_n=None): | ||
46 | + """Construct a calculator to calculate the (macro) average precision. | ||
47 | + | ||
48 | + Args: | ||
49 | + num_class: A positive Integer specifying the number of classes. | ||
50 | + filter_empty_classes: whether to filter classes without any positives. | ||
51 | + top_n: A positive Integer specifying the average precision at n, or None | ||
52 | + to use all provided data points. | ||
53 | + | ||
54 | + Raises: | ||
55 | + ValueError: An error occurred when num_class is not a positive integer; | ||
56 | + or the top_n_array is not a list of positive integers. | ||
57 | + """ | ||
58 | + if not isinstance(num_class, int) or num_class <= 1: | ||
59 | + raise ValueError("num_class must be a positive integer.") | ||
60 | + | ||
61 | + self._ap_calculators = [] # member of AveragePrecisionCalculator | ||
62 | + self._num_class = num_class # total number of classes | ||
63 | + self._filter_empty_classes = filter_empty_classes | ||
64 | + for _ in range(num_class): | ||
65 | + self._ap_calculators.append( | ||
66 | + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n)) | ||
67 | + | ||
68 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
69 | + """Accumulate the predictions and their ground truth labels. | ||
70 | + | ||
71 | + Args: | ||
72 | + predictions: A list of lists storing the prediction scores. The outer | ||
73 | + dimension corresponds to classes. | ||
74 | + actuals: A list of lists storing the ground truth labels. The dimensions | ||
75 | + should correspond to the predictions input. Any value larger than 0 will | ||
76 | + be treated as positives, otherwise as negatives. | ||
77 | + num_positives: If provided, it is a list of numbers representing the | ||
78 | + number of true positives for each class. If not provided, the number of | ||
79 | + true positives will be inferred from the 'actuals' array. | ||
80 | + | ||
81 | + Raises: | ||
82 | + ValueError: An error occurred when the shape of predictions and actuals | ||
83 | + does not match. | ||
84 | + """ | ||
85 | + if not num_positives: | ||
86 | + num_positives = [None for i in range(self._num_class)] | ||
87 | + | ||
88 | + calculators = self._ap_calculators | ||
89 | + for i in range(self._num_class): | ||
90 | + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) | ||
91 | + | ||
92 | + def clear(self): | ||
93 | + for calculator in self._ap_calculators: | ||
94 | + calculator.clear() | ||
95 | + | ||
96 | + def is_empty(self): | ||
97 | + return ([calculator.heap_size for calculator in self._ap_calculators | ||
98 | + ] == [0 for _ in range(self._num_class)]) | ||
99 | + | ||
100 | + def peek_map_at_n(self): | ||
101 | + """Peek the non-interpolated mean average precision at n. | ||
102 | + | ||
103 | + Returns: | ||
104 | + An array of non-interpolated average precision at n (default 0) for each | ||
105 | + class. | ||
106 | + """ | ||
107 | + aps = [] | ||
108 | + for i in range(self._num_class): | ||
109 | + if (not self._filter_empty_classes or | ||
110 | + self._ap_calculators[i].num_accumulated_positives > 0): | ||
111 | + ap = self._ap_calculators[i].peek_ap_at_n() | ||
112 | + aps.append(ap) | ||
113 | + return aps |
web/backend/yt8m/model_utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for model construction.""" | ||
15 | +import numpy | ||
16 | +import tensorflow as tf | ||
17 | +from tensorflow import logging | ||
18 | +from tensorflow import flags | ||
19 | +import tensorflow.contrib.slim as slim | ||
20 | + | ||
21 | + | ||
22 | +def SampleRandomSequence(model_input, num_frames, num_samples): | ||
23 | + """Samples a random sequence of frames of size num_samples. | ||
24 | + | ||
25 | + Args: | ||
26 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
27 | + num_frames: A tensor of size batch_size x 1 | ||
28 | + num_samples: A scalar | ||
29 | + | ||
30 | + Returns: | ||
31 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
32 | + """ | ||
33 | + | ||
34 | + batch_size = tf.shape(model_input)[0] | ||
35 | + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0), | ||
36 | + [batch_size, 1]) | ||
37 | + max_start_frame_index = tf.maximum(num_frames - num_samples, 0) | ||
38 | + start_frame_index = tf.cast( | ||
39 | + tf.multiply(tf.random_uniform([batch_size, 1]), | ||
40 | + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) | ||
41 | + frame_index = tf.minimum(start_frame_index + frame_index_offset, | ||
42 | + tf.cast(num_frames - 1, tf.int32)) | ||
43 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
44 | + [1, num_samples]) | ||
45 | + index = tf.stack([batch_index, frame_index], 2) | ||
46 | + return tf.gather_nd(model_input, index) | ||
47 | + | ||
48 | + | ||
49 | +def SampleRandomFrames(model_input, num_frames, num_samples): | ||
50 | + """Samples a random set of frames of size num_samples. | ||
51 | + | ||
52 | + Args: | ||
53 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
54 | + num_frames: A tensor of size batch_size x 1 | ||
55 | + num_samples: A scalar | ||
56 | + | ||
57 | + Returns: | ||
58 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
59 | + """ | ||
60 | + batch_size = tf.shape(model_input)[0] | ||
61 | + frame_index = tf.cast( | ||
62 | + tf.multiply(tf.random_uniform([batch_size, num_samples]), | ||
63 | + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), | ||
64 | + tf.int32) | ||
65 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
66 | + [1, num_samples]) | ||
67 | + index = tf.stack([batch_index, frame_index], 2) | ||
68 | + return tf.gather_nd(model_input, index) | ||
69 | + | ||
70 | + | ||
71 | +def FramePooling(frames, method, **unused_params): | ||
72 | + """Pools over the frames of a video. | ||
73 | + | ||
74 | + Args: | ||
75 | + frames: A tensor with shape [batch_size, num_frames, feature_size]. | ||
76 | + method: "average", "max", "attention", or "none". | ||
77 | + | ||
78 | + Returns: | ||
79 | + A tensor with shape [batch_size, feature_size] for average, max, or | ||
80 | + attention pooling. A tensor with shape [batch_size*num_frames, feature_size] | ||
81 | + for none pooling. | ||
82 | + | ||
83 | + Raises: | ||
84 | + ValueError: if method is other than "average", "max", "attention", or | ||
85 | + "none". | ||
86 | + """ | ||
87 | + if method == "average": | ||
88 | + return tf.reduce_mean(frames, 1) | ||
89 | + elif method == "max": | ||
90 | + return tf.reduce_max(frames, 1) | ||
91 | + elif method == "none": | ||
92 | + feature_size = frames.shape_as_list()[2] | ||
93 | + return tf.reshape(frames, [-1, feature_size]) | ||
94 | + else: | ||
95 | + raise ValueError("Unrecognized pooling method: %s" % method) |
web/backend/yt8m/models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains the base class for models.""" | ||
15 | + | ||
16 | + | ||
17 | +class BaseModel(object): | ||
18 | + """Inherit from this class when implementing new models.""" | ||
19 | + | ||
20 | + def create_model(self, unused_model_input, **unused_params): | ||
21 | + raise NotImplementedError() |
web/backend/yt8m/readers.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/segment_eval_inference.py
0 → 100644
1 | +"""Eval mAP@N metric from inference file.""" | ||
2 | + | ||
3 | +from __future__ import absolute_import | ||
4 | +from __future__ import division | ||
5 | +from __future__ import print_function | ||
6 | + | ||
7 | +from absl import app | ||
8 | +from absl import flags | ||
9 | + | ||
10 | +import mean_average_precision_calculator as map_calculator | ||
11 | +import numpy as np | ||
12 | +import tensorflow as tf | ||
13 | + | ||
14 | +flags.DEFINE_string( | ||
15 | + "eval_data_pattern", "", | ||
16 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
17 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
18 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
19 | +flags.DEFINE_string( | ||
20 | + "label_cache", "", | ||
21 | + "The path for the label cache file. Leave blank for not to cache.") | ||
22 | +flags.DEFINE_string("submission_file", "", | ||
23 | + "The segment submission file generated by inference.py.") | ||
24 | +flags.DEFINE_integer( | ||
25 | + "top_n", 0, | ||
26 | + "The cap per-class predictions by a maximum of N. Use 0 for not capping.") | ||
27 | + | ||
28 | +FLAGS = flags.FLAGS | ||
29 | + | ||
30 | + | ||
31 | +class Labels(object): | ||
32 | + """Contains the class to hold label objects. | ||
33 | + | ||
34 | + This class can serialize and de-serialize the groundtruths. | ||
35 | + The ground truth is in a mapping from (segment_id, class_id) -> label_score. | ||
36 | + """ | ||
37 | + | ||
38 | + def __init__(self, labels): | ||
39 | + """__init__ method.""" | ||
40 | + self._labels = labels | ||
41 | + | ||
42 | + @property | ||
43 | + def labels(self): | ||
44 | + """Return the ground truth mapping. See class docstring for details.""" | ||
45 | + return self._labels | ||
46 | + | ||
47 | + def to_file(self, file_name): | ||
48 | + """Materialize the GT mapping to file.""" | ||
49 | + with tf.gfile.Open(file_name, "w") as fobj: | ||
50 | + for k, v in self._labels.items(): | ||
51 | + seg_id, label = k | ||
52 | + line = "%s,%s,%s\n" % (seg_id, label, v) | ||
53 | + fobj.write(line) | ||
54 | + | ||
55 | + @classmethod | ||
56 | + def from_file(cls, file_name): | ||
57 | + """Read the GT mapping from cached file.""" | ||
58 | + labels = {} | ||
59 | + with tf.gfile.Open(file_name) as fobj: | ||
60 | + for line in fobj: | ||
61 | + line = line.strip().strip("\n") | ||
62 | + seg_id, label, score = line.split(",") | ||
63 | + labels[(seg_id, int(label))] = float(score) | ||
64 | + return cls(labels) | ||
65 | + | ||
66 | + | ||
67 | +def read_labels(data_pattern, cache_path=""): | ||
68 | + """Read labels from TFRecords. | ||
69 | + | ||
70 | + Args: | ||
71 | + data_pattern: the data pattern to the TFRecords. | ||
72 | + cache_path: the cache path for the label file. | ||
73 | + | ||
74 | + Returns: | ||
75 | + a Labels object. | ||
76 | + """ | ||
77 | + if cache_path: | ||
78 | + if tf.gfile.Exists(cache_path): | ||
79 | + tf.logging.info("Reading cached labels from %s..." % cache_path) | ||
80 | + return Labels.from_file(cache_path) | ||
81 | + tf.enable_eager_execution() | ||
82 | + data_paths = tf.gfile.Glob(data_pattern) | ||
83 | + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) | ||
84 | + context_features = { | ||
85 | + "id": tf.FixedLenFeature([], tf.string), | ||
86 | + "segment_labels": tf.VarLenFeature(tf.int64), | ||
87 | + "segment_start_times": tf.VarLenFeature(tf.int64), | ||
88 | + "segment_scores": tf.VarLenFeature(tf.float32) | ||
89 | + } | ||
90 | + | ||
91 | + def _parse_se_func(sequence_example): | ||
92 | + return tf.parse_single_sequence_example(sequence_example, | ||
93 | + context_features=context_features) | ||
94 | + | ||
95 | + ds = ds.map(_parse_se_func) | ||
96 | + rated_labels = {} | ||
97 | + tf.logging.info("Reading labels from TFRecords...") | ||
98 | + last_batch = 0 | ||
99 | + batch_size = 5000 | ||
100 | + for cxt_feature_val, _ in ds: | ||
101 | + video_id = cxt_feature_val["id"].numpy() | ||
102 | + segment_labels = cxt_feature_val["segment_labels"].values.numpy() | ||
103 | + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy() | ||
104 | + segment_scores = cxt_feature_val["segment_scores"].values.numpy() | ||
105 | + for label, start_time, score in zip(segment_labels, segment_start_times, | ||
106 | + segment_scores): | ||
107 | + rated_labels[("%s:%d" % (video_id, start_time), label)] = score | ||
108 | + batch_id = len(rated_labels) // batch_size | ||
109 | + if batch_id != last_batch: | ||
110 | + tf.logging.info("%d examples processed.", len(rated_labels)) | ||
111 | + last_batch = batch_id | ||
112 | + tf.logging.info("Finish reading labels from TFRecords...") | ||
113 | + labels_obj = Labels(rated_labels) | ||
114 | + if cache_path: | ||
115 | + tf.logging.info("Caching labels to %s..." % cache_path) | ||
116 | + labels_obj.to_file(cache_path) | ||
117 | + return labels_obj | ||
118 | + | ||
119 | + | ||
120 | +def read_segment_predictions(file_path, labels, top_n=None): | ||
121 | + """Read segement predictions. | ||
122 | + | ||
123 | + Args: | ||
124 | + file_path: the submission file path. | ||
125 | + labels: a Labels object containing the eval labels. | ||
126 | + top_n: the per-class class capping. | ||
127 | + | ||
128 | + Returns: | ||
129 | + a segment prediction list for each classes. | ||
130 | + """ | ||
131 | + cls_preds = {} # A label_id to pred list mapping. | ||
132 | + with tf.gfile.Open(file_path) as fobj: | ||
133 | + tf.logging.info("Reading predictions from %s..." % file_path) | ||
134 | + for line in fobj: | ||
135 | + label_id, pred_ids_val = line.split(",") | ||
136 | + pred_ids = pred_ids_val.split(" ") | ||
137 | + if top_n: | ||
138 | + pred_ids = pred_ids[:top_n] | ||
139 | + pred_ids = [ | ||
140 | + pred_id for pred_id in pred_ids | ||
141 | + if (pred_id, int(label_id)) in labels.labels | ||
142 | + ] | ||
143 | + cls_preds[int(label_id)] = pred_ids | ||
144 | + if len(cls_preds) % 50 == 0: | ||
145 | + tf.logging.info("Processed %d classes..." % len(cls_preds)) | ||
146 | + tf.logging.info("Finish reading predictions.") | ||
147 | + return cls_preds | ||
148 | + | ||
149 | + | ||
150 | +def main(unused_argv): | ||
151 | + """Entry function of the script.""" | ||
152 | + if not FLAGS.submission_file: | ||
153 | + raise ValueError("You must input submission file.") | ||
154 | + eval_labels = read_labels(FLAGS.eval_data_pattern, | ||
155 | + cache_path=FLAGS.label_cache) | ||
156 | + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels)) | ||
157 | + positive_counter = {} | ||
158 | + for k, v in eval_labels.labels.items(): | ||
159 | + _, label_id = k | ||
160 | + if v > 0: | ||
161 | + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1 | ||
162 | + | ||
163 | + seg_preds = read_segment_predictions(FLAGS.submission_file, | ||
164 | + eval_labels, | ||
165 | + top_n=FLAGS.top_n) | ||
166 | + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds)) | ||
167 | + seg_labels = [] | ||
168 | + seg_scored_preds = [] | ||
169 | + num_positives = [] | ||
170 | + for label_id in sorted(seg_preds): | ||
171 | + class_preds = seg_preds[label_id] | ||
172 | + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds] | ||
173 | + seg_labels.append(seg_label) | ||
174 | + seg_scored_pred = [] | ||
175 | + if class_preds: | ||
176 | + seg_scored_pred = [ | ||
177 | + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) | ||
178 | + ] | ||
179 | + seg_scored_preds.append(seg_scored_pred) | ||
180 | + num_positives.append(positive_counter[label_id]) | ||
181 | + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) | ||
182 | + map_at_n = np.mean(map_cal.peek_map_at_n()) | ||
183 | + tf.logging.info("Num classes: %d | mAP@%d: %.6f" % | ||
184 | + (len(seg_preds), FLAGS.top_n, map_at_n)) | ||
185 | + | ||
186 | + | ||
187 | +if __name__ == "__main__": | ||
188 | + app.run(main) |
web/backend/yt8m/segment_label_ids.csv
0 → 100644
1 | +Index | ||
2 | +3 | ||
3 | +7 | ||
4 | +8 | ||
5 | +11 | ||
6 | +12 | ||
7 | +17 | ||
8 | +18 | ||
9 | +19 | ||
10 | +21 | ||
11 | +22 | ||
12 | +23 | ||
13 | +28 | ||
14 | +31 | ||
15 | +30 | ||
16 | +32 | ||
17 | +33 | ||
18 | +34 | ||
19 | +41 | ||
20 | +43 | ||
21 | +45 | ||
22 | +46 | ||
23 | +48 | ||
24 | +53 | ||
25 | +54 | ||
26 | +52 | ||
27 | +55 | ||
28 | +58 | ||
29 | +59 | ||
30 | +60 | ||
31 | +61 | ||
32 | +65 | ||
33 | +68 | ||
34 | +73 | ||
35 | +71 | ||
36 | +74 | ||
37 | +75 | ||
38 | +76 | ||
39 | +77 | ||
40 | +80 | ||
41 | +83 | ||
42 | +90 | ||
43 | +88 | ||
44 | +89 | ||
45 | +92 | ||
46 | +95 | ||
47 | +100 | ||
48 | +101 | ||
49 | +99 | ||
50 | +104 | ||
51 | +105 | ||
52 | +109 | ||
53 | +113 | ||
54 | +112 | ||
55 | +115 | ||
56 | +116 | ||
57 | +118 | ||
58 | +120 | ||
59 | +121 | ||
60 | +123 | ||
61 | +125 | ||
62 | +127 | ||
63 | +131 | ||
64 | +128 | ||
65 | +129 | ||
66 | +130 | ||
67 | +137 | ||
68 | +141 | ||
69 | +143 | ||
70 | +145 | ||
71 | +148 | ||
72 | +152 | ||
73 | +151 | ||
74 | +156 | ||
75 | +155 | ||
76 | +158 | ||
77 | +160 | ||
78 | +164 | ||
79 | +163 | ||
80 | +169 | ||
81 | +170 | ||
82 | +172 | ||
83 | +171 | ||
84 | +173 | ||
85 | +174 | ||
86 | +175 | ||
87 | +176 | ||
88 | +178 | ||
89 | +182 | ||
90 | +184 | ||
91 | +186 | ||
92 | +188 | ||
93 | +187 | ||
94 | +192 | ||
95 | +191 | ||
96 | +190 | ||
97 | +194 | ||
98 | +197 | ||
99 | +196 | ||
100 | +198 | ||
101 | +201 | ||
102 | +202 | ||
103 | +200 | ||
104 | +199 | ||
105 | +205 | ||
106 | +204 | ||
107 | +209 | ||
108 | +207 | ||
109 | +206 | ||
110 | +210 | ||
111 | +213 | ||
112 | +214 | ||
113 | +220 | ||
114 | +218 | ||
115 | +217 | ||
116 | +226 | ||
117 | +227 | ||
118 | +231 | ||
119 | +232 | ||
120 | +229 | ||
121 | +233 | ||
122 | +235 | ||
123 | +237 | ||
124 | +244 | ||
125 | +240 | ||
126 | +249 | ||
127 | +246 | ||
128 | +248 | ||
129 | +239 | ||
130 | +250 | ||
131 | +245 | ||
132 | +255 | ||
133 | +253 | ||
134 | +256 | ||
135 | +261 | ||
136 | +259 | ||
137 | +263 | ||
138 | +262 | ||
139 | +266 | ||
140 | +267 | ||
141 | +268 | ||
142 | +269 | ||
143 | +271 | ||
144 | +276 | ||
145 | +273 | ||
146 | +277 | ||
147 | +274 | ||
148 | +278 | ||
149 | +279 | ||
150 | +280 | ||
151 | +288 | ||
152 | +291 | ||
153 | +295 | ||
154 | +294 | ||
155 | +293 | ||
156 | +297 | ||
157 | +296 | ||
158 | +300 | ||
159 | +299 | ||
160 | +303 | ||
161 | +302 | ||
162 | +304 | ||
163 | +305 | ||
164 | +313 | ||
165 | +307 | ||
166 | +311 | ||
167 | +310 | ||
168 | +312 | ||
169 | +316 | ||
170 | +318 | ||
171 | +321 | ||
172 | +322 | ||
173 | +331 | ||
174 | +333 | ||
175 | +329 | ||
176 | +330 | ||
177 | +334 | ||
178 | +343 | ||
179 | +349 | ||
180 | +340 | ||
181 | +344 | ||
182 | +348 | ||
183 | +358 | ||
184 | +347 | ||
185 | +359 | ||
186 | +355 | ||
187 | +361 | ||
188 | +360 | ||
189 | +364 | ||
190 | +365 | ||
191 | +368 | ||
192 | +369 | ||
193 | +366 | ||
194 | +370 | ||
195 | +374 | ||
196 | +380 | ||
197 | +373 | ||
198 | +385 | ||
199 | +384 | ||
200 | +388 | ||
201 | +389 | ||
202 | +382 | ||
203 | +393 | ||
204 | +381 | ||
205 | +390 | ||
206 | +394 | ||
207 | +399 | ||
208 | +397 | ||
209 | +396 | ||
210 | +402 | ||
211 | +400 | ||
212 | +398 | ||
213 | +401 | ||
214 | +405 | ||
215 | +406 | ||
216 | +410 | ||
217 | +408 | ||
218 | +416 | ||
219 | +415 | ||
220 | +419 | ||
221 | +422 | ||
222 | +414 | ||
223 | +421 | ||
224 | +424 | ||
225 | +429 | ||
226 | +418 | ||
227 | +427 | ||
228 | +434 | ||
229 | +428 | ||
230 | +435 | ||
231 | +430 | ||
232 | +441 | ||
233 | +439 | ||
234 | +437 | ||
235 | +443 | ||
236 | +440 | ||
237 | +442 | ||
238 | +445 | ||
239 | +446 | ||
240 | +448 | ||
241 | +454 | ||
242 | +444 | ||
243 | +453 | ||
244 | +455 | ||
245 | +451 | ||
246 | +452 | ||
247 | +458 | ||
248 | +460 | ||
249 | +465 | ||
250 | +457 | ||
251 | +463 | ||
252 | +462 | ||
253 | +461 | ||
254 | +464 | ||
255 | +469 | ||
256 | +468 | ||
257 | +472 | ||
258 | +473 | ||
259 | +471 | ||
260 | +475 | ||
261 | +474 | ||
262 | +477 | ||
263 | +485 | ||
264 | +491 | ||
265 | +488 | ||
266 | +482 | ||
267 | +490 | ||
268 | +496 | ||
269 | +494 | ||
270 | +483 | ||
271 | +495 | ||
272 | +493 | ||
273 | +507 | ||
274 | +501 | ||
275 | +499 | ||
276 | +503 | ||
277 | +498 | ||
278 | +514 | ||
279 | +504 | ||
280 | +502 | ||
281 | +506 | ||
282 | +508 | ||
283 | +511 | ||
284 | +527 | ||
285 | +526 | ||
286 | +532 | ||
287 | +513 | ||
288 | +519 | ||
289 | +525 | ||
290 | +518 | ||
291 | +528 | ||
292 | +522 | ||
293 | +523 | ||
294 | +535 | ||
295 | +539 | ||
296 | +540 | ||
297 | +533 | ||
298 | +521 | ||
299 | +541 | ||
300 | +547 | ||
301 | +550 | ||
302 | +544 | ||
303 | +549 | ||
304 | +551 | ||
305 | +554 | ||
306 | +543 | ||
307 | +548 | ||
308 | +557 | ||
309 | +560 | ||
310 | +552 | ||
311 | +559 | ||
312 | +563 | ||
313 | +565 | ||
314 | +567 | ||
315 | +555 | ||
316 | +576 | ||
317 | +568 | ||
318 | +564 | ||
319 | +573 | ||
320 | +581 | ||
321 | +580 | ||
322 | +572 | ||
323 | +571 | ||
324 | +584 | ||
325 | +590 | ||
326 | +585 | ||
327 | +587 | ||
328 | +588 | ||
329 | +592 | ||
330 | +598 | ||
331 | +597 | ||
332 | +599 | ||
333 | +603 | ||
334 | +600 | ||
335 | +604 | ||
336 | +605 | ||
337 | +614 | ||
338 | +602 | ||
339 | +610 | ||
340 | +608 | ||
341 | +611 | ||
342 | +612 | ||
343 | +613 | ||
344 | +617 | ||
345 | +620 | ||
346 | +607 | ||
347 | +624 | ||
348 | +627 | ||
349 | +625 | ||
350 | +631 | ||
351 | +629 | ||
352 | +638 | ||
353 | +632 | ||
354 | +634 | ||
355 | +644 | ||
356 | +641 | ||
357 | +642 | ||
358 | +646 | ||
359 | +652 | ||
360 | +647 | ||
361 | +637 | ||
362 | +661 | ||
363 | +635 | ||
364 | +658 | ||
365 | +648 | ||
366 | +663 | ||
367 | +668 | ||
368 | +664 | ||
369 | +656 | ||
370 | +666 | ||
371 | +671 | ||
372 | +683 | ||
373 | +675 | ||
374 | +669 | ||
375 | +676 | ||
376 | +667 | ||
377 | +691 | ||
378 | +685 | ||
379 | +673 | ||
380 | +688 | ||
381 | +702 | ||
382 | +684 | ||
383 | +679 | ||
384 | +694 | ||
385 | +686 | ||
386 | +689 | ||
387 | +680 | ||
388 | +693 | ||
389 | +703 | ||
390 | +697 | ||
391 | +698 | ||
392 | +692 | ||
393 | +705 | ||
394 | +706 | ||
395 | +712 | ||
396 | +711 | ||
397 | +709 | ||
398 | +710 | ||
399 | +726 | ||
400 | +713 | ||
401 | +721 | ||
402 | +720 | ||
403 | +715 | ||
404 | +717 | ||
405 | +730 | ||
406 | +728 | ||
407 | +723 | ||
408 | +716 | ||
409 | +722 | ||
410 | +718 | ||
411 | +732 | ||
412 | +724 | ||
413 | +736 | ||
414 | +725 | ||
415 | +742 | ||
416 | +727 | ||
417 | +735 | ||
418 | +740 | ||
419 | +748 | ||
420 | +738 | ||
421 | +746 | ||
422 | +751 | ||
423 | +749 | ||
424 | +752 | ||
425 | +754 | ||
426 | +760 | ||
427 | +763 | ||
428 | +756 | ||
429 | +758 | ||
430 | +766 | ||
431 | +764 | ||
432 | +757 | ||
433 | +780 | ||
434 | +767 | ||
435 | +769 | ||
436 | +771 | ||
437 | +786 | ||
438 | +785 | ||
439 | +781 | ||
440 | +787 | ||
441 | +778 | ||
442 | +783 | ||
443 | +792 | ||
444 | +791 | ||
445 | +795 | ||
446 | +788 | ||
447 | +805 | ||
448 | +802 | ||
449 | +801 | ||
450 | +793 | ||
451 | +796 | ||
452 | +804 | ||
453 | +803 | ||
454 | +797 | ||
455 | +814 | ||
456 | +813 | ||
457 | +789 | ||
458 | +808 | ||
459 | +818 | ||
460 | +816 | ||
461 | +817 | ||
462 | +811 | ||
463 | +820 | ||
464 | +826 | ||
465 | +829 | ||
466 | +824 | ||
467 | +821 | ||
468 | +825 | ||
469 | +822 | ||
470 | +835 | ||
471 | +833 | ||
472 | +843 | ||
473 | +823 | ||
474 | +827 | ||
475 | +830 | ||
476 | +832 | ||
477 | +837 | ||
478 | +852 | ||
479 | +844 | ||
480 | +841 | ||
481 | +812 | ||
482 | +847 | ||
483 | +862 | ||
484 | +869 | ||
485 | +860 | ||
486 | +838 | ||
487 | +870 | ||
488 | +846 | ||
489 | +858 | ||
490 | +854 | ||
491 | +880 | ||
492 | +876 | ||
493 | +857 | ||
494 | +859 | ||
495 | +877 | ||
496 | +871 | ||
497 | +855 | ||
498 | +875 | ||
499 | +861 | ||
500 | +867 | ||
501 | +892 | ||
502 | +898 | ||
503 | +888 | ||
504 | +884 | ||
505 | +887 | ||
506 | +891 | ||
507 | +906 | ||
508 | +900 | ||
509 | +878 | ||
510 | +885 | ||
511 | +883 | ||
512 | +901 | ||
513 | +903 | ||
514 | +907 | ||
515 | +930 | ||
516 | +897 | ||
517 | +914 | ||
518 | +917 | ||
519 | +910 | ||
520 | +905 | ||
521 | +909 | ||
522 | +933 | ||
523 | +932 | ||
524 | +922 | ||
525 | +913 | ||
526 | +923 | ||
527 | +931 | ||
528 | +911 | ||
529 | +937 | ||
530 | +918 | ||
531 | +955 | ||
532 | +915 | ||
533 | +944 | ||
534 | +952 | ||
535 | +945 | ||
536 | +948 | ||
537 | +946 | ||
538 | +970 | ||
539 | +974 | ||
540 | +958 | ||
541 | +925 | ||
542 | +979 | ||
543 | +942 | ||
544 | +965 | ||
545 | +975 | ||
546 | +950 | ||
547 | +982 | ||
548 | +940 | ||
549 | +973 | ||
550 | +962 | ||
551 | +972 | ||
552 | +957 | ||
553 | +984 | ||
554 | +983 | ||
555 | +964 | ||
556 | +1007 | ||
557 | +971 | ||
558 | +981 | ||
559 | +954 | ||
560 | +993 | ||
561 | +991 | ||
562 | +996 | ||
563 | +1005 | ||
564 | +1015 | ||
565 | +1009 | ||
566 | +995 | ||
567 | +986 | ||
568 | +1000 | ||
569 | +985 | ||
570 | +980 | ||
571 | +1016 | ||
572 | +1011 | ||
573 | +999 | ||
574 | +1002 | ||
575 | +994 | ||
576 | +1013 | ||
577 | +1010 | ||
578 | +992 | ||
579 | +1008 | ||
580 | +1036 | ||
581 | +1025 | ||
582 | +1012 | ||
583 | +990 | ||
584 | +1037 | ||
585 | +1040 | ||
586 | +1031 | ||
587 | +1019 | ||
588 | +1052 | ||
589 | +1001 | ||
590 | +1055 | ||
591 | +1032 | ||
592 | +1069 | ||
593 | +1058 | ||
594 | +1014 | ||
595 | +1023 | ||
596 | +1030 | ||
597 | +1061 | ||
598 | +1035 | ||
599 | +1034 | ||
600 | +1053 | ||
601 | +1045 | ||
602 | +1046 | ||
603 | +1067 | ||
604 | +1060 | ||
605 | +1049 | ||
606 | +1056 | ||
607 | +1074 | ||
608 | +1066 | ||
609 | +1044 | ||
610 | +1038 | ||
611 | +1073 | ||
612 | +1077 | ||
613 | +1068 | ||
614 | +1057 | ||
615 | +1072 | ||
616 | +1104 | ||
617 | +1083 | ||
618 | +1089 | ||
619 | +1087 | ||
620 | +1099 | ||
621 | +1076 | ||
622 | +1086 | ||
623 | +1098 | ||
624 | +1094 | ||
625 | +1095 | ||
626 | +1096 | ||
627 | +1101 | ||
628 | +1107 | ||
629 | +1105 | ||
630 | +1117 | ||
631 | +1093 | ||
632 | +1106 | ||
633 | +1122 | ||
634 | +1119 | ||
635 | +1103 | ||
636 | +1128 | ||
637 | +1120 | ||
638 | +1126 | ||
639 | +1102 | ||
640 | +1115 | ||
641 | +1124 | ||
642 | +1123 | ||
643 | +1131 | ||
644 | +1136 | ||
645 | +1144 | ||
646 | +1121 | ||
647 | +1137 | ||
648 | +1132 | ||
649 | +1133 | ||
650 | +1157 | ||
651 | +1134 | ||
652 | +1143 | ||
653 | +1159 | ||
654 | +1164 | ||
655 | +1155 | ||
656 | +1142 | ||
657 | +1150 | ||
658 | +1148 | ||
659 | +1161 | ||
660 | +1165 | ||
661 | +1147 | ||
662 | +1162 | ||
663 | +1152 | ||
664 | +1174 | ||
665 | +1160 | ||
666 | +1166 | ||
667 | +1190 | ||
668 | +1175 | ||
669 | +1167 | ||
670 | +1156 | ||
671 | +1180 | ||
672 | +1171 | ||
673 | +1179 | ||
674 | +1172 | ||
675 | +1186 | ||
676 | +1188 | ||
677 | +1201 | ||
678 | +1177 | ||
679 | +1208 | ||
680 | +1183 | ||
681 | +1189 | ||
682 | +1192 | ||
683 | +1209 | ||
684 | +1214 | ||
685 | +1197 | ||
686 | +1168 | ||
687 | +1202 | ||
688 | +1205 | ||
689 | +1203 | ||
690 | +1199 | ||
691 | +1219 | ||
692 | +1217 | ||
693 | +1187 | ||
694 | +1206 | ||
695 | +1210 | ||
696 | +1241 | ||
697 | +1221 | ||
698 | +1218 | ||
699 | +1223 | ||
700 | +1236 | ||
701 | +1212 | ||
702 | +1237 | ||
703 | +1195 | ||
704 | +1216 | ||
705 | +1247 | ||
706 | +1234 | ||
707 | +1240 | ||
708 | +1257 | ||
709 | +1224 | ||
710 | +1243 | ||
711 | +1259 | ||
712 | +1242 | ||
713 | +1282 | ||
714 | +1222 | ||
715 | +1254 | ||
716 | +1227 | ||
717 | +1235 | ||
718 | +1269 | ||
719 | +1258 | ||
720 | +1290 | ||
721 | +1275 | ||
722 | +1262 | ||
723 | +1252 | ||
724 | +1248 | ||
725 | +1272 | ||
726 | +1246 | ||
727 | +1225 | ||
728 | +1245 | ||
729 | +1277 | ||
730 | +1298 | ||
731 | +1288 | ||
732 | +1271 | ||
733 | +1265 | ||
734 | +1286 | ||
735 | +1260 | ||
736 | +1266 | ||
737 | +1296 | ||
738 | +1280 | ||
739 | +1285 | ||
740 | +1293 | ||
741 | +1276 | ||
742 | +1287 | ||
743 | +1289 | ||
744 | +1261 | ||
745 | +1264 | ||
746 | +1295 | ||
747 | +1291 | ||
748 | +1283 | ||
749 | +1311 | ||
750 | +1303 | ||
751 | +1330 | ||
752 | +1315 | ||
753 | +1300 | ||
754 | +1333 | ||
755 | +1307 | ||
756 | +1325 | ||
757 | +1334 | ||
758 | +1316 | ||
759 | +1314 | ||
760 | +1317 | ||
761 | +1310 | ||
762 | +1329 | ||
763 | +1324 | ||
764 | +1339 | ||
765 | +1346 | ||
766 | +1342 | ||
767 | +1352 | ||
768 | +1321 | ||
769 | +1376 | ||
770 | +1366 | ||
771 | +1308 | ||
772 | +1345 | ||
773 | +1348 | ||
774 | +1386 | ||
775 | +1383 | ||
776 | +1372 | ||
777 | +1367 | ||
778 | +1400 | ||
779 | +1382 | ||
780 | +1375 | ||
781 | +1392 | ||
782 | +1380 | ||
783 | +1371 | ||
784 | +1393 | ||
785 | +1389 | ||
786 | +1353 | ||
787 | +1387 | ||
788 | +1374 | ||
789 | +1379 | ||
790 | +1381 | ||
791 | +1359 | ||
792 | +1360 | ||
793 | +1396 | ||
794 | +1399 | ||
795 | +1365 | ||
796 | +1424 | ||
797 | +1373 | ||
798 | +1411 | ||
799 | +1401 | ||
800 | +1397 | ||
801 | +1395 | ||
802 | +1412 | ||
803 | +1394 | ||
804 | +1368 | ||
805 | +1423 | ||
806 | +1391 | ||
807 | +1435 | ||
808 | +1409 | ||
809 | +1443 | ||
810 | +1402 | ||
811 | +1425 | ||
812 | +1415 | ||
813 | +1421 | ||
814 | +1426 | ||
815 | +1433 | ||
816 | +1420 | ||
817 | +1452 | ||
818 | +1436 | ||
819 | +1430 | ||
820 | +1408 | ||
821 | +1458 | ||
822 | +1429 | ||
823 | +1453 | ||
824 | +1454 | ||
825 | +1447 | ||
826 | +1472 | ||
827 | +1486 | ||
828 | +1468 | ||
829 | +1461 | ||
830 | +1467 | ||
831 | +1484 | ||
832 | +1457 | ||
833 | +1444 | ||
834 | +1450 | ||
835 | +1451 | ||
836 | +1459 | ||
837 | +1462 | ||
838 | +1449 | ||
839 | +1476 | ||
840 | +1470 | ||
841 | +1471 | ||
842 | +1498 | ||
843 | +1488 | ||
844 | +1442 | ||
845 | +1480 | ||
846 | +1456 | ||
847 | +1466 | ||
848 | +1505 | ||
849 | +1517 | ||
850 | +1464 | ||
851 | +1503 | ||
852 | +1490 | ||
853 | +1519 | ||
854 | +1481 | ||
855 | +1493 | ||
856 | +1463 | ||
857 | +1532 | ||
858 | +1487 | ||
859 | +1501 | ||
860 | +1500 | ||
861 | +1495 | ||
862 | +1509 | ||
863 | +1535 | ||
864 | +1506 | ||
865 | +1521 | ||
866 | +1580 | ||
867 | +1540 | ||
868 | +1502 | ||
869 | +1520 | ||
870 | +1496 | ||
871 | +1569 | ||
872 | +1515 | ||
873 | +1489 | ||
874 | +1507 | ||
875 | +1527 | ||
876 | +1545 | ||
877 | +1560 | ||
878 | +1510 | ||
879 | +1514 | ||
880 | +1526 | ||
881 | +1594 | ||
882 | +1511 | ||
883 | +1572 | ||
884 | +1548 | ||
885 | +1584 | ||
886 | +1556 | ||
887 | +1588 | ||
888 | +1628 | ||
889 | +1555 | ||
890 | +1568 | ||
891 | +1550 | ||
892 | +1622 | ||
893 | +1563 | ||
894 | +1603 | ||
895 | +1616 | ||
896 | +1576 | ||
897 | +1549 | ||
898 | +1537 | ||
899 | +1593 | ||
900 | +1618 | ||
901 | +1645 | ||
902 | +1624 | ||
903 | +1617 | ||
904 | +1634 | ||
905 | +1595 | ||
906 | +1597 | ||
907 | +1590 | ||
908 | +1632 | ||
909 | +1575 | ||
910 | +1559 | ||
911 | +1625 | ||
912 | +1615 | ||
913 | +1591 | ||
914 | +1630 | ||
915 | +1608 | ||
916 | +1621 | ||
917 | +1589 | ||
918 | +1646 | ||
919 | +1643 | ||
920 | +1652 | ||
921 | +1627 | ||
922 | +1611 | ||
923 | +1626 | ||
924 | +1613 | ||
925 | +1639 | ||
926 | +1655 | ||
927 | +1620 | ||
928 | +1602 | ||
929 | +1651 | ||
930 | +1653 | ||
931 | +1669 | ||
932 | +1638 | ||
933 | +1696 | ||
934 | +1649 | ||
935 | +1675 | ||
936 | +1660 | ||
937 | +1683 | ||
938 | +1666 | ||
939 | +1671 | ||
940 | +1703 | ||
941 | +1716 | ||
942 | +1637 | ||
943 | +1672 | ||
944 | +1676 | ||
945 | +1692 | ||
946 | +1711 | ||
947 | +1680 | ||
948 | +1641 | ||
949 | +1688 | ||
950 | +1708 | ||
951 | +1704 | ||
952 | +1690 | ||
953 | +1674 | ||
954 | +1718 | ||
955 | +1699 | ||
956 | +1723 | ||
957 | +1756 | ||
958 | +1700 | ||
959 | +1662 | ||
960 | +1715 | ||
961 | +1657 | ||
962 | +1733 | ||
963 | +1728 | ||
964 | +1670 | ||
965 | +1712 | ||
966 | +1685 | ||
967 | +1724 | ||
968 | +1735 | ||
969 | +1714 | ||
970 | +1730 | ||
971 | +1747 | ||
972 | +1656 | ||
973 | +1737 | ||
974 | +1705 | ||
975 | +1693 | ||
976 | +1713 | ||
977 | +1689 | ||
978 | +1753 | ||
979 | +1739 | ||
980 | +1721 | ||
981 | +1725 | ||
982 | +1749 | ||
983 | +1732 | ||
984 | +1743 | ||
985 | +1731 | ||
986 | +1767 | ||
987 | +1738 | ||
988 | +1831 | ||
989 | +1771 | ||
990 | +1726 | ||
991 | +1746 | ||
992 | +1776 | ||
993 | +1775 | ||
994 | +1799 | ||
995 | +1774 | ||
996 | +1780 | ||
997 | +1781 | ||
998 | +1769 | ||
999 | +1805 | ||
1000 | +1788 | ||
1001 | +1801 |
web/backend/yt8m/train.py
0 → 100644
This diff is collapsed. Click to expand it.
web/backend/yt8m/utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for training and evaluating.""" | ||
15 | + | ||
16 | +import numpy | ||
17 | +import tensorflow as tf | ||
18 | +from tensorflow import logging | ||
19 | + | ||
20 | +try: | ||
21 | + xrange # Python 2 | ||
22 | +except NameError: | ||
23 | + xrange = range # Python 3 | ||
24 | + | ||
25 | + | ||
26 | +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): | ||
27 | + """Dequantize the feature from the byte format to the float format. | ||
28 | + | ||
29 | + Args: | ||
30 | + feat_vector: the input 1-d vector. | ||
31 | + max_quantized_value: the maximum of the quantized value. | ||
32 | + min_quantized_value: the minimum of the quantized value. | ||
33 | + | ||
34 | + Returns: | ||
35 | + A float vector which has the same shape as feat_vector. | ||
36 | + """ | ||
37 | + assert max_quantized_value > min_quantized_value | ||
38 | + quantized_range = max_quantized_value - min_quantized_value | ||
39 | + scalar = quantized_range / 255.0 | ||
40 | + bias = (quantized_range / 512.0) + min_quantized_value | ||
41 | + return feat_vector * scalar + bias | ||
42 | + | ||
43 | + | ||
44 | +def MakeSummary(name, value): | ||
45 | + """Creates a tf.Summary proto with the given name and value.""" | ||
46 | + summary = tf.Summary() | ||
47 | + val = summary.value.add() | ||
48 | + val.tag = str(name) | ||
49 | + val.simple_value = float(value) | ||
50 | + return summary | ||
51 | + | ||
52 | + | ||
53 | +def AddGlobalStepSummary(summary_writer, | ||
54 | + global_step_val, | ||
55 | + global_step_info_dict, | ||
56 | + summary_scope="Eval"): | ||
57 | + """Add the global_step summary to the Tensorboard. | ||
58 | + | ||
59 | + Args: | ||
60 | + summary_writer: Tensorflow summary_writer. | ||
61 | + global_step_val: a int value of the global step. | ||
62 | + global_step_info_dict: a dictionary of the evaluation metrics calculated for | ||
63 | + a mini-batch. | ||
64 | + summary_scope: Train or Eval. | ||
65 | + | ||
66 | + Returns: | ||
67 | + A string of this global_step summary | ||
68 | + """ | ||
69 | + this_hit_at_one = global_step_info_dict["hit_at_one"] | ||
70 | + this_perr = global_step_info_dict["perr"] | ||
71 | + this_loss = global_step_info_dict["loss"] | ||
72 | + examples_per_second = global_step_info_dict.get("examples_per_second", -1) | ||
73 | + | ||
74 | + summary_writer.add_summary( | ||
75 | + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), | ||
76 | + global_step_val) | ||
77 | + summary_writer.add_summary( | ||
78 | + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), | ||
79 | + global_step_val) | ||
80 | + summary_writer.add_summary( | ||
81 | + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), | ||
82 | + global_step_val) | ||
83 | + | ||
84 | + if examples_per_second != -1: | ||
85 | + summary_writer.add_summary( | ||
86 | + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", | ||
87 | + examples_per_second), global_step_val) | ||
88 | + | ||
89 | + summary_writer.flush() | ||
90 | + info = ( | ||
91 | + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch " | ||
92 | + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format( | ||
93 | + global_step_val, this_hit_at_one, this_perr, this_loss, | ||
94 | + examples_per_second) | ||
95 | + return info | ||
96 | + | ||
97 | + | ||
98 | +def AddEpochSummary(summary_writer, | ||
99 | + global_step_val, | ||
100 | + epoch_info_dict, | ||
101 | + summary_scope="Eval"): | ||
102 | + """Add the epoch summary to the Tensorboard. | ||
103 | + | ||
104 | + Args: | ||
105 | + summary_writer: Tensorflow summary_writer. | ||
106 | + global_step_val: a int value of the global step. | ||
107 | + epoch_info_dict: a dictionary of the evaluation metrics calculated for the | ||
108 | + whole epoch. | ||
109 | + summary_scope: Train or Eval. | ||
110 | + | ||
111 | + Returns: | ||
112 | + A string of this global_step summary | ||
113 | + """ | ||
114 | + epoch_id = epoch_info_dict["epoch_id"] | ||
115 | + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] | ||
116 | + avg_perr = epoch_info_dict["avg_perr"] | ||
117 | + avg_loss = epoch_info_dict["avg_loss"] | ||
118 | + aps = epoch_info_dict["aps"] | ||
119 | + gap = epoch_info_dict["gap"] | ||
120 | + mean_ap = numpy.mean(aps) | ||
121 | + | ||
122 | + summary_writer.add_summary( | ||
123 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), | ||
124 | + global_step_val) | ||
125 | + summary_writer.add_summary( | ||
126 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), | ||
127 | + global_step_val) | ||
128 | + summary_writer.add_summary( | ||
129 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), | ||
130 | + global_step_val) | ||
131 | + summary_writer.add_summary( | ||
132 | + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val) | ||
133 | + summary_writer.add_summary( | ||
134 | + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val) | ||
135 | + summary_writer.flush() | ||
136 | + | ||
137 | + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " | ||
138 | + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}" | ||
139 | + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss, | ||
140 | + len(aps)) | ||
141 | + return info | ||
142 | + | ||
143 | + | ||
144 | +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): | ||
145 | + """Extract the list of feature names and the dimensionality of each feature | ||
146 | + | ||
147 | + from string of comma separated values. | ||
148 | + | ||
149 | + Args: | ||
150 | + feature_names: string containing comma separated list of feature names | ||
151 | + feature_sizes: string containing comma separated list of feature sizes | ||
152 | + | ||
153 | + Returns: | ||
154 | + List of the feature names and list of the dimensionality of each feature. | ||
155 | + Elements in the first/second list are strings/integers. | ||
156 | + """ | ||
157 | + list_of_feature_names = [ | ||
158 | + feature_names.strip() for feature_names in feature_names.split(",") | ||
159 | + ] | ||
160 | + list_of_feature_sizes = [ | ||
161 | + int(feature_sizes) for feature_sizes in feature_sizes.split(",") | ||
162 | + ] | ||
163 | + if len(list_of_feature_names) != len(list_of_feature_sizes): | ||
164 | + logging.error("length of the feature names (=" + | ||
165 | + str(len(list_of_feature_names)) + ") != length of feature " | ||
166 | + "sizes (=" + str(len(list_of_feature_sizes)) + ")") | ||
167 | + | ||
168 | + return list_of_feature_names, list_of_feature_sizes | ||
169 | + | ||
170 | + | ||
171 | +def clip_gradient_norms(gradients_to_variables, max_norm): | ||
172 | + """Clips the gradients by the given value. | ||
173 | + | ||
174 | + Args: | ||
175 | + gradients_to_variables: A list of gradient to variable pairs (tuples). | ||
176 | + max_norm: the maximum norm value. | ||
177 | + | ||
178 | + Returns: | ||
179 | + A list of clipped gradient to variable pairs. | ||
180 | + """ | ||
181 | + clipped_grads_and_vars = [] | ||
182 | + for grad, var in gradients_to_variables: | ||
183 | + if grad is not None: | ||
184 | + if isinstance(grad, tf.IndexedSlices): | ||
185 | + tmp = tf.clip_by_norm(grad.values, max_norm) | ||
186 | + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) | ||
187 | + else: | ||
188 | + grad = tf.clip_by_norm(grad, max_norm) | ||
189 | + clipped_grads_and_vars.append((grad, var)) | ||
190 | + return clipped_grads_and_vars | ||
191 | + | ||
192 | + | ||
193 | +def combine_gradients(tower_grads): | ||
194 | + """Calculate the combined gradient for each shared variable across all towers. | ||
195 | + | ||
196 | + Note that this function provides a synchronization point across all towers. | ||
197 | + | ||
198 | + Args: | ||
199 | + tower_grads: List of lists of (gradient, variable) tuples. The outer list is | ||
200 | + over individual gradients. The inner list is over the gradient calculation | ||
201 | + for each tower. | ||
202 | + | ||
203 | + Returns: | ||
204 | + List of pairs of (gradient, variable) where the gradient has been summed | ||
205 | + across all towers. | ||
206 | + """ | ||
207 | + filtered_grads = [ | ||
208 | + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads | ||
209 | + ] | ||
210 | + final_grads = [] | ||
211 | + for i in xrange(len(filtered_grads[0])): | ||
212 | + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] | ||
213 | + grad = tf.stack([x[0] for x in grads], 0) | ||
214 | + grad = tf.reduce_sum(grad, 0) | ||
215 | + final_grads.append(( | ||
216 | + grad, | ||
217 | + filtered_grads[0][i][1], | ||
218 | + )) | ||
219 | + | ||
220 | + return final_grads |
web/backend/yt8m/video_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains model definitions.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import models | ||
18 | +import tensorflow as tf | ||
19 | +import utils | ||
20 | + | ||
21 | +from tensorflow import flags | ||
22 | +import tensorflow.contrib.slim as slim | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer( | ||
26 | + "moe_num_mixtures", 2, | ||
27 | + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") | ||
28 | + | ||
29 | + | ||
30 | +class LogisticModel(models.BaseModel): | ||
31 | + """Logistic model with L2 regularization.""" | ||
32 | + | ||
33 | + def create_model(self, | ||
34 | + model_input, | ||
35 | + vocab_size, | ||
36 | + l2_penalty=1e-8, | ||
37 | + **unused_params): | ||
38 | + """Creates a logistic model. | ||
39 | + | ||
40 | + Args: | ||
41 | + model_input: 'batch' x 'num_features' matrix of input features. | ||
42 | + vocab_size: The number of classes in the dataset. | ||
43 | + | ||
44 | + Returns: | ||
45 | + A dictionary with a tensor containing the probability predictions of the | ||
46 | + model in the 'predictions' key. The dimensions of the tensor are | ||
47 | + batch_size x num_classes. | ||
48 | + """ | ||
49 | + output = slim.fully_connected( | ||
50 | + model_input, | ||
51 | + vocab_size, | ||
52 | + activation_fn=tf.nn.sigmoid, | ||
53 | + weights_regularizer=slim.l2_regularizer(l2_penalty)) | ||
54 | + return {"predictions": output} | ||
55 | + | ||
56 | + | ||
57 | +class MoeModel(models.BaseModel): | ||
58 | + """A softmax over a mixture of logistic models (with L2 regularization).""" | ||
59 | + | ||
60 | + def create_model(self, | ||
61 | + model_input, | ||
62 | + vocab_size, | ||
63 | + num_mixtures=None, | ||
64 | + l2_penalty=1e-8, | ||
65 | + **unused_params): | ||
66 | + """Creates a Mixture of (Logistic) Experts model. | ||
67 | + | ||
68 | + The model consists of a per-class softmax distribution over a | ||
69 | + configurable number of logistic classifiers. One of the classifiers in the | ||
70 | + mixture is not trained, and always predicts 0. | ||
71 | + | ||
72 | + Args: | ||
73 | + model_input: 'batch_size' x 'num_features' matrix of input features. | ||
74 | + vocab_size: The number of classes in the dataset. | ||
75 | + num_mixtures: The number of mixtures (excluding a dummy 'expert' that | ||
76 | + always predicts the non-existence of an entity). | ||
77 | + l2_penalty: How much to penalize the squared magnitudes of parameter | ||
78 | + values. | ||
79 | + | ||
80 | + Returns: | ||
81 | + A dictionary with a tensor containing the probability predictions of the | ||
82 | + model in the 'predictions' key. The dimensions of the tensor are | ||
83 | + batch_size x num_classes. | ||
84 | + """ | ||
85 | + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures | ||
86 | + | ||
87 | + gate_activations = slim.fully_connected( | ||
88 | + model_input, | ||
89 | + vocab_size * (num_mixtures + 1), | ||
90 | + activation_fn=None, | ||
91 | + biases_initializer=None, | ||
92 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
93 | + scope="gates") | ||
94 | + expert_activations = slim.fully_connected( | ||
95 | + model_input, | ||
96 | + vocab_size * num_mixtures, | ||
97 | + activation_fn=None, | ||
98 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
99 | + scope="experts") | ||
100 | + | ||
101 | + gating_distribution = tf.nn.softmax( | ||
102 | + tf.reshape( | ||
103 | + gate_activations, | ||
104 | + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) | ||
105 | + expert_distribution = tf.nn.sigmoid( | ||
106 | + tf.reshape(expert_activations, | ||
107 | + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures | ||
108 | + | ||
109 | + final_probabilities_by_class_and_batch = tf.reduce_sum( | ||
110 | + gating_distribution[:, :num_mixtures] * expert_distribution, 1) | ||
111 | + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, | ||
112 | + [-1, vocab_size]) | ||
113 | + return {"predictions": final_probabilities} |
web/backend/yt8m/vocabulary.csv
0 → 100644
This diff could not be displayed because it is too large.
web/frontend/.browserslistrc
0 → 100644
web/frontend/.editorconfig
0 → 100644
web/frontend/.eslintrc.js
0 → 100644
1 | +module.exports = { | ||
2 | + root: true, | ||
3 | + env: { | ||
4 | + browser: true, | ||
5 | + es6: true, | ||
6 | + node: true, | ||
7 | + }, | ||
8 | + extends: [ | ||
9 | + 'plugin:vue/essential', | ||
10 | + '@vue/airbnb', | ||
11 | + ], | ||
12 | + parserOptions: { | ||
13 | + parser: 'babel-eslint', | ||
14 | + }, | ||
15 | + rules: { | ||
16 | + 'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off', | ||
17 | + 'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off', | ||
18 | + 'linebreak-style': ['error', 'unix'], | ||
19 | + 'max-len': ['error', { code: 200 }], | ||
20 | + }, | ||
21 | +}; |
web/frontend/.gitignore
0 → 100644
web/frontend/README.md
0 → 100644
1 | +# front | ||
2 | + | ||
3 | +## Project setup | ||
4 | +``` | ||
5 | +yarn install | ||
6 | +``` | ||
7 | + | ||
8 | +### Compiles and hot-reloads for development | ||
9 | +``` | ||
10 | +yarn serve | ||
11 | +``` | ||
12 | + | ||
13 | +### Compiles and minifies for production | ||
14 | +``` | ||
15 | +yarn build | ||
16 | +``` | ||
17 | + | ||
18 | +### Lints and fixes files | ||
19 | +``` | ||
20 | +yarn lint | ||
21 | +``` | ||
22 | + | ||
23 | +### Customize configuration | ||
24 | +See [Configuration Reference](https://cli.vuejs.org/config/). |
web/frontend/babel.config.js
0 → 100644
web/frontend/package-lock.json
0 → 100644
This diff could not be displayed because it is too large.
web/frontend/package.json
0 → 100644
1 | +{ | ||
2 | + "name": "front", | ||
3 | + "version": "0.1.0", | ||
4 | + "private": true, | ||
5 | + "scripts": { | ||
6 | + "serve": "vue-cli-service serve", | ||
7 | + "build": "vue-cli-service build", | ||
8 | + "lint": "vue-cli-service lint" | ||
9 | + }, | ||
10 | + "dependencies": { | ||
11 | + "@mdi/font": "^3.6.95", | ||
12 | + "axios": "^0.19.2", | ||
13 | + "core-js": "^3.6.5", | ||
14 | + "moment": "^2.24.0", | ||
15 | + "roboto-fontface": "*", | ||
16 | + "vue": "^2.6.11", | ||
17 | + "vue-router": "^3.1.6", | ||
18 | + "vuetify": "^2.2.11", | ||
19 | + "vuex": "^3.1.3" | ||
20 | + }, | ||
21 | + "devDependencies": { | ||
22 | + "@vue/cli-plugin-babel": "~4.3.0", | ||
23 | + "@vue/cli-plugin-eslint": "~4.3.0", | ||
24 | + "@vue/cli-plugin-router": "~4.3.0", | ||
25 | + "@vue/cli-plugin-vuex": "~4.3.0", | ||
26 | + "@vue/cli-service": "~4.3.0", | ||
27 | + "@vue/eslint-config-airbnb": "^5.0.2", | ||
28 | + "babel-eslint": "^10.1.0", | ||
29 | + "eslint": "^6.7.2", | ||
30 | + "eslint-plugin-import": "^2.20.2", | ||
31 | + "eslint-plugin-vue": "^6.2.2", | ||
32 | + "node-sass": "^4.12.0", | ||
33 | + "sass": "^1.19.0", | ||
34 | + "sass-loader": "^8.0.2", | ||
35 | + "vue-cli-plugin-vuetify": "~2.0.5", | ||
36 | + "vue-template-compiler": "^2.6.11", | ||
37 | + "vuetify-loader": "^1.3.0" | ||
38 | + } | ||
39 | +} |
web/frontend/public/favicon.ico
0 → 100644
No preview for this file type
web/frontend/public/index.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html lang="en"> | ||
3 | + <head> | ||
4 | + <meta charset="utf-8"> | ||
5 | + <meta http-equiv="X-UA-Compatible" content="IE=edge"> | ||
6 | + <meta name="viewport" content="width=device-width,initial-scale=1.0"> | ||
7 | + <link rel="icon" href="<%= BASE_URL %>favicon.ico"> | ||
8 | + <title>Profit-Hunter</title> | ||
9 | + </head> | ||
10 | + <body> | ||
11 | + <noscript> | ||
12 | + <strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong> | ||
13 | + </noscript> | ||
14 | + <div id="app"></div> | ||
15 | + <!-- built files will be auto injected --> | ||
16 | + </body> | ||
17 | +</html> |
web/frontend/src/App.vue
0 → 100644
1 | +<template> | ||
2 | + <v-app> | ||
3 | + <v-app-bar app color="#ffffff" elevation="1" hide-on-scroll> | ||
4 | + <v-icon size="35" class="mr-1" color="grey700">mdi-youtube</v-icon> | ||
5 | + <div style="color: #343a40; font-size: 20px; font-weight: 500;">Youtube Auto Tagger</div> | ||
6 | + <v-spacer></v-spacer> | ||
7 | + <v-tooltip bottom> | ||
8 | + <template v-slot:activator="{ on }"> | ||
9 | + <v-btn icon v-on="on" color="grey700" @click="clickInfo=true"> | ||
10 | + <v-icon size="30">mdi-information</v-icon> | ||
11 | + </v-btn> | ||
12 | + </template> | ||
13 | + <span>Service Info</span> | ||
14 | + </v-tooltip> | ||
15 | + </v-app-bar> | ||
16 | + <v-dialog v-model="clickInfo" max-width="600" class="pa-2"> | ||
17 | + <v-card elevation="0" outlined class="pa-4"> | ||
18 | + <v-row justify="center" class="mx-0 mb-8 mt-2"> | ||
19 | + <v-icon color="primary">mdi-power-on</v-icon> | ||
20 | + <div | ||
21 | + style="text-align: center; font-size: 22px; font-weight: 400; color: #343a40;" | ||
22 | + >Information of This Service</div> | ||
23 | + <v-icon color="primary">mdi-power-on</v-icon> | ||
24 | + </v-row> | ||
25 | + <v-btn color="primary" class="mt-2" block>Description of service</v-btn> | ||
26 | + <v-btn color="primary" class="my-2" block>Term of this service</v-btn> | ||
27 | + <v-btn color="primary" block>Used Opensource</v-btn> | ||
28 | + </v-card> | ||
29 | + </v-dialog> | ||
30 | + <v-content> | ||
31 | + <router-view /> | ||
32 | + </v-content> | ||
33 | + <v-footer> | ||
34 | + <v-row justify="center"> | ||
35 | + <v-avatar size="25" tile style="border-radius: 4px"> | ||
36 | + <v-img src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/99f7d5c73e506d2c5c0072a21f362181/logo.69342704.png"></v-img> | ||
37 | + </v-avatar> | ||
38 | + <a href="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1"> | ||
39 | + <div | ||
40 | + style="margin-left: 4px; font-size: 16px; color: #5a5a5a; font-weight: 400" | ||
41 | + >Profit-Hunter</div> | ||
42 | + </a> | ||
43 | + </v-row> | ||
44 | + </v-footer> | ||
45 | + </v-app> | ||
46 | +</template> | ||
47 | +<script> | ||
48 | +export default { | ||
49 | + name: 'App', | ||
50 | + data() { | ||
51 | + return { | ||
52 | + clickInfo: false, | ||
53 | + }; | ||
54 | + }, | ||
55 | +}; | ||
56 | +</script> | ||
57 | +<style lang="scss"> | ||
58 | +a:hover { | ||
59 | + text-decoration: none; | ||
60 | +} | ||
61 | +a:link { | ||
62 | + text-decoration: none; | ||
63 | +} | ||
64 | +a:visited { | ||
65 | + text-decoration: none; | ||
66 | +} | ||
67 | +</style> |
web/frontend/src/components/description.vue
0 → 100644
web/frontend/src/main.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import App from './App.vue'; | ||
3 | +import router from './router'; | ||
4 | +import store from './store'; | ||
5 | +import vuetify from './plugins/vuetify'; | ||
6 | +import 'roboto-fontface/css/roboto/roboto-fontface.css'; | ||
7 | +import '@mdi/font/css/materialdesignicons.css'; | ||
8 | + | ||
9 | +Vue.config.productionTip = false; | ||
10 | + | ||
11 | +new Vue({ | ||
12 | + router, | ||
13 | + store, | ||
14 | + vuetify, | ||
15 | + render: (h) => h(App), | ||
16 | +}).$mount('#app'); |
web/frontend/src/plugins/vuetify.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import Vuetify from 'vuetify/lib'; | ||
3 | + | ||
4 | +Vue.use(Vuetify); | ||
5 | + | ||
6 | +export default new Vuetify({ | ||
7 | + theme: { | ||
8 | + themes: { | ||
9 | + light: { | ||
10 | + primary: '#343a40', | ||
11 | + secondary: '#506980', | ||
12 | + accent: '#505B80', | ||
13 | + error: '#FF5252', | ||
14 | + info: '#2196F3', | ||
15 | + blue: '#173f5f', | ||
16 | + lightblue: '#72b1e4', | ||
17 | + success: '#2779bd', | ||
18 | + warning: '#12283a', | ||
19 | + grey300: '#eceeef', | ||
20 | + grey500: '#aaaaaa', | ||
21 | + grey700: '#5a5a5a', | ||
22 | + grey900: '#212529', | ||
23 | + }, | ||
24 | + dark: { | ||
25 | + primary: '#343a40', | ||
26 | + secondary: '#506980', | ||
27 | + accent: '#505B80', | ||
28 | + error: '#FF5252', | ||
29 | + info: '#2196F3', | ||
30 | + blue: '#173f5f', | ||
31 | + lightblue: '#72b1e4', | ||
32 | + success: '#2779bd', | ||
33 | + warning: '#12283a', | ||
34 | + grey300: '#eceeef', | ||
35 | + grey500: '#aaaaaa', | ||
36 | + grey700: '#5a5a5a', | ||
37 | + grey900: '#212529', | ||
38 | + }, | ||
39 | + }, | ||
40 | + }, | ||
41 | +}); |
web/frontend/src/router/index.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import VueRouter from 'vue-router'; | ||
3 | +import axios from 'axios'; | ||
4 | +import Home from '../views/Home.vue'; | ||
5 | + | ||
6 | +Vue.prototype.$axios = axios; | ||
7 | +const apiRootPath = process.env.NODE_ENV !== 'production' | ||
8 | + ? 'http://localhost:8000/api/' | ||
9 | + : '/api/'; | ||
10 | +Vue.prototype.$apiRootPath = apiRootPath; | ||
11 | +axios.defaults.baseURL = apiRootPath; | ||
12 | + | ||
13 | +Vue.use(VueRouter); | ||
14 | + | ||
15 | +const routes = [ | ||
16 | + { | ||
17 | + path: '/', | ||
18 | + name: 'Home', | ||
19 | + component: Home, | ||
20 | + }, | ||
21 | +]; | ||
22 | + | ||
23 | +const router = new VueRouter({ | ||
24 | + mode: 'history', | ||
25 | + base: process.env.BASE_URL, | ||
26 | + routes, | ||
27 | +}); | ||
28 | + | ||
29 | +export default router; |
web/frontend/src/store/index.js
0 → 100644
web/frontend/src/views/Home.vue
0 → 100644
1 | +<template> | ||
2 | + <v-sheet> | ||
3 | + <v-overlay v-model="loadingProcess"> | ||
4 | + <v-progress-circular :size="120" width="10" color="primary" indeterminate></v-progress-circular> | ||
5 | + <div | ||
6 | + style="color: #ffffff; font-size: 22px; margin-top: 20px; margin-left: -40px;" | ||
7 | + >Analyzing Your Video...</div> | ||
8 | + </v-overlay> | ||
9 | + <v-layout justify-center> | ||
10 | + <v-flex xs12 sm8 md6 lg4> | ||
11 | + <v-row justify="center" class="mx-0 mt-12"> | ||
12 | + <div style="font-size: 34px; font-weight: 500; color: #343a40;">WELCOME</div> | ||
13 | + </v-row> | ||
14 | + <v-card elevation="0"> | ||
15 | + <!-- 데스크톱 화면 설명 요약 --> | ||
16 | + <v-row justify="center" class="mx-0 mt-12" v-if="$vuetify.breakpoint.mdAndUp"> | ||
17 | + <v-flex md7 class="mt-1"> | ||
18 | + <div | ||
19 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
20 | + >This is Video auto tagging Service</div> | ||
21 | + <div | ||
22 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
23 | + >Designed for Youtube Videos</div> | ||
24 | + <div | ||
25 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
26 | + >It takes few minutes to analyze your Video!</div> | ||
27 | + </v-flex> | ||
28 | + <v-flex md5> | ||
29 | + <v-card width="240" elevation="0" class="ml-5"> | ||
30 | + <v-img | ||
31 | + width="240" | ||
32 | + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png" | ||
33 | + ></v-img> | ||
34 | + </v-card> | ||
35 | + </v-flex> | ||
36 | + </v-row> | ||
37 | + | ||
38 | + <!-- 모바일 화면 설명 요약 --> | ||
39 | + <v-card elevation="0" class="mt-8" v-else> | ||
40 | + <div | ||
41 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
42 | + >This is Video auto tagging Service</div> | ||
43 | + <div | ||
44 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
45 | + >Designed for Youtube Videos</div> | ||
46 | + <div | ||
47 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
48 | + >It takes few minutes to analyze your Video!</div> | ||
49 | + <v-img | ||
50 | + style="margin: auto; margin-top: 20px" | ||
51 | + width="180" | ||
52 | + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png" | ||
53 | + ></v-img> | ||
54 | + </v-card> | ||
55 | + | ||
56 | + <!-- Set Threshold --> | ||
57 | + <div | ||
58 | + class="mt-10" | ||
59 | + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;" | ||
60 | + >How To start this service</div> | ||
61 | + <div | ||
62 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center; margin-bottom: 15px" | ||
63 | + > | ||
64 | + <div>Set up Threshold of</div> | ||
65 | + <div>Recommended Youtube link</div> | ||
66 | + </div> | ||
67 | + <v-row style="max-width: 300px; margin: auto"> | ||
68 | + <v-slider v-model="threshold" :thumb-size="20" thumb-label="always" :min="2" :max="15"></v-slider> | ||
69 | + </v-row> | ||
70 | + | ||
71 | + <!-- Upload Video --> | ||
72 | + <div | ||
73 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
74 | + >Then, Just Upload your Video</div> | ||
75 | + <v-row justify="center" class="mx-0 mt-2"> | ||
76 | + <v-card | ||
77 | + max-width="500" | ||
78 | + outlined | ||
79 | + height="120" | ||
80 | + class="pa-9" | ||
81 | + @dragover.prevent | ||
82 | + @dragenter.prevent | ||
83 | + @drop.prevent="onDrop" | ||
84 | + > | ||
85 | + <v-btn | ||
86 | + style="text-transform: none" | ||
87 | + @click="clickUploadButton" | ||
88 | + text | ||
89 | + large | ||
90 | + color="primary" | ||
91 | + >CLICK or DRAG & DROP</v-btn> | ||
92 | + <input ref="fileInput" style="display: none" type="file" @change="onFileChange" /> | ||
93 | + </v-card> | ||
94 | + </v-row> | ||
95 | + | ||
96 | + <!-- 결과 화면 --> | ||
97 | + <div | ||
98 | + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;" | ||
99 | + class="mt-10" | ||
100 | + >The Results of Analyzed Video</div> | ||
101 | + <v-card outlined class="pa-2 mx-5 mt-6" elevation="0" min-height="67"> | ||
102 | + <div | ||
103 | + style="margin-left: 5px; margin-top: -18px; background-color: #fff; width: 110px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500" | ||
104 | + >Generated Tags</div> | ||
105 | + <v-chip-group column> | ||
106 | + <v-chip color="secondary" v-for="(tag, index) in generatedTag" :key="index">{{ tag[0] }} : {{tag[1]}}</v-chip> | ||
107 | + </v-chip-group> | ||
108 | + </v-card> | ||
109 | + <v-card outlined class="pa-3 mx-5 mt-8" elevation="0" min-height="67"> | ||
110 | + <div | ||
111 | + style="margin-left: 5px; margin-top: -22px; margin-bottom: 5px; background-color: #fff; width: 140px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500" | ||
112 | + >Related Youtube Link</div> | ||
113 | + <v-flex style="margin-bottom: 2px" v-for="(url) in YoutubeUrl" :key="url.id"> | ||
114 | + <div> | ||
115 | + <a style="color: #343a40; font-size: 16px; font-weight: 500" :href="url">{{url}}</a> | ||
116 | + </div> | ||
117 | + </v-flex> | ||
118 | + </v-card> | ||
119 | + <div | ||
120 | + class="mt-3" | ||
121 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
122 | + >If the Video is analyzed successfully,</div> | ||
123 | + <div | ||
124 | + class="mb-5" | ||
125 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
126 | + >Result Show up in each of Boxes!</div> | ||
127 | + </v-card> | ||
128 | + | ||
129 | + </v-flex> | ||
130 | + </v-layout> | ||
131 | + </v-sheet> | ||
132 | +</template> | ||
133 | +<script> | ||
134 | +export default { | ||
135 | + name: 'Home', | ||
136 | + data() { | ||
137 | + return { | ||
138 | + videoFile: '', | ||
139 | + YoutubeUrl: [], | ||
140 | + generatedTag: [], | ||
141 | + threshold: 5, | ||
142 | + successDialog: false, | ||
143 | + errorDialog: false, | ||
144 | + loadingProcess: false, | ||
145 | + }; | ||
146 | + }, | ||
147 | + created() { | ||
148 | + // this.YoutubeUrl = []; | ||
149 | + // this.generatedTag = []; | ||
150 | + }, | ||
151 | + methods: { | ||
152 | + loadVideoInfo() {}, | ||
153 | + uploadVideo(files) { | ||
154 | + this.loadingProcess = true; | ||
155 | + const formData = new FormData(); | ||
156 | + formData.append('file', files[0]); | ||
157 | + formData.append('threshold', this.threshold); | ||
158 | + console.log(files[0]); | ||
159 | + this.$axios | ||
160 | + .post('/upload', formData, { | ||
161 | + headers: { 'Content-Type': 'multipart/form-data' }, | ||
162 | + }) | ||
163 | + .then((r) => { | ||
164 | + const tag = r.data.tag_result; | ||
165 | + const url = r.data.video_result; | ||
166 | + url.forEach((element) => { | ||
167 | + this.YoutubeUrl.push(element.video_url); | ||
168 | + }); | ||
169 | + this.generatedTag = tag; | ||
170 | + this.loadingProcess = false; | ||
171 | + this.successDialog = true; | ||
172 | + console.log(tag, url); | ||
173 | + }) | ||
174 | + .catch((e) => { | ||
175 | + this.errorDialog = true; | ||
176 | + console.log(e.message); | ||
177 | + }); | ||
178 | + }, | ||
179 | + onDrop(event) { | ||
180 | + this.uploadVideo(event.dataTransfer.files); | ||
181 | + }, | ||
182 | + clickUploadButton() { | ||
183 | + this.$refs.fileInput.click(); | ||
184 | + }, | ||
185 | + onFileChange(event) { | ||
186 | + this.uploadVideo(event.target.files); | ||
187 | + }, | ||
188 | + }, | ||
189 | +}; | ||
190 | +</script> |
web/frontend/vue.config.js
0 → 100644
web/frontend/yarn.lock
0 → 100644
This diff could not be displayed because it is too large.
-
Please register or login to post a comment