Showing
94 changed files
with
7120 additions
and
0 deletions
.gitignore
0 → 100644
1 | +# Created by https://www.gitignore.io/api/django | ||
2 | +# Edit at https://www.gitignore.io/?templates=django | ||
3 | + | ||
4 | +### Django ### | ||
5 | +*.log | ||
6 | +*.pot | ||
7 | +*.pyc | ||
8 | +__pycache__/ | ||
9 | +local_settings.py | ||
10 | +db.sqlite3 | ||
11 | +db.sqlite3-journal | ||
12 | +media | ||
13 | +static/ | ||
14 | +/static | ||
15 | +/backend/env | ||
16 | +env/ | ||
17 | +/env | ||
18 | + | ||
19 | +# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/ | ||
20 | +# in your Git repository. Update and uncomment the following line accordingly. | ||
21 | +# <django-project-name>/staticfiles/ | ||
22 | + | ||
23 | +### Django.Python Stack ### | ||
24 | +# Byte-compiled / optimized / DLL files | ||
25 | +*.py[cod] | ||
26 | +*$py.class | ||
27 | + | ||
28 | +# C extensions | ||
29 | +*.so | ||
30 | + | ||
31 | +# Distribution / packaging | ||
32 | +.Python | ||
33 | +build/ | ||
34 | +develop-eggs/ | ||
35 | +dist/ | ||
36 | +downloads/ | ||
37 | +eggs/ | ||
38 | +.eggs/ | ||
39 | +lib/ | ||
40 | +lib64/ | ||
41 | +parts/ | ||
42 | +sdist/ | ||
43 | +var/ | ||
44 | +wheels/ | ||
45 | +pip-wheel-metadata/ | ||
46 | +share/python-wheels/ | ||
47 | +*.egg-info/ | ||
48 | +.installed.cfg | ||
49 | +*.egg | ||
50 | +MANIFEST | ||
51 | + | ||
52 | +# PyInstaller | ||
53 | +# Usually these files are written by a python script from a template | ||
54 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
55 | +*.manifest | ||
56 | +*.spec | ||
57 | + | ||
58 | +# Installer logs | ||
59 | +pip-log.txt | ||
60 | +pip-delete-this-directory.txt | ||
61 | + | ||
62 | +# Unit test / coverage reports | ||
63 | +htmlcov/ | ||
64 | +.tox/ | ||
65 | +.nox/ | ||
66 | +.coverage | ||
67 | +.coverage.* | ||
68 | +.cache | ||
69 | +nosetests.xml | ||
70 | +coverage.xml | ||
71 | +*.cover | ||
72 | +.hypothesis/ | ||
73 | +.pytest_cache/ | ||
74 | + | ||
75 | +# Translations | ||
76 | +*.mo | ||
77 | + | ||
78 | +# Scrapy stuff: | ||
79 | +.scrapy | ||
80 | + | ||
81 | +# Sphinx documentation | ||
82 | +docs/_build/ | ||
83 | + | ||
84 | +# PyBuilder | ||
85 | +target/ | ||
86 | + | ||
87 | +# pyenv | ||
88 | +.python-version | ||
89 | + | ||
90 | +# pipenv | ||
91 | +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
92 | +# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
93 | +# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
94 | +# install all needed dependencies. | ||
95 | +#Pipfile.lock | ||
96 | + | ||
97 | +# celery beat schedule file | ||
98 | +celerybeat-schedule | ||
99 | + | ||
100 | +# SageMath parsed files | ||
101 | +*.sage.py | ||
102 | + | ||
103 | +# Spyder project settings | ||
104 | +.spyderproject | ||
105 | +.spyproject | ||
106 | + | ||
107 | +# Rope project settings | ||
108 | +.ropeproject | ||
109 | + | ||
110 | +# Mr Developer | ||
111 | +.mr.developer.cfg | ||
112 | +.project | ||
113 | +.pydevproject | ||
114 | + | ||
115 | +# mkdocs documentation | ||
116 | +/site | ||
117 | + | ||
118 | +# mypy | ||
119 | +.mypy_cache/ | ||
120 | +.dmypy.json | ||
121 | +dmypy.json | ||
122 | + | ||
123 | +# Pyre type checker | ||
124 | +.pyre/ | ||
125 | + | ||
126 | +# End of https://www.gitignore.io/api/django | ||
127 | + | ||
128 | +.DS_Store | ||
129 | +node_modules | ||
130 | +/dist | ||
131 | + | ||
132 | +# local env files | ||
133 | +.env.local | ||
134 | +.env.*.local | ||
135 | + | ||
136 | +# Log files | ||
137 | +npm-debug.log* | ||
138 | +yarn-debug.log* | ||
139 | +yarn-error.log* | ||
140 | + | ||
141 | +# Editor directories and files | ||
142 | +.idea | ||
143 | +.vscode | ||
144 | +*.suo | ||
145 | +*.ntvs* | ||
146 | +*.njsproj | ||
147 | +*.sln | ||
148 | +*.sw? | ||
149 | + |
No preview for this file type
img/profit_hunter.png
0 → 100644

40.3 KB
web/backend/api/__init__.py
0 → 100644
1 | +import os | ||
2 | +from django.conf.global_settings import FILE_UPLOAD_MAX_MEMORY_SIZE | ||
3 | +import datetime | ||
4 | +import uuid | ||
5 | + | ||
6 | +# 각 media 파일에 대한 URL Prefix | ||
7 | +MEDIA_URL = '/media/' # 항상 / 로 끝나도록 설정 | ||
8 | +# MEDIA_URL = 'http://static.myservice.com/media/' 다른 서버로 media 파일 복사시 | ||
9 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
10 | + | ||
11 | +MEDIA_ROOT = os.path.join(BASE_DIR, 'media') | ||
12 | + | ||
13 | +# 파일 업로드 사이즈 100M ( 100 * 1024 * 1024 ) | ||
14 | +#FILE_UPLOAD_MAX_MEMORY_SIZE = 104857600 | ||
15 | + | ||
16 | +# 실제 파일을 저장할 경로 및 파일 명 생성 | ||
17 | +# 폴더는 일별로 생성됨 | ||
18 | +def file_upload_path( filename): | ||
19 | + ext = filename.split('.')[-1] | ||
20 | + d = datetime.datetime.now() | ||
21 | + filepath = d.strftime('%Y-%m-%d') | ||
22 | + suffix = d.strftime("%Y%m%d%H%M%S") | ||
23 | + filename = "%s.%s"%(suffix, ext) | ||
24 | + return os.path.join( MEDIA_ROOT , filepath, filename) | ||
25 | + | ||
26 | +# DB 필드에서 호출 | ||
27 | +def file_upload_path_for_db( intance, filename): | ||
28 | + return file_upload_path(filename) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/api/admin.py
0 → 100644
web/backend/api/apps.py
0 → 100644
web/backend/api/migrations/0001_initial.py
0 → 100644
1 | +# Generated by Django 3.0.5 on 2020-05-02 12:19 | ||
2 | + | ||
3 | +import api | ||
4 | +from django.db import migrations, models | ||
5 | + | ||
6 | + | ||
7 | +class Migration(migrations.Migration): | ||
8 | + | ||
9 | + initial = True | ||
10 | + | ||
11 | + dependencies = [ | ||
12 | + ] | ||
13 | + | ||
14 | + operations = [ | ||
15 | + migrations.CreateModel( | ||
16 | + name='Video', | ||
17 | + fields=[ | ||
18 | + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
19 | + ('videourl', models.CharField(blank=True, max_length=1000)), | ||
20 | + ('title', models.CharField(max_length=200)), | ||
21 | + ('tags', models.CharField(max_length=500)), | ||
22 | + ], | ||
23 | + ), | ||
24 | + migrations.CreateModel( | ||
25 | + name='VideoFile', | ||
26 | + fields=[ | ||
27 | + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
28 | + ('file_save_name', models.FileField(upload_to=api.file_upload_path_for_db)), | ||
29 | + ('file_origin_name', models.CharField(max_length=100)), | ||
30 | + ('file_path', models.CharField(max_length=100)), | ||
31 | + ], | ||
32 | + ), | ||
33 | + ] |
1 | +# Generated by Django 3.0.5 on 2020-05-31 05:39 | ||
2 | + | ||
3 | +from django.db import migrations, models | ||
4 | + | ||
5 | + | ||
6 | +class Migration(migrations.Migration): | ||
7 | + | ||
8 | + dependencies = [ | ||
9 | + ('api', '0001_initial'), | ||
10 | + ] | ||
11 | + | ||
12 | + operations = [ | ||
13 | + migrations.AddField( | ||
14 | + model_name='video', | ||
15 | + name='threshold', | ||
16 | + field=models.CharField(default=20, max_length=20), | ||
17 | + preserve_default=False, | ||
18 | + ), | ||
19 | + ] |
web/backend/api/migrations/__init__.py
0 → 100644
File mode changed
web/backend/api/models.py
0 → 100644
1 | +from django.db import models | ||
2 | +from api import file_upload_path_for_db | ||
3 | +# Create your models here. | ||
4 | + | ||
5 | + | ||
6 | +class Video(models.Model): | ||
7 | + videourl = models.CharField(max_length=1000, blank=True) | ||
8 | + title = models.CharField(max_length=200) | ||
9 | + threshold = models.CharField(max_length=20) | ||
10 | + tags = models.CharField(max_length=500) | ||
11 | + | ||
12 | + | ||
13 | +class VideoFile(models.Model): | ||
14 | + file_save_name = models.FileField(upload_to=file_upload_path_for_db, blank=False, null=False) | ||
15 | + # 파일의 원래 이름 | ||
16 | + file_origin_name = models.CharField(max_length=100) | ||
17 | + # 파일 저장 경로 | ||
18 | + file_path = models.CharField(max_length=100) | ||
19 | + | ||
20 | + def __str__(self): | ||
21 | + return self.file.name |
web/backend/api/serializers.py
0 → 100644
1 | +from rest_framework import serializers | ||
2 | +from api.models import Video, VideoFile | ||
3 | + | ||
4 | + | ||
5 | +class VideoSerializer(serializers.ModelSerializer): | ||
6 | + | ||
7 | + class Meta: | ||
8 | + model = Video | ||
9 | + fields = '__all__' | ||
10 | + | ||
11 | + | ||
12 | +class VideoFileSerializer(serializers.ModelSerializer): | ||
13 | + | ||
14 | + class Meta: | ||
15 | + model = VideoFile | ||
16 | + fields = '__all__' |
web/backend/api/tests.py
0 → 100644
web/backend/api/urls.py
0 → 100644
1 | +from django.urls import path, include | ||
2 | +from django.conf.urls import url | ||
3 | +from api.views import VideoFileUploadView, VideoFileList, FileListView | ||
4 | +from . import views | ||
5 | +from rest_framework.routers import DefaultRouter | ||
6 | +from django.views.generic import TemplateView | ||
7 | + | ||
8 | +router = DefaultRouter() | ||
9 | +router.register('db/videofile', views.VideoFileViewSet) | ||
10 | +router.register('db/video', views.VideoViewSet) | ||
11 | + | ||
12 | +urlpatterns = [ | ||
13 | + # FBV | ||
14 | + path('api/upload', VideoFileUploadView.as_view(), name="file-upload"), | ||
15 | + path('api/upload/<int:pk>/', VideoFileList.as_view(), name="file-list"), | ||
16 | + path('api/file', FileListView.as_view(), name="file"), | ||
17 | + url(r'^(?P<path>.*)$', TemplateView.as_view(template_name='index.html')), | ||
18 | + # path('api/upload', views.VideoFile_Upload), | ||
19 | + path('', include(router.urls)), | ||
20 | +] |
web/backend/api/views.py
0 → 100644
1 | +from rest_framework import status | ||
2 | +from rest_framework.views import APIView | ||
3 | +from rest_framework.response import Response | ||
4 | +from rest_framework.parsers import MultiPartParser, FormParser | ||
5 | +from rest_framework.response import Response | ||
6 | +from rest_framework import viewsets | ||
7 | +import os | ||
8 | +from django.http.request import QueryDict | ||
9 | +from django.http import Http404 | ||
10 | +from django.http import JsonResponse | ||
11 | +from django.shortcuts import get_object_or_404, render | ||
12 | +from api.models import Video, VideoFile | ||
13 | +from api.serializers import VideoSerializer, VideoFileSerializer | ||
14 | +from api import file_upload_path | ||
15 | + | ||
16 | +import subprocess | ||
17 | +import shlex | ||
18 | +import json | ||
19 | +# Create your views here. | ||
20 | +import sys | ||
21 | +sys.path.insert(0, "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria") | ||
22 | +import inference_pb | ||
23 | + | ||
24 | +def with_ffprobe(filename): | ||
25 | + | ||
26 | + result = subprocess.check_output( | ||
27 | + f'ffprobe -v quiet -show_streams -select_streams v:0 -of json "{filename}"', | ||
28 | + shell=True).decode() | ||
29 | + fields = json.loads(result)['streams'][0] | ||
30 | + duration = int(float(fields['duration'])) | ||
31 | + return duration | ||
32 | + | ||
33 | +def index(request): | ||
34 | + return render(request, template_name='index.html') | ||
35 | + | ||
36 | +class VideoViewSet(viewsets.ModelViewSet): | ||
37 | + queryset = Video.objects.all() | ||
38 | + serializer_class = VideoSerializer | ||
39 | + | ||
40 | +class VideoFileViewSet(viewsets.ModelViewSet): | ||
41 | + queryset = VideoFile.objects.all() | ||
42 | + serializer_class = VideoFileSerializer | ||
43 | + | ||
44 | + | ||
45 | +class VideoFileUploadView(APIView): | ||
46 | + parser_classes = (MultiPartParser, FormParser) | ||
47 | + | ||
48 | + def get(self, request, format=None): | ||
49 | + videoFiles = VideoFile.objects.all() | ||
50 | + serializer = VideoFileSerializer(videoFiles, many=True) | ||
51 | + return Response(serializer.data) | ||
52 | + | ||
53 | + def post(self, req, *args, **kwargs): | ||
54 | + # 동영상 길이 | ||
55 | + runTime = 0 | ||
56 | + # 요청된 데이터를 꺼냄( QueryDict) | ||
57 | + new_data = req.data.dict() | ||
58 | + | ||
59 | + # 요청된 파일 객체 | ||
60 | + file_name = req.data['file'] | ||
61 | + threshold = req.data['threshold'] | ||
62 | + | ||
63 | + # 저장될 파일의 풀path를 생성 | ||
64 | + new_file_full_name = file_upload_path(file_name.name) | ||
65 | + # 새롭게 생성된 파일의 경로 | ||
66 | + file_path = '-'.join(new_file_full_name.split('-')[0:-1]) | ||
67 | + | ||
68 | + new_data['file_path'] = file_path | ||
69 | + new_data['file_origin_name'] = req.data['file'].name | ||
70 | + new_data['file_save_name'] = req.data['file'] | ||
71 | + | ||
72 | + new_query_dict = QueryDict('', mutable=True) | ||
73 | + new_query_dict.update(new_data) | ||
74 | + file_serializer = VideoFileSerializer(data = new_query_dict) | ||
75 | + | ||
76 | + if file_serializer.is_valid(): | ||
77 | + file_serializer.save() | ||
78 | + # 동영상 길이 출력 | ||
79 | + runTime = with_ffprobe('/'+file_serializer.data['file_save_name']) | ||
80 | + print(runTime) | ||
81 | + print(threshold) | ||
82 | + process = subprocess.Popen(['./runMediaPipe.sh %s %s' %(file_serializer.data['file_save_name'],runTime,)], shell = True) | ||
83 | + process.wait() | ||
84 | + | ||
85 | + | ||
86 | + result = inference_pb.inference_pb('/tmp/mediapipe/features.pb', threshold) | ||
87 | + | ||
88 | + return Response(result, status=status.HTTP_201_CREATED) | ||
89 | + else: | ||
90 | + return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) | ||
91 | + | ||
92 | + | ||
93 | +class VideoFileList(APIView): | ||
94 | + | ||
95 | + def get_object(self, pk): | ||
96 | + try: | ||
97 | + return VideoFile.objects.get(pk=pk) | ||
98 | + except VideoFile.DoesNotExist: | ||
99 | + raise Http404 | ||
100 | + | ||
101 | + def get(self, request, pk, format=None): | ||
102 | + video = self.get_object(pk) | ||
103 | + serializer = VideoFileSerializer(video) | ||
104 | + return Response(serializer.data) | ||
105 | + | ||
106 | + def put(self, request, pk, format=None): | ||
107 | + video = self.get_object(pk) | ||
108 | + serializer = VideoFileSerializer(video, data=request.data) | ||
109 | + if serializer.is_valid(): | ||
110 | + serializer.save() | ||
111 | + return Response(serializer.data) | ||
112 | + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) | ||
113 | + | ||
114 | + def delete(self, request, pk, format=None): | ||
115 | + video = self.get_object(pk) | ||
116 | + video.delete() | ||
117 | + return Response(status=status.HTTP_204_NO_CONTENT) | ||
118 | + | ||
119 | + | ||
120 | +class FileListView(APIView): | ||
121 | + def get(self, request): | ||
122 | + data = { | ||
123 | + "search": '', | ||
124 | + "limit": 10, | ||
125 | + "skip": 0, | ||
126 | + "order": "time", | ||
127 | + "fileList": [ | ||
128 | + { | ||
129 | + "name": "1.png", | ||
130 | + "created": "2020-04-30", | ||
131 | + "size": 10234, | ||
132 | + "isFolder": False, | ||
133 | + "deletedDate": "", | ||
134 | + }, | ||
135 | + { | ||
136 | + "name": "2.png", | ||
137 | + "created": "2020-04-30", | ||
138 | + "size": 3145, | ||
139 | + "isFolder": False, | ||
140 | + "deletedDate": "", | ||
141 | + }, | ||
142 | + { | ||
143 | + "name": "3.png", | ||
144 | + "created": "2020-05-01", | ||
145 | + "size": 5653, | ||
146 | + "isFolder": False, | ||
147 | + "deletedDate": "", | ||
148 | + }, | ||
149 | + ] | ||
150 | + } | ||
151 | + return Response(data) | ||
152 | + def post(self, request, format=None): | ||
153 | + data = { | ||
154 | + "isSuccess": True, | ||
155 | + "File": { | ||
156 | + "name": "test.jpg", | ||
157 | + "created": "2020-05-02", | ||
158 | + "deletedDate": "", | ||
159 | + "size": 2312, | ||
160 | + "isFolder": False | ||
161 | + } | ||
162 | + } | ||
163 | + return Response(data) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/backend/__init__.py
0 → 100644
File mode changed
web/backend/backend/asgi.py
0 → 100644
1 | +""" | ||
2 | +ASGI config for backend project. | ||
3 | + | ||
4 | +It exposes the ASGI callable as a module-level variable named ``application``. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | ||
8 | +""" | ||
9 | + | ||
10 | +import os | ||
11 | + | ||
12 | +from django.core.asgi import get_asgi_application | ||
13 | + | ||
14 | +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
15 | + | ||
16 | +application = get_asgi_application() |
web/backend/backend/settings.py
0 → 100644
1 | +""" | ||
2 | +Django settings for backend project. | ||
3 | + | ||
4 | +Generated by 'django-admin startproject' using Django 3.0.5. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/topics/settings/ | ||
8 | + | ||
9 | +For the full list of settings and their values, see | ||
10 | +https://docs.djangoproject.com/en/3.0/ref/settings/ | ||
11 | +""" | ||
12 | + | ||
13 | +import os | ||
14 | + | ||
15 | +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) | ||
16 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
17 | + | ||
18 | + | ||
19 | +# Quick-start development settings - unsuitable for production | ||
20 | +# See https://docs.djangoproject.com/en/3.0/howto/deployment/checklist/ | ||
21 | + | ||
22 | +# SECURITY WARNING: keep the secret key used in production secret! | ||
23 | +SECRET_KEY = '7e^4!u6019jww&=-!mu%r$hz6jy#=i+i9@9m_44+ga^#%7#e0l' | ||
24 | + | ||
25 | +# SECURITY WARNING: don't run with debug turned on in production! | ||
26 | +DEBUG = True | ||
27 | + | ||
28 | +ALLOWED_HOSTS = ['*'] | ||
29 | + | ||
30 | + | ||
31 | +# Static File DIR Settings | ||
32 | +STATIC_URL = '/static/' | ||
33 | +STATIC_ROOT = os.path.join(BASE_DIR, 'static') | ||
34 | + | ||
35 | + | ||
36 | +FRONTEND_DIR = os.path.join(os.path.abspath('../'), 'frontend') | ||
37 | +STATICFILES_DIRS = [ | ||
38 | + os.path.join(FRONTEND_DIR, 'dist/'), | ||
39 | +] | ||
40 | +# Application definition | ||
41 | + | ||
42 | +REST_FRAMEWORK = { | ||
43 | + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', | ||
44 | + 'PAGE_SIZE': 10 | ||
45 | +} | ||
46 | + | ||
47 | +INSTALLED_APPS = [ | ||
48 | + 'django.contrib.admin', | ||
49 | + 'django.contrib.auth', | ||
50 | + 'django.contrib.contenttypes', | ||
51 | + 'django.contrib.sessions', | ||
52 | + 'django.contrib.messages', | ||
53 | + 'django.contrib.staticfiles', | ||
54 | + 'rest_framework', | ||
55 | + 'corsheaders', | ||
56 | + 'api' | ||
57 | +] | ||
58 | + | ||
59 | +MIDDLEWARE = [ | ||
60 | + 'django.middleware.security.SecurityMiddleware', | ||
61 | + 'django.contrib.sessions.middleware.SessionMiddleware', | ||
62 | + 'django.middleware.common.CommonMiddleware', | ||
63 | + # 'django.middleware.csrf.CsrfViewMiddleware', | ||
64 | + 'django.contrib.auth.middleware.AuthenticationMiddleware', | ||
65 | + 'django.contrib.messages.middleware.MessageMiddleware', | ||
66 | + 'django.middleware.clickjacking.XFrameOptionsMiddleware', | ||
67 | + 'corsheaders.middleware.CorsMiddleware', | ||
68 | +] | ||
69 | + | ||
70 | +ROOT_URLCONF = 'backend.urls' | ||
71 | + | ||
72 | +TEMPLATES = [ | ||
73 | + { | ||
74 | + 'BACKEND': 'django.template.backends.django.DjangoTemplates', | ||
75 | + 'DIRS': [ | ||
76 | + os.path.join(BASE_DIR, 'templates'), | ||
77 | + ], | ||
78 | + 'APP_DIRS': True, | ||
79 | + 'OPTIONS': { | ||
80 | + 'context_processors': [ | ||
81 | + 'django.template.context_processors.debug', | ||
82 | + 'django.template.context_processors.request', | ||
83 | + 'django.contrib.auth.context_processors.auth', | ||
84 | + 'django.contrib.messages.context_processors.messages', | ||
85 | + ], | ||
86 | + }, | ||
87 | + }, | ||
88 | +] | ||
89 | + | ||
90 | +WSGI_APPLICATION = 'backend.wsgi.application' | ||
91 | + | ||
92 | + | ||
93 | +# Database | ||
94 | +# https://docs.djangoproject.com/en/3.0/ref/settings/#databases | ||
95 | + | ||
96 | +DATABASES = { | ||
97 | + 'default': { | ||
98 | + 'ENGINE': 'django.db.backends.sqlite3', | ||
99 | + 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), | ||
100 | + } | ||
101 | +} | ||
102 | + | ||
103 | + | ||
104 | +# Password validation | ||
105 | +# https://docs.djangoproject.com/en/3.0/ref/settings/#auth-password-validators | ||
106 | + | ||
107 | +AUTH_PASSWORD_VALIDATORS = [ | ||
108 | + { | ||
109 | + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', | ||
110 | + }, | ||
111 | + { | ||
112 | + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', | ||
113 | + }, | ||
114 | + { | ||
115 | + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', | ||
116 | + }, | ||
117 | + { | ||
118 | + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', | ||
119 | + }, | ||
120 | +] | ||
121 | + | ||
122 | + | ||
123 | +# Internationalization | ||
124 | +# https://docs.djangoproject.com/en/3.0/topics/i18n/ | ||
125 | + | ||
126 | +LANGUAGE_CODE = 'en-us' | ||
127 | + | ||
128 | +TIME_ZONE = 'UTC' | ||
129 | + | ||
130 | +USE_I18N = True | ||
131 | + | ||
132 | +USE_L10N = True | ||
133 | + | ||
134 | +USE_TZ = True | ||
135 | + | ||
136 | + | ||
137 | +# CORS settings | ||
138 | + | ||
139 | +CORS_ORIGIN_ALLOW_ALL = True | ||
140 | +CORS_ALLOW_CREDENTIALS = True | ||
141 | +CORS_ORIGIN_WHITELIST = [ | ||
142 | + "http://127.0.0.1:12233", | ||
143 | + "http://localhost:8080", | ||
144 | + "http://127.0.0.1:9000" | ||
145 | +] | ||
146 | +CORS_URLS_REGEX = r'^/api/.*$' | ||
147 | +CORS_ALLOW_METHODS = ( | ||
148 | + 'DELETE', | ||
149 | + 'GET', | ||
150 | + 'OPTIONS', | ||
151 | + 'PATCH', | ||
152 | + 'POST', | ||
153 | + 'PUT', | ||
154 | +) | ||
155 | + | ||
156 | +CORS_ALLOW_HEADERS = ( | ||
157 | + 'accept', | ||
158 | + 'accept-encoding', | ||
159 | + 'authorization', | ||
160 | + 'content-type', | ||
161 | + 'dnt', | ||
162 | + 'origin', | ||
163 | + 'user-agent', | ||
164 | + 'x-csrftoken', | ||
165 | + 'x-requested-with', | ||
166 | +) |
web/backend/backend/urls.py
0 → 100644
1 | +"""backend URL Configuration | ||
2 | + | ||
3 | +The `urlpatterns` list routes URLs to views. For more information please see: | ||
4 | + https://docs.djangoproject.com/en/3.0/topics/http/urls/ | ||
5 | +Examples: | ||
6 | +Function views | ||
7 | + 1. Add an import: from my_app import views | ||
8 | + 2. Add a URL to urlpatterns: path('', views.home, name='home') | ||
9 | +Class-based views | ||
10 | + 1. Add an import: from other_app.views import Home | ||
11 | + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') | ||
12 | +Including another URLconf | ||
13 | + 1. Import the include() function: from django.urls import include, path | ||
14 | + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) | ||
15 | +""" | ||
16 | +from django.contrib import admin | ||
17 | +from django.urls import path | ||
18 | +from django.conf.urls import url, include | ||
19 | +from django.conf import settings | ||
20 | +from django.conf.urls.static import static | ||
21 | + | ||
22 | +urlpatterns = [ | ||
23 | + path('', include('api.urls')), | ||
24 | + path('admin/', admin.site.urls), | ||
25 | +] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/backend/wsgi.py
0 → 100644
1 | +""" | ||
2 | +WSGI config for backend project. | ||
3 | + | ||
4 | +It exposes the WSGI callable as a module-level variable named ``application``. | ||
5 | + | ||
6 | +For more information on this file, see | ||
7 | +https://docs.djangoproject.com/en/3.0/howto/deployment/wsgi/ | ||
8 | +""" | ||
9 | + | ||
10 | +import os | ||
11 | + | ||
12 | +from django.core.wsgi import get_wsgi_application | ||
13 | + | ||
14 | +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
15 | + | ||
16 | +application = get_wsgi_application() |
web/backend/convertPb2Tfrecord.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import glob, os | ||
3 | +import numpy | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + audio_seg = quantize(seg) | ||
68 | + byte_audio.append(audio_seg) | ||
69 | + | ||
70 | + for seg in rgb_slices: | ||
71 | + rgb_seg = quantize(seg) | ||
72 | + byte_rgb.append(rgb_seg) | ||
73 | + | ||
74 | + return byte_audio, byte_rgb | ||
75 | + | ||
76 | + | ||
77 | +def make_exmp(id, labels, audio, rgb): | ||
78 | + audio_features = [] | ||
79 | + rgb_features = [] | ||
80 | + | ||
81 | + for embedding in audio: | ||
82 | + embedding_feature = tf.train.Feature( | ||
83 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
84 | + audio_features.append(embedding_feature) | ||
85 | + | ||
86 | + for embedding in rgb: | ||
87 | + embedding_feature = tf.train.Feature( | ||
88 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
89 | + rgb_features.append(embedding_feature) | ||
90 | + | ||
91 | + # for construct yt8m data | ||
92 | + seq_exmp = tf.train.SequenceExample( | ||
93 | + context=tf.train.Features( | ||
94 | + feature={ | ||
95 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
96 | + value=[id.encode('utf-8')])), | ||
97 | + 'labels': tf.train.Feature(int64_list=tf.train.Int64List( | ||
98 | + value=[labels])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +if __name__ == '__main__': | ||
115 | + filename = '/tmp/mediapipe/features.pb' | ||
116 | + | ||
117 | + sequence_example = open(filename, 'rb').read() | ||
118 | + | ||
119 | + audio, rgb = parse_exmp(sequence_example) | ||
120 | + | ||
121 | + id = 'test_001' | ||
122 | + | ||
123 | + labels = 1 | ||
124 | + | ||
125 | + tmp_example = make_exmp(id, labels, audio, rgb) | ||
126 | + | ||
127 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
128 | + print(decoded) | ||
129 | + | ||
130 | + # then you can write tmp_example to tfrecord files | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/manage.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | +"""Django's command-line utility for administrative tasks.""" | ||
3 | +import os | ||
4 | +import sys | ||
5 | + | ||
6 | + | ||
7 | +def main(): | ||
8 | + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') | ||
9 | + try: | ||
10 | + from django.core.management import execute_from_command_line | ||
11 | + except ImportError as exc: | ||
12 | + raise ImportError( | ||
13 | + "Couldn't import Django. Are you sure it's installed and " | ||
14 | + "available on your PYTHONPATH environment variable? Did you " | ||
15 | + "forget to activate a virtual environment?" | ||
16 | + ) from exc | ||
17 | + execute_from_command_line(sys.argv) | ||
18 | + | ||
19 | + | ||
20 | +if __name__ == '__main__': | ||
21 | + main() |
web/backend/requirements.txt
0 → 100644
web/backend/run
0 → 100644
web/backend/runMediaPipe.sh
0 → 100644
1 | +#!/bin/bash | ||
2 | +cd ../../../mediapipe | ||
3 | +. venv/bin/activate | ||
4 | + | ||
5 | +/usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \ | ||
6 | +alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel' | ||
7 | + | ||
8 | +python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ | ||
9 | + --path_to_input_video=/$1 \ | ||
10 | + --clip_end_time_sec=$2 | ||
11 | + | ||
12 | +GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \ | ||
13 | + --calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \ | ||
14 | + --input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.pb \ | ||
15 | + --output_side_packets=output_sequence_example=/tmp/mediapipe/features.pb | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/templates/index.html
0 → 100644
1 | +{% load static %} | ||
2 | +<!doctype html> | ||
3 | +<html lang="en"> | ||
4 | + | ||
5 | +<head> | ||
6 | + <meta charset=utf-8> | ||
7 | + <meta http-equiv=X-UA-Compatible content="IE=edge"> | ||
8 | + <meta name=viewport content="width=device-width,initial-scale=1"> | ||
9 | + <link rel=icon href="{% static 'favicon.ico' %}"> | ||
10 | + <title>ThrowBox</title> | ||
11 | + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=preload as=style> | ||
12 | + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=preload as=style> | ||
13 | + <link href="{% static '/js/app.1e0612ff.js' %}" rel=preload as=script> | ||
14 | + <link href="{% static '/js/chunk-vendors.951af5fd.js' %}" rel=preload as=script> | ||
15 | + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=stylesheet> | ||
16 | + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=stylesheet> | ||
17 | +</head> | ||
18 | + | ||
19 | +<body> | ||
20 | + <noscript><strong>We're sorry but ThrowBox doesn't work properly without JavaScript enabled. Please enable it to | ||
21 | + continue.</strong></noscript> | ||
22 | + <div id=app> | ||
23 | + </div> | ||
24 | + <script src="{% static '/js/chunk-vendors.951af5fd.js' %}"></script> | ||
25 | + <script src="{% static '/js/app.1e0612ff.js' %}"></script> | ||
26 | +</body> | ||
27 | + | ||
28 | +</html> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/vue2djangoTemplate.py
0 → 100644
web/backend/yt8m/__init__.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate or keep track of the interpolated average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating interpolated average precision for an | ||
17 | +entire list or the top-n ranked items. For the definition of the | ||
18 | +(non-)interpolated average precision: | ||
19 | +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | ||
20 | + | ||
21 | +Example usages: | ||
22 | +1) Use it as a static function call to directly calculate average precision for | ||
23 | +a short ranked list in the memory. | ||
24 | + | ||
25 | +``` | ||
26 | +import random | ||
27 | + | ||
28 | +p = np.array([random.random() for _ in xrange(10)]) | ||
29 | +a = np.array([random.choice([0, 1]) for _ in xrange(10)]) | ||
30 | + | ||
31 | +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) | ||
32 | +``` | ||
33 | + | ||
34 | +2) Use it as an object for long ranked list that cannot be stored in memory or | ||
35 | +the case where partial predictions can be observed at a time (Tensorflow | ||
36 | +predictions). In this case, we first call the function accumulate many times | ||
37 | +to process parts of the ranked list. After processing all the parts, we call | ||
38 | +peek_interpolated_ap_at_n. | ||
39 | +``` | ||
40 | +p1 = np.array([random.random() for _ in xrange(5)]) | ||
41 | +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
42 | +p2 = np.array([random.random() for _ in xrange(5)]) | ||
43 | +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
44 | + | ||
45 | +# interpolated average precision at 10 using 1000 break points | ||
46 | +calculator = average_precision_calculator.AveragePrecisionCalculator(10) | ||
47 | +calculator.accumulate(p1, a1) | ||
48 | +calculator.accumulate(p2, a2) | ||
49 | +ap3 = calculator.peek_ap_at_n() | ||
50 | +``` | ||
51 | +""" | ||
52 | + | ||
53 | +import heapq | ||
54 | +import random | ||
55 | +import numbers | ||
56 | + | ||
57 | +import numpy | ||
58 | + | ||
59 | + | ||
60 | +class AveragePrecisionCalculator(object): | ||
61 | + """Calculate the average precision and average precision at n.""" | ||
62 | + | ||
63 | + def __init__(self, top_n=None): | ||
64 | + """Construct an AveragePrecisionCalculator to calculate average precision. | ||
65 | + | ||
66 | + This class is used to calculate the average precision for a single label. | ||
67 | + | ||
68 | + Args: | ||
69 | + top_n: A positive Integer specifying the average precision at n, or None | ||
70 | + to use all provided data points. | ||
71 | + | ||
72 | + Raises: | ||
73 | + ValueError: An error occurred when the top_n is not a positive integer. | ||
74 | + """ | ||
75 | + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): | ||
76 | + raise ValueError("top_n must be a positive integer or None.") | ||
77 | + | ||
78 | + self._top_n = top_n # average precision at n | ||
79 | + self._total_positives = 0 # total number of positives have seen | ||
80 | + self._heap = [] # max heap of (prediction, actual) | ||
81 | + | ||
82 | + @property | ||
83 | + def heap_size(self): | ||
84 | + """Gets the heap size maintained in the class.""" | ||
85 | + return len(self._heap) | ||
86 | + | ||
87 | + @property | ||
88 | + def num_accumulated_positives(self): | ||
89 | + """Gets the number of positive samples that have been accumulated.""" | ||
90 | + return self._total_positives | ||
91 | + | ||
92 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
93 | + """Accumulate the predictions and their ground truth labels. | ||
94 | + | ||
95 | + After the function call, we may call peek_ap_at_n to actually calculate | ||
96 | + the average precision. | ||
97 | + Note predictions and actuals must have the same shape. | ||
98 | + | ||
99 | + Args: | ||
100 | + predictions: a list storing the prediction scores. | ||
101 | + actuals: a list storing the ground truth labels. Any value larger than 0 | ||
102 | + will be treated as positives, otherwise as negatives. num_positives = If | ||
103 | + the 'predictions' and 'actuals' inputs aren't complete, then it's | ||
104 | + possible some true positives were missed in them. In that case, you can | ||
105 | + provide 'num_positives' in order to accurately track recall. | ||
106 | + | ||
107 | + Raises: | ||
108 | + ValueError: An error occurred when the format of the input is not the | ||
109 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
110 | + """ | ||
111 | + if len(predictions) != len(actuals): | ||
112 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
113 | + | ||
114 | + if num_positives is not None: | ||
115 | + if not isinstance(num_positives, numbers.Number) or num_positives < 0: | ||
116 | + raise ValueError( | ||
117 | + "'num_positives' was provided but it was a negative number.") | ||
118 | + | ||
119 | + if num_positives is not None: | ||
120 | + self._total_positives += num_positives | ||
121 | + else: | ||
122 | + self._total_positives += numpy.size( | ||
123 | + numpy.where(numpy.array(actuals) > 1e-5)) | ||
124 | + topk = self._top_n | ||
125 | + heap = self._heap | ||
126 | + | ||
127 | + for i in range(numpy.size(predictions)): | ||
128 | + if topk is None or len(heap) < topk: | ||
129 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
130 | + else: | ||
131 | + if predictions[i] > heap[0][0]: # heap[0] is the smallest | ||
132 | + heapq.heappop(heap) | ||
133 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
134 | + | ||
135 | + def clear(self): | ||
136 | + """Clear the accumulated predictions.""" | ||
137 | + self._heap = [] | ||
138 | + self._total_positives = 0 | ||
139 | + | ||
140 | + def peek_ap_at_n(self): | ||
141 | + """Peek the non-interpolated average precision at n. | ||
142 | + | ||
143 | + Returns: | ||
144 | + The non-interpolated average precision at n (default 0). | ||
145 | + If n is larger than the length of the ranked list, | ||
146 | + the average precision will be returned. | ||
147 | + """ | ||
148 | + if self.heap_size <= 0: | ||
149 | + return 0 | ||
150 | + predlists = numpy.array(list(zip(*self._heap))) | ||
151 | + | ||
152 | + ap = self.ap_at_n(predlists[0], | ||
153 | + predlists[1], | ||
154 | + n=self._top_n, | ||
155 | + total_num_positives=self._total_positives) | ||
156 | + return ap | ||
157 | + | ||
158 | + @staticmethod | ||
159 | + def ap(predictions, actuals): | ||
160 | + """Calculate the non-interpolated average precision. | ||
161 | + | ||
162 | + Args: | ||
163 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
164 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
165 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
166 | + | ||
167 | + Returns: | ||
168 | + The non-interpolated average precision at n. | ||
169 | + If n is larger than the length of the ranked list, | ||
170 | + the average precision will be returned. | ||
171 | + | ||
172 | + Raises: | ||
173 | + ValueError: An error occurred when the format of the input is not the | ||
174 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
175 | + """ | ||
176 | + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None) | ||
177 | + | ||
178 | + @staticmethod | ||
179 | + def ap_at_n(predictions, actuals, n=20, total_num_positives=None): | ||
180 | + """Calculate the non-interpolated average precision. | ||
181 | + | ||
182 | + Args: | ||
183 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
184 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
185 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
186 | + n: the top n items to be considered in ap@n. | ||
187 | + total_num_positives : (optionally) you can specify the number of total | ||
188 | + positive in the list. If specified, it will be used in calculation. | ||
189 | + | ||
190 | + Returns: | ||
191 | + The non-interpolated average precision at n. | ||
192 | + If n is larger than the length of the ranked list, | ||
193 | + the average precision will be returned. | ||
194 | + | ||
195 | + Raises: | ||
196 | + ValueError: An error occurred when | ||
197 | + 1) the format of the input is not the numpy 1-D array; | ||
198 | + 2) the shape of predictions and actuals does not match; | ||
199 | + 3) the input n is not a positive integer. | ||
200 | + """ | ||
201 | + if len(predictions) != len(actuals): | ||
202 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
203 | + | ||
204 | + if n is not None: | ||
205 | + if not isinstance(n, int) or n <= 0: | ||
206 | + raise ValueError("n must be 'None' or a positive integer." | ||
207 | + " It was '%s'." % n) | ||
208 | + | ||
209 | + ap = 0.0 | ||
210 | + | ||
211 | + predictions = numpy.array(predictions) | ||
212 | + actuals = numpy.array(actuals) | ||
213 | + | ||
214 | + # add a shuffler to avoid overestimating the ap | ||
215 | + predictions, actuals = AveragePrecisionCalculator._shuffle( | ||
216 | + predictions, actuals) | ||
217 | + sortidx = sorted(range(len(predictions)), | ||
218 | + key=lambda k: predictions[k], | ||
219 | + reverse=True) | ||
220 | + | ||
221 | + if total_num_positives is None: | ||
222 | + numpos = numpy.size(numpy.where(actuals > 0)) | ||
223 | + else: | ||
224 | + numpos = total_num_positives | ||
225 | + | ||
226 | + if numpos == 0: | ||
227 | + return 0 | ||
228 | + | ||
229 | + if n is not None: | ||
230 | + numpos = min(numpos, n) | ||
231 | + delta_recall = 1.0 / numpos | ||
232 | + poscount = 0.0 | ||
233 | + | ||
234 | + # calculate the ap | ||
235 | + r = len(sortidx) | ||
236 | + if n is not None: | ||
237 | + r = min(r, n) | ||
238 | + for i in range(r): | ||
239 | + if actuals[sortidx[i]] > 0: | ||
240 | + poscount += 1 | ||
241 | + ap += poscount / (i + 1) * delta_recall | ||
242 | + return ap | ||
243 | + | ||
244 | + @staticmethod | ||
245 | + def _shuffle(predictions, actuals): | ||
246 | + random.seed(0) | ||
247 | + suffidx = random.sample(range(len(predictions)), len(predictions)) | ||
248 | + predictions = predictions[suffidx] | ||
249 | + actuals = actuals[suffidx] | ||
250 | + return predictions, actuals | ||
251 | + | ||
252 | + @staticmethod | ||
253 | + def _zero_one_normalize(predictions, epsilon=1e-7): | ||
254 | + """Normalize the predictions to the range between 0.0 and 1.0. | ||
255 | + | ||
256 | + For some predictions like SVM predictions, we need to normalize them before | ||
257 | + calculate the interpolated average precision. The normalization will not | ||
258 | + change the rank in the original list and thus won't change the average | ||
259 | + precision. | ||
260 | + | ||
261 | + Args: | ||
262 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
263 | + epsilon: a small constant to avoid denominator being zero. | ||
264 | + | ||
265 | + Returns: | ||
266 | + The normalized prediction. | ||
267 | + """ | ||
268 | + denominator = numpy.max(predictions) - numpy.min(predictions) | ||
269 | + ret = (predictions - numpy.min(predictions)) / numpy.max( | ||
270 | + denominator, epsilon) | ||
271 | + return ret |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utility to convert the output of batch prediction into a CSV submission. | ||
15 | + | ||
16 | +It converts the JSON files created by the command | ||
17 | +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. | ||
18 | +""" | ||
19 | + | ||
20 | +import json | ||
21 | +import tensorflow as tf | ||
22 | + | ||
23 | +from builtins import range | ||
24 | +from tensorflow import app | ||
25 | +from tensorflow import flags | ||
26 | +from tensorflow import gfile | ||
27 | +from tensorflow import logging | ||
28 | + | ||
29 | +FLAGS = flags.FLAGS | ||
30 | + | ||
31 | +if __name__ == "__main__": | ||
32 | + | ||
33 | + flags.DEFINE_string( | ||
34 | + "json_prediction_files_pattern", None, | ||
35 | + "Pattern specifying the list of JSON files that the command " | ||
36 | + "'gcloud beta ml jobs submit prediction' outputs. These files are " | ||
37 | + "located in the output path of the prediction command and are prefixed " | ||
38 | + "with 'prediction.results'.") | ||
39 | + flags.DEFINE_string( | ||
40 | + "csv_output_file", None, | ||
41 | + "The file to save the predictions converted to the CSV format.") | ||
42 | + | ||
43 | + | ||
44 | +def get_csv_header(): | ||
45 | + return "VideoId,LabelConfidencePairs\n" | ||
46 | + | ||
47 | + | ||
48 | +def to_csv_row(json_data): | ||
49 | + | ||
50 | + video_id = json_data["video_id"] | ||
51 | + | ||
52 | + class_indexes = json_data["class_indexes"] | ||
53 | + predictions = json_data["predictions"] | ||
54 | + | ||
55 | + if isinstance(video_id, list): | ||
56 | + video_id = video_id[0] | ||
57 | + class_indexes = class_indexes[0] | ||
58 | + predictions = predictions[0] | ||
59 | + | ||
60 | + if len(class_indexes) != len(predictions): | ||
61 | + raise ValueError( | ||
62 | + "The number of indexes (%s) and predictions (%s) must be equal." % | ||
63 | + (len(class_indexes), len(predictions))) | ||
64 | + | ||
65 | + return (video_id.decode("utf-8") + "," + | ||
66 | + " ".join("%i %f" % (class_indexes[i], predictions[i]) | ||
67 | + for i in range(len(class_indexes))) + "\n") | ||
68 | + | ||
69 | + | ||
70 | +def main(unused_argv): | ||
71 | + logging.set_verbosity(tf.logging.INFO) | ||
72 | + | ||
73 | + if not FLAGS.json_prediction_files_pattern: | ||
74 | + raise ValueError( | ||
75 | + "The flag --json_prediction_files_pattern must be specified.") | ||
76 | + | ||
77 | + if not FLAGS.csv_output_file: | ||
78 | + raise ValueError("The flag --csv_output_file must be specified.") | ||
79 | + | ||
80 | + logging.info("Looking for prediction files with pattern: %s", | ||
81 | + FLAGS.json_prediction_files_pattern) | ||
82 | + | ||
83 | + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) | ||
84 | + logging.info("Found files: %s", file_paths) | ||
85 | + | ||
86 | + logging.info("Writing submission file to: %s", FLAGS.csv_output_file) | ||
87 | + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: | ||
88 | + output_file.write(get_csv_header()) | ||
89 | + | ||
90 | + for file_path in file_paths: | ||
91 | + logging.info("processing file: %s", file_path) | ||
92 | + | ||
93 | + with gfile.Open(file_path) as input_file: | ||
94 | + | ||
95 | + for line in input_file: | ||
96 | + json_data = json.loads(line) | ||
97 | + output_file.write(to_csv_row(json_data)) | ||
98 | + | ||
99 | + output_file.flush() | ||
100 | + logging.info("done") | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == "__main__": | ||
104 | + app.run() |
web/backend/yt8m/esot3ria/inference_pb.py
0 → 100644
1 | +import numpy as np | ||
2 | +import tensorflow as tf | ||
3 | +from tensorflow import logging | ||
4 | +from tensorflow import gfile | ||
5 | +import operator | ||
6 | +import pb_util as pbutil | ||
7 | +import video_recommender as recommender | ||
8 | +import video_util as videoutil | ||
9 | + | ||
10 | +# Define file paths. | ||
11 | +MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model" | ||
12 | +VOCAB_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/vocabulary.csv" | ||
13 | +VIDEO_TAGS_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/kaggle_solution_40k.csv" | ||
14 | +TAG_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/tag_vectors.model" | ||
15 | +VIDEO_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/video_vectors.model" | ||
16 | +SEGMENT_LABEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/segment_label_ids.csv" | ||
17 | + | ||
18 | +# Define parameters. | ||
19 | +TAG_TOP_K = 5 | ||
20 | +VIDEO_TOP_K = 10 | ||
21 | + | ||
22 | + | ||
23 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
24 | + """Get segment-level inputs from frame-level features.""" | ||
25 | + video_batch_size = batch_video_mtx.shape[0] | ||
26 | + max_frame = batch_video_mtx.shape[1] | ||
27 | + feature_dim = batch_video_mtx.shape[-1] | ||
28 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
29 | + padded_segment_sizes *= segment_size | ||
30 | + segment_mask = ( | ||
31 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
32 | + | ||
33 | + # Segment bags. | ||
34 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
35 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
36 | + (-1, segment_size, feature_dim)) | ||
37 | + | ||
38 | + # Segment num frames. | ||
39 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
40 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
41 | + num_segment_bags = num_segments.reshape((-1)) | ||
42 | + valid_segment_mask = num_segment_bags > 0 | ||
43 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
44 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
45 | + | ||
46 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
47 | + video_idxs = np.tile( | ||
48 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
49 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
50 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
51 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
52 | + | ||
53 | + return { | ||
54 | + "video_batch": segment_frames, | ||
55 | + "num_frames_batch": segment_num_frames, | ||
56 | + "video_segment_ids": video_segment_ids | ||
57 | + } | ||
58 | + | ||
59 | + | ||
60 | +def format_predictions(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
61 | + batch_size = len(video_ids) | ||
62 | + for video_index in range(batch_size): | ||
63 | + video_prediction = predictions[video_index] | ||
64 | + if whitelisted_cls_mask is not None: | ||
65 | + # Whitelist classes. | ||
66 | + video_prediction *= whitelisted_cls_mask | ||
67 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
68 | + line = [(class_index, predictions[video_index][class_index]) | ||
69 | + for class_index in top_indices] | ||
70 | + line = sorted(line, key=lambda p: -p[1]) | ||
71 | + yield (video_ids[video_index] + "," + | ||
72 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
73 | + "\n").encode("utf8") | ||
74 | + | ||
75 | + | ||
76 | +def normalize_tag(tag): | ||
77 | + if isinstance(tag, str): | ||
78 | + new_tag = tag.lower().replace('[^a-zA-Z]', ' ') | ||
79 | + if new_tag.find(" (") != -1: | ||
80 | + new_tag = new_tag[:new_tag.find(" (")] | ||
81 | + new_tag = new_tag.replace(" ", "-") | ||
82 | + return new_tag | ||
83 | + else: | ||
84 | + return tag | ||
85 | + | ||
86 | + | ||
87 | +def inference_pb(file_path, threshold): | ||
88 | + VIDEO_TOP_K = int(threshold) | ||
89 | + inference_result = {} | ||
90 | + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
91 | + | ||
92 | + # 0. Import SequenceExample type target from pb. | ||
93 | + target_video = pbutil.convert_pb(file_path) | ||
94 | + | ||
95 | + # 1. Load video features from pb. | ||
96 | + video_id_batch_val = np.array([b'video']) | ||
97 | + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature) | ||
98 | + # Restrict frame size to 300 | ||
99 | + if n_frames > 300: | ||
100 | + n_frames = 300 | ||
101 | + video_batch_val = np.zeros((300, 1152)) | ||
102 | + for i in range(n_frames): | ||
103 | + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0] | ||
104 | + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval()) | ||
105 | + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0] | ||
106 | + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval()) | ||
107 | + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) | ||
108 | + video_batch_val = np.array([video_batch_val]) | ||
109 | + num_frames_batch_val = np.array([n_frames]) | ||
110 | + | ||
111 | + # Restore checkpoint and meta-graph file. | ||
112 | + if not gfile.Exists(MODEL_PATH + ".meta"): | ||
113 | + raise IOError("Cannot find %s. Did you run eval.py?" % MODEL_PATH) | ||
114 | + meta_graph_location = MODEL_PATH + ".meta" | ||
115 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
116 | + | ||
117 | + with tf.device("/cpu:0"): | ||
118 | + saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
119 | + logging.info("restoring variables from " + MODEL_PATH) | ||
120 | + saver.restore(sess, MODEL_PATH) | ||
121 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
122 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
123 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
124 | + | ||
125 | + # Workaround for num_epochs issue. | ||
126 | + def set_up_init_ops(variables): | ||
127 | + init_op_list = [] | ||
128 | + for variable in list(variables): | ||
129 | + if "train_input" in variable.name: | ||
130 | + init_op_list.append(tf.assign(variable, 1)) | ||
131 | + variables.remove(variable) | ||
132 | + init_op_list.append(tf.variables_initializer(variables)) | ||
133 | + return init_op_list | ||
134 | + | ||
135 | + sess.run( | ||
136 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
137 | + | ||
138 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
139 | + dtype=np.float32) | ||
140 | + with tf.io.gfile.GFile(SEGMENT_LABEL_PATH) as fobj: | ||
141 | + for line in fobj: | ||
142 | + try: | ||
143 | + cls_id = int(line) | ||
144 | + whitelisted_cls_mask[cls_id] = 1. | ||
145 | + except ValueError: | ||
146 | + # Simply skip the non-integer line. | ||
147 | + continue | ||
148 | + | ||
149 | + # 2. Make segment features. | ||
150 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
151 | + video_segment_ids = results["video_segment_ids"] | ||
152 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
153 | + video_id_batch_val = np.array([ | ||
154 | + "%s:%d" % (x.decode("utf8"), y) | ||
155 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
156 | + ]) | ||
157 | + video_batch_val = results["video_batch"] | ||
158 | + num_frames_batch_val = results["num_frames_batch"] | ||
159 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
160 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
161 | + "with correct segment_labels settings.") | ||
162 | + | ||
163 | + predictions_val, = sess.run([predictions_tensor], | ||
164 | + feed_dict={ | ||
165 | + input_tensor: video_batch_val, | ||
166 | + num_frames_tensor: num_frames_batch_val | ||
167 | + }) | ||
168 | + | ||
169 | + # 3. Make vocabularies. | ||
170 | + voca_dict = {} | ||
171 | + vocabs = open(VOCAB_PATH, 'r') | ||
172 | + while True: | ||
173 | + line = vocabs.readline() | ||
174 | + if not line: break | ||
175 | + vocab_dict_item = line.split(",") | ||
176 | + if vocab_dict_item[0] != "Index": | ||
177 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
178 | + vocabs.close() | ||
179 | + | ||
180 | + # 4. Make combined scores. | ||
181 | + combined_scores = {} | ||
182 | + for line in format_predictions(video_id_batch_val, predictions_val, TAG_TOP_K, whitelisted_cls_mask): | ||
183 | + segment_id, preds = line.decode("utf8").split(",") | ||
184 | + preds = preds.split(" ") | ||
185 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
186 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
187 | + for i in range(len(pred_cls_ids)): | ||
188 | + if pred_cls_ids[i] in combined_scores: | ||
189 | + combined_scores[pred_cls_ids[i]] += pred_cls_scores[i] | ||
190 | + else: | ||
191 | + combined_scores[pred_cls_ids[i]] = pred_cls_scores[i] | ||
192 | + | ||
193 | + combined_scores = sorted(combined_scores.items(), key=operator.itemgetter(1), reverse=True) | ||
194 | + demoninator = float(combined_scores[0][1] + combined_scores[1][1] | ||
195 | + + combined_scores[2][1] + combined_scores[3][1] + combined_scores[4][1]) | ||
196 | + | ||
197 | + tag_result = [] | ||
198 | + for itemIndex in range(TAG_TOP_K): | ||
199 | + segment_tag = str(voca_dict[str(combined_scores[itemIndex][0])]) | ||
200 | + normalized_tag = normalize_tag(segment_tag) | ||
201 | + tag_percentage = format(combined_scores[itemIndex][1] / demoninator, ".3f") | ||
202 | + tag_result.append((normalized_tag, tag_percentage)) | ||
203 | + | ||
204 | + # 5. Create recommend videos info, Combine results. | ||
205 | + recommend_video_ids = recommender.recommend_videos(tag_result, TAG_VECTOR_MODEL_PATH, | ||
206 | + VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K) | ||
207 | + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids] | ||
208 | + | ||
209 | + inference_result = { | ||
210 | + "tag_result": tag_result, | ||
211 | + "video_result": video_result | ||
212 | + } | ||
213 | + | ||
214 | + # 6. Dispose instances. | ||
215 | + sess.close() | ||
216 | + | ||
217 | + return inference_result | ||
218 | + | ||
219 | + | ||
220 | +if __name__ == '__main__': | ||
221 | + filepath = "/tmp/mediapipe/features.pb" | ||
222 | + result = inference_pb(filepath) | ||
223 | + print(result) |
This diff could not be displayed because it is too large.
web/backend/yt8m/esot3ria/model/eval/events.out.tfevents.1591170123.mlvc-nogadahalf12-instance
0 → 100644
No preview for this file type
web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model.data-00000-of-00001
0 → 100644
This file is too large to display.
No preview for this file type
No preview for this file type
1 | +{"model": "FrameLevelLogisticModel", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
web/backend/yt8m/esot3ria/pb_util.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy | ||
3 | + | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + # audio_seg = quantize(seg) | ||
68 | + audio_seg = _make_bytes(seg) | ||
69 | + byte_audio.append(audio_seg) | ||
70 | + | ||
71 | + for seg in rgb_slices: | ||
72 | + # rgb_seg = quantize(seg) | ||
73 | + rgb_seg = _make_bytes(seg) | ||
74 | + byte_rgb.append(rgb_seg) | ||
75 | + | ||
76 | + return byte_audio, byte_rgb | ||
77 | + | ||
78 | + | ||
79 | +def make_exmp(id, audio, rgb): | ||
80 | + audio_features = [] | ||
81 | + rgb_features = [] | ||
82 | + | ||
83 | + for embedding in audio: | ||
84 | + embedding_feature = tf.train.Feature( | ||
85 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
86 | + audio_features.append(embedding_feature) | ||
87 | + | ||
88 | + for embedding in rgb: | ||
89 | + embedding_feature = tf.train.Feature( | ||
90 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
91 | + rgb_features.append(embedding_feature) | ||
92 | + | ||
93 | + # for construct yt8m data | ||
94 | + seq_exmp = tf.train.SequenceExample( | ||
95 | + context=tf.train.Features( | ||
96 | + feature={ | ||
97 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
98 | + value=[id.encode('utf-8')])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +def convert_pb(filename): | ||
115 | + sequence_example = open(filename, 'rb').read() | ||
116 | + | ||
117 | + audio, rgb = parse_exmp(sequence_example) | ||
118 | + tmp_example = make_exmp('video', audio, rgb) | ||
119 | + | ||
120 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
121 | + return decoded |
web/backend/yt8m/esot3ria/readpb.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +frame_lvl_record = "test0000.tfrecord" | ||
5 | + | ||
6 | +feat_rgb = [] | ||
7 | +feat_audio = [] | ||
8 | + | ||
9 | +for example in tf.python_io.tf_record_iterator(frame_lvl_record): | ||
10 | + tf_seq_example = tf.train.SequenceExample.FromString(example) | ||
11 | + test = tf_seq_example.SerializeToString() | ||
12 | + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature) | ||
13 | + sess = tf.InteractiveSession() | ||
14 | + rgb_frame = [] | ||
15 | + audio_frame = [] | ||
16 | + # iterate through frames | ||
17 | + for i in range(n_frames): | ||
18 | + rgb_frame.append(tf.cast(tf.decode_raw( | ||
19 | + tf_seq_example.feature_lists.feature_list['rgb'] | ||
20 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
21 | + , tf.float32).eval()) | ||
22 | + audio_frame.append(tf.cast(tf.decode_raw( | ||
23 | + tf_seq_example.feature_lists.feature_list['audio'] | ||
24 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
25 | + , tf.float32).eval()) | ||
26 | + | ||
27 | + sess.close() | ||
28 | + | ||
29 | + feat_audio.append(audio_frame) | ||
30 | + feat_rgb.append(rgb_frame) | ||
31 | + break | ||
32 | + | ||
33 | +print('The first video has %d frames' %len(feat_rgb[0])) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +import nltk | ||
2 | +import gensim | ||
3 | +import pandas as pd | ||
4 | + | ||
5 | +# Load files. | ||
6 | +nltk.download('stopwords') | ||
7 | +vocab = pd.read_csv('../vocabulary.csv') | ||
8 | + | ||
9 | +# Lower corpus and Remove () from name. | ||
10 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.lower().str.replace('[^a-zA-Z0-9]', ' ') | ||
11 | +for i in range(vocab['Name'].__len__()): | ||
12 | + name = vocab['Name'][i] | ||
13 | + if isinstance(name, str) and name.find(" (") != -1: | ||
14 | + vocab['Name'][i] = name[:name.find(" (")] | ||
15 | +vocab['Name'] = vocab['Name'].str.lower() | ||
16 | + | ||
17 | +# Combine separated names.(mobile phone -> mobile-phone) | ||
18 | +for name in vocab['Name']: | ||
19 | + if isinstance(name, str) and name.find(" ") != -1: | ||
20 | + combined_name = name.replace(" ", "-") | ||
21 | + for i in range(vocab['WikiDescription'].__len__()): | ||
22 | + if isinstance(vocab['WikiDescription'][i], str): | ||
23 | + vocab['WikiDescription'][i] = vocab['WikiDescription'][i].replace(name, combined_name) | ||
24 | + | ||
25 | + | ||
26 | +# Remove stopwords from corpus. | ||
27 | +stop_re = '\\b'+'\\b|\\b'.join(nltk.corpus.stopwords.words('english'))+'\\b' | ||
28 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.replace(stop_re, '') | ||
29 | +vocab['WikiDescription'] = vocab['WikiDescription'].str.split() | ||
30 | + | ||
31 | +# Tokenize corpus. | ||
32 | +tokenlist = [x for x in vocab['WikiDescription'] if str(x) != 'nan'] | ||
33 | +phrases = gensim.models.phrases.Phrases(tokenlist) | ||
34 | +phraser = gensim.models.phrases.Phraser(phrases) | ||
35 | +vocab_phrased = phraser[tokenlist] | ||
36 | + | ||
37 | +# Vectorize tags. | ||
38 | +w2v = gensim.models.word2vec.Word2Vec(sentences=tokenlist, min_count=1) | ||
39 | +w2v.save('tag_vectors.model') | ||
40 | + | ||
41 | +# word_vectors = w2v.wv | ||
42 | +# vocabs = word_vectors.vocab.keys() | ||
43 | +# word_vectors_list = [word_vectors[v] for v in vocabs] |
web/backend/yt8m/esot3ria/tag_vectors.model
0 → 100644
This file is too large to display.
1 | +from gensim.models import Word2Vec | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +def recommend_videos(tags, tag_model_path, video_model_path, top_k): | ||
5 | + tag_vectors = Word2Vec.load(tag_model_path).wv | ||
6 | + video_vectors = Word2Vec().wv.load(video_model_path) | ||
7 | + error_tags = [] | ||
8 | + | ||
9 | + video_vector = np.zeros(100) | ||
10 | + for (tag, weight) in tags: | ||
11 | + if tag in tag_vectors.vocab: | ||
12 | + video_vector = video_vector + (tag_vectors[tag] * float(weight)) | ||
13 | + else: | ||
14 | + # Pass if tag is unknown | ||
15 | + if tag not in error_tags: | ||
16 | + error_tags.append(tag) | ||
17 | + | ||
18 | + similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)] | ||
19 | + return similar_ids |
web/backend/yt8m/esot3ria/video_util.py
0 → 100644
1 | +import requests | ||
2 | +import pandas as pd | ||
3 | + | ||
4 | +base_URL = 'https://data.yt8m.org/2/j/i/' | ||
5 | +youtube_url = 'https://www.youtube.com/watch?v=' | ||
6 | + | ||
7 | + | ||
8 | +def getURL(vid_id): | ||
9 | + URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js' | ||
10 | + response = requests.get(URL, verify = False) | ||
11 | + if response.status_code == 200: | ||
12 | + return youtube_url + response.text[10:-3] | ||
13 | + | ||
14 | + | ||
15 | +def getVideoInfo(vid_id, video_tags_path, top_k): | ||
16 | + video_url = getURL(vid_id) | ||
17 | + | ||
18 | + entire_video_tags = pd.read_csv(video_tags_path) | ||
19 | + video_tags_info = entire_video_tags.loc[entire_video_tags["vid_id"] == vid_id] | ||
20 | + video_tags = [] | ||
21 | + for i in range(1, top_k + 1): | ||
22 | + video_tag_tuple = video_tags_info["segment" + str(i)].values[0] # ex: "mobile-phone:0.361" | ||
23 | + video_tags.append(video_tag_tuple.split(":")[0]) | ||
24 | + | ||
25 | + return { | ||
26 | + "video_url": video_url, | ||
27 | + "video_tags": video_tags | ||
28 | + } |
1 | +import pandas as pd | ||
2 | +import numpy as np | ||
3 | +from gensim.models import Word2Vec | ||
4 | + | ||
5 | +BATCH_SIZE = 1000 | ||
6 | + | ||
7 | + | ||
8 | +def vectorization_video(): | ||
9 | + print('[0.1 0.2]') | ||
10 | + | ||
11 | + | ||
12 | +if __name__ == '__main__': | ||
13 | + tag_vectors = Word2Vec.load("tag_vectors.model").wv | ||
14 | + video_vectors = Word2Vec().wv # Empty model | ||
15 | + | ||
16 | + # Load video recommendation tags. | ||
17 | + video_tags = pd.read_csv('kaggle_solution_40k.csv') | ||
18 | + | ||
19 | + # Define batch variables. | ||
20 | + batch_video_ids = [] | ||
21 | + batch_video_vectors = [] | ||
22 | + error_tags = [] | ||
23 | + | ||
24 | + for i, row in video_tags.iterrows(): | ||
25 | + video_id = row[0] | ||
26 | + video_vector = np.zeros(100) | ||
27 | + for segment_index in range(1, 6): | ||
28 | + tag, weight = row[segment_index].split(":") | ||
29 | + if tag in tag_vectors.vocab: | ||
30 | + video_vector = video_vector + (tag_vectors[tag] * float(weight)) | ||
31 | + else: | ||
32 | + # Pass if tag is unknown | ||
33 | + if tag not in error_tags: | ||
34 | + error_tags.append(tag) | ||
35 | + | ||
36 | + batch_video_ids.append(video_id) | ||
37 | + batch_video_vectors.append(video_vector) | ||
38 | + # Add video vectors. | ||
39 | + if (i+1) % BATCH_SIZE == 0: | ||
40 | + video_vectors.add(batch_video_ids, batch_video_vectors) | ||
41 | + batch_video_ids = [] | ||
42 | + batch_video_vectors = [] | ||
43 | + print("Video vectors created: ", i+1) | ||
44 | + | ||
45 | + # Add rest of video vectors. | ||
46 | + video_vectors.add(batch_video_ids, batch_video_vectors) | ||
47 | + print("error tags: ") | ||
48 | + print(error_tags) | ||
49 | + | ||
50 | + video_vectors.save("video_vectors.model") | ||
51 | + | ||
52 | + # Usage | ||
53 | + # video_vectors = Word2Vec().wv.load("video_vectors.model") | ||
54 | + # video_vectors.most_similar("XwFj", topn=5) |
This file is too large to display.
web/backend/yt8m/eval.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for evaluating Tensorflow models on the YouTube-8M dataset.""" | ||
15 | + | ||
16 | +import json | ||
17 | +import os | ||
18 | +import time | ||
19 | + | ||
20 | +from absl import logging | ||
21 | +import eval_util | ||
22 | +import frame_level_models | ||
23 | +import losses | ||
24 | +import readers | ||
25 | +import tensorflow as tf | ||
26 | +from tensorflow import flags | ||
27 | +from tensorflow.python.lib.io import file_io | ||
28 | +import utils | ||
29 | +import video_level_models | ||
30 | + | ||
31 | +FLAGS = flags.FLAGS | ||
32 | + | ||
33 | +if __name__ == "__main__": | ||
34 | + # Dataset flags. | ||
35 | + flags.DEFINE_string( | ||
36 | + "train_dir", "/tmp/yt8m_model/", | ||
37 | + "The directory to load the model files from. " | ||
38 | + "The tensorboard metrics files are also saved to this " | ||
39 | + "directory.") | ||
40 | + flags.DEFINE_string( | ||
41 | + "eval_data_pattern", "", | ||
42 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
43 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
44 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
45 | + flags.DEFINE_bool( | ||
46 | + "segment_labels", False, | ||
47 | + "If set, then --eval_data_pattern must be frame-level features (but with" | ||
48 | + " segment_labels). Otherwise, --eval_data_pattern must be aggregated " | ||
49 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
50 | + "read 3D batches VS 4D batches.") | ||
51 | + | ||
52 | + # Other flags. | ||
53 | + flags.DEFINE_integer("batch_size", 1024, | ||
54 | + "How many examples to process per batch.") | ||
55 | + flags.DEFINE_integer("num_readers", 8, | ||
56 | + "How many threads to use for reading input files.") | ||
57 | + flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") | ||
58 | + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | ||
59 | + | ||
60 | + | ||
61 | +def find_class_by_name(name, modules): | ||
62 | + """Searches the provided modules for the named class and returns it.""" | ||
63 | + modules = [getattr(module, name, None) for module in modules] | ||
64 | + return next(a for a in modules if a) | ||
65 | + | ||
66 | + | ||
67 | +def get_input_evaluation_tensors(reader, | ||
68 | + data_pattern, | ||
69 | + batch_size=1024, | ||
70 | + num_readers=1): | ||
71 | + """Creates the section of the graph which reads the evaluation data. | ||
72 | + | ||
73 | + Args: | ||
74 | + reader: A class which parses the training data. | ||
75 | + data_pattern: A 'glob' style path to the data files. | ||
76 | + batch_size: How many examples to process at a time. | ||
77 | + num_readers: How many I/O threads to use. | ||
78 | + | ||
79 | + Returns: | ||
80 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
81 | + tensor containing the number of frames per video. The exact dimensions | ||
82 | + depend on the reader being used. | ||
83 | + | ||
84 | + Raises: | ||
85 | + IOError: If no files matching the given pattern were found. | ||
86 | + """ | ||
87 | + logging.info("Using batch size of %d for evaluation.", batch_size) | ||
88 | + with tf.name_scope("eval_input"): | ||
89 | + files = tf.io.gfile.glob(data_pattern) | ||
90 | + if not files: | ||
91 | + raise IOError("Unable to find the evaluation files.") | ||
92 | + logging.info("number of evaluation files: %d", len(files)) | ||
93 | + filename_queue = tf.train.string_input_producer(files, | ||
94 | + shuffle=False, | ||
95 | + num_epochs=1) | ||
96 | + eval_data = [ | ||
97 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
98 | + ] | ||
99 | + return tf.train.batch_join(eval_data, | ||
100 | + batch_size=batch_size, | ||
101 | + capacity=3 * batch_size, | ||
102 | + allow_smaller_final_batch=True, | ||
103 | + enqueue_many=True) | ||
104 | + | ||
105 | + | ||
106 | +def build_graph(reader, | ||
107 | + model, | ||
108 | + eval_data_pattern, | ||
109 | + label_loss_fn, | ||
110 | + batch_size=1024, | ||
111 | + num_readers=1): | ||
112 | + """Creates the Tensorflow graph for evaluation. | ||
113 | + | ||
114 | + Args: | ||
115 | + reader: The data file reader. It should inherit from BaseReader. | ||
116 | + model: The core model (e.g. logistic or neural net). It should inherit from | ||
117 | + BaseModel. | ||
118 | + eval_data_pattern: glob path to the evaluation data files. | ||
119 | + label_loss_fn: What kind of loss to apply to the model. It should inherit | ||
120 | + from BaseLoss. | ||
121 | + batch_size: How many examples to process at a time. | ||
122 | + num_readers: How many threads to use for I/O operations. | ||
123 | + """ | ||
124 | + | ||
125 | + global_step = tf.Variable(0, trainable=False, name="global_step") | ||
126 | + input_data_dict = get_input_evaluation_tensors(reader, | ||
127 | + eval_data_pattern, | ||
128 | + batch_size=batch_size, | ||
129 | + num_readers=num_readers) | ||
130 | + video_id_batch = input_data_dict["video_ids"] | ||
131 | + model_input_raw = input_data_dict["video_matrix"] | ||
132 | + labels_batch = input_data_dict["labels"] | ||
133 | + num_frames = input_data_dict["num_frames"] | ||
134 | + tf.compat.v1.summary.histogram("model_input_raw", model_input_raw) | ||
135 | + | ||
136 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
137 | + | ||
138 | + # Normalize input features. | ||
139 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
140 | + | ||
141 | + with tf.compat.v1.variable_scope("tower"): | ||
142 | + result = model.create_model(model_input, | ||
143 | + num_frames=num_frames, | ||
144 | + vocab_size=reader.num_classes, | ||
145 | + labels=labels_batch, | ||
146 | + is_training=False) | ||
147 | + | ||
148 | + predictions = result["predictions"] | ||
149 | + tf.compat.v1.summary.histogram("model_activations", predictions) | ||
150 | + if "loss" in result.keys(): | ||
151 | + label_loss = result["loss"] | ||
152 | + else: | ||
153 | + label_loss = label_loss_fn.calculate_loss(predictions, labels_batch) | ||
154 | + | ||
155 | + tf.compat.v1.add_to_collection("global_step", global_step) | ||
156 | + tf.compat.v1.add_to_collection("loss", label_loss) | ||
157 | + tf.compat.v1.add_to_collection("predictions", predictions) | ||
158 | + tf.compat.v1.add_to_collection("input_batch", model_input) | ||
159 | + tf.compat.v1.add_to_collection("input_batch_raw", model_input_raw) | ||
160 | + tf.compat.v1.add_to_collection("video_id_batch", video_id_batch) | ||
161 | + tf.compat.v1.add_to_collection("num_frames", num_frames) | ||
162 | + tf.compat.v1.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) | ||
163 | + if FLAGS.segment_labels: | ||
164 | + tf.compat.v1.add_to_collection("label_weights", | ||
165 | + input_data_dict["label_weights"]) | ||
166 | + tf.compat.v1.add_to_collection("summary_op", tf.compat.v1.summary.merge_all()) | ||
167 | + | ||
168 | + | ||
169 | +def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | ||
170 | + last_global_step_val): | ||
171 | + """Run the evaluation loop once. | ||
172 | + | ||
173 | + Args: | ||
174 | + fetches: a dict of tensors to be run within Session. | ||
175 | + saver: a tensorflow saver to restore the model. | ||
176 | + summary_writer: a tensorflow summary_writer | ||
177 | + evl_metrics: an EvaluationMetrics object. | ||
178 | + last_global_step_val: the global step used in the previous evaluation. | ||
179 | + | ||
180 | + Returns: | ||
181 | + The global_step used in the latest model. | ||
182 | + """ | ||
183 | + | ||
184 | + global_step_val = -1 | ||
185 | + with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( | ||
186 | + allow_growth=True))) as sess: | ||
187 | + latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir) | ||
188 | + if latest_checkpoint: | ||
189 | + logging.info("Loading checkpoint for eval: %s", latest_checkpoint) | ||
190 | + # Restores from checkpoint | ||
191 | + saver.restore(sess, latest_checkpoint) | ||
192 | + # Assuming model_checkpoint_path looks something like: | ||
193 | + # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. | ||
194 | + global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] | ||
195 | + | ||
196 | + # Save model | ||
197 | + if FLAGS.segment_labels: | ||
198 | + inference_model_name = "segment_inference_model" | ||
199 | + else: | ||
200 | + inference_model_name = "inference_model" | ||
201 | + saver.save( | ||
202 | + sess, | ||
203 | + os.path.join(FLAGS.train_dir, "inference_model", | ||
204 | + inference_model_name)) | ||
205 | + else: | ||
206 | + logging.info("No checkpoint file found.") | ||
207 | + return global_step_val | ||
208 | + | ||
209 | + if global_step_val == last_global_step_val: | ||
210 | + logging.info( | ||
211 | + "skip this checkpoint global_step_val=%s " | ||
212 | + "(same as the previous one).", global_step_val) | ||
213 | + return global_step_val | ||
214 | + | ||
215 | + sess.run([tf.local_variables_initializer()]) | ||
216 | + | ||
217 | + # Start the queue runners. | ||
218 | + coord = tf.train.Coordinator() | ||
219 | + try: | ||
220 | + threads = [] | ||
221 | + for qr in tf.compat.v1.get_collection(tf.GraphKeys.QUEUE_RUNNERS): | ||
222 | + threads.extend( | ||
223 | + qr.create_threads(sess, coord=coord, daemon=True, start=True)) | ||
224 | + logging.info("enter eval_once loop global_step_val = %s. ", | ||
225 | + global_step_val) | ||
226 | + | ||
227 | + evl_metrics.clear() | ||
228 | + | ||
229 | + examples_processed = 0 | ||
230 | + while not coord.should_stop(): | ||
231 | + batch_start_time = time.time() | ||
232 | + output_data_dict = sess.run(fetches) | ||
233 | + seconds_per_batch = time.time() - batch_start_time | ||
234 | + labels_val = output_data_dict["labels"] | ||
235 | + summary_val = output_data_dict["summary"] | ||
236 | + example_per_second = labels_val.shape[0] / seconds_per_batch | ||
237 | + examples_processed += labels_val.shape[0] | ||
238 | + | ||
239 | + predictions = output_data_dict["predictions"] | ||
240 | + if FLAGS.segment_labels: | ||
241 | + # This is a workaround to ignore the unrated labels. | ||
242 | + predictions *= output_data_dict["label_weights"] | ||
243 | + iteration_info_dict = evl_metrics.accumulate(predictions, labels_val, | ||
244 | + output_data_dict["loss"]) | ||
245 | + iteration_info_dict["examples_per_second"] = example_per_second | ||
246 | + | ||
247 | + iterinfo = utils.AddGlobalStepSummary( | ||
248 | + summary_writer, | ||
249 | + global_step_val, | ||
250 | + iteration_info_dict, | ||
251 | + summary_scope="SegEval" if FLAGS.segment_labels else "Eval") | ||
252 | + logging.info("examples_processed: %d | %s", examples_processed, | ||
253 | + iterinfo) | ||
254 | + | ||
255 | + except tf.errors.OutOfRangeError as e: | ||
256 | + logging.info( | ||
257 | + "Done with batched inference. Now calculating global performance " | ||
258 | + "metrics.") | ||
259 | + # calculate the metrics for the entire epoch | ||
260 | + epoch_info_dict = evl_metrics.get() | ||
261 | + epoch_info_dict["epoch_id"] = global_step_val | ||
262 | + | ||
263 | + summary_writer.add_summary(summary_val, global_step_val) | ||
264 | + epochinfo = utils.AddEpochSummary( | ||
265 | + summary_writer, | ||
266 | + global_step_val, | ||
267 | + epoch_info_dict, | ||
268 | + summary_scope="SegEval" if FLAGS.segment_labels else "Eval") | ||
269 | + logging.info(epochinfo) | ||
270 | + evl_metrics.clear() | ||
271 | + except Exception as e: # pylint: disable=broad-except | ||
272 | + logging.info("Unexpected exception: %s", str(e)) | ||
273 | + coord.request_stop(e) | ||
274 | + | ||
275 | + coord.request_stop() | ||
276 | + coord.join(threads, stop_grace_period_secs=10) | ||
277 | + logging.info("Total: examples_processed: %d", examples_processed) | ||
278 | + | ||
279 | + return global_step_val | ||
280 | + | ||
281 | + | ||
282 | +def evaluate(): | ||
283 | + """Starts main evaluation loop.""" | ||
284 | + tf.compat.v1.set_random_seed(0) # for reproducibility | ||
285 | + | ||
286 | + # Write json of flags | ||
287 | + model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
288 | + if not file_io.file_exists(model_flags_path): | ||
289 | + raise IOError(("Cannot find file %s. Did you run train.py on the same " | ||
290 | + "--train_dir?") % model_flags_path) | ||
291 | + flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read()) | ||
292 | + | ||
293 | + with tf.Graph().as_default(): | ||
294 | + # convert feature_names and feature_sizes to lists of values | ||
295 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
296 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
297 | + | ||
298 | + if flags_dict["frame_features"]: | ||
299 | + reader = readers.YT8MFrameFeatureReader( | ||
300 | + feature_names=feature_names, | ||
301 | + feature_sizes=feature_sizes, | ||
302 | + segment_labels=FLAGS.segment_labels) | ||
303 | + else: | ||
304 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
305 | + feature_sizes=feature_sizes) | ||
306 | + | ||
307 | + model = find_class_by_name(flags_dict["model"], | ||
308 | + [frame_level_models, video_level_models])() | ||
309 | + label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])() | ||
310 | + | ||
311 | + if not FLAGS.eval_data_pattern: | ||
312 | + raise IOError("'eval_data_pattern' was not specified. Nothing to " | ||
313 | + "evaluate.") | ||
314 | + | ||
315 | + build_graph(reader=reader, | ||
316 | + model=model, | ||
317 | + eval_data_pattern=FLAGS.eval_data_pattern, | ||
318 | + label_loss_fn=label_loss_fn, | ||
319 | + num_readers=FLAGS.num_readers, | ||
320 | + batch_size=FLAGS.batch_size) | ||
321 | + logging.info("built evaluation graph") | ||
322 | + | ||
323 | + # A dict of tensors to be run in Session. | ||
324 | + fetches = { | ||
325 | + "video_id": tf.compat.v1.get_collection("video_id_batch")[0], | ||
326 | + "predictions": tf.compat.v1.get_collection("predictions")[0], | ||
327 | + "labels": tf.compat.v1.get_collection("labels")[0], | ||
328 | + "loss": tf.compat.v1.get_collection("loss")[0], | ||
329 | + "summary": tf.compat.v1.get_collection("summary_op")[0] | ||
330 | + } | ||
331 | + if FLAGS.segment_labels: | ||
332 | + fetches["label_weights"] = tf.compat.v1.get_collection("label_weights")[0] | ||
333 | + | ||
334 | + saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | ||
335 | + summary_writer = tf.compat.v1.summary.FileWriter( | ||
336 | + os.path.join(FLAGS.train_dir, "eval"), | ||
337 | + graph=tf.compat.v1.get_default_graph()) | ||
338 | + | ||
339 | + evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k, | ||
340 | + None) | ||
341 | + | ||
342 | + last_global_step_val = -1 | ||
343 | + while True: | ||
344 | + last_global_step_val = evaluation_loop(fetches, saver, summary_writer, | ||
345 | + evl_metrics, last_global_step_val) | ||
346 | + if FLAGS.run_once: | ||
347 | + break | ||
348 | + | ||
349 | + | ||
350 | +def main(unused_argv): | ||
351 | + logging.set_verbosity(logging.INFO) | ||
352 | + logging.info("tensorflow version: %s", tf.__version__) | ||
353 | + evaluate() | ||
354 | + | ||
355 | + | ||
356 | +if __name__ == "__main__": | ||
357 | + tf.compat.v1.app.run() |
web/backend/yt8m/eval_util.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides functions to help with evaluating models.""" | ||
15 | +import average_precision_calculator as ap_calculator | ||
16 | +import mean_average_precision_calculator as map_calculator | ||
17 | +import numpy | ||
18 | +from tensorflow.python.platform import gfile | ||
19 | + | ||
20 | + | ||
21 | +def flatten(l): | ||
22 | + """Merges a list of lists into a single list. """ | ||
23 | + return [item for sublist in l for item in sublist] | ||
24 | + | ||
25 | + | ||
26 | +def calculate_hit_at_one(predictions, actuals): | ||
27 | + """Performs a local (numpy) calculation of the hit at one. | ||
28 | + | ||
29 | + Args: | ||
30 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
31 | + 'batch' x 'num_classes'. | ||
32 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
33 | + 'num_classes'. | ||
34 | + | ||
35 | + Returns: | ||
36 | + float: The average hit at one across the entire batch. | ||
37 | + """ | ||
38 | + top_prediction = numpy.argmax(predictions, 1) | ||
39 | + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] | ||
40 | + return numpy.average(hits) | ||
41 | + | ||
42 | + | ||
43 | +def calculate_precision_at_equal_recall_rate(predictions, actuals): | ||
44 | + """Performs a local (numpy) calculation of the PERR. | ||
45 | + | ||
46 | + Args: | ||
47 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
48 | + 'batch' x 'num_classes'. | ||
49 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
50 | + 'num_classes'. | ||
51 | + | ||
52 | + Returns: | ||
53 | + float: The average precision at equal recall rate across the entire batch. | ||
54 | + """ | ||
55 | + aggregated_precision = 0.0 | ||
56 | + num_videos = actuals.shape[0] | ||
57 | + for row in numpy.arange(num_videos): | ||
58 | + num_labels = int(numpy.sum(actuals[row])) | ||
59 | + top_indices = numpy.argpartition(predictions[row], | ||
60 | + -num_labels)[-num_labels:] | ||
61 | + item_precision = 0.0 | ||
62 | + for label_index in top_indices: | ||
63 | + if predictions[row][label_index] > 0: | ||
64 | + item_precision += actuals[row][label_index] | ||
65 | + item_precision /= top_indices.size | ||
66 | + aggregated_precision += item_precision | ||
67 | + aggregated_precision /= num_videos | ||
68 | + return aggregated_precision | ||
69 | + | ||
70 | + | ||
71 | +def calculate_gap(predictions, actuals, top_k=20): | ||
72 | + """Performs a local (numpy) calculation of the global average precision. | ||
73 | + | ||
74 | + Only the top_k predictions are taken for each of the videos. | ||
75 | + | ||
76 | + Args: | ||
77 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
78 | + 'batch' x 'num_classes'. | ||
79 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
80 | + 'num_classes'. | ||
81 | + top_k: How many predictions to use per video. | ||
82 | + | ||
83 | + Returns: | ||
84 | + float: The global average precision. | ||
85 | + """ | ||
86 | + gap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
87 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
88 | + predictions, actuals, top_k) | ||
89 | + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), | ||
90 | + sum(num_positives)) | ||
91 | + return gap_calculator.peek_ap_at_n() | ||
92 | + | ||
93 | + | ||
94 | +def top_k_by_class(predictions, labels, k=20): | ||
95 | + """Extracts the top k predictions for each video, sorted by class. | ||
96 | + | ||
97 | + Args: | ||
98 | + predictions: A numpy matrix containing the outputs of the model. Dimensions | ||
99 | + are 'batch' x 'num_classes'. | ||
100 | + k: the top k non-zero entries to preserve in each prediction. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A tuple (predictions,labels, true_positives). 'predictions' and 'labels' | ||
104 | + are lists of lists of floats. 'true_positives' is a list of scalars. The | ||
105 | + length of the lists are equal to the number of classes. The entries in the | ||
106 | + predictions variable are probability predictions, and | ||
107 | + the corresponding entries in the labels variable are the ground truth for | ||
108 | + those predictions. The entries in 'true_positives' are the number of true | ||
109 | + positives for each class in the ground truth. | ||
110 | + | ||
111 | + Raises: | ||
112 | + ValueError: An error occurred when the k is not a positive integer. | ||
113 | + """ | ||
114 | + if k <= 0: | ||
115 | + raise ValueError("k must be a positive integer.") | ||
116 | + k = min(k, predictions.shape[1]) | ||
117 | + num_classes = predictions.shape[1] | ||
118 | + prediction_triplets = [] | ||
119 | + for video_index in range(predictions.shape[0]): | ||
120 | + prediction_triplets.extend( | ||
121 | + top_k_triplets(predictions[video_index], labels[video_index], k)) | ||
122 | + out_predictions = [[] for _ in range(num_classes)] | ||
123 | + out_labels = [[] for _ in range(num_classes)] | ||
124 | + for triplet in prediction_triplets: | ||
125 | + out_predictions[triplet[0]].append(triplet[1]) | ||
126 | + out_labels[triplet[0]].append(triplet[2]) | ||
127 | + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)] | ||
128 | + | ||
129 | + return out_predictions, out_labels, out_true_positives | ||
130 | + | ||
131 | + | ||
132 | +def top_k_triplets(predictions, labels, k=20): | ||
133 | + """Get the top_k for a 1-d numpy array. | ||
134 | + | ||
135 | + Returns a sparse list of tuples in | ||
136 | + (prediction, class) format | ||
137 | + """ | ||
138 | + m = len(predictions) | ||
139 | + k = min(k, m) | ||
140 | + indices = numpy.argpartition(predictions, -k)[-k:] | ||
141 | + return [(index, predictions[index], labels[index]) for index in indices] | ||
142 | + | ||
143 | + | ||
144 | +class EvaluationMetrics(object): | ||
145 | + """A class to store the evaluation metrics.""" | ||
146 | + | ||
147 | + def __init__(self, num_class, top_k, top_n): | ||
148 | + """Construct an EvaluationMetrics object to store the evaluation metrics. | ||
149 | + | ||
150 | + Args: | ||
151 | + num_class: A positive integer specifying the number of classes. | ||
152 | + top_k: A positive integer specifying how many predictions are considered | ||
153 | + per video. | ||
154 | + top_n: A positive Integer specifying the average precision at n, or None | ||
155 | + to use all provided data points. | ||
156 | + | ||
157 | + Raises: | ||
158 | + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot | ||
159 | + not be constructed. | ||
160 | + """ | ||
161 | + self.sum_hit_at_one = 0.0 | ||
162 | + self.sum_perr = 0.0 | ||
163 | + self.sum_loss = 0.0 | ||
164 | + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( | ||
165 | + num_class, top_n=top_n) | ||
166 | + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
167 | + self.top_k = top_k | ||
168 | + self.num_examples = 0 | ||
169 | + | ||
170 | + def accumulate(self, predictions, labels, loss): | ||
171 | + """Accumulate the metrics calculated locally for this mini-batch. | ||
172 | + | ||
173 | + Args: | ||
174 | + predictions: A numpy matrix containing the outputs of the model. | ||
175 | + Dimensions are 'batch' x 'num_classes'. | ||
176 | + labels: A numpy matrix containing the ground truth labels. Dimensions are | ||
177 | + 'batch' x 'num_classes'. | ||
178 | + loss: A numpy array containing the loss for each sample. | ||
179 | + | ||
180 | + Returns: | ||
181 | + dictionary: A dictionary storing the metrics for the mini-batch. | ||
182 | + | ||
183 | + Raises: | ||
184 | + ValueError: An error occurred when the shape of predictions and actuals | ||
185 | + does not match. | ||
186 | + """ | ||
187 | + batch_size = labels.shape[0] | ||
188 | + mean_hit_at_one = calculate_hit_at_one(predictions, labels) | ||
189 | + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) | ||
190 | + mean_loss = numpy.mean(loss) | ||
191 | + | ||
192 | + # Take the top 20 predictions. | ||
193 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
194 | + predictions, labels, self.top_k) | ||
195 | + self.map_calculator.accumulate(sparse_predictions, sparse_labels, | ||
196 | + num_positives) | ||
197 | + self.global_ap_calculator.accumulate(flatten(sparse_predictions), | ||
198 | + flatten(sparse_labels), | ||
199 | + sum(num_positives)) | ||
200 | + | ||
201 | + self.num_examples += batch_size | ||
202 | + self.sum_hit_at_one += mean_hit_at_one * batch_size | ||
203 | + self.sum_perr += mean_perr * batch_size | ||
204 | + self.sum_loss += mean_loss * batch_size | ||
205 | + | ||
206 | + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} | ||
207 | + | ||
208 | + def get(self): | ||
209 | + """Calculate the evaluation metrics for the whole epoch. | ||
210 | + | ||
211 | + Raises: | ||
212 | + ValueError: If no examples were accumulated. | ||
213 | + | ||
214 | + Returns: | ||
215 | + dictionary: a dictionary storing the evaluation metrics for the epoch. The | ||
216 | + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and | ||
217 | + aps (default nan). | ||
218 | + """ | ||
219 | + if self.num_examples <= 0: | ||
220 | + raise ValueError("total_sample must be positive.") | ||
221 | + avg_hit_at_one = self.sum_hit_at_one / self.num_examples | ||
222 | + avg_perr = self.sum_perr / self.num_examples | ||
223 | + avg_loss = self.sum_loss / self.num_examples | ||
224 | + | ||
225 | + aps = self.map_calculator.peek_map_at_n() | ||
226 | + gap = self.global_ap_calculator.peek_ap_at_n() | ||
227 | + | ||
228 | + epoch_info_dict = { | ||
229 | + "avg_hit_at_one": avg_hit_at_one, | ||
230 | + "avg_perr": avg_perr, | ||
231 | + "avg_loss": avg_loss, | ||
232 | + "aps": aps, | ||
233 | + "gap": gap | ||
234 | + } | ||
235 | + return epoch_info_dict | ||
236 | + | ||
237 | + def clear(self): | ||
238 | + """Clear the evaluation metrics and reset the EvaluationMetrics object.""" | ||
239 | + self.sum_hit_at_one = 0.0 | ||
240 | + self.sum_perr = 0.0 | ||
241 | + self.sum_loss = 0.0 | ||
242 | + self.map_calculator.clear() | ||
243 | + self.global_ap_calculator.clear() | ||
244 | + self.num_examples = 0 |
web/backend/yt8m/export_model.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utilities to export a model for batch prediction.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import tensorflow.contrib.slim as slim | ||
18 | + | ||
19 | +from tensorflow.python.saved_model import builder as saved_model_builder | ||
20 | +from tensorflow.python.saved_model import signature_constants | ||
21 | +from tensorflow.python.saved_model import signature_def_utils | ||
22 | +from tensorflow.python.saved_model import tag_constants | ||
23 | +from tensorflow.python.saved_model import utils as saved_model_utils | ||
24 | + | ||
25 | +_TOP_PREDICTIONS_IN_OUTPUT = 20 | ||
26 | + | ||
27 | + | ||
28 | +class ModelExporter(object): | ||
29 | + | ||
30 | + def __init__(self, frame_features, model, reader): | ||
31 | + self.frame_features = frame_features | ||
32 | + self.model = model | ||
33 | + self.reader = reader | ||
34 | + | ||
35 | + with tf.Graph().as_default() as graph: | ||
36 | + self.inputs, self.outputs = self.build_inputs_and_outputs() | ||
37 | + self.graph = graph | ||
38 | + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) | ||
39 | + | ||
40 | + def export_model(self, model_dir, global_step_val, last_checkpoint): | ||
41 | + """Exports the model so that it can used for batch predictions.""" | ||
42 | + | ||
43 | + with self.graph.as_default(): | ||
44 | + with tf.Session() as session: | ||
45 | + session.run(tf.global_variables_initializer()) | ||
46 | + self.saver.restore(session, last_checkpoint) | ||
47 | + | ||
48 | + signature = signature_def_utils.build_signature_def( | ||
49 | + inputs=self.inputs, | ||
50 | + outputs=self.outputs, | ||
51 | + method_name=signature_constants.PREDICT_METHOD_NAME) | ||
52 | + | ||
53 | + signature_map = { | ||
54 | + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature | ||
55 | + } | ||
56 | + | ||
57 | + model_builder = saved_model_builder.SavedModelBuilder(model_dir) | ||
58 | + model_builder.add_meta_graph_and_variables( | ||
59 | + session, | ||
60 | + tags=[tag_constants.SERVING], | ||
61 | + signature_def_map=signature_map, | ||
62 | + clear_devices=True) | ||
63 | + model_builder.save() | ||
64 | + | ||
65 | + def build_inputs_and_outputs(self): | ||
66 | + if self.frame_features: | ||
67 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
68 | + | ||
69 | + fn = lambda x: self.build_prediction_graph(x) | ||
70 | + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn( | ||
71 | + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32))) | ||
72 | + | ||
73 | + else: | ||
74 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
75 | + | ||
76 | + video_id_output, top_indices_output, top_predictions_output = ( | ||
77 | + self.build_prediction_graph(serialized_examples)) | ||
78 | + | ||
79 | + inputs = { | ||
80 | + "example_bytes": | ||
81 | + saved_model_utils.build_tensor_info(serialized_examples) | ||
82 | + } | ||
83 | + | ||
84 | + outputs = { | ||
85 | + "video_id": | ||
86 | + saved_model_utils.build_tensor_info(video_id_output), | ||
87 | + "class_indexes": | ||
88 | + saved_model_utils.build_tensor_info(top_indices_output), | ||
89 | + "predictions": | ||
90 | + saved_model_utils.build_tensor_info(top_predictions_output) | ||
91 | + } | ||
92 | + | ||
93 | + return inputs, outputs | ||
94 | + | ||
95 | + def build_prediction_graph(self, serialized_examples): | ||
96 | + input_data_dict = ( | ||
97 | + self.reader.prepare_serialized_examples(serialized_examples)) | ||
98 | + video_id = input_data_dict["video_ids"] | ||
99 | + model_input_raw = input_data_dict["video_matrix"] | ||
100 | + labels_batch = input_data_dict["labels"] | ||
101 | + num_frames = input_data_dict["num_frames"] | ||
102 | + | ||
103 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
104 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
105 | + | ||
106 | + with tf.variable_scope("tower"): | ||
107 | + result = self.model.create_model(model_input, | ||
108 | + num_frames=num_frames, | ||
109 | + vocab_size=self.reader.num_classes, | ||
110 | + labels=labels_batch, | ||
111 | + is_training=False) | ||
112 | + | ||
113 | + for variable in slim.get_model_variables(): | ||
114 | + tf.summary.histogram(variable.op.name, variable) | ||
115 | + | ||
116 | + predictions = result["predictions"] | ||
117 | + | ||
118 | + top_predictions, top_indices = tf.nn.top_k(predictions, | ||
119 | + _TOP_PREDICTIONS_IN_OUTPUT) | ||
120 | + return video_id, top_indices, top_predictions |
web/backend/yt8m/export_model_mediapipe.py
0 → 100644
1 | +# Lint as: python3 | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from tensorflow import app | ||
5 | +from tensorflow import flags | ||
6 | + | ||
7 | +FLAGS = flags.FLAGS | ||
8 | + | ||
9 | + | ||
10 | +def main(unused_argv): | ||
11 | + # Get the input tensor names to be replaced. | ||
12 | + tf.reset_default_graph() | ||
13 | + meta_graph_location = FLAGS.checkpoint_file + ".meta" | ||
14 | + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
15 | + | ||
16 | + input_tensor_name = tf.get_collection("input_batch_raw")[0].name | ||
17 | + num_frames_tensor_name = tf.get_collection("num_frames")[0].name | ||
18 | + | ||
19 | + # Create output graph. | ||
20 | + saver = tf.train.Saver() | ||
21 | + tf.reset_default_graph() | ||
22 | + | ||
23 | + input_feature_placeholder = tf.placeholder( | ||
24 | + tf.float32, shape=(None, None, 1152)) | ||
25 | + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) | ||
26 | + | ||
27 | + saver = tf.train.import_meta_graph( | ||
28 | + meta_graph_location, | ||
29 | + input_map={ | ||
30 | + input_tensor_name: input_feature_placeholder, | ||
31 | + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) | ||
32 | + }, | ||
33 | + clear_devices=True) | ||
34 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
35 | + | ||
36 | + with tf.Session() as sess: | ||
37 | + print("restoring variables from " + FLAGS.checkpoint_file) | ||
38 | + saver.restore(sess, FLAGS.checkpoint_file) | ||
39 | + tf.saved_model.simple_save( | ||
40 | + sess, | ||
41 | + FLAGS.output_dir, | ||
42 | + inputs={'rgb_and_audio': input_feature_placeholder, | ||
43 | + 'num_frames': num_frames_placeholder}, | ||
44 | + outputs={'predictions': predictions_tensor}) | ||
45 | + | ||
46 | + # Try running inference. | ||
47 | + predictions = sess.run( | ||
48 | + [predictions_tensor], | ||
49 | + feed_dict={ | ||
50 | + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), | ||
51 | + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) | ||
52 | + print('Test inference:', predictions) | ||
53 | + | ||
54 | + print('Model saved to ', FLAGS.output_dir) | ||
55 | + | ||
56 | + | ||
57 | +if __name__ == '__main__': | ||
58 | + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') | ||
59 | + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') | ||
60 | + app.run(main) |
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
web/backend/yt8m/frame_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of models which operate on variable-length sequences.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import model_utils as utils | ||
18 | +import models | ||
19 | +import tensorflow as tf | ||
20 | +from tensorflow import flags | ||
21 | +import tensorflow.contrib.slim as slim | ||
22 | +import video_level_models | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer("iterations", 30, "Number of frames per batch for DBoF.") | ||
26 | +flags.DEFINE_bool("dbof_add_batch_norm", True, | ||
27 | + "Adds batch normalization to the DBoF model.") | ||
28 | +flags.DEFINE_bool( | ||
29 | + "sample_random_frames", True, | ||
30 | + "If true samples random frames (for frame level models). If false, a random" | ||
31 | + "sequence of frames is sampled instead.") | ||
32 | +flags.DEFINE_integer("dbof_cluster_size", 8192, | ||
33 | + "Number of units in the DBoF cluster layer.") | ||
34 | +flags.DEFINE_integer("dbof_hidden_size", 1024, | ||
35 | + "Number of units in the DBoF hidden layer.") | ||
36 | +flags.DEFINE_string( | ||
37 | + "dbof_pooling_method", "max", | ||
38 | + "The pooling method used in the DBoF cluster layer. " | ||
39 | + "Choices are 'average' and 'max'.") | ||
40 | +flags.DEFINE_string( | ||
41 | + "dbof_activation", "sigmoid", | ||
42 | + "The nonlinear activation method for cluster and hidden dense layer, e.g., " | ||
43 | + "sigmoid, relu6, etc.") | ||
44 | +flags.DEFINE_string( | ||
45 | + "video_level_classifier_model", "MoeModel", | ||
46 | + "Some Frame-Level models can be decomposed into a " | ||
47 | + "generalized pooling operation followed by a " | ||
48 | + "classifier layer") | ||
49 | +flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.") | ||
50 | +flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") | ||
51 | + | ||
52 | + | ||
53 | +class FrameLevelLogisticModel(models.BaseModel): | ||
54 | + """Creates a logistic classifier over the aggregated frame-level features.""" | ||
55 | + | ||
56 | + def create_model(self, model_input, vocab_size, num_frames, **unused_params): | ||
57 | + """See base class. | ||
58 | + | ||
59 | + This class is intended to be an example for implementors of frame level | ||
60 | + models. If you want to train a model over averaged features it is more | ||
61 | + efficient to average them beforehand rather than on the fly. | ||
62 | + | ||
63 | + Args: | ||
64 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
65 | + input features. | ||
66 | + vocab_size: The number of classes in the dataset. | ||
67 | + num_frames: A vector of length 'batch' which indicates the number of | ||
68 | + frames for each video (before padding). | ||
69 | + | ||
70 | + Returns: | ||
71 | + A dictionary with a tensor containing the probability predictions of the | ||
72 | + model in the 'predictions' key. The dimensions of the tensor are | ||
73 | + 'batch_size' x 'num_classes'. | ||
74 | + """ | ||
75 | + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) | ||
76 | + feature_size = model_input.get_shape().as_list()[2] | ||
77 | + | ||
78 | + denominators = tf.reshape(tf.tile(num_frames, [1, feature_size]), | ||
79 | + [-1, feature_size]) | ||
80 | + avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators | ||
81 | + | ||
82 | + output = slim.fully_connected(avg_pooled, | ||
83 | + vocab_size, | ||
84 | + activation_fn=tf.nn.sigmoid, | ||
85 | + weights_regularizer=slim.l2_regularizer(1e-8)) | ||
86 | + return {"predictions": output} | ||
87 | + | ||
88 | + | ||
89 | +class DbofModel(models.BaseModel): | ||
90 | + """Creates a Deep Bag of Frames model. | ||
91 | + | ||
92 | + The model projects the features for each frame into a higher dimensional | ||
93 | + 'clustering' space, pools across frames in that space, and then | ||
94 | + uses a configurable video-level model to classify the now aggregated features. | ||
95 | + | ||
96 | + The model will randomly sample either frames or sequences of frames during | ||
97 | + training to speed up convergence. | ||
98 | + """ | ||
99 | + | ||
100 | + ACT_FN_MAP = { | ||
101 | + "sigmoid": tf.nn.sigmoid, | ||
102 | + "relu6": tf.nn.relu6, | ||
103 | + } | ||
104 | + | ||
105 | + def create_model(self, | ||
106 | + model_input, | ||
107 | + vocab_size, | ||
108 | + num_frames, | ||
109 | + iterations=None, | ||
110 | + add_batch_norm=None, | ||
111 | + sample_random_frames=None, | ||
112 | + cluster_size=None, | ||
113 | + hidden_size=None, | ||
114 | + is_training=True, | ||
115 | + **unused_params): | ||
116 | + """See base class. | ||
117 | + | ||
118 | + Args: | ||
119 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
120 | + input features. | ||
121 | + vocab_size: The number of classes in the dataset. | ||
122 | + num_frames: A vector of length 'batch' which indicates the number of | ||
123 | + frames for each video (before padding). | ||
124 | + iterations: the number of frames to be sampled. | ||
125 | + add_batch_norm: whether to add batch norm during training. | ||
126 | + sample_random_frames: whether to sample random frames or random sequences. | ||
127 | + cluster_size: the output neuron number of the cluster layer. | ||
128 | + hidden_size: the output neuron number of the hidden layer. | ||
129 | + is_training: whether to build the graph in training mode. | ||
130 | + | ||
131 | + Returns: | ||
132 | + A dictionary with a tensor containing the probability predictions of the | ||
133 | + model in the 'predictions' key. The dimensions of the tensor are | ||
134 | + 'batch_size' x 'num_classes'. | ||
135 | + """ | ||
136 | + iterations = iterations or FLAGS.iterations | ||
137 | + add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm | ||
138 | + random_frames = sample_random_frames or FLAGS.sample_random_frames | ||
139 | + cluster_size = cluster_size or FLAGS.dbof_cluster_size | ||
140 | + hidden1_size = hidden_size or FLAGS.dbof_hidden_size | ||
141 | + act_fn = self.ACT_FN_MAP.get(FLAGS.dbof_activation) | ||
142 | + assert act_fn is not None, ("dbof_activation is not valid: %s." % | ||
143 | + FLAGS.dbof_activation) | ||
144 | + | ||
145 | + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) | ||
146 | + if random_frames: | ||
147 | + model_input = utils.SampleRandomFrames(model_input, num_frames, | ||
148 | + iterations) | ||
149 | + else: | ||
150 | + model_input = utils.SampleRandomSequence(model_input, num_frames, | ||
151 | + iterations) | ||
152 | + max_frames = model_input.get_shape().as_list()[1] | ||
153 | + feature_size = model_input.get_shape().as_list()[2] | ||
154 | + reshaped_input = tf.reshape(model_input, [-1, feature_size]) | ||
155 | + tf.compat.v1.summary.histogram("input_hist", reshaped_input) | ||
156 | + | ||
157 | + if add_batch_norm: | ||
158 | + reshaped_input = slim.batch_norm(reshaped_input, | ||
159 | + center=True, | ||
160 | + scale=True, | ||
161 | + is_training=is_training, | ||
162 | + scope="input_bn") | ||
163 | + | ||
164 | + cluster_weights = tf.compat.v1.get_variable( | ||
165 | + "cluster_weights", [feature_size, cluster_size], | ||
166 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
167 | + math.sqrt(feature_size))) | ||
168 | + tf.compat.v1.summary.histogram("cluster_weights", cluster_weights) | ||
169 | + activation = tf.matmul(reshaped_input, cluster_weights) | ||
170 | + if add_batch_norm: | ||
171 | + activation = slim.batch_norm(activation, | ||
172 | + center=True, | ||
173 | + scale=True, | ||
174 | + is_training=is_training, | ||
175 | + scope="cluster_bn") | ||
176 | + else: | ||
177 | + cluster_biases = tf.compat.v1.get_variable( | ||
178 | + "cluster_biases", [cluster_size], | ||
179 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
180 | + math.sqrt(feature_size))) | ||
181 | + tf.compat.v1.summary.histogram("cluster_biases", cluster_biases) | ||
182 | + activation += cluster_biases | ||
183 | + activation = act_fn(activation) | ||
184 | + tf.compat.v1.summary.histogram("cluster_output", activation) | ||
185 | + | ||
186 | + activation = tf.reshape(activation, [-1, max_frames, cluster_size]) | ||
187 | + activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method) | ||
188 | + | ||
189 | + hidden1_weights = tf.compat.v1.get_variable( | ||
190 | + "hidden1_weights", [cluster_size, hidden1_size], | ||
191 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
192 | + math.sqrt(cluster_size))) | ||
193 | + tf.compat.v1.summary.histogram("hidden1_weights", hidden1_weights) | ||
194 | + activation = tf.matmul(activation, hidden1_weights) | ||
195 | + if add_batch_norm: | ||
196 | + activation = slim.batch_norm(activation, | ||
197 | + center=True, | ||
198 | + scale=True, | ||
199 | + is_training=is_training, | ||
200 | + scope="hidden1_bn") | ||
201 | + else: | ||
202 | + hidden1_biases = tf.compat.v1.get_variable( | ||
203 | + "hidden1_biases", [hidden1_size], | ||
204 | + initializer=tf.random_normal_initializer(stddev=0.01)) | ||
205 | + tf.compat.v1.summary.histogram("hidden1_biases", hidden1_biases) | ||
206 | + activation += hidden1_biases | ||
207 | + activation = act_fn(activation) | ||
208 | + tf.compat.v1.summary.histogram("hidden1_output", activation) | ||
209 | + | ||
210 | + aggregated_model = getattr(video_level_models, | ||
211 | + FLAGS.video_level_classifier_model) | ||
212 | + return aggregated_model().create_model(model_input=activation, | ||
213 | + vocab_size=vocab_size, | ||
214 | + **unused_params) | ||
215 | + | ||
216 | + | ||
217 | +class LstmModel(models.BaseModel): | ||
218 | + """Creates a model which uses a stack of LSTMs to represent the video.""" | ||
219 | + | ||
220 | + def create_model(self, model_input, vocab_size, num_frames, **unused_params): | ||
221 | + """See base class. | ||
222 | + | ||
223 | + Args: | ||
224 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
225 | + input features. | ||
226 | + vocab_size: The number of classes in the dataset. | ||
227 | + num_frames: A vector of length 'batch' which indicates the number of | ||
228 | + frames for each video (before padding). | ||
229 | + | ||
230 | + Returns: | ||
231 | + A dictionary with a tensor containing the probability predictions of the | ||
232 | + model in the 'predictions' key. The dimensions of the tensor are | ||
233 | + 'batch_size' x 'num_classes'. | ||
234 | + """ | ||
235 | + lstm_size = FLAGS.lstm_cells | ||
236 | + number_of_layers = FLAGS.lstm_layers | ||
237 | + | ||
238 | + stacked_lstm = tf.contrib.rnn.MultiRNNCell([ | ||
239 | + tf.contrib.rnn.BasicLSTMCell(lstm_size, forget_bias=1.0) | ||
240 | + for _ in range(number_of_layers) | ||
241 | + ]) | ||
242 | + | ||
243 | + _, state = tf.nn.dynamic_rnn(stacked_lstm, | ||
244 | + model_input, | ||
245 | + sequence_length=num_frames, | ||
246 | + dtype=tf.float32) | ||
247 | + | ||
248 | + aggregated_model = getattr(video_level_models, | ||
249 | + FLAGS.video_level_classifier_model) | ||
250 | + | ||
251 | + return aggregated_model().create_model(model_input=state[-1].h, | ||
252 | + vocab_size=vocab_size, | ||
253 | + **unused_params) |
web/backend/yt8m/inference.py
0 → 100644
1 | +# Copyright 2017 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for generating predictions over a set of videos.""" | ||
15 | + | ||
16 | +from __future__ import print_function | ||
17 | + | ||
18 | +import glob | ||
19 | +import heapq | ||
20 | +import json | ||
21 | +import os | ||
22 | +import tarfile | ||
23 | +import tempfile | ||
24 | +import time | ||
25 | +import numpy as np | ||
26 | + | ||
27 | +import readers | ||
28 | +from six.moves import urllib | ||
29 | +import tensorflow as tf | ||
30 | +from tensorflow import app | ||
31 | +from tensorflow import flags | ||
32 | +from tensorflow import gfile | ||
33 | +from tensorflow import logging | ||
34 | +from tensorflow.python.lib.io import file_io | ||
35 | +import utils | ||
36 | + | ||
37 | +FLAGS = flags.FLAGS | ||
38 | + | ||
39 | +if __name__ == "__main__": | ||
40 | + # Input | ||
41 | + flags.DEFINE_string( | ||
42 | + "train_dir", "", "The directory to load the model files from. We assume " | ||
43 | + "that you have already run eval.py onto this, such that " | ||
44 | + "inference_model.* files already exist.") | ||
45 | + flags.DEFINE_string( | ||
46 | + "input_data_pattern", "", | ||
47 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
48 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
49 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
50 | + flags.DEFINE_string( | ||
51 | + "input_model_tgz", "", | ||
52 | + "If given, must be path to a .tgz file that was written " | ||
53 | + "by this binary using flag --output_model_tgz. In this " | ||
54 | + "case, the .tgz file will be untarred to " | ||
55 | + "--untar_model_dir and the model will be used for " | ||
56 | + "inference.") | ||
57 | + flags.DEFINE_string( | ||
58 | + "untar_model_dir", "/tmp/yt8m-model", | ||
59 | + "If --input_model_tgz is given, then this directory will " | ||
60 | + "be created and the contents of the .tgz file will be " | ||
61 | + "untarred here.") | ||
62 | + flags.DEFINE_bool( | ||
63 | + "segment_labels", False, | ||
64 | + "If set, then --input_data_pattern must be frame-level features (but with" | ||
65 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " | ||
66 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
67 | + "read 3D batches VS 4D batches.") | ||
68 | + flags.DEFINE_integer("segment_max_pred", 100000, | ||
69 | + "Limit total number of segment outputs per entity.") | ||
70 | + flags.DEFINE_string( | ||
71 | + "segment_label_ids_file", | ||
72 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | ||
73 | + "The file that contains the segment label ids.") | ||
74 | + | ||
75 | + # Output | ||
76 | + flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | ||
77 | + flags.DEFINE_string( | ||
78 | + "output_model_tgz", "", | ||
79 | + "If given, should be a filename with a .tgz extension, " | ||
80 | + "the model graph and checkpoint will be bundled in this " | ||
81 | + "gzip tar. This file can be uploaded to Kaggle for the " | ||
82 | + "top 10 participants.") | ||
83 | + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | ||
84 | + | ||
85 | + # Other flags. | ||
86 | + flags.DEFINE_integer("batch_size", 512, | ||
87 | + "How many examples to process per batch.") | ||
88 | + flags.DEFINE_integer("num_readers", 1, | ||
89 | + "How many threads to use for reading input files.") | ||
90 | + | ||
91 | + | ||
92 | +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
93 | + """Create an information line the submission file.""" | ||
94 | + batch_size = len(video_ids) | ||
95 | + for video_index in range(batch_size): | ||
96 | + video_prediction = predictions[video_index] | ||
97 | + if whitelisted_cls_mask is not None: | ||
98 | + # Whitelist classes. | ||
99 | + video_prediction *= whitelisted_cls_mask | ||
100 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
101 | + line = [(class_index, predictions[video_index][class_index]) | ||
102 | + for class_index in top_indices] | ||
103 | + line = sorted(line, key=lambda p: -p[1]) | ||
104 | + yield (video_ids[video_index] + "," + | ||
105 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
106 | + "\n").encode("utf8") | ||
107 | + | ||
108 | + | ||
109 | +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | ||
110 | + """Creates the section of the graph which reads the input data. | ||
111 | + | ||
112 | + Args: | ||
113 | + reader: A class which parses the input data. | ||
114 | + data_pattern: A 'glob' style path to the data files. | ||
115 | + batch_size: How many examples to process at a time. | ||
116 | + num_readers: How many I/O threads to use. | ||
117 | + | ||
118 | + Returns: | ||
119 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
120 | + tensor containing the number of frames per video. The exact dimensions | ||
121 | + depend on the reader being used. | ||
122 | + | ||
123 | + Raises: | ||
124 | + IOError: If no files matching the given pattern were found. | ||
125 | + """ | ||
126 | + with tf.name_scope("input"): | ||
127 | + files = gfile.Glob(data_pattern) | ||
128 | + if not files: | ||
129 | + raise IOError("Unable to find input files. data_pattern='" + | ||
130 | + data_pattern + "'") | ||
131 | + logging.info("number of input files: " + str(len(files))) | ||
132 | + filename_queue = tf.train.string_input_producer(files, | ||
133 | + num_epochs=1, | ||
134 | + shuffle=False) | ||
135 | + examples_and_labels = [ | ||
136 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
137 | + ] | ||
138 | + | ||
139 | + input_data_dict = (tf.train.batch_join(examples_and_labels, | ||
140 | + batch_size=batch_size, | ||
141 | + allow_smaller_final_batch=True, | ||
142 | + enqueue_many=True)) | ||
143 | + video_id_batch = input_data_dict["video_ids"] | ||
144 | + video_batch = input_data_dict["video_matrix"] | ||
145 | + num_frames_batch = input_data_dict["num_frames"] | ||
146 | + return video_id_batch, video_batch, num_frames_batch | ||
147 | + | ||
148 | + | ||
149 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
150 | + """Get segment-level inputs from frame-level features.""" | ||
151 | + video_batch_size = batch_video_mtx.shape[0] | ||
152 | + max_frame = batch_video_mtx.shape[1] | ||
153 | + feature_dim = batch_video_mtx.shape[-1] | ||
154 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
155 | + padded_segment_sizes *= segment_size | ||
156 | + segment_mask = ( | ||
157 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
158 | + | ||
159 | + # Segment bags. | ||
160 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
161 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
162 | + (-1, segment_size, feature_dim)) | ||
163 | + | ||
164 | + # Segment num frames. | ||
165 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
166 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
167 | + num_segment_bags = num_segments.reshape((-1)) | ||
168 | + valid_segment_mask = num_segment_bags > 0 | ||
169 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
170 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
171 | + | ||
172 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
173 | + video_idxs = np.tile( | ||
174 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
175 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
176 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
177 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
178 | + | ||
179 | + return { | ||
180 | + "video_batch": segment_frames, | ||
181 | + "num_frames_batch": segment_num_frames, | ||
182 | + "video_segment_ids": video_segment_ids | ||
183 | + } | ||
184 | + | ||
185 | + | ||
186 | +def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
187 | + top_k): | ||
188 | + """Inference function.""" | ||
189 | + with tf.Session(config=tf.ConfigProto( | ||
190 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
191 | + "w+") as out_file: | ||
192 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
193 | + reader, data_pattern, batch_size) | ||
194 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
195 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
196 | + inference_model_name) | ||
197 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
198 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
199 | + meta_graph_location = checkpoint_file + ".meta" | ||
200 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
201 | + | ||
202 | + if FLAGS.output_model_tgz: | ||
203 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
204 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
205 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
206 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
207 | + arcname="model_flags.json") | ||
208 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
209 | + with tf.device("/cpu:0"): | ||
210 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
211 | + clear_devices=True) | ||
212 | + logging.info("restoring variables from " + checkpoint_file) | ||
213 | + saver.restore(sess, checkpoint_file) | ||
214 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
215 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
216 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
217 | + | ||
218 | + # Workaround for num_epochs issue. | ||
219 | + def set_up_init_ops(variables): | ||
220 | + init_op_list = [] | ||
221 | + for variable in list(variables): | ||
222 | + if "train_input" in variable.name: | ||
223 | + init_op_list.append(tf.assign(variable, 1)) | ||
224 | + variables.remove(variable) | ||
225 | + init_op_list.append(tf.variables_initializer(variables)) | ||
226 | + return init_op_list | ||
227 | + | ||
228 | + sess.run( | ||
229 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
230 | + | ||
231 | + coord = tf.train.Coordinator() | ||
232 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
233 | + num_examples_processed = 0 | ||
234 | + start_time = time.time() | ||
235 | + whitelisted_cls_mask = None | ||
236 | + if FLAGS.segment_labels: | ||
237 | + final_out_file = out_file | ||
238 | + out_file = tempfile.NamedTemporaryFile() | ||
239 | + logging.info( | ||
240 | + "Segment temp prediction output will be written to temp file: %s", | ||
241 | + out_file.name) | ||
242 | + if FLAGS.segment_label_ids_file: | ||
243 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
244 | + dtype=np.float32) | ||
245 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
246 | + if segment_label_ids_file.startswith("http"): | ||
247 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
248 | + segment_label_ids_file) | ||
249 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
250 | + segment_label_ids_file) | ||
251 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
252 | + for line in fobj: | ||
253 | + try: | ||
254 | + cls_id = int(line) | ||
255 | + whitelisted_cls_mask[cls_id] = 1. | ||
256 | + except ValueError: | ||
257 | + # Simply skip the non-integer line. | ||
258 | + continue | ||
259 | + | ||
260 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
261 | + | ||
262 | + try: | ||
263 | + while not coord.should_stop(): | ||
264 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
265 | + [video_id_batch, video_batch, num_frames_batch]) | ||
266 | + if FLAGS.segment_labels: | ||
267 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
268 | + video_segment_ids = results["video_segment_ids"] | ||
269 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
270 | + video_id_batch_val = np.array([ | ||
271 | + "%s:%d" % (x.decode("utf8"), y) | ||
272 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
273 | + ]) | ||
274 | + video_batch_val = results["video_batch"] | ||
275 | + num_frames_batch_val = results["num_frames_batch"] | ||
276 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
277 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
278 | + "with correct segment_labels settings.") | ||
279 | + | ||
280 | + predictions_val, = sess.run([predictions_tensor], | ||
281 | + feed_dict={ | ||
282 | + input_tensor: video_batch_val, | ||
283 | + num_frames_tensor: num_frames_batch_val | ||
284 | + }) | ||
285 | + now = time.time() | ||
286 | + num_examples_processed += len(video_batch_val) | ||
287 | + elapsed_time = now - start_time | ||
288 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
289 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
290 | + " examples/sec: %.2f" % | ||
291 | + (num_examples_processed / elapsed_time)) | ||
292 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
293 | + whitelisted_cls_mask): | ||
294 | + out_file.write(line) | ||
295 | + out_file.flush() | ||
296 | + | ||
297 | + except tf.errors.OutOfRangeError: | ||
298 | + logging.info("Done with inference. The output file was written to " + | ||
299 | + out_file.name) | ||
300 | + finally: | ||
301 | + coord.request_stop() | ||
302 | + | ||
303 | + if FLAGS.segment_labels: | ||
304 | + # Re-read the file and do heap sort. | ||
305 | + # Create multiple heaps. | ||
306 | + logging.info("Post-processing segment predictions...") | ||
307 | + heaps = {} | ||
308 | + out_file.seek(0, 0) | ||
309 | + for line in out_file: | ||
310 | + segment_id, preds = line.decode("utf8").split(",") | ||
311 | + if segment_id == "VideoId": | ||
312 | + # Skip the headline. | ||
313 | + continue | ||
314 | + preds = preds.split(" ") | ||
315 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
316 | + pred_cls_scores = [ | ||
317 | + float(preds[idx]) for idx in range(1, len(preds), 2) | ||
318 | + ] | ||
319 | + for cls, score in zip(pred_cls_ids, pred_cls_scores): | ||
320 | + if not whitelisted_cls_mask[cls]: | ||
321 | + # Skip non-whitelisted classes. | ||
322 | + continue | ||
323 | + if cls not in heaps: | ||
324 | + heaps[cls] = [] | ||
325 | + if len(heaps[cls]) >= FLAGS.segment_max_pred: | ||
326 | + heapq.heappushpop(heaps[cls], (score, segment_id)) | ||
327 | + else: | ||
328 | + heapq.heappush(heaps[cls], (score, segment_id)) | ||
329 | + logging.info("Writing sorted segment predictions to: %s", | ||
330 | + final_out_file.name) | ||
331 | + final_out_file.write("Class,Segments\n") | ||
332 | + for cls, cls_heap in heaps.items(): | ||
333 | + cls_heap.sort(key=lambda x: x[0], reverse=True) | ||
334 | + final_out_file.write("%d,%s\n" % | ||
335 | + (cls, " ".join([x[1] for x in cls_heap]))) | ||
336 | + final_out_file.close() | ||
337 | + | ||
338 | + out_file.close() | ||
339 | + | ||
340 | + coord.join(threads) | ||
341 | + sess.close() | ||
342 | + | ||
343 | + | ||
344 | +def main(unused_argv): | ||
345 | + logging.set_verbosity(tf.logging.INFO) | ||
346 | + if FLAGS.input_model_tgz: | ||
347 | + if FLAGS.train_dir: | ||
348 | + raise ValueError("You cannot supply --train_dir if supplying " | ||
349 | + "--input_model_tgz") | ||
350 | + # Untar. | ||
351 | + if not os.path.exists(FLAGS.untar_model_dir): | ||
352 | + os.makedirs(FLAGS.untar_model_dir) | ||
353 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | ||
354 | + FLAGS.train_dir = FLAGS.untar_model_dir | ||
355 | + | ||
356 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
357 | + if not file_io.file_exists(flags_dict_file): | ||
358 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ||
359 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | ||
360 | + | ||
361 | + # convert feature_names and feature_sizes to lists of values | ||
362 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
363 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
364 | + | ||
365 | + if flags_dict["frame_features"]: | ||
366 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
367 | + feature_sizes=feature_sizes) | ||
368 | + else: | ||
369 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
370 | + feature_sizes=feature_sizes) | ||
371 | + | ||
372 | + if not FLAGS.output_file: | ||
373 | + raise ValueError("'output_file' was not specified. " | ||
374 | + "Unable to continue with inference.") | ||
375 | + | ||
376 | + if not FLAGS.input_data_pattern: | ||
377 | + raise ValueError("'input_data_pattern' was not specified. " | ||
378 | + "Unable to continue with inference.") | ||
379 | + | ||
380 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | ||
381 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | ||
382 | + | ||
383 | + | ||
384 | +if __name__ == "__main__": | ||
385 | + app.run() |
web/backend/yt8m/inference_per_segment.py
0 → 100644
1 | +# Copyright 2017 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for generating predictions over a set of videos.""" | ||
15 | + | ||
16 | +from __future__ import print_function | ||
17 | + | ||
18 | +import glob | ||
19 | +import heapq | ||
20 | +import json | ||
21 | +import os | ||
22 | +import tarfile | ||
23 | +import tempfile | ||
24 | +import time | ||
25 | +import numpy as np | ||
26 | + | ||
27 | +import readers | ||
28 | +from six.moves import urllib | ||
29 | +import tensorflow as tf | ||
30 | +from tensorflow import app | ||
31 | +from tensorflow import flags | ||
32 | +from tensorflow import gfile | ||
33 | +from tensorflow import logging | ||
34 | +from tensorflow.python.lib.io import file_io | ||
35 | +import utils | ||
36 | +from collections import Counter | ||
37 | +import operator | ||
38 | + | ||
39 | +FLAGS = flags.FLAGS | ||
40 | + | ||
41 | +if __name__ == "__main__": | ||
42 | + # Input | ||
43 | + flags.DEFINE_string( | ||
44 | + "train_dir", "", "The directory to load the model files from. We assume " | ||
45 | + "that you have already run eval.py onto this, such that " | ||
46 | + "inference_model.* files already exist.") | ||
47 | + flags.DEFINE_string( | ||
48 | + "input_data_pattern", "", | ||
49 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
50 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
51 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
52 | + flags.DEFINE_string( | ||
53 | + "input_model_tgz", "", | ||
54 | + "If given, must be path to a .tgz file that was written " | ||
55 | + "by this binary using flag --output_model_tgz. In this " | ||
56 | + "case, the .tgz file will be untarred to " | ||
57 | + "--untar_model_dir and the model will be used for " | ||
58 | + "inference.") | ||
59 | + flags.DEFINE_string( | ||
60 | + "untar_model_dir", "/tmp/yt8m-model", | ||
61 | + "If --input_model_tgz is given, then this directory will " | ||
62 | + "be created and the contents of the .tgz file will be " | ||
63 | + "untarred here.") | ||
64 | + flags.DEFINE_bool( | ||
65 | + "segment_labels", False, | ||
66 | + "If set, then --input_data_pattern must be frame-level features (but with" | ||
67 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " | ||
68 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
69 | + "read 3D batches VS 4D batches.") | ||
70 | + flags.DEFINE_integer("segment_max_pred", 100000, | ||
71 | + "Limit total number of segment outputs per entity.") | ||
72 | + flags.DEFINE_string( | ||
73 | + "segment_label_ids_file", | ||
74 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | ||
75 | + "The file that contains the segment label ids.") | ||
76 | + | ||
77 | + # Output | ||
78 | + flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | ||
79 | + flags.DEFINE_string( | ||
80 | + "output_model_tgz", "", | ||
81 | + "If given, should be a filename with a .tgz extension, " | ||
82 | + "the model graph and checkpoint will be bundled in this " | ||
83 | + "gzip tar. This file can be uploaded to Kaggle for the " | ||
84 | + "top 10 participants.") | ||
85 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") | ||
86 | + | ||
87 | + # Other flags. | ||
88 | + flags.DEFINE_integer("batch_size", 512, | ||
89 | + "How many examples to process per batch.") | ||
90 | + flags.DEFINE_integer("num_readers", 1, | ||
91 | + "How many threads to use for reading input files.") | ||
92 | + | ||
93 | + | ||
94 | +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
95 | + """Create an information line the submission file.""" | ||
96 | + batch_size = len(video_ids) | ||
97 | + for video_index in range(batch_size): | ||
98 | + video_prediction = predictions[video_index] | ||
99 | + if whitelisted_cls_mask is not None: | ||
100 | + # Whitelist classes. | ||
101 | + video_prediction *= whitelisted_cls_mask | ||
102 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
103 | + line = [(class_index, predictions[video_index][class_index]) | ||
104 | + for class_index in top_indices] | ||
105 | + line = sorted(line, key=lambda p: -p[1]) | ||
106 | + yield (video_ids[video_index] + "," + | ||
107 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
108 | + "\n").encode("utf8") | ||
109 | + | ||
110 | + | ||
111 | +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | ||
112 | + """Creates the section of the graph which reads the input data. | ||
113 | + | ||
114 | + Args: | ||
115 | + reader: A class which parses the input data. | ||
116 | + data_pattern: A 'glob' style path to the data files. | ||
117 | + batch_size: How many examples to process at a time. | ||
118 | + num_readers: How many I/O threads to use. | ||
119 | + | ||
120 | + Returns: | ||
121 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
122 | + tensor containing the number of frames per video. The exact dimensions | ||
123 | + depend on the reader being used. | ||
124 | + | ||
125 | + Raises: | ||
126 | + IOError: If no files matching the given pattern were found. | ||
127 | + """ | ||
128 | + with tf.name_scope("input"): | ||
129 | + files = gfile.Glob(data_pattern) | ||
130 | + if not files: | ||
131 | + raise IOError("Unable to find input files. data_pattern='" + | ||
132 | + data_pattern + "'") | ||
133 | + logging.info("number of input files: " + str(len(files))) | ||
134 | + filename_queue = tf.train.string_input_producer(files, | ||
135 | + num_epochs=1, | ||
136 | + shuffle=False) | ||
137 | + examples_and_labels = [ | ||
138 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
139 | + ] | ||
140 | + | ||
141 | + input_data_dict = (tf.train.batch_join(examples_and_labels, | ||
142 | + batch_size=batch_size, | ||
143 | + allow_smaller_final_batch=True, | ||
144 | + enqueue_many=True)) | ||
145 | + video_id_batch = input_data_dict["video_ids"] | ||
146 | + video_batch = input_data_dict["video_matrix"] | ||
147 | + num_frames_batch = input_data_dict["num_frames"] | ||
148 | + return video_id_batch, video_batch, num_frames_batch | ||
149 | + | ||
150 | + | ||
151 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
152 | + """Get segment-level inputs from frame-level features.""" | ||
153 | + video_batch_size = batch_video_mtx.shape[0] | ||
154 | + max_frame = batch_video_mtx.shape[1] | ||
155 | + feature_dim = batch_video_mtx.shape[-1] | ||
156 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
157 | + padded_segment_sizes *= segment_size | ||
158 | + segment_mask = ( | ||
159 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
160 | + | ||
161 | + # Segment bags. | ||
162 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
163 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
164 | + (-1, segment_size, feature_dim)) | ||
165 | + | ||
166 | + # Segment num frames. | ||
167 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
168 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
169 | + num_segment_bags = num_segments.reshape((-1)) | ||
170 | + valid_segment_mask = num_segment_bags > 0 | ||
171 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
172 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
173 | + | ||
174 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
175 | + video_idxs = np.tile( | ||
176 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
177 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
178 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
179 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
180 | + | ||
181 | + return { | ||
182 | + "video_batch": segment_frames, | ||
183 | + "num_frames_batch": segment_num_frames, | ||
184 | + "video_segment_ids": video_segment_ids | ||
185 | + } | ||
186 | + | ||
187 | + | ||
188 | +def normalize_tag(tag): | ||
189 | + if isinstance(tag, str): | ||
190 | + new_tag = tag.lower().replace('[^a-zA-Z]', ' ') | ||
191 | + if new_tag.find(" (") != -1: | ||
192 | + new_tag = new_tag[:new_tag.find(" (")] | ||
193 | + new_tag = new_tag.replace(" ", "-") | ||
194 | + return new_tag | ||
195 | + else: | ||
196 | + return tag | ||
197 | + | ||
198 | + | ||
199 | +def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
200 | + top_k): | ||
201 | + """Inference function.""" | ||
202 | + with tf.Session(config=tf.ConfigProto( | ||
203 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
204 | + "w+") as out_file: | ||
205 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
206 | + reader, data_pattern, batch_size) | ||
207 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
208 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
209 | + inference_model_name) | ||
210 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
211 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
212 | + meta_graph_location = checkpoint_file + ".meta" | ||
213 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
214 | + | ||
215 | + if FLAGS.output_model_tgz: | ||
216 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
217 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
218 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
219 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
220 | + arcname="model_flags.json") | ||
221 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
222 | + with tf.device("/cpu:0"): | ||
223 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
224 | + clear_devices=True) | ||
225 | + logging.info("restoring variables from " + checkpoint_file) | ||
226 | + saver.restore(sess, checkpoint_file) | ||
227 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
228 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
229 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
230 | + | ||
231 | + # Workaround for num_epochs issue. | ||
232 | + def set_up_init_ops(variables): | ||
233 | + init_op_list = [] | ||
234 | + for variable in list(variables): | ||
235 | + if "train_input" in variable.name: | ||
236 | + init_op_list.append(tf.assign(variable, 1)) | ||
237 | + variables.remove(variable) | ||
238 | + init_op_list.append(tf.variables_initializer(variables)) | ||
239 | + return init_op_list | ||
240 | + | ||
241 | + sess.run( | ||
242 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
243 | + | ||
244 | + coord = tf.train.Coordinator() | ||
245 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
246 | + num_examples_processed = 0 | ||
247 | + start_time = time.time() | ||
248 | + whitelisted_cls_mask = None | ||
249 | + if FLAGS.segment_labels: | ||
250 | + final_out_file = out_file | ||
251 | + out_file = tempfile.NamedTemporaryFile() | ||
252 | + logging.info( | ||
253 | + "Segment temp prediction output will be written to temp file: %s", | ||
254 | + out_file.name) | ||
255 | + if FLAGS.segment_label_ids_file: | ||
256 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
257 | + dtype=np.float32) | ||
258 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
259 | + if segment_label_ids_file.startswith("http"): | ||
260 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
261 | + segment_label_ids_file) | ||
262 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
263 | + segment_label_ids_file) | ||
264 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
265 | + for line in fobj: | ||
266 | + try: | ||
267 | + cls_id = int(line) | ||
268 | + whitelisted_cls_mask[cls_id] = 1. | ||
269 | + except ValueError: | ||
270 | + # Simply skip the non-integer line. | ||
271 | + continue | ||
272 | + | ||
273 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
274 | + | ||
275 | + #========================================= | ||
276 | + #open vocab csv file and store to dictionary | ||
277 | + #========================================= | ||
278 | + voca_dict = {} | ||
279 | + vocabs = open("./vocabulary.csv", 'r') | ||
280 | + while True: | ||
281 | + line = vocabs.readline() | ||
282 | + if not line: break | ||
283 | + vocab_dict_item = line.split(",") | ||
284 | + if vocab_dict_item[0] != "Index": | ||
285 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
286 | + vocabs.close() | ||
287 | + try: | ||
288 | + while not coord.should_stop(): | ||
289 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
290 | + [video_id_batch, video_batch, num_frames_batch]) | ||
291 | + if FLAGS.segment_labels: | ||
292 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
293 | + video_segment_ids = results["video_segment_ids"] | ||
294 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
295 | + video_id_batch_val = np.array([ | ||
296 | + "%s:%d" % (x.decode("utf8"), y) | ||
297 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
298 | + ]) | ||
299 | + video_batch_val = results["video_batch"] | ||
300 | + num_frames_batch_val = results["num_frames_batch"] | ||
301 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
302 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
303 | + "with correct segment_labels settings.") | ||
304 | + | ||
305 | + predictions_val, = sess.run([predictions_tensor], | ||
306 | + feed_dict={ | ||
307 | + input_tensor: video_batch_val, | ||
308 | + num_frames_tensor: num_frames_batch_val | ||
309 | + }) | ||
310 | + now = time.time() | ||
311 | + num_examples_processed += len(video_batch_val) | ||
312 | + elapsed_time = now - start_time | ||
313 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
314 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
315 | + " examples/sec: %.2f" % | ||
316 | + (num_examples_processed / elapsed_time)) | ||
317 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
318 | + whitelisted_cls_mask): | ||
319 | + out_file.write(line) | ||
320 | + out_file.flush() | ||
321 | + | ||
322 | + except tf.errors.OutOfRangeError: | ||
323 | + logging.info("Done with inference. The output file was written to " + | ||
324 | + out_file.name) | ||
325 | + finally: | ||
326 | + coord.request_stop() | ||
327 | + | ||
328 | + if FLAGS.segment_labels: | ||
329 | + # Re-read the file and do heap sort. | ||
330 | + # Create multiple heaps. | ||
331 | + logging.info("Post-processing segment predictions...") | ||
332 | + segment_id_list = [] | ||
333 | + segment_classes = [] | ||
334 | + cls_result_arr = [] | ||
335 | + cls_score_dict = {} | ||
336 | + out_file.seek(0, 0) | ||
337 | + old_seg_name = '0000' | ||
338 | + counter = 0 | ||
339 | + for line in out_file: | ||
340 | + counter += 1 | ||
341 | + if counter / 5000 == 0: | ||
342 | + print(counter, " processed") | ||
343 | + segment_id, preds = line.decode("utf8").split(",") | ||
344 | + if segment_id == "VideoId": | ||
345 | + # Skip the headline. | ||
346 | + continue | ||
347 | + | ||
348 | + preds = preds.split(" ") | ||
349 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
350 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
351 | + #======================================= | ||
352 | + segment_id = str(segment_id.split(":")[0]) | ||
353 | + if segment_id not in segment_id_list: | ||
354 | + segment_id_list.append(str(segment_id)) | ||
355 | + segment_classes.append("") | ||
356 | + | ||
357 | + index = segment_id_list.index(segment_id) | ||
358 | + | ||
359 | + if old_seg_name != segment_id: | ||
360 | + cls_score_dict[segment_id] = {} | ||
361 | + old_seg_name = segment_id | ||
362 | + | ||
363 | + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids: | ||
364 | + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment | ||
365 | + if pred_cls_ids[classes] in cls_score_dict[segment_id]: | ||
366 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes] | ||
367 | + else: | ||
368 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes] | ||
369 | + | ||
370 | + for segs,item in zip(segment_id_list,segment_classes): | ||
371 | + print('====== R E C O R D ======') | ||
372 | + cls_arr = item.split(" ")[:-1] | ||
373 | + | ||
374 | + cls_arr = list(map(int,cls_arr)) | ||
375 | + cls_arr = sorted(cls_arr) #클래스별로 정렬 | ||
376 | + | ||
377 | + result_string = "" | ||
378 | + | ||
379 | + temp = cls_score_dict[segs] | ||
380 | + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 | ||
381 | + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) | ||
382 | + #for item in temp: | ||
383 | + for itemIndex in range(0, top_k): | ||
384 | + # Normalize tag name | ||
385 | + segment_tag = str(voca_dict[str(temp[itemIndex][0])]) | ||
386 | + normalized_tag = normalize_tag(segment_tag) | ||
387 | + result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," | ||
388 | + | ||
389 | + cls_result_arr.append(result_string[:-1]) | ||
390 | + logging.info(segs + " : " + result_string[:-1]) | ||
391 | + #======================================= | ||
392 | + final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") | ||
393 | + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | ||
394 | + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) | ||
395 | + final_out_file.close() | ||
396 | + | ||
397 | + out_file.close() | ||
398 | + | ||
399 | + coord.join(threads) | ||
400 | + sess.close() | ||
401 | + | ||
402 | +def main(unused_argv): | ||
403 | + logging.set_verbosity(tf.logging.INFO) | ||
404 | + if FLAGS.input_model_tgz: | ||
405 | + if FLAGS.train_dir: | ||
406 | + raise ValueError("You cannot supply --train_dir if supplying " | ||
407 | + "--input_model_tgz") | ||
408 | + # Untar. | ||
409 | + if not os.path.exists(FLAGS.untar_model_dir): | ||
410 | + os.makedirs(FLAGS.untar_model_dir) | ||
411 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | ||
412 | + FLAGS.train_dir = FLAGS.untar_model_dir | ||
413 | + | ||
414 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
415 | + if not file_io.file_exists(flags_dict_file): | ||
416 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ||
417 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | ||
418 | + | ||
419 | + # convert feature_names and feature_sizes to lists of values | ||
420 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
421 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
422 | + | ||
423 | + if flags_dict["frame_features"]: | ||
424 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
425 | + feature_sizes=feature_sizes) | ||
426 | + else: | ||
427 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
428 | + feature_sizes=feature_sizes) | ||
429 | + | ||
430 | + if not FLAGS.output_file: | ||
431 | + raise ValueError("'output_file' was not specified. " | ||
432 | + "Unable to continue with inference.") | ||
433 | + | ||
434 | + if not FLAGS.input_data_pattern: | ||
435 | + raise ValueError("'input_data_pattern' was not specified. " | ||
436 | + "Unable to continue with inference.") | ||
437 | + | ||
438 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | ||
439 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | ||
440 | + | ||
441 | + | ||
442 | +if __name__ == "__main__": | ||
443 | + app.run() |
web/backend/yt8m/losses.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides definitions for non-regularized training or test losses.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | + | ||
18 | + | ||
19 | +class BaseLoss(object): | ||
20 | + """Inherit from this class when implementing new losses.""" | ||
21 | + | ||
22 | + def calculate_loss(self, unused_predictions, unused_labels, **unused_params): | ||
23 | + """Calculates the average loss of the examples in a mini-batch. | ||
24 | + | ||
25 | + Args: | ||
26 | + unused_predictions: a 2-d tensor storing the prediction scores, in which | ||
27 | + each row represents a sample in the mini-batch and each column | ||
28 | + represents a class. | ||
29 | + unused_labels: a 2-d tensor storing the labels, which has the same shape | ||
30 | + as the unused_predictions. The labels must be in the range of 0 and 1. | ||
31 | + unused_params: loss specific parameters. | ||
32 | + | ||
33 | + Returns: | ||
34 | + A scalar loss tensor. | ||
35 | + """ | ||
36 | + raise NotImplementedError() | ||
37 | + | ||
38 | + | ||
39 | +class CrossEntropyLoss(BaseLoss): | ||
40 | + """Calculate the cross entropy loss between the predictions and labels.""" | ||
41 | + | ||
42 | + def calculate_loss(self, | ||
43 | + predictions, | ||
44 | + labels, | ||
45 | + label_weights=None, | ||
46 | + **unused_params): | ||
47 | + with tf.name_scope("loss_xent"): | ||
48 | + epsilon = 1e-5 | ||
49 | + float_labels = tf.cast(labels, tf.float32) | ||
50 | + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + ( | ||
51 | + 1 - float_labels) * tf.math.log(1 - predictions + epsilon) | ||
52 | + cross_entropy_loss = tf.negative(cross_entropy_loss) | ||
53 | + if label_weights is not None: | ||
54 | + cross_entropy_loss *= label_weights | ||
55 | + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) | ||
56 | + | ||
57 | + | ||
58 | +class HingeLoss(BaseLoss): | ||
59 | + """Calculate the hinge loss between the predictions and labels. | ||
60 | + | ||
61 | + Note the subgradient is used in the backpropagation, and thus the optimization | ||
62 | + may converge slower. The predictions trained by the hinge loss are between -1 | ||
63 | + and +1. | ||
64 | + """ | ||
65 | + | ||
66 | + def calculate_loss(self, predictions, labels, b=1.0, **unused_params): | ||
67 | + with tf.name_scope("loss_hinge"): | ||
68 | + float_labels = tf.cast(labels, tf.float32) | ||
69 | + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) | ||
70 | + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) | ||
71 | + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) | ||
72 | + hinge_loss = tf.maximum( | ||
73 | + all_zeros, | ||
74 | + tf.scalar_mul(b, all_ones) - sign_labels * predictions) | ||
75 | + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) | ||
76 | + | ||
77 | + | ||
78 | +class SoftmaxLoss(BaseLoss): | ||
79 | + """Calculate the softmax loss between the predictions and labels. | ||
80 | + | ||
81 | + The function calculates the loss in the following way: first we feed the | ||
82 | + predictions to the softmax activation function and then we calculate | ||
83 | + the minus linear dot product between the logged softmax activations and the | ||
84 | + normalized ground truth label. | ||
85 | + | ||
86 | + It is an extension to the one-hot label. It allows for more than one positive | ||
87 | + labels for each sample. | ||
88 | + """ | ||
89 | + | ||
90 | + def calculate_loss(self, predictions, labels, **unused_params): | ||
91 | + with tf.name_scope("loss_softmax"): | ||
92 | + epsilon = 10e-8 | ||
93 | + float_labels = tf.cast(labels, tf.float32) | ||
94 | + # l1 normalization (labels are no less than 0) | ||
95 | + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True), | ||
96 | + epsilon) | ||
97 | + norm_float_labels = tf.div(float_labels, label_rowsum) | ||
98 | + softmax_outputs = tf.nn.softmax(predictions) | ||
99 | + softmax_loss = tf.negative( | ||
100 | + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)), | ||
101 | + 1)) | ||
102 | + return tf.reduce_mean(softmax_loss) |
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate the mean average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating mean average precision | ||
17 | +for an entire list or the top-n ranked items. | ||
18 | + | ||
19 | +Example usages: | ||
20 | +We first call the function accumulate many times to process parts of the ranked | ||
21 | +list. After processing all the parts, we call peek_map_at_n | ||
22 | +to calculate the mean average precision. | ||
23 | + | ||
24 | +``` | ||
25 | +import random | ||
26 | + | ||
27 | +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) | ||
28 | +a = np.array([[random.choice([0, 1]) for _ in xrange(50)] | ||
29 | + for _ in xrange(1000)]) | ||
30 | + | ||
31 | +# mean average precision for 50 classes. | ||
32 | +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( | ||
33 | + num_class=50) | ||
34 | +calculator.accumulate(p, a) | ||
35 | +aps = calculator.peek_map_at_n() | ||
36 | +``` | ||
37 | +""" | ||
38 | + | ||
39 | +import average_precision_calculator | ||
40 | + | ||
41 | + | ||
42 | +class MeanAveragePrecisionCalculator(object): | ||
43 | + """This class is to calculate mean average precision.""" | ||
44 | + | ||
45 | + def __init__(self, num_class, filter_empty_classes=True, top_n=None): | ||
46 | + """Construct a calculator to calculate the (macro) average precision. | ||
47 | + | ||
48 | + Args: | ||
49 | + num_class: A positive Integer specifying the number of classes. | ||
50 | + filter_empty_classes: whether to filter classes without any positives. | ||
51 | + top_n: A positive Integer specifying the average precision at n, or None | ||
52 | + to use all provided data points. | ||
53 | + | ||
54 | + Raises: | ||
55 | + ValueError: An error occurred when num_class is not a positive integer; | ||
56 | + or the top_n_array is not a list of positive integers. | ||
57 | + """ | ||
58 | + if not isinstance(num_class, int) or num_class <= 1: | ||
59 | + raise ValueError("num_class must be a positive integer.") | ||
60 | + | ||
61 | + self._ap_calculators = [] # member of AveragePrecisionCalculator | ||
62 | + self._num_class = num_class # total number of classes | ||
63 | + self._filter_empty_classes = filter_empty_classes | ||
64 | + for _ in range(num_class): | ||
65 | + self._ap_calculators.append( | ||
66 | + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n)) | ||
67 | + | ||
68 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
69 | + """Accumulate the predictions and their ground truth labels. | ||
70 | + | ||
71 | + Args: | ||
72 | + predictions: A list of lists storing the prediction scores. The outer | ||
73 | + dimension corresponds to classes. | ||
74 | + actuals: A list of lists storing the ground truth labels. The dimensions | ||
75 | + should correspond to the predictions input. Any value larger than 0 will | ||
76 | + be treated as positives, otherwise as negatives. | ||
77 | + num_positives: If provided, it is a list of numbers representing the | ||
78 | + number of true positives for each class. If not provided, the number of | ||
79 | + true positives will be inferred from the 'actuals' array. | ||
80 | + | ||
81 | + Raises: | ||
82 | + ValueError: An error occurred when the shape of predictions and actuals | ||
83 | + does not match. | ||
84 | + """ | ||
85 | + if not num_positives: | ||
86 | + num_positives = [None for i in range(self._num_class)] | ||
87 | + | ||
88 | + calculators = self._ap_calculators | ||
89 | + for i in range(self._num_class): | ||
90 | + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) | ||
91 | + | ||
92 | + def clear(self): | ||
93 | + for calculator in self._ap_calculators: | ||
94 | + calculator.clear() | ||
95 | + | ||
96 | + def is_empty(self): | ||
97 | + return ([calculator.heap_size for calculator in self._ap_calculators | ||
98 | + ] == [0 for _ in range(self._num_class)]) | ||
99 | + | ||
100 | + def peek_map_at_n(self): | ||
101 | + """Peek the non-interpolated mean average precision at n. | ||
102 | + | ||
103 | + Returns: | ||
104 | + An array of non-interpolated average precision at n (default 0) for each | ||
105 | + class. | ||
106 | + """ | ||
107 | + aps = [] | ||
108 | + for i in range(self._num_class): | ||
109 | + if (not self._filter_empty_classes or | ||
110 | + self._ap_calculators[i].num_accumulated_positives > 0): | ||
111 | + ap = self._ap_calculators[i].peek_ap_at_n() | ||
112 | + aps.append(ap) | ||
113 | + return aps |
web/backend/yt8m/model_utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for model construction.""" | ||
15 | +import numpy | ||
16 | +import tensorflow as tf | ||
17 | +from tensorflow import logging | ||
18 | +from tensorflow import flags | ||
19 | +import tensorflow.contrib.slim as slim | ||
20 | + | ||
21 | + | ||
22 | +def SampleRandomSequence(model_input, num_frames, num_samples): | ||
23 | + """Samples a random sequence of frames of size num_samples. | ||
24 | + | ||
25 | + Args: | ||
26 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
27 | + num_frames: A tensor of size batch_size x 1 | ||
28 | + num_samples: A scalar | ||
29 | + | ||
30 | + Returns: | ||
31 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
32 | + """ | ||
33 | + | ||
34 | + batch_size = tf.shape(model_input)[0] | ||
35 | + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0), | ||
36 | + [batch_size, 1]) | ||
37 | + max_start_frame_index = tf.maximum(num_frames - num_samples, 0) | ||
38 | + start_frame_index = tf.cast( | ||
39 | + tf.multiply(tf.random_uniform([batch_size, 1]), | ||
40 | + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) | ||
41 | + frame_index = tf.minimum(start_frame_index + frame_index_offset, | ||
42 | + tf.cast(num_frames - 1, tf.int32)) | ||
43 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
44 | + [1, num_samples]) | ||
45 | + index = tf.stack([batch_index, frame_index], 2) | ||
46 | + return tf.gather_nd(model_input, index) | ||
47 | + | ||
48 | + | ||
49 | +def SampleRandomFrames(model_input, num_frames, num_samples): | ||
50 | + """Samples a random set of frames of size num_samples. | ||
51 | + | ||
52 | + Args: | ||
53 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
54 | + num_frames: A tensor of size batch_size x 1 | ||
55 | + num_samples: A scalar | ||
56 | + | ||
57 | + Returns: | ||
58 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
59 | + """ | ||
60 | + batch_size = tf.shape(model_input)[0] | ||
61 | + frame_index = tf.cast( | ||
62 | + tf.multiply(tf.random_uniform([batch_size, num_samples]), | ||
63 | + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), | ||
64 | + tf.int32) | ||
65 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
66 | + [1, num_samples]) | ||
67 | + index = tf.stack([batch_index, frame_index], 2) | ||
68 | + return tf.gather_nd(model_input, index) | ||
69 | + | ||
70 | + | ||
71 | +def FramePooling(frames, method, **unused_params): | ||
72 | + """Pools over the frames of a video. | ||
73 | + | ||
74 | + Args: | ||
75 | + frames: A tensor with shape [batch_size, num_frames, feature_size]. | ||
76 | + method: "average", "max", "attention", or "none". | ||
77 | + | ||
78 | + Returns: | ||
79 | + A tensor with shape [batch_size, feature_size] for average, max, or | ||
80 | + attention pooling. A tensor with shape [batch_size*num_frames, feature_size] | ||
81 | + for none pooling. | ||
82 | + | ||
83 | + Raises: | ||
84 | + ValueError: if method is other than "average", "max", "attention", or | ||
85 | + "none". | ||
86 | + """ | ||
87 | + if method == "average": | ||
88 | + return tf.reduce_mean(frames, 1) | ||
89 | + elif method == "max": | ||
90 | + return tf.reduce_max(frames, 1) | ||
91 | + elif method == "none": | ||
92 | + feature_size = frames.shape_as_list()[2] | ||
93 | + return tf.reshape(frames, [-1, feature_size]) | ||
94 | + else: | ||
95 | + raise ValueError("Unrecognized pooling method: %s" % method) |
web/backend/yt8m/models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains the base class for models.""" | ||
15 | + | ||
16 | + | ||
17 | +class BaseModel(object): | ||
18 | + """Inherit from this class when implementing new models.""" | ||
19 | + | ||
20 | + def create_model(self, unused_model_input, **unused_params): | ||
21 | + raise NotImplementedError() |
web/backend/yt8m/readers.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides readers configured for different datasets.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import utils | ||
18 | + | ||
19 | + | ||
20 | +def resize_axis(tensor, axis, new_size, fill_value=0): | ||
21 | + """Truncates or pads a tensor to new_size on on a given axis. | ||
22 | + | ||
23 | + Truncate or extend tensor such that tensor.shape[axis] == new_size. If the | ||
24 | + size increases, the padding will be performed at the end, using fill_value. | ||
25 | + | ||
26 | + Args: | ||
27 | + tensor: The tensor to be resized. | ||
28 | + axis: An integer representing the dimension to be sliced. | ||
29 | + new_size: An integer or 0d tensor representing the new value for | ||
30 | + tensor.shape[axis]. | ||
31 | + fill_value: Value to use to fill any new entries in the tensor. Will be cast | ||
32 | + to the type of tensor. | ||
33 | + | ||
34 | + Returns: | ||
35 | + The resized tensor. | ||
36 | + """ | ||
37 | + tensor = tf.convert_to_tensor(tensor) | ||
38 | + shape = tf.unstack(tf.shape(tensor)) | ||
39 | + | ||
40 | + pad_shape = shape[:] | ||
41 | + pad_shape[axis] = tf.maximum(0, new_size - shape[axis]) | ||
42 | + | ||
43 | + shape[axis] = tf.minimum(shape[axis], new_size) | ||
44 | + shape = tf.stack(shape) | ||
45 | + | ||
46 | + resized = tf.concat([ | ||
47 | + tf.slice(tensor, tf.zeros_like(shape), shape), | ||
48 | + tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype)) | ||
49 | + ], axis) | ||
50 | + | ||
51 | + # Update shape. | ||
52 | + new_shape = tensor.get_shape().as_list() # A copy is being made. | ||
53 | + new_shape[axis] = new_size | ||
54 | + resized.set_shape(new_shape) | ||
55 | + return resized | ||
56 | + | ||
57 | + | ||
58 | +class BaseReader(object): | ||
59 | + """Inherit from this class when implementing new readers.""" | ||
60 | + | ||
61 | + def prepare_reader(self, unused_filename_queue): | ||
62 | + """Create a thread for generating prediction and label tensors.""" | ||
63 | + raise NotImplementedError() | ||
64 | + | ||
65 | + | ||
66 | +class YT8MAggregatedFeatureReader(BaseReader): | ||
67 | + """Reads TFRecords of pre-aggregated Examples. | ||
68 | + | ||
69 | + The TFRecords must contain Examples with a sparse int64 'labels' feature and | ||
70 | + a fixed length float32 feature, obtained from the features in 'feature_name'. | ||
71 | + The float features are assumed to be an average of dequantized values. | ||
72 | + """ | ||
73 | + | ||
74 | + def __init__( # pylint: disable=dangerous-default-value | ||
75 | + self, | ||
76 | + num_classes=3862, | ||
77 | + feature_sizes=[1024, 128], | ||
78 | + feature_names=["mean_rgb", "mean_audio"]): | ||
79 | + """Construct a YT8MAggregatedFeatureReader. | ||
80 | + | ||
81 | + Args: | ||
82 | + num_classes: a positive integer for the number of classes. | ||
83 | + feature_sizes: positive integer(s) for the feature dimensions as a list. | ||
84 | + feature_names: the feature name(s) in the tensorflow record as a list. | ||
85 | + """ | ||
86 | + | ||
87 | + assert len(feature_names) == len(feature_sizes), ( | ||
88 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
89 | + len(feature_names), len(feature_sizes))) | ||
90 | + | ||
91 | + self.num_classes = num_classes | ||
92 | + self.feature_sizes = feature_sizes | ||
93 | + self.feature_names = feature_names | ||
94 | + | ||
95 | + def prepare_reader(self, filename_queue, batch_size=1024): | ||
96 | + """Creates a single reader thread for pre-aggregated YouTube 8M Examples. | ||
97 | + | ||
98 | + Args: | ||
99 | + filename_queue: A tensorflow queue of filename locations. | ||
100 | + batch_size: batch size used for feature output. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A dict of video indexes, features, labels, and frame counts. | ||
104 | + """ | ||
105 | + reader = tf.TFRecordReader() | ||
106 | + _, serialized_examples = reader.read_up_to(filename_queue, batch_size) | ||
107 | + | ||
108 | + tf.add_to_collection("serialized_examples", serialized_examples) | ||
109 | + return self.prepare_serialized_examples(serialized_examples) | ||
110 | + | ||
111 | + def prepare_serialized_examples(self, serialized_examples): | ||
112 | + """Parse a single video-level TF Example.""" | ||
113 | + # set the mapping from the fields to data types in the proto | ||
114 | + num_features = len(self.feature_names) | ||
115 | + assert num_features > 0, "self.feature_names is empty!" | ||
116 | + assert len(self.feature_names) == len(self.feature_sizes), \ | ||
117 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
118 | + len(self.feature_names), len(self.feature_sizes)) | ||
119 | + | ||
120 | + feature_map = { | ||
121 | + "id": tf.io.FixedLenFeature([], tf.string), | ||
122 | + "labels": tf.io.VarLenFeature(tf.int64) | ||
123 | + } | ||
124 | + for feature_index in range(num_features): | ||
125 | + feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( | ||
126 | + [self.feature_sizes[feature_index]], tf.float32) | ||
127 | + | ||
128 | + features = tf.parse_example(serialized_examples, features=feature_map) | ||
129 | + labels = tf.sparse_to_indicator(features["labels"], self.num_classes) | ||
130 | + labels.set_shape([None, self.num_classes]) | ||
131 | + concatenated_features = tf.concat( | ||
132 | + [features[feature_name] for feature_name in self.feature_names], 1) | ||
133 | + | ||
134 | + output_dict = { | ||
135 | + "video_ids": features["id"], | ||
136 | + "video_matrix": concatenated_features, | ||
137 | + "labels": labels, | ||
138 | + "num_frames": tf.ones([tf.shape(serialized_examples)[0]]) | ||
139 | + } | ||
140 | + | ||
141 | + return output_dict | ||
142 | + | ||
143 | + | ||
144 | +class YT8MFrameFeatureReader(BaseReader): | ||
145 | + """Reads TFRecords of SequenceExamples. | ||
146 | + | ||
147 | + The TFRecords must contain SequenceExamples with the sparse in64 'labels' | ||
148 | + context feature and a fixed length byte-quantized feature vector, obtained | ||
149 | + from the features in 'feature_names'. The quantized features will be mapped | ||
150 | + back into a range between min_quantized_value and max_quantized_value. | ||
151 | + """ | ||
152 | + | ||
153 | + def __init__( # pylint: disable=dangerous-default-value | ||
154 | + self, | ||
155 | + num_classes=3862, | ||
156 | + feature_sizes=[1024, 128], | ||
157 | + feature_names=["rgb", "audio"], | ||
158 | + max_frames=300, | ||
159 | + segment_labels=False, | ||
160 | + segment_size=5): | ||
161 | + """Construct a YT8MFrameFeatureReader. | ||
162 | + | ||
163 | + Args: | ||
164 | + num_classes: a positive integer for the number of classes. | ||
165 | + feature_sizes: positive integer(s) for the feature dimensions as a list. | ||
166 | + feature_names: the feature name(s) in the tensorflow record as a list. | ||
167 | + max_frames: the maximum number of frames to process. | ||
168 | + segment_labels: if we read segment labels instead. | ||
169 | + segment_size: the segment_size used for reading segments. | ||
170 | + """ | ||
171 | + | ||
172 | + assert len(feature_names) == len(feature_sizes), ( | ||
173 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
174 | + len(feature_names), len(feature_sizes))) | ||
175 | + | ||
176 | + self.num_classes = num_classes | ||
177 | + self.feature_sizes = feature_sizes | ||
178 | + self.feature_names = feature_names | ||
179 | + self.max_frames = max_frames | ||
180 | + self.segment_labels = segment_labels | ||
181 | + self.segment_size = segment_size | ||
182 | + | ||
183 | + def get_video_matrix(self, features, feature_size, max_frames, | ||
184 | + max_quantized_value, min_quantized_value): | ||
185 | + """Decodes features from an input string and quantizes it. | ||
186 | + | ||
187 | + Args: | ||
188 | + features: raw feature values | ||
189 | + feature_size: length of each frame feature vector | ||
190 | + max_frames: number of frames (rows) in the output feature_matrix | ||
191 | + max_quantized_value: the maximum of the quantized value. | ||
192 | + min_quantized_value: the minimum of the quantized value. | ||
193 | + | ||
194 | + Returns: | ||
195 | + feature_matrix: matrix of all frame-features | ||
196 | + num_frames: number of frames in the sequence | ||
197 | + """ | ||
198 | + decoded_features = tf.reshape( | ||
199 | + tf.cast(tf.decode_raw(features, tf.uint8), tf.float32), | ||
200 | + [-1, feature_size]) | ||
201 | + | ||
202 | + num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames) | ||
203 | + feature_matrix = utils.Dequantize(decoded_features, max_quantized_value, | ||
204 | + min_quantized_value) | ||
205 | + feature_matrix = resize_axis(feature_matrix, 0, max_frames) | ||
206 | + return feature_matrix, num_frames | ||
207 | + | ||
208 | + def prepare_reader(self, | ||
209 | + filename_queue, | ||
210 | + max_quantized_value=2, | ||
211 | + min_quantized_value=-2): | ||
212 | + """Creates a single reader thread for YouTube8M SequenceExamples. | ||
213 | + | ||
214 | + Args: | ||
215 | + filename_queue: A tensorflow queue of filename locations. | ||
216 | + max_quantized_value: the maximum of the quantized value. | ||
217 | + min_quantized_value: the minimum of the quantized value. | ||
218 | + | ||
219 | + Returns: | ||
220 | + A dict of video indexes, video features, labels, and frame counts. | ||
221 | + """ | ||
222 | + reader = tf.TFRecordReader() | ||
223 | + _, serialized_example = reader.read(filename_queue) | ||
224 | + | ||
225 | + return self.prepare_serialized_examples(serialized_example, | ||
226 | + max_quantized_value, | ||
227 | + min_quantized_value) | ||
228 | + | ||
229 | + def prepare_serialized_examples(self, | ||
230 | + serialized_example, | ||
231 | + max_quantized_value=2, | ||
232 | + min_quantized_value=-2): | ||
233 | + """Parse single serialized SequenceExample from the TFRecords.""" | ||
234 | + | ||
235 | + # Read/parse frame/segment-level labels. | ||
236 | + context_features = { | ||
237 | + "id": tf.io.FixedLenFeature([], tf.string), | ||
238 | + } | ||
239 | + if self.segment_labels: | ||
240 | + context_features.update({ | ||
241 | + # There is no need to read end-time given we always assume the segment | ||
242 | + # has the same size. | ||
243 | + "segment_labels": tf.io.VarLenFeature(tf.int64), | ||
244 | + "segment_start_times": tf.io.VarLenFeature(tf.int64), | ||
245 | + "segment_scores": tf.io.VarLenFeature(tf.float32) | ||
246 | + }) | ||
247 | + else: | ||
248 | + context_features.update({"labels": tf.io.VarLenFeature(tf.int64)}) | ||
249 | + sequence_features = { | ||
250 | + feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string) | ||
251 | + for feature_name in self.feature_names | ||
252 | + } | ||
253 | + contexts, features = tf.io.parse_single_sequence_example( | ||
254 | + serialized_example, | ||
255 | + context_features=context_features, | ||
256 | + sequence_features=sequence_features) | ||
257 | + | ||
258 | + # loads (potentially) different types of features and concatenates them | ||
259 | + num_features = len(self.feature_names) | ||
260 | + assert num_features > 0, "No feature selected: feature_names is empty!" | ||
261 | + | ||
262 | + assert len(self.feature_names) == len(self.feature_sizes), ( | ||
263 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
264 | + len(self.feature_names), len(self.feature_sizes))) | ||
265 | + | ||
266 | + num_frames = -1 # the number of frames in the video | ||
267 | + feature_matrices = [None] * num_features # an array of different features | ||
268 | + for feature_index in range(num_features): | ||
269 | + feature_matrix, num_frames_in_this_feature = self.get_video_matrix( | ||
270 | + features[self.feature_names[feature_index]], | ||
271 | + self.feature_sizes[feature_index], self.max_frames, | ||
272 | + max_quantized_value, min_quantized_value) | ||
273 | + if num_frames == -1: | ||
274 | + num_frames = num_frames_in_this_feature | ||
275 | + | ||
276 | + feature_matrices[feature_index] = feature_matrix | ||
277 | + | ||
278 | + # cap the number of frames at self.max_frames | ||
279 | + num_frames = tf.minimum(num_frames, self.max_frames) | ||
280 | + | ||
281 | + # concatenate different features | ||
282 | + video_matrix = tf.concat(feature_matrices, 1) | ||
283 | + | ||
284 | + # Partition frame-level feature matrix to segment-level feature matrix. | ||
285 | + if self.segment_labels: | ||
286 | + start_times = contexts["segment_start_times"].values | ||
287 | + # Here we assume all the segments that started at the same start time has | ||
288 | + # the same segment_size. | ||
289 | + uniq_start_times, seg_idxs = tf.unique(start_times, | ||
290 | + out_idx=tf.dtypes.int64) | ||
291 | + # TODO(zhengxu): Ensure the segment_sizes are all same. | ||
292 | + segment_size = self.segment_size | ||
293 | + # Range gather matrix, e.g., [[0,1,2],[1,2,3]] for segment_size == 3. | ||
294 | + range_mtx = tf.expand_dims(uniq_start_times, axis=-1) + tf.expand_dims( | ||
295 | + tf.range(0, segment_size, dtype=tf.int64), axis=0) | ||
296 | + # Shape: [num_segment, segment_size, feature_dim]. | ||
297 | + batch_video_matrix = tf.gather_nd(video_matrix, | ||
298 | + tf.expand_dims(range_mtx, axis=-1)) | ||
299 | + num_segment = tf.shape(batch_video_matrix)[0] | ||
300 | + batch_video_ids = tf.reshape(tf.tile([contexts["id"]], [num_segment]), | ||
301 | + (num_segment,)) | ||
302 | + batch_frames = tf.reshape(tf.tile([segment_size], [num_segment]), | ||
303 | + (num_segment,)) | ||
304 | + | ||
305 | + # For segment labels, all labels are not exhausively rated. So we only | ||
306 | + # evaluate the rated labels. | ||
307 | + | ||
308 | + # Label indices for each segment, shape: [num_segment, 2]. | ||
309 | + label_indices = tf.stack([seg_idxs, contexts["segment_labels"].values], | ||
310 | + axis=-1) | ||
311 | + label_values = contexts["segment_scores"].values | ||
312 | + sparse_labels = tf.sparse.SparseTensor(label_indices, label_values, | ||
313 | + (num_segment, self.num_classes)) | ||
314 | + batch_labels = tf.sparse.to_dense(sparse_labels, validate_indices=False) | ||
315 | + | ||
316 | + sparse_label_weights = tf.sparse.SparseTensor( | ||
317 | + label_indices, tf.ones_like(label_values, dtype=tf.float32), | ||
318 | + (num_segment, self.num_classes)) | ||
319 | + batch_label_weights = tf.sparse.to_dense(sparse_label_weights, | ||
320 | + validate_indices=False) | ||
321 | + else: | ||
322 | + # Process video-level labels. | ||
323 | + label_indices = contexts["labels"].values | ||
324 | + sparse_labels = tf.sparse.SparseTensor( | ||
325 | + tf.expand_dims(label_indices, axis=-1), | ||
326 | + tf.ones_like(contexts["labels"].values, dtype=tf.bool), | ||
327 | + (self.num_classes,)) | ||
328 | + labels = tf.sparse.to_dense(sparse_labels, | ||
329 | + default_value=False, | ||
330 | + validate_indices=False) | ||
331 | + # convert to batch format. | ||
332 | + batch_video_ids = tf.expand_dims(contexts["id"], 0) | ||
333 | + batch_video_matrix = tf.expand_dims(video_matrix, 0) | ||
334 | + batch_labels = tf.expand_dims(labels, 0) | ||
335 | + batch_frames = tf.expand_dims(num_frames, 0) | ||
336 | + batch_label_weights = None | ||
337 | + | ||
338 | + output_dict = { | ||
339 | + "video_ids": batch_video_ids, | ||
340 | + "video_matrix": batch_video_matrix, | ||
341 | + "labels": batch_labels, | ||
342 | + "num_frames": batch_frames, | ||
343 | + } | ||
344 | + if batch_label_weights is not None: | ||
345 | + output_dict["label_weights"] = batch_label_weights | ||
346 | + | ||
347 | + return output_dict |
web/backend/yt8m/segment_eval_inference.py
0 → 100644
1 | +"""Eval mAP@N metric from inference file.""" | ||
2 | + | ||
3 | +from __future__ import absolute_import | ||
4 | +from __future__ import division | ||
5 | +from __future__ import print_function | ||
6 | + | ||
7 | +from absl import app | ||
8 | +from absl import flags | ||
9 | + | ||
10 | +import mean_average_precision_calculator as map_calculator | ||
11 | +import numpy as np | ||
12 | +import tensorflow as tf | ||
13 | + | ||
14 | +flags.DEFINE_string( | ||
15 | + "eval_data_pattern", "", | ||
16 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
17 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
18 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
19 | +flags.DEFINE_string( | ||
20 | + "label_cache", "", | ||
21 | + "The path for the label cache file. Leave blank for not to cache.") | ||
22 | +flags.DEFINE_string("submission_file", "", | ||
23 | + "The segment submission file generated by inference.py.") | ||
24 | +flags.DEFINE_integer( | ||
25 | + "top_n", 0, | ||
26 | + "The cap per-class predictions by a maximum of N. Use 0 for not capping.") | ||
27 | + | ||
28 | +FLAGS = flags.FLAGS | ||
29 | + | ||
30 | + | ||
31 | +class Labels(object): | ||
32 | + """Contains the class to hold label objects. | ||
33 | + | ||
34 | + This class can serialize and de-serialize the groundtruths. | ||
35 | + The ground truth is in a mapping from (segment_id, class_id) -> label_score. | ||
36 | + """ | ||
37 | + | ||
38 | + def __init__(self, labels): | ||
39 | + """__init__ method.""" | ||
40 | + self._labels = labels | ||
41 | + | ||
42 | + @property | ||
43 | + def labels(self): | ||
44 | + """Return the ground truth mapping. See class docstring for details.""" | ||
45 | + return self._labels | ||
46 | + | ||
47 | + def to_file(self, file_name): | ||
48 | + """Materialize the GT mapping to file.""" | ||
49 | + with tf.gfile.Open(file_name, "w") as fobj: | ||
50 | + for k, v in self._labels.items(): | ||
51 | + seg_id, label = k | ||
52 | + line = "%s,%s,%s\n" % (seg_id, label, v) | ||
53 | + fobj.write(line) | ||
54 | + | ||
55 | + @classmethod | ||
56 | + def from_file(cls, file_name): | ||
57 | + """Read the GT mapping from cached file.""" | ||
58 | + labels = {} | ||
59 | + with tf.gfile.Open(file_name) as fobj: | ||
60 | + for line in fobj: | ||
61 | + line = line.strip().strip("\n") | ||
62 | + seg_id, label, score = line.split(",") | ||
63 | + labels[(seg_id, int(label))] = float(score) | ||
64 | + return cls(labels) | ||
65 | + | ||
66 | + | ||
67 | +def read_labels(data_pattern, cache_path=""): | ||
68 | + """Read labels from TFRecords. | ||
69 | + | ||
70 | + Args: | ||
71 | + data_pattern: the data pattern to the TFRecords. | ||
72 | + cache_path: the cache path for the label file. | ||
73 | + | ||
74 | + Returns: | ||
75 | + a Labels object. | ||
76 | + """ | ||
77 | + if cache_path: | ||
78 | + if tf.gfile.Exists(cache_path): | ||
79 | + tf.logging.info("Reading cached labels from %s..." % cache_path) | ||
80 | + return Labels.from_file(cache_path) | ||
81 | + tf.enable_eager_execution() | ||
82 | + data_paths = tf.gfile.Glob(data_pattern) | ||
83 | + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) | ||
84 | + context_features = { | ||
85 | + "id": tf.FixedLenFeature([], tf.string), | ||
86 | + "segment_labels": tf.VarLenFeature(tf.int64), | ||
87 | + "segment_start_times": tf.VarLenFeature(tf.int64), | ||
88 | + "segment_scores": tf.VarLenFeature(tf.float32) | ||
89 | + } | ||
90 | + | ||
91 | + def _parse_se_func(sequence_example): | ||
92 | + return tf.parse_single_sequence_example(sequence_example, | ||
93 | + context_features=context_features) | ||
94 | + | ||
95 | + ds = ds.map(_parse_se_func) | ||
96 | + rated_labels = {} | ||
97 | + tf.logging.info("Reading labels from TFRecords...") | ||
98 | + last_batch = 0 | ||
99 | + batch_size = 5000 | ||
100 | + for cxt_feature_val, _ in ds: | ||
101 | + video_id = cxt_feature_val["id"].numpy() | ||
102 | + segment_labels = cxt_feature_val["segment_labels"].values.numpy() | ||
103 | + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy() | ||
104 | + segment_scores = cxt_feature_val["segment_scores"].values.numpy() | ||
105 | + for label, start_time, score in zip(segment_labels, segment_start_times, | ||
106 | + segment_scores): | ||
107 | + rated_labels[("%s:%d" % (video_id, start_time), label)] = score | ||
108 | + batch_id = len(rated_labels) // batch_size | ||
109 | + if batch_id != last_batch: | ||
110 | + tf.logging.info("%d examples processed.", len(rated_labels)) | ||
111 | + last_batch = batch_id | ||
112 | + tf.logging.info("Finish reading labels from TFRecords...") | ||
113 | + labels_obj = Labels(rated_labels) | ||
114 | + if cache_path: | ||
115 | + tf.logging.info("Caching labels to %s..." % cache_path) | ||
116 | + labels_obj.to_file(cache_path) | ||
117 | + return labels_obj | ||
118 | + | ||
119 | + | ||
120 | +def read_segment_predictions(file_path, labels, top_n=None): | ||
121 | + """Read segement predictions. | ||
122 | + | ||
123 | + Args: | ||
124 | + file_path: the submission file path. | ||
125 | + labels: a Labels object containing the eval labels. | ||
126 | + top_n: the per-class class capping. | ||
127 | + | ||
128 | + Returns: | ||
129 | + a segment prediction list for each classes. | ||
130 | + """ | ||
131 | + cls_preds = {} # A label_id to pred list mapping. | ||
132 | + with tf.gfile.Open(file_path) as fobj: | ||
133 | + tf.logging.info("Reading predictions from %s..." % file_path) | ||
134 | + for line in fobj: | ||
135 | + label_id, pred_ids_val = line.split(",") | ||
136 | + pred_ids = pred_ids_val.split(" ") | ||
137 | + if top_n: | ||
138 | + pred_ids = pred_ids[:top_n] | ||
139 | + pred_ids = [ | ||
140 | + pred_id for pred_id in pred_ids | ||
141 | + if (pred_id, int(label_id)) in labels.labels | ||
142 | + ] | ||
143 | + cls_preds[int(label_id)] = pred_ids | ||
144 | + if len(cls_preds) % 50 == 0: | ||
145 | + tf.logging.info("Processed %d classes..." % len(cls_preds)) | ||
146 | + tf.logging.info("Finish reading predictions.") | ||
147 | + return cls_preds | ||
148 | + | ||
149 | + | ||
150 | +def main(unused_argv): | ||
151 | + """Entry function of the script.""" | ||
152 | + if not FLAGS.submission_file: | ||
153 | + raise ValueError("You must input submission file.") | ||
154 | + eval_labels = read_labels(FLAGS.eval_data_pattern, | ||
155 | + cache_path=FLAGS.label_cache) | ||
156 | + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels)) | ||
157 | + positive_counter = {} | ||
158 | + for k, v in eval_labels.labels.items(): | ||
159 | + _, label_id = k | ||
160 | + if v > 0: | ||
161 | + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1 | ||
162 | + | ||
163 | + seg_preds = read_segment_predictions(FLAGS.submission_file, | ||
164 | + eval_labels, | ||
165 | + top_n=FLAGS.top_n) | ||
166 | + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds)) | ||
167 | + seg_labels = [] | ||
168 | + seg_scored_preds = [] | ||
169 | + num_positives = [] | ||
170 | + for label_id in sorted(seg_preds): | ||
171 | + class_preds = seg_preds[label_id] | ||
172 | + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds] | ||
173 | + seg_labels.append(seg_label) | ||
174 | + seg_scored_pred = [] | ||
175 | + if class_preds: | ||
176 | + seg_scored_pred = [ | ||
177 | + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) | ||
178 | + ] | ||
179 | + seg_scored_preds.append(seg_scored_pred) | ||
180 | + num_positives.append(positive_counter[label_id]) | ||
181 | + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) | ||
182 | + map_at_n = np.mean(map_cal.peek_map_at_n()) | ||
183 | + tf.logging.info("Num classes: %d | mAP@%d: %.6f" % | ||
184 | + (len(seg_preds), FLAGS.top_n, map_at_n)) | ||
185 | + | ||
186 | + | ||
187 | +if __name__ == "__main__": | ||
188 | + app.run(main) |
web/backend/yt8m/segment_label_ids.csv
0 → 100644
1 | +Index | ||
2 | +3 | ||
3 | +7 | ||
4 | +8 | ||
5 | +11 | ||
6 | +12 | ||
7 | +17 | ||
8 | +18 | ||
9 | +19 | ||
10 | +21 | ||
11 | +22 | ||
12 | +23 | ||
13 | +28 | ||
14 | +31 | ||
15 | +30 | ||
16 | +32 | ||
17 | +33 | ||
18 | +34 | ||
19 | +41 | ||
20 | +43 | ||
21 | +45 | ||
22 | +46 | ||
23 | +48 | ||
24 | +53 | ||
25 | +54 | ||
26 | +52 | ||
27 | +55 | ||
28 | +58 | ||
29 | +59 | ||
30 | +60 | ||
31 | +61 | ||
32 | +65 | ||
33 | +68 | ||
34 | +73 | ||
35 | +71 | ||
36 | +74 | ||
37 | +75 | ||
38 | +76 | ||
39 | +77 | ||
40 | +80 | ||
41 | +83 | ||
42 | +90 | ||
43 | +88 | ||
44 | +89 | ||
45 | +92 | ||
46 | +95 | ||
47 | +100 | ||
48 | +101 | ||
49 | +99 | ||
50 | +104 | ||
51 | +105 | ||
52 | +109 | ||
53 | +113 | ||
54 | +112 | ||
55 | +115 | ||
56 | +116 | ||
57 | +118 | ||
58 | +120 | ||
59 | +121 | ||
60 | +123 | ||
61 | +125 | ||
62 | +127 | ||
63 | +131 | ||
64 | +128 | ||
65 | +129 | ||
66 | +130 | ||
67 | +137 | ||
68 | +141 | ||
69 | +143 | ||
70 | +145 | ||
71 | +148 | ||
72 | +152 | ||
73 | +151 | ||
74 | +156 | ||
75 | +155 | ||
76 | +158 | ||
77 | +160 | ||
78 | +164 | ||
79 | +163 | ||
80 | +169 | ||
81 | +170 | ||
82 | +172 | ||
83 | +171 | ||
84 | +173 | ||
85 | +174 | ||
86 | +175 | ||
87 | +176 | ||
88 | +178 | ||
89 | +182 | ||
90 | +184 | ||
91 | +186 | ||
92 | +188 | ||
93 | +187 | ||
94 | +192 | ||
95 | +191 | ||
96 | +190 | ||
97 | +194 | ||
98 | +197 | ||
99 | +196 | ||
100 | +198 | ||
101 | +201 | ||
102 | +202 | ||
103 | +200 | ||
104 | +199 | ||
105 | +205 | ||
106 | +204 | ||
107 | +209 | ||
108 | +207 | ||
109 | +206 | ||
110 | +210 | ||
111 | +213 | ||
112 | +214 | ||
113 | +220 | ||
114 | +218 | ||
115 | +217 | ||
116 | +226 | ||
117 | +227 | ||
118 | +231 | ||
119 | +232 | ||
120 | +229 | ||
121 | +233 | ||
122 | +235 | ||
123 | +237 | ||
124 | +244 | ||
125 | +240 | ||
126 | +249 | ||
127 | +246 | ||
128 | +248 | ||
129 | +239 | ||
130 | +250 | ||
131 | +245 | ||
132 | +255 | ||
133 | +253 | ||
134 | +256 | ||
135 | +261 | ||
136 | +259 | ||
137 | +263 | ||
138 | +262 | ||
139 | +266 | ||
140 | +267 | ||
141 | +268 | ||
142 | +269 | ||
143 | +271 | ||
144 | +276 | ||
145 | +273 | ||
146 | +277 | ||
147 | +274 | ||
148 | +278 | ||
149 | +279 | ||
150 | +280 | ||
151 | +288 | ||
152 | +291 | ||
153 | +295 | ||
154 | +294 | ||
155 | +293 | ||
156 | +297 | ||
157 | +296 | ||
158 | +300 | ||
159 | +299 | ||
160 | +303 | ||
161 | +302 | ||
162 | +304 | ||
163 | +305 | ||
164 | +313 | ||
165 | +307 | ||
166 | +311 | ||
167 | +310 | ||
168 | +312 | ||
169 | +316 | ||
170 | +318 | ||
171 | +321 | ||
172 | +322 | ||
173 | +331 | ||
174 | +333 | ||
175 | +329 | ||
176 | +330 | ||
177 | +334 | ||
178 | +343 | ||
179 | +349 | ||
180 | +340 | ||
181 | +344 | ||
182 | +348 | ||
183 | +358 | ||
184 | +347 | ||
185 | +359 | ||
186 | +355 | ||
187 | +361 | ||
188 | +360 | ||
189 | +364 | ||
190 | +365 | ||
191 | +368 | ||
192 | +369 | ||
193 | +366 | ||
194 | +370 | ||
195 | +374 | ||
196 | +380 | ||
197 | +373 | ||
198 | +385 | ||
199 | +384 | ||
200 | +388 | ||
201 | +389 | ||
202 | +382 | ||
203 | +393 | ||
204 | +381 | ||
205 | +390 | ||
206 | +394 | ||
207 | +399 | ||
208 | +397 | ||
209 | +396 | ||
210 | +402 | ||
211 | +400 | ||
212 | +398 | ||
213 | +401 | ||
214 | +405 | ||
215 | +406 | ||
216 | +410 | ||
217 | +408 | ||
218 | +416 | ||
219 | +415 | ||
220 | +419 | ||
221 | +422 | ||
222 | +414 | ||
223 | +421 | ||
224 | +424 | ||
225 | +429 | ||
226 | +418 | ||
227 | +427 | ||
228 | +434 | ||
229 | +428 | ||
230 | +435 | ||
231 | +430 | ||
232 | +441 | ||
233 | +439 | ||
234 | +437 | ||
235 | +443 | ||
236 | +440 | ||
237 | +442 | ||
238 | +445 | ||
239 | +446 | ||
240 | +448 | ||
241 | +454 | ||
242 | +444 | ||
243 | +453 | ||
244 | +455 | ||
245 | +451 | ||
246 | +452 | ||
247 | +458 | ||
248 | +460 | ||
249 | +465 | ||
250 | +457 | ||
251 | +463 | ||
252 | +462 | ||
253 | +461 | ||
254 | +464 | ||
255 | +469 | ||
256 | +468 | ||
257 | +472 | ||
258 | +473 | ||
259 | +471 | ||
260 | +475 | ||
261 | +474 | ||
262 | +477 | ||
263 | +485 | ||
264 | +491 | ||
265 | +488 | ||
266 | +482 | ||
267 | +490 | ||
268 | +496 | ||
269 | +494 | ||
270 | +483 | ||
271 | +495 | ||
272 | +493 | ||
273 | +507 | ||
274 | +501 | ||
275 | +499 | ||
276 | +503 | ||
277 | +498 | ||
278 | +514 | ||
279 | +504 | ||
280 | +502 | ||
281 | +506 | ||
282 | +508 | ||
283 | +511 | ||
284 | +527 | ||
285 | +526 | ||
286 | +532 | ||
287 | +513 | ||
288 | +519 | ||
289 | +525 | ||
290 | +518 | ||
291 | +528 | ||
292 | +522 | ||
293 | +523 | ||
294 | +535 | ||
295 | +539 | ||
296 | +540 | ||
297 | +533 | ||
298 | +521 | ||
299 | +541 | ||
300 | +547 | ||
301 | +550 | ||
302 | +544 | ||
303 | +549 | ||
304 | +551 | ||
305 | +554 | ||
306 | +543 | ||
307 | +548 | ||
308 | +557 | ||
309 | +560 | ||
310 | +552 | ||
311 | +559 | ||
312 | +563 | ||
313 | +565 | ||
314 | +567 | ||
315 | +555 | ||
316 | +576 | ||
317 | +568 | ||
318 | +564 | ||
319 | +573 | ||
320 | +581 | ||
321 | +580 | ||
322 | +572 | ||
323 | +571 | ||
324 | +584 | ||
325 | +590 | ||
326 | +585 | ||
327 | +587 | ||
328 | +588 | ||
329 | +592 | ||
330 | +598 | ||
331 | +597 | ||
332 | +599 | ||
333 | +603 | ||
334 | +600 | ||
335 | +604 | ||
336 | +605 | ||
337 | +614 | ||
338 | +602 | ||
339 | +610 | ||
340 | +608 | ||
341 | +611 | ||
342 | +612 | ||
343 | +613 | ||
344 | +617 | ||
345 | +620 | ||
346 | +607 | ||
347 | +624 | ||
348 | +627 | ||
349 | +625 | ||
350 | +631 | ||
351 | +629 | ||
352 | +638 | ||
353 | +632 | ||
354 | +634 | ||
355 | +644 | ||
356 | +641 | ||
357 | +642 | ||
358 | +646 | ||
359 | +652 | ||
360 | +647 | ||
361 | +637 | ||
362 | +661 | ||
363 | +635 | ||
364 | +658 | ||
365 | +648 | ||
366 | +663 | ||
367 | +668 | ||
368 | +664 | ||
369 | +656 | ||
370 | +666 | ||
371 | +671 | ||
372 | +683 | ||
373 | +675 | ||
374 | +669 | ||
375 | +676 | ||
376 | +667 | ||
377 | +691 | ||
378 | +685 | ||
379 | +673 | ||
380 | +688 | ||
381 | +702 | ||
382 | +684 | ||
383 | +679 | ||
384 | +694 | ||
385 | +686 | ||
386 | +689 | ||
387 | +680 | ||
388 | +693 | ||
389 | +703 | ||
390 | +697 | ||
391 | +698 | ||
392 | +692 | ||
393 | +705 | ||
394 | +706 | ||
395 | +712 | ||
396 | +711 | ||
397 | +709 | ||
398 | +710 | ||
399 | +726 | ||
400 | +713 | ||
401 | +721 | ||
402 | +720 | ||
403 | +715 | ||
404 | +717 | ||
405 | +730 | ||
406 | +728 | ||
407 | +723 | ||
408 | +716 | ||
409 | +722 | ||
410 | +718 | ||
411 | +732 | ||
412 | +724 | ||
413 | +736 | ||
414 | +725 | ||
415 | +742 | ||
416 | +727 | ||
417 | +735 | ||
418 | +740 | ||
419 | +748 | ||
420 | +738 | ||
421 | +746 | ||
422 | +751 | ||
423 | +749 | ||
424 | +752 | ||
425 | +754 | ||
426 | +760 | ||
427 | +763 | ||
428 | +756 | ||
429 | +758 | ||
430 | +766 | ||
431 | +764 | ||
432 | +757 | ||
433 | +780 | ||
434 | +767 | ||
435 | +769 | ||
436 | +771 | ||
437 | +786 | ||
438 | +785 | ||
439 | +781 | ||
440 | +787 | ||
441 | +778 | ||
442 | +783 | ||
443 | +792 | ||
444 | +791 | ||
445 | +795 | ||
446 | +788 | ||
447 | +805 | ||
448 | +802 | ||
449 | +801 | ||
450 | +793 | ||
451 | +796 | ||
452 | +804 | ||
453 | +803 | ||
454 | +797 | ||
455 | +814 | ||
456 | +813 | ||
457 | +789 | ||
458 | +808 | ||
459 | +818 | ||
460 | +816 | ||
461 | +817 | ||
462 | +811 | ||
463 | +820 | ||
464 | +826 | ||
465 | +829 | ||
466 | +824 | ||
467 | +821 | ||
468 | +825 | ||
469 | +822 | ||
470 | +835 | ||
471 | +833 | ||
472 | +843 | ||
473 | +823 | ||
474 | +827 | ||
475 | +830 | ||
476 | +832 | ||
477 | +837 | ||
478 | +852 | ||
479 | +844 | ||
480 | +841 | ||
481 | +812 | ||
482 | +847 | ||
483 | +862 | ||
484 | +869 | ||
485 | +860 | ||
486 | +838 | ||
487 | +870 | ||
488 | +846 | ||
489 | +858 | ||
490 | +854 | ||
491 | +880 | ||
492 | +876 | ||
493 | +857 | ||
494 | +859 | ||
495 | +877 | ||
496 | +871 | ||
497 | +855 | ||
498 | +875 | ||
499 | +861 | ||
500 | +867 | ||
501 | +892 | ||
502 | +898 | ||
503 | +888 | ||
504 | +884 | ||
505 | +887 | ||
506 | +891 | ||
507 | +906 | ||
508 | +900 | ||
509 | +878 | ||
510 | +885 | ||
511 | +883 | ||
512 | +901 | ||
513 | +903 | ||
514 | +907 | ||
515 | +930 | ||
516 | +897 | ||
517 | +914 | ||
518 | +917 | ||
519 | +910 | ||
520 | +905 | ||
521 | +909 | ||
522 | +933 | ||
523 | +932 | ||
524 | +922 | ||
525 | +913 | ||
526 | +923 | ||
527 | +931 | ||
528 | +911 | ||
529 | +937 | ||
530 | +918 | ||
531 | +955 | ||
532 | +915 | ||
533 | +944 | ||
534 | +952 | ||
535 | +945 | ||
536 | +948 | ||
537 | +946 | ||
538 | +970 | ||
539 | +974 | ||
540 | +958 | ||
541 | +925 | ||
542 | +979 | ||
543 | +942 | ||
544 | +965 | ||
545 | +975 | ||
546 | +950 | ||
547 | +982 | ||
548 | +940 | ||
549 | +973 | ||
550 | +962 | ||
551 | +972 | ||
552 | +957 | ||
553 | +984 | ||
554 | +983 | ||
555 | +964 | ||
556 | +1007 | ||
557 | +971 | ||
558 | +981 | ||
559 | +954 | ||
560 | +993 | ||
561 | +991 | ||
562 | +996 | ||
563 | +1005 | ||
564 | +1015 | ||
565 | +1009 | ||
566 | +995 | ||
567 | +986 | ||
568 | +1000 | ||
569 | +985 | ||
570 | +980 | ||
571 | +1016 | ||
572 | +1011 | ||
573 | +999 | ||
574 | +1002 | ||
575 | +994 | ||
576 | +1013 | ||
577 | +1010 | ||
578 | +992 | ||
579 | +1008 | ||
580 | +1036 | ||
581 | +1025 | ||
582 | +1012 | ||
583 | +990 | ||
584 | +1037 | ||
585 | +1040 | ||
586 | +1031 | ||
587 | +1019 | ||
588 | +1052 | ||
589 | +1001 | ||
590 | +1055 | ||
591 | +1032 | ||
592 | +1069 | ||
593 | +1058 | ||
594 | +1014 | ||
595 | +1023 | ||
596 | +1030 | ||
597 | +1061 | ||
598 | +1035 | ||
599 | +1034 | ||
600 | +1053 | ||
601 | +1045 | ||
602 | +1046 | ||
603 | +1067 | ||
604 | +1060 | ||
605 | +1049 | ||
606 | +1056 | ||
607 | +1074 | ||
608 | +1066 | ||
609 | +1044 | ||
610 | +1038 | ||
611 | +1073 | ||
612 | +1077 | ||
613 | +1068 | ||
614 | +1057 | ||
615 | +1072 | ||
616 | +1104 | ||
617 | +1083 | ||
618 | +1089 | ||
619 | +1087 | ||
620 | +1099 | ||
621 | +1076 | ||
622 | +1086 | ||
623 | +1098 | ||
624 | +1094 | ||
625 | +1095 | ||
626 | +1096 | ||
627 | +1101 | ||
628 | +1107 | ||
629 | +1105 | ||
630 | +1117 | ||
631 | +1093 | ||
632 | +1106 | ||
633 | +1122 | ||
634 | +1119 | ||
635 | +1103 | ||
636 | +1128 | ||
637 | +1120 | ||
638 | +1126 | ||
639 | +1102 | ||
640 | +1115 | ||
641 | +1124 | ||
642 | +1123 | ||
643 | +1131 | ||
644 | +1136 | ||
645 | +1144 | ||
646 | +1121 | ||
647 | +1137 | ||
648 | +1132 | ||
649 | +1133 | ||
650 | +1157 | ||
651 | +1134 | ||
652 | +1143 | ||
653 | +1159 | ||
654 | +1164 | ||
655 | +1155 | ||
656 | +1142 | ||
657 | +1150 | ||
658 | +1148 | ||
659 | +1161 | ||
660 | +1165 | ||
661 | +1147 | ||
662 | +1162 | ||
663 | +1152 | ||
664 | +1174 | ||
665 | +1160 | ||
666 | +1166 | ||
667 | +1190 | ||
668 | +1175 | ||
669 | +1167 | ||
670 | +1156 | ||
671 | +1180 | ||
672 | +1171 | ||
673 | +1179 | ||
674 | +1172 | ||
675 | +1186 | ||
676 | +1188 | ||
677 | +1201 | ||
678 | +1177 | ||
679 | +1208 | ||
680 | +1183 | ||
681 | +1189 | ||
682 | +1192 | ||
683 | +1209 | ||
684 | +1214 | ||
685 | +1197 | ||
686 | +1168 | ||
687 | +1202 | ||
688 | +1205 | ||
689 | +1203 | ||
690 | +1199 | ||
691 | +1219 | ||
692 | +1217 | ||
693 | +1187 | ||
694 | +1206 | ||
695 | +1210 | ||
696 | +1241 | ||
697 | +1221 | ||
698 | +1218 | ||
699 | +1223 | ||
700 | +1236 | ||
701 | +1212 | ||
702 | +1237 | ||
703 | +1195 | ||
704 | +1216 | ||
705 | +1247 | ||
706 | +1234 | ||
707 | +1240 | ||
708 | +1257 | ||
709 | +1224 | ||
710 | +1243 | ||
711 | +1259 | ||
712 | +1242 | ||
713 | +1282 | ||
714 | +1222 | ||
715 | +1254 | ||
716 | +1227 | ||
717 | +1235 | ||
718 | +1269 | ||
719 | +1258 | ||
720 | +1290 | ||
721 | +1275 | ||
722 | +1262 | ||
723 | +1252 | ||
724 | +1248 | ||
725 | +1272 | ||
726 | +1246 | ||
727 | +1225 | ||
728 | +1245 | ||
729 | +1277 | ||
730 | +1298 | ||
731 | +1288 | ||
732 | +1271 | ||
733 | +1265 | ||
734 | +1286 | ||
735 | +1260 | ||
736 | +1266 | ||
737 | +1296 | ||
738 | +1280 | ||
739 | +1285 | ||
740 | +1293 | ||
741 | +1276 | ||
742 | +1287 | ||
743 | +1289 | ||
744 | +1261 | ||
745 | +1264 | ||
746 | +1295 | ||
747 | +1291 | ||
748 | +1283 | ||
749 | +1311 | ||
750 | +1303 | ||
751 | +1330 | ||
752 | +1315 | ||
753 | +1300 | ||
754 | +1333 | ||
755 | +1307 | ||
756 | +1325 | ||
757 | +1334 | ||
758 | +1316 | ||
759 | +1314 | ||
760 | +1317 | ||
761 | +1310 | ||
762 | +1329 | ||
763 | +1324 | ||
764 | +1339 | ||
765 | +1346 | ||
766 | +1342 | ||
767 | +1352 | ||
768 | +1321 | ||
769 | +1376 | ||
770 | +1366 | ||
771 | +1308 | ||
772 | +1345 | ||
773 | +1348 | ||
774 | +1386 | ||
775 | +1383 | ||
776 | +1372 | ||
777 | +1367 | ||
778 | +1400 | ||
779 | +1382 | ||
780 | +1375 | ||
781 | +1392 | ||
782 | +1380 | ||
783 | +1371 | ||
784 | +1393 | ||
785 | +1389 | ||
786 | +1353 | ||
787 | +1387 | ||
788 | +1374 | ||
789 | +1379 | ||
790 | +1381 | ||
791 | +1359 | ||
792 | +1360 | ||
793 | +1396 | ||
794 | +1399 | ||
795 | +1365 | ||
796 | +1424 | ||
797 | +1373 | ||
798 | +1411 | ||
799 | +1401 | ||
800 | +1397 | ||
801 | +1395 | ||
802 | +1412 | ||
803 | +1394 | ||
804 | +1368 | ||
805 | +1423 | ||
806 | +1391 | ||
807 | +1435 | ||
808 | +1409 | ||
809 | +1443 | ||
810 | +1402 | ||
811 | +1425 | ||
812 | +1415 | ||
813 | +1421 | ||
814 | +1426 | ||
815 | +1433 | ||
816 | +1420 | ||
817 | +1452 | ||
818 | +1436 | ||
819 | +1430 | ||
820 | +1408 | ||
821 | +1458 | ||
822 | +1429 | ||
823 | +1453 | ||
824 | +1454 | ||
825 | +1447 | ||
826 | +1472 | ||
827 | +1486 | ||
828 | +1468 | ||
829 | +1461 | ||
830 | +1467 | ||
831 | +1484 | ||
832 | +1457 | ||
833 | +1444 | ||
834 | +1450 | ||
835 | +1451 | ||
836 | +1459 | ||
837 | +1462 | ||
838 | +1449 | ||
839 | +1476 | ||
840 | +1470 | ||
841 | +1471 | ||
842 | +1498 | ||
843 | +1488 | ||
844 | +1442 | ||
845 | +1480 | ||
846 | +1456 | ||
847 | +1466 | ||
848 | +1505 | ||
849 | +1517 | ||
850 | +1464 | ||
851 | +1503 | ||
852 | +1490 | ||
853 | +1519 | ||
854 | +1481 | ||
855 | +1493 | ||
856 | +1463 | ||
857 | +1532 | ||
858 | +1487 | ||
859 | +1501 | ||
860 | +1500 | ||
861 | +1495 | ||
862 | +1509 | ||
863 | +1535 | ||
864 | +1506 | ||
865 | +1521 | ||
866 | +1580 | ||
867 | +1540 | ||
868 | +1502 | ||
869 | +1520 | ||
870 | +1496 | ||
871 | +1569 | ||
872 | +1515 | ||
873 | +1489 | ||
874 | +1507 | ||
875 | +1527 | ||
876 | +1545 | ||
877 | +1560 | ||
878 | +1510 | ||
879 | +1514 | ||
880 | +1526 | ||
881 | +1594 | ||
882 | +1511 | ||
883 | +1572 | ||
884 | +1548 | ||
885 | +1584 | ||
886 | +1556 | ||
887 | +1588 | ||
888 | +1628 | ||
889 | +1555 | ||
890 | +1568 | ||
891 | +1550 | ||
892 | +1622 | ||
893 | +1563 | ||
894 | +1603 | ||
895 | +1616 | ||
896 | +1576 | ||
897 | +1549 | ||
898 | +1537 | ||
899 | +1593 | ||
900 | +1618 | ||
901 | +1645 | ||
902 | +1624 | ||
903 | +1617 | ||
904 | +1634 | ||
905 | +1595 | ||
906 | +1597 | ||
907 | +1590 | ||
908 | +1632 | ||
909 | +1575 | ||
910 | +1559 | ||
911 | +1625 | ||
912 | +1615 | ||
913 | +1591 | ||
914 | +1630 | ||
915 | +1608 | ||
916 | +1621 | ||
917 | +1589 | ||
918 | +1646 | ||
919 | +1643 | ||
920 | +1652 | ||
921 | +1627 | ||
922 | +1611 | ||
923 | +1626 | ||
924 | +1613 | ||
925 | +1639 | ||
926 | +1655 | ||
927 | +1620 | ||
928 | +1602 | ||
929 | +1651 | ||
930 | +1653 | ||
931 | +1669 | ||
932 | +1638 | ||
933 | +1696 | ||
934 | +1649 | ||
935 | +1675 | ||
936 | +1660 | ||
937 | +1683 | ||
938 | +1666 | ||
939 | +1671 | ||
940 | +1703 | ||
941 | +1716 | ||
942 | +1637 | ||
943 | +1672 | ||
944 | +1676 | ||
945 | +1692 | ||
946 | +1711 | ||
947 | +1680 | ||
948 | +1641 | ||
949 | +1688 | ||
950 | +1708 | ||
951 | +1704 | ||
952 | +1690 | ||
953 | +1674 | ||
954 | +1718 | ||
955 | +1699 | ||
956 | +1723 | ||
957 | +1756 | ||
958 | +1700 | ||
959 | +1662 | ||
960 | +1715 | ||
961 | +1657 | ||
962 | +1733 | ||
963 | +1728 | ||
964 | +1670 | ||
965 | +1712 | ||
966 | +1685 | ||
967 | +1724 | ||
968 | +1735 | ||
969 | +1714 | ||
970 | +1730 | ||
971 | +1747 | ||
972 | +1656 | ||
973 | +1737 | ||
974 | +1705 | ||
975 | +1693 | ||
976 | +1713 | ||
977 | +1689 | ||
978 | +1753 | ||
979 | +1739 | ||
980 | +1721 | ||
981 | +1725 | ||
982 | +1749 | ||
983 | +1732 | ||
984 | +1743 | ||
985 | +1731 | ||
986 | +1767 | ||
987 | +1738 | ||
988 | +1831 | ||
989 | +1771 | ||
990 | +1726 | ||
991 | +1746 | ||
992 | +1776 | ||
993 | +1775 | ||
994 | +1799 | ||
995 | +1774 | ||
996 | +1780 | ||
997 | +1781 | ||
998 | +1769 | ||
999 | +1805 | ||
1000 | +1788 | ||
1001 | +1801 |
web/backend/yt8m/train.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for training Tensorflow models on the YouTube-8M dataset.""" | ||
15 | + | ||
16 | +import json | ||
17 | +import os | ||
18 | +import time | ||
19 | + | ||
20 | +import eval_util | ||
21 | +import export_model | ||
22 | +import losses | ||
23 | +import frame_level_models | ||
24 | +import video_level_models | ||
25 | +import readers | ||
26 | +import tensorflow as tf | ||
27 | +import tensorflow.contrib.slim as slim | ||
28 | +from tensorflow.python.lib.io import file_io | ||
29 | +from tensorflow import app | ||
30 | +from tensorflow import flags | ||
31 | +from tensorflow import gfile | ||
32 | +from tensorflow import logging | ||
33 | +from tensorflow.python.client import device_lib | ||
34 | +import utils | ||
35 | + | ||
36 | +FLAGS = flags.FLAGS | ||
37 | + | ||
38 | +if __name__ == "__main__": | ||
39 | + # Dataset flags. | ||
40 | + flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", | ||
41 | + "The directory to save the model files in.") | ||
42 | + flags.DEFINE_string( | ||
43 | + "train_data_pattern", "", | ||
44 | + "File glob for the training dataset. If the files refer to Frame Level " | ||
45 | + "features (i.e. tensorflow.SequenceExample), then set --reader_type " | ||
46 | + "format. The (Sequence)Examples are expected to have 'rgb' byte array " | ||
47 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
48 | + flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature " | ||
49 | + "to use for training.") | ||
50 | + flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.") | ||
51 | + | ||
52 | + # Model flags. | ||
53 | + flags.DEFINE_bool( | ||
54 | + "frame_features", False, | ||
55 | + "If set, then --train_data_pattern must be frame-level features. " | ||
56 | + "Otherwise, --train_data_pattern must be aggregated video-level " | ||
57 | + "features. The model must also be set appropriately (i.e. to read 3D " | ||
58 | + "batches VS 4D batches.") | ||
59 | + flags.DEFINE_bool( | ||
60 | + "segment_labels", False, | ||
61 | + "If set, then --train_data_pattern must be frame-level features (but with" | ||
62 | + " segment_labels). Otherwise, --train_data_pattern must be aggregated " | ||
63 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
64 | + "read 3D batches VS 4D batches.") | ||
65 | + flags.DEFINE_string( | ||
66 | + "model", "LogisticModel", | ||
67 | + "Which architecture to use for the model. Models are defined " | ||
68 | + "in models.py.") | ||
69 | + flags.DEFINE_bool( | ||
70 | + "start_new_model", False, | ||
71 | + "If set, this will not resume from a checkpoint and will instead create a" | ||
72 | + " new model instance.") | ||
73 | + | ||
74 | + # Training flags. | ||
75 | + flags.DEFINE_integer( | ||
76 | + "num_gpu", 1, "The maximum number of GPU devices to use for training. " | ||
77 | + "Flag only applies if GPUs are installed") | ||
78 | + flags.DEFINE_integer("batch_size", 1024, | ||
79 | + "How many examples to process per batch for training.") | ||
80 | + flags.DEFINE_string("label_loss", "CrossEntropyLoss", | ||
81 | + "Which loss function to use for training the model.") | ||
82 | + flags.DEFINE_float( | ||
83 | + "regularization_penalty", 1.0, | ||
84 | + "How much weight to give to the regularization loss (the label loss has " | ||
85 | + "a weight of 1).") | ||
86 | + flags.DEFINE_float("base_learning_rate", 0.01, | ||
87 | + "Which learning rate to start with.") | ||
88 | + flags.DEFINE_float( | ||
89 | + "learning_rate_decay", 0.95, | ||
90 | + "Learning rate decay factor to be applied every " | ||
91 | + "learning_rate_decay_examples.") | ||
92 | + flags.DEFINE_float( | ||
93 | + "learning_rate_decay_examples", 4000000, | ||
94 | + "Multiply current learning rate by learning_rate_decay " | ||
95 | + "every learning_rate_decay_examples.") | ||
96 | + flags.DEFINE_integer( | ||
97 | + "num_epochs", 1000, "How many passes to make over the dataset before " | ||
98 | + "halting training.") | ||
99 | + flags.DEFINE_integer( | ||
100 | + "max_steps", None, | ||
101 | + "The maximum number of iterations of the training loop.") | ||
102 | + flags.DEFINE_integer( | ||
103 | + "export_model_steps", 1000, | ||
104 | + "The period, in number of steps, with which the model " | ||
105 | + "is exported for batch prediction.") | ||
106 | + | ||
107 | + # Other flags. | ||
108 | + flags.DEFINE_integer("num_readers", 8, | ||
109 | + "How many threads to use for reading input files.") | ||
110 | + flags.DEFINE_string("optimizer", "AdamOptimizer", | ||
111 | + "What optimizer class to use.") | ||
112 | + flags.DEFINE_float("clip_gradient_norm", 1.0, "Norm to clip gradients to.") | ||
113 | + flags.DEFINE_bool( | ||
114 | + "log_device_placement", False, | ||
115 | + "Whether to write the device on which every op will run into the " | ||
116 | + "logs on startup.") | ||
117 | + | ||
118 | + | ||
119 | +def validate_class_name(flag_value, category, modules, expected_superclass): | ||
120 | + """Checks that the given string matches a class of the expected type. | ||
121 | + | ||
122 | + Args: | ||
123 | + flag_value: A string naming the class to instantiate. | ||
124 | + category: A string used further describe the class in error messages (e.g. | ||
125 | + 'model', 'reader', 'loss'). | ||
126 | + modules: A list of modules to search for the given class. | ||
127 | + expected_superclass: A class that the given class should inherit from. | ||
128 | + | ||
129 | + Raises: | ||
130 | + FlagsError: If the given class could not be found or if the first class | ||
131 | + found with that name doesn't inherit from the expected superclass. | ||
132 | + | ||
133 | + Returns: | ||
134 | + True if a class was found that matches the given constraints. | ||
135 | + """ | ||
136 | + candidates = [getattr(module, flag_value, None) for module in modules] | ||
137 | + for candidate in candidates: | ||
138 | + if not candidate: | ||
139 | + continue | ||
140 | + if not issubclass(candidate, expected_superclass): | ||
141 | + raise flags.FlagsError( | ||
142 | + "%s '%s' doesn't inherit from %s." % | ||
143 | + (category, flag_value, expected_superclass.__name__)) | ||
144 | + return True | ||
145 | + raise flags.FlagsError("Unable to find %s '%s'." % (category, flag_value)) | ||
146 | + | ||
147 | + | ||
148 | +def get_input_data_tensors(reader, | ||
149 | + data_pattern, | ||
150 | + batch_size=1000, | ||
151 | + num_epochs=None, | ||
152 | + num_readers=1): | ||
153 | + """Creates the section of the graph which reads the training data. | ||
154 | + | ||
155 | + Args: | ||
156 | + reader: A class which parses the training data. | ||
157 | + data_pattern: A 'glob' style path to the data files. | ||
158 | + batch_size: How many examples to process at a time. | ||
159 | + num_epochs: How many passes to make over the training data. Set to 'None' to | ||
160 | + run indefinitely. | ||
161 | + num_readers: How many I/O threads to use. | ||
162 | + | ||
163 | + Returns: | ||
164 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
165 | + tensor containing the number of frames per video. The exact dimensions | ||
166 | + depend on the reader being used. | ||
167 | + | ||
168 | + Raises: | ||
169 | + IOError: If no files matching the given pattern were found. | ||
170 | + """ | ||
171 | + logging.info("Using batch size of " + str(batch_size) + " for training.") | ||
172 | + with tf.name_scope("train_input"): | ||
173 | + files = gfile.Glob(data_pattern) | ||
174 | + if not files: | ||
175 | + raise IOError("Unable to find training files. data_pattern='" + | ||
176 | + data_pattern + "'.") | ||
177 | + logging.info("Number of training files: %s.", str(len(files))) | ||
178 | + filename_queue = tf.train.string_input_producer(files, | ||
179 | + num_epochs=num_epochs, | ||
180 | + shuffle=True) | ||
181 | + training_data = [ | ||
182 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
183 | + ] | ||
184 | + | ||
185 | + return tf.train.shuffle_batch_join(training_data, | ||
186 | + batch_size=batch_size, | ||
187 | + capacity=batch_size * 5, | ||
188 | + min_after_dequeue=batch_size, | ||
189 | + allow_smaller_final_batch=True, | ||
190 | + enqueue_many=True) | ||
191 | + | ||
192 | + | ||
193 | +def find_class_by_name(name, modules): | ||
194 | + """Searches the provided modules for the named class and returns it.""" | ||
195 | + modules = [getattr(module, name, None) for module in modules] | ||
196 | + return next(a for a in modules if a) | ||
197 | + | ||
198 | + | ||
199 | +def build_graph(reader, | ||
200 | + model, | ||
201 | + train_data_pattern, | ||
202 | + label_loss_fn=losses.CrossEntropyLoss(), | ||
203 | + batch_size=1000, | ||
204 | + base_learning_rate=0.01, | ||
205 | + learning_rate_decay_examples=1000000, | ||
206 | + learning_rate_decay=0.95, | ||
207 | + optimizer_class=tf.train.AdamOptimizer, | ||
208 | + clip_gradient_norm=1.0, | ||
209 | + regularization_penalty=1, | ||
210 | + num_readers=1, | ||
211 | + num_epochs=None): | ||
212 | + """Creates the Tensorflow graph. | ||
213 | + | ||
214 | + This will only be called once in the life of | ||
215 | + a training model, because after the graph is created the model will be | ||
216 | + restored from a meta graph file rather than being recreated. | ||
217 | + | ||
218 | + Args: | ||
219 | + reader: The data file reader. It should inherit from BaseReader. | ||
220 | + model: The core model (e.g. logistic or neural net). It should inherit from | ||
221 | + BaseModel. | ||
222 | + train_data_pattern: glob path to the training data files. | ||
223 | + label_loss_fn: What kind of loss to apply to the model. It should inherit | ||
224 | + from BaseLoss. | ||
225 | + batch_size: How many examples to process at a time. | ||
226 | + base_learning_rate: What learning rate to initialize the optimizer with. | ||
227 | + optimizer_class: Which optimization algorithm to use. | ||
228 | + clip_gradient_norm: Magnitude of the gradient to clip to. | ||
229 | + regularization_penalty: How much weight to give the regularization loss | ||
230 | + compared to the label loss. | ||
231 | + num_readers: How many threads to use for I/O operations. | ||
232 | + num_epochs: How many passes to make over the data. 'None' means an unlimited | ||
233 | + number of passes. | ||
234 | + """ | ||
235 | + | ||
236 | + global_step = tf.Variable(0, trainable=False, name="global_step") | ||
237 | + | ||
238 | + local_device_protos = device_lib.list_local_devices() | ||
239 | + gpus = [x.name for x in local_device_protos if x.device_type == "GPU"] | ||
240 | + gpus = gpus[:FLAGS.num_gpu] | ||
241 | + num_gpus = len(gpus) | ||
242 | + | ||
243 | + if num_gpus > 0: | ||
244 | + logging.info("Using the following GPUs to train: " + str(gpus)) | ||
245 | + num_towers = num_gpus | ||
246 | + device_string = "/gpu:%d" | ||
247 | + else: | ||
248 | + logging.info("No GPUs found. Training on CPU.") | ||
249 | + num_towers = 1 | ||
250 | + device_string = "/cpu:%d" | ||
251 | + | ||
252 | + learning_rate = tf.train.exponential_decay(base_learning_rate, | ||
253 | + global_step * batch_size * | ||
254 | + num_towers, | ||
255 | + learning_rate_decay_examples, | ||
256 | + learning_rate_decay, | ||
257 | + staircase=True) | ||
258 | + tf.summary.scalar("learning_rate", learning_rate) | ||
259 | + | ||
260 | + optimizer = optimizer_class(learning_rate) | ||
261 | + input_data_dict = (get_input_data_tensors(reader, | ||
262 | + train_data_pattern, | ||
263 | + batch_size=batch_size * num_towers, | ||
264 | + num_readers=num_readers, | ||
265 | + num_epochs=num_epochs)) | ||
266 | + model_input_raw = input_data_dict["video_matrix"] | ||
267 | + labels_batch = input_data_dict["labels"] | ||
268 | + num_frames = input_data_dict["num_frames"] | ||
269 | + print("model_input_shape, ", model_input_raw.shape) | ||
270 | + tf.summary.histogram("model/input_raw", model_input_raw) | ||
271 | + | ||
272 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
273 | + | ||
274 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
275 | + | ||
276 | + tower_inputs = tf.split(model_input, num_towers) | ||
277 | + tower_labels = tf.split(labels_batch, num_towers) | ||
278 | + tower_num_frames = tf.split(num_frames, num_towers) | ||
279 | + tower_gradients = [] | ||
280 | + tower_predictions = [] | ||
281 | + tower_label_losses = [] | ||
282 | + tower_reg_losses = [] | ||
283 | + for i in range(num_towers): | ||
284 | + # For some reason these 'with' statements can't be combined onto the same | ||
285 | + # line. They have to be nested. | ||
286 | + with tf.device(device_string % i): | ||
287 | + with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)): | ||
288 | + with (slim.arg_scope([slim.model_variable, slim.variable], | ||
289 | + device="/cpu:0" if num_gpus != 1 else "/gpu:0")): | ||
290 | + result = model.create_model(tower_inputs[i], | ||
291 | + num_frames=tower_num_frames[i], | ||
292 | + vocab_size=reader.num_classes, | ||
293 | + labels=tower_labels[i]) | ||
294 | + for variable in slim.get_model_variables(): | ||
295 | + tf.summary.histogram(variable.op.name, variable) | ||
296 | + | ||
297 | + predictions = result["predictions"] | ||
298 | + tower_predictions.append(predictions) | ||
299 | + | ||
300 | + if "loss" in result.keys(): | ||
301 | + label_loss = result["loss"] | ||
302 | + else: | ||
303 | + label_loss = label_loss_fn.calculate_loss(predictions, | ||
304 | + tower_labels[i]) | ||
305 | + | ||
306 | + if "regularization_loss" in result.keys(): | ||
307 | + reg_loss = result["regularization_loss"] | ||
308 | + else: | ||
309 | + reg_loss = tf.constant(0.0) | ||
310 | + | ||
311 | + reg_losses = tf.losses.get_regularization_losses() | ||
312 | + if reg_losses: | ||
313 | + reg_loss += tf.add_n(reg_losses) | ||
314 | + | ||
315 | + tower_reg_losses.append(reg_loss) | ||
316 | + | ||
317 | + # Adds update_ops (e.g., moving average updates in batch normalization) as | ||
318 | + # a dependency to the train_op. | ||
319 | + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | ||
320 | + if "update_ops" in result.keys(): | ||
321 | + update_ops += result["update_ops"] | ||
322 | + if update_ops: | ||
323 | + with tf.control_dependencies(update_ops): | ||
324 | + barrier = tf.no_op(name="gradient_barrier") | ||
325 | + with tf.control_dependencies([barrier]): | ||
326 | + label_loss = tf.identity(label_loss) | ||
327 | + | ||
328 | + tower_label_losses.append(label_loss) | ||
329 | + | ||
330 | + # Incorporate the L2 weight penalties etc. | ||
331 | + final_loss = regularization_penalty * reg_loss + label_loss | ||
332 | + gradients = optimizer.compute_gradients( | ||
333 | + final_loss, colocate_gradients_with_ops=False) | ||
334 | + tower_gradients.append(gradients) | ||
335 | + label_loss = tf.reduce_mean(tf.stack(tower_label_losses)) | ||
336 | + tf.summary.scalar("label_loss", label_loss) | ||
337 | + if regularization_penalty != 0: | ||
338 | + reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses)) | ||
339 | + tf.summary.scalar("reg_loss", reg_loss) | ||
340 | + merged_gradients = utils.combine_gradients(tower_gradients) | ||
341 | + | ||
342 | + if clip_gradient_norm > 0: | ||
343 | + with tf.name_scope("clip_grads"): | ||
344 | + merged_gradients = utils.clip_gradient_norms(merged_gradients, | ||
345 | + clip_gradient_norm) | ||
346 | + | ||
347 | + train_op = optimizer.apply_gradients(merged_gradients, | ||
348 | + global_step=global_step) | ||
349 | + | ||
350 | + tf.add_to_collection("global_step", global_step) | ||
351 | + tf.add_to_collection("loss", label_loss) | ||
352 | + tf.add_to_collection("predictions", tf.concat(tower_predictions, 0)) | ||
353 | + tf.add_to_collection("input_batch_raw", model_input_raw) | ||
354 | + tf.add_to_collection("input_batch", model_input) | ||
355 | + tf.add_to_collection("num_frames", num_frames) | ||
356 | + tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) | ||
357 | + tf.add_to_collection("train_op", train_op) | ||
358 | + | ||
359 | + | ||
360 | +class Trainer(object): | ||
361 | + """A Trainer to train a Tensorflow graph.""" | ||
362 | + | ||
363 | + def __init__(self, | ||
364 | + cluster, | ||
365 | + task, | ||
366 | + train_dir, | ||
367 | + model, | ||
368 | + reader, | ||
369 | + model_exporter, | ||
370 | + log_device_placement=True, | ||
371 | + max_steps=None, | ||
372 | + export_model_steps=1000): | ||
373 | + """"Creates a Trainer. | ||
374 | + | ||
375 | + Args: | ||
376 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
377 | + otherwise. | ||
378 | + task: A TaskSpec describing the job type and the task index. | ||
379 | + """ | ||
380 | + | ||
381 | + self.cluster = cluster | ||
382 | + self.task = task | ||
383 | + self.is_master = (task.type == "master" and task.index == 0) | ||
384 | + self.train_dir = train_dir | ||
385 | + self.config = tf.ConfigProto(allow_soft_placement=True, | ||
386 | + log_device_placement=log_device_placement) | ||
387 | + self.config.gpu_options.allow_growth = True | ||
388 | + self.model = model | ||
389 | + self.reader = reader | ||
390 | + self.model_exporter = model_exporter | ||
391 | + self.max_steps = max_steps | ||
392 | + self.max_steps_reached = False | ||
393 | + self.export_model_steps = export_model_steps | ||
394 | + self.last_model_export_step = 0 | ||
395 | + | ||
396 | + | ||
397 | +# if self.is_master and self.task.index > 0: | ||
398 | +# raise StandardError("%s: Only one replica of master expected", | ||
399 | +# task_as_string(self.task)) | ||
400 | + | ||
401 | + def run(self, start_new_model=False): | ||
402 | + """Performs training on the currently defined Tensorflow graph. | ||
403 | + | ||
404 | + Returns: | ||
405 | + A tuple of the training Hit@1 and the training PERR. | ||
406 | + """ | ||
407 | + if self.is_master and start_new_model: | ||
408 | + self.remove_training_directory(self.train_dir) | ||
409 | + | ||
410 | + if not os.path.exists(self.train_dir): | ||
411 | + os.makedirs(self.train_dir) | ||
412 | + | ||
413 | + model_flags_dict = { | ||
414 | + "model": FLAGS.model, | ||
415 | + "feature_sizes": FLAGS.feature_sizes, | ||
416 | + "feature_names": FLAGS.feature_names, | ||
417 | + "frame_features": FLAGS.frame_features, | ||
418 | + "label_loss": FLAGS.label_loss, | ||
419 | + } | ||
420 | + flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
421 | + if file_io.file_exists(flags_json_path): | ||
422 | + existing_flags = json.load(file_io.FileIO(flags_json_path, mode="r")) | ||
423 | + if existing_flags != model_flags_dict: | ||
424 | + logging.error( | ||
425 | + "Model flags do not match existing file %s. Please " | ||
426 | + "delete the file, change --train_dir, or pass flag " | ||
427 | + "--start_new_model", flags_json_path) | ||
428 | + logging.error("Ran model with flags: %s", str(model_flags_dict)) | ||
429 | + logging.error("Previously ran with flags: %s", str(existing_flags)) | ||
430 | + exit(1) | ||
431 | + else: | ||
432 | + # Write the file. | ||
433 | + with file_io.FileIO(flags_json_path, mode="w") as fout: | ||
434 | + fout.write(json.dumps(model_flags_dict)) | ||
435 | + | ||
436 | + target, device_fn = self.start_server_if_distributed() | ||
437 | + | ||
438 | + meta_filename = self.get_meta_filename(start_new_model, self.train_dir) | ||
439 | + | ||
440 | + with tf.Graph().as_default() as graph: | ||
441 | + if meta_filename: | ||
442 | + saver = self.recover_model(meta_filename) | ||
443 | + | ||
444 | + with tf.device(device_fn): | ||
445 | + if not meta_filename: | ||
446 | + saver = self.build_model(self.model, self.reader) | ||
447 | + | ||
448 | + global_step = tf.get_collection("global_step")[0] | ||
449 | + loss = tf.get_collection("loss")[0] | ||
450 | + predictions = tf.get_collection("predictions")[0] | ||
451 | + labels = tf.get_collection("labels")[0] | ||
452 | + train_op = tf.get_collection("train_op")[0] | ||
453 | + init_op = tf.global_variables_initializer() | ||
454 | + | ||
455 | + sv = tf.train.Supervisor(graph, | ||
456 | + logdir=self.train_dir, | ||
457 | + init_op=init_op, | ||
458 | + is_chief=self.is_master, | ||
459 | + global_step=global_step, | ||
460 | + save_model_secs=15 * 60, | ||
461 | + save_summaries_secs=120, | ||
462 | + saver=saver) | ||
463 | + | ||
464 | + logging.info("%s: Starting managed session.", task_as_string(self.task)) | ||
465 | + with sv.managed_session(target, config=self.config) as sess: | ||
466 | + try: | ||
467 | + logging.info("%s: Entering training loop.", task_as_string(self.task)) | ||
468 | + while (not sv.should_stop()) and (not self.max_steps_reached): | ||
469 | + batch_start_time = time.time() | ||
470 | + _, global_step_val, loss_val, predictions_val, labels_val = sess.run( | ||
471 | + [train_op, global_step, loss, predictions, labels]) | ||
472 | + seconds_per_batch = time.time() - batch_start_time | ||
473 | + examples_per_second = labels_val.shape[0] / seconds_per_batch | ||
474 | + | ||
475 | + if self.max_steps and self.max_steps <= global_step_val: | ||
476 | + self.max_steps_reached = True | ||
477 | + | ||
478 | + if self.is_master and global_step_val % 10 == 0 and self.train_dir: | ||
479 | + eval_start_time = time.time() | ||
480 | + hit_at_one = eval_util.calculate_hit_at_one(predictions_val, | ||
481 | + labels_val) | ||
482 | + perr = eval_util.calculate_precision_at_equal_recall_rate( | ||
483 | + predictions_val, labels_val) | ||
484 | + gap = eval_util.calculate_gap(predictions_val, labels_val) | ||
485 | + eval_end_time = time.time() | ||
486 | + eval_time = eval_end_time - eval_start_time | ||
487 | + | ||
488 | + logging.info("training step " + str(global_step_val) + " | Loss: " + | ||
489 | + ("%.2f" % loss_val) + " Examples/sec: " + | ||
490 | + ("%.2f" % examples_per_second) + " | Hit@1: " + | ||
491 | + ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) + | ||
492 | + " GAP: " + ("%.2f" % gap)) | ||
493 | + | ||
494 | + sv.summary_writer.add_summary( | ||
495 | + utils.MakeSummary("model/Training_Hit@1", hit_at_one), | ||
496 | + global_step_val) | ||
497 | + sv.summary_writer.add_summary( | ||
498 | + utils.MakeSummary("model/Training_Perr", perr), global_step_val) | ||
499 | + sv.summary_writer.add_summary( | ||
500 | + utils.MakeSummary("model/Training_GAP", gap), global_step_val) | ||
501 | + sv.summary_writer.add_summary( | ||
502 | + utils.MakeSummary("global_step/Examples/Second", | ||
503 | + examples_per_second), global_step_val) | ||
504 | + sv.summary_writer.flush() | ||
505 | + | ||
506 | + # Exporting the model every x steps | ||
507 | + time_to_export = ((self.last_model_export_step == 0) or | ||
508 | + (global_step_val - self.last_model_export_step >= | ||
509 | + self.export_model_steps)) | ||
510 | + | ||
511 | + if self.is_master and time_to_export: | ||
512 | + self.export_model(global_step_val, sv.saver, sv.save_path, sess) | ||
513 | + self.last_model_export_step = global_step_val | ||
514 | + else: | ||
515 | + logging.info("training step " + str(global_step_val) + " | Loss: " + | ||
516 | + ("%.2f" % loss_val) + " Examples/sec: " + | ||
517 | + ("%.2f" % examples_per_second)) | ||
518 | + except tf.errors.OutOfRangeError: | ||
519 | + logging.info("%s: Done training -- epoch limit reached.", | ||
520 | + task_as_string(self.task)) | ||
521 | + | ||
522 | + logging.info("%s: Exited training loop.", task_as_string(self.task)) | ||
523 | + sv.Stop() | ||
524 | + | ||
525 | + def export_model(self, global_step_val, saver, save_path, session): | ||
526 | + | ||
527 | + # If the model has already been exported at this step, return. | ||
528 | + if global_step_val == self.last_model_export_step: | ||
529 | + return | ||
530 | + | ||
531 | + saver.save(session, save_path, global_step_val) | ||
532 | + | ||
533 | + def start_server_if_distributed(self): | ||
534 | + """Starts a server if the execution is distributed.""" | ||
535 | + | ||
536 | + if self.cluster: | ||
537 | + logging.info("%s: Starting trainer within cluster %s.", | ||
538 | + task_as_string(self.task), self.cluster.as_dict()) | ||
539 | + server = start_server(self.cluster, self.task) | ||
540 | + target = server.target | ||
541 | + device_fn = tf.train.replica_device_setter( | ||
542 | + ps_device="/job:ps", | ||
543 | + worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), | ||
544 | + cluster=self.cluster) | ||
545 | + else: | ||
546 | + target = "" | ||
547 | + device_fn = "" | ||
548 | + return (target, device_fn) | ||
549 | + | ||
550 | + def remove_training_directory(self, train_dir): | ||
551 | + """Removes the training directory.""" | ||
552 | + try: | ||
553 | + logging.info("%s: Removing existing train directory.", | ||
554 | + task_as_string(self.task)) | ||
555 | + gfile.DeleteRecursively(train_dir) | ||
556 | + except: | ||
557 | + logging.error( | ||
558 | + "%s: Failed to delete directory " + train_dir + | ||
559 | + " when starting a new model. Please delete it manually and" + | ||
560 | + " try again.", task_as_string(self.task)) | ||
561 | + | ||
562 | + def get_meta_filename(self, start_new_model, train_dir): | ||
563 | + if start_new_model: | ||
564 | + logging.info("%s: Flag 'start_new_model' is set. Building a new model.", | ||
565 | + task_as_string(self.task)) | ||
566 | + return None | ||
567 | + | ||
568 | + latest_checkpoint = tf.train.latest_checkpoint(train_dir) | ||
569 | + if not latest_checkpoint: | ||
570 | + logging.info("%s: No checkpoint file found. Building a new model.", | ||
571 | + task_as_string(self.task)) | ||
572 | + return None | ||
573 | + | ||
574 | + meta_filename = latest_checkpoint + ".meta" | ||
575 | + if not gfile.Exists(meta_filename): | ||
576 | + logging.info("%s: No meta graph file found. Building a new model.", | ||
577 | + task_as_string(self.task)) | ||
578 | + return None | ||
579 | + else: | ||
580 | + return meta_filename | ||
581 | + | ||
582 | + def recover_model(self, meta_filename): | ||
583 | + logging.info("%s: Restoring from meta graph file %s", | ||
584 | + task_as_string(self.task), meta_filename) | ||
585 | + return tf.train.import_meta_graph(meta_filename) | ||
586 | + | ||
587 | + def build_model(self, model, reader): | ||
588 | + """Find the model and build the graph.""" | ||
589 | + | ||
590 | + label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])() | ||
591 | + optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train]) | ||
592 | + | ||
593 | + build_graph(reader=reader, | ||
594 | + model=model, | ||
595 | + optimizer_class=optimizer_class, | ||
596 | + clip_gradient_norm=FLAGS.clip_gradient_norm, | ||
597 | + train_data_pattern=FLAGS.train_data_pattern, | ||
598 | + label_loss_fn=label_loss_fn, | ||
599 | + base_learning_rate=FLAGS.base_learning_rate, | ||
600 | + learning_rate_decay=FLAGS.learning_rate_decay, | ||
601 | + learning_rate_decay_examples=FLAGS.learning_rate_decay_examples, | ||
602 | + regularization_penalty=FLAGS.regularization_penalty, | ||
603 | + num_readers=FLAGS.num_readers, | ||
604 | + batch_size=FLAGS.batch_size, | ||
605 | + num_epochs=FLAGS.num_epochs) | ||
606 | + | ||
607 | + return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25) | ||
608 | + | ||
609 | + | ||
610 | +def get_reader(): | ||
611 | + # Convert feature_names and feature_sizes to lists of values. | ||
612 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
613 | + FLAGS.feature_names, FLAGS.feature_sizes) | ||
614 | + | ||
615 | + if FLAGS.frame_features: | ||
616 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
617 | + feature_sizes=feature_sizes, | ||
618 | + segment_labels=FLAGS.segment_labels) | ||
619 | + else: | ||
620 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
621 | + feature_sizes=feature_sizes) | ||
622 | + | ||
623 | + return reader | ||
624 | + | ||
625 | + | ||
626 | +class ParameterServer(object): | ||
627 | + """A parameter server to serve variables in a distributed execution.""" | ||
628 | + | ||
629 | + def __init__(self, cluster, task): | ||
630 | + """Creates a ParameterServer. | ||
631 | + | ||
632 | + Args: | ||
633 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
634 | + otherwise. | ||
635 | + task: A TaskSpec describing the job type and the task index. | ||
636 | + """ | ||
637 | + | ||
638 | + self.cluster = cluster | ||
639 | + self.task = task | ||
640 | + | ||
641 | + def run(self): | ||
642 | + """Starts the parameter server.""" | ||
643 | + | ||
644 | + logging.info("%s: Starting parameter server within cluster %s.", | ||
645 | + task_as_string(self.task), self.cluster.as_dict()) | ||
646 | + server = start_server(self.cluster, self.task) | ||
647 | + server.join() | ||
648 | + | ||
649 | + | ||
650 | +def start_server(cluster, task): | ||
651 | + """Creates a Server. | ||
652 | + | ||
653 | + Args: | ||
654 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
655 | + otherwise. | ||
656 | + task: A TaskSpec describing the job type and the task index. | ||
657 | + """ | ||
658 | + | ||
659 | + if not task.type: | ||
660 | + raise ValueError("%s: The task type must be specified." % | ||
661 | + task_as_string(task)) | ||
662 | + if task.index is None: | ||
663 | + raise ValueError("%s: The task index must be specified." % | ||
664 | + task_as_string(task)) | ||
665 | + | ||
666 | + # Create and start a server. | ||
667 | + return tf.train.Server(tf.train.ClusterSpec(cluster), | ||
668 | + protocol="grpc", | ||
669 | + job_name=task.type, | ||
670 | + task_index=task.index) | ||
671 | + | ||
672 | + | ||
673 | +def task_as_string(task): | ||
674 | + return "/job:%s/task:%s" % (task.type, task.index) | ||
675 | + | ||
676 | + | ||
677 | +def main(unused_argv): | ||
678 | + # Load the environment. | ||
679 | + env = json.loads(os.environ.get("TF_CONFIG", "{}")) | ||
680 | + | ||
681 | + # Load the cluster data from the environment. | ||
682 | + cluster_data = env.get("cluster", None) | ||
683 | + cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None | ||
684 | + | ||
685 | + # Load the task data from the environment. | ||
686 | + task_data = env.get("task", None) or {"type": "master", "index": 0} | ||
687 | + task = type("TaskSpec", (object,), task_data) | ||
688 | + | ||
689 | + # Logging the version. | ||
690 | + logging.set_verbosity(tf.logging.INFO) | ||
691 | + logging.info("%s: Tensorflow version: %s.", task_as_string(task), | ||
692 | + tf.__version__) | ||
693 | + | ||
694 | + # Dispatch to a master, a worker, or a parameter server. | ||
695 | + if not cluster or task.type == "master" or task.type == "worker": | ||
696 | + model = find_class_by_name(FLAGS.model, | ||
697 | + [frame_level_models, video_level_models])() | ||
698 | + | ||
699 | + reader = get_reader() | ||
700 | + | ||
701 | + model_exporter = export_model.ModelExporter( | ||
702 | + frame_features=FLAGS.frame_features, model=model, reader=reader) | ||
703 | + | ||
704 | + Trainer(cluster, task, FLAGS.train_dir, model, reader, model_exporter, | ||
705 | + FLAGS.log_device_placement, FLAGS.max_steps, | ||
706 | + FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model) | ||
707 | + | ||
708 | + elif task.type == "ps": | ||
709 | + ParameterServer(cluster, task).run() | ||
710 | + else: | ||
711 | + raise ValueError("%s: Invalid task_type: %s." % | ||
712 | + (task_as_string(task), task.type)) | ||
713 | + | ||
714 | + | ||
715 | +if __name__ == "__main__": | ||
716 | + app.run() |
web/backend/yt8m/utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for training and evaluating.""" | ||
15 | + | ||
16 | +import numpy | ||
17 | +import tensorflow as tf | ||
18 | +from tensorflow import logging | ||
19 | + | ||
20 | +try: | ||
21 | + xrange # Python 2 | ||
22 | +except NameError: | ||
23 | + xrange = range # Python 3 | ||
24 | + | ||
25 | + | ||
26 | +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): | ||
27 | + """Dequantize the feature from the byte format to the float format. | ||
28 | + | ||
29 | + Args: | ||
30 | + feat_vector: the input 1-d vector. | ||
31 | + max_quantized_value: the maximum of the quantized value. | ||
32 | + min_quantized_value: the minimum of the quantized value. | ||
33 | + | ||
34 | + Returns: | ||
35 | + A float vector which has the same shape as feat_vector. | ||
36 | + """ | ||
37 | + assert max_quantized_value > min_quantized_value | ||
38 | + quantized_range = max_quantized_value - min_quantized_value | ||
39 | + scalar = quantized_range / 255.0 | ||
40 | + bias = (quantized_range / 512.0) + min_quantized_value | ||
41 | + return feat_vector * scalar + bias | ||
42 | + | ||
43 | + | ||
44 | +def MakeSummary(name, value): | ||
45 | + """Creates a tf.Summary proto with the given name and value.""" | ||
46 | + summary = tf.Summary() | ||
47 | + val = summary.value.add() | ||
48 | + val.tag = str(name) | ||
49 | + val.simple_value = float(value) | ||
50 | + return summary | ||
51 | + | ||
52 | + | ||
53 | +def AddGlobalStepSummary(summary_writer, | ||
54 | + global_step_val, | ||
55 | + global_step_info_dict, | ||
56 | + summary_scope="Eval"): | ||
57 | + """Add the global_step summary to the Tensorboard. | ||
58 | + | ||
59 | + Args: | ||
60 | + summary_writer: Tensorflow summary_writer. | ||
61 | + global_step_val: a int value of the global step. | ||
62 | + global_step_info_dict: a dictionary of the evaluation metrics calculated for | ||
63 | + a mini-batch. | ||
64 | + summary_scope: Train or Eval. | ||
65 | + | ||
66 | + Returns: | ||
67 | + A string of this global_step summary | ||
68 | + """ | ||
69 | + this_hit_at_one = global_step_info_dict["hit_at_one"] | ||
70 | + this_perr = global_step_info_dict["perr"] | ||
71 | + this_loss = global_step_info_dict["loss"] | ||
72 | + examples_per_second = global_step_info_dict.get("examples_per_second", -1) | ||
73 | + | ||
74 | + summary_writer.add_summary( | ||
75 | + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), | ||
76 | + global_step_val) | ||
77 | + summary_writer.add_summary( | ||
78 | + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), | ||
79 | + global_step_val) | ||
80 | + summary_writer.add_summary( | ||
81 | + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), | ||
82 | + global_step_val) | ||
83 | + | ||
84 | + if examples_per_second != -1: | ||
85 | + summary_writer.add_summary( | ||
86 | + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", | ||
87 | + examples_per_second), global_step_val) | ||
88 | + | ||
89 | + summary_writer.flush() | ||
90 | + info = ( | ||
91 | + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch " | ||
92 | + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format( | ||
93 | + global_step_val, this_hit_at_one, this_perr, this_loss, | ||
94 | + examples_per_second) | ||
95 | + return info | ||
96 | + | ||
97 | + | ||
98 | +def AddEpochSummary(summary_writer, | ||
99 | + global_step_val, | ||
100 | + epoch_info_dict, | ||
101 | + summary_scope="Eval"): | ||
102 | + """Add the epoch summary to the Tensorboard. | ||
103 | + | ||
104 | + Args: | ||
105 | + summary_writer: Tensorflow summary_writer. | ||
106 | + global_step_val: a int value of the global step. | ||
107 | + epoch_info_dict: a dictionary of the evaluation metrics calculated for the | ||
108 | + whole epoch. | ||
109 | + summary_scope: Train or Eval. | ||
110 | + | ||
111 | + Returns: | ||
112 | + A string of this global_step summary | ||
113 | + """ | ||
114 | + epoch_id = epoch_info_dict["epoch_id"] | ||
115 | + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] | ||
116 | + avg_perr = epoch_info_dict["avg_perr"] | ||
117 | + avg_loss = epoch_info_dict["avg_loss"] | ||
118 | + aps = epoch_info_dict["aps"] | ||
119 | + gap = epoch_info_dict["gap"] | ||
120 | + mean_ap = numpy.mean(aps) | ||
121 | + | ||
122 | + summary_writer.add_summary( | ||
123 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), | ||
124 | + global_step_val) | ||
125 | + summary_writer.add_summary( | ||
126 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), | ||
127 | + global_step_val) | ||
128 | + summary_writer.add_summary( | ||
129 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), | ||
130 | + global_step_val) | ||
131 | + summary_writer.add_summary( | ||
132 | + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val) | ||
133 | + summary_writer.add_summary( | ||
134 | + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val) | ||
135 | + summary_writer.flush() | ||
136 | + | ||
137 | + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " | ||
138 | + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}" | ||
139 | + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss, | ||
140 | + len(aps)) | ||
141 | + return info | ||
142 | + | ||
143 | + | ||
144 | +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): | ||
145 | + """Extract the list of feature names and the dimensionality of each feature | ||
146 | + | ||
147 | + from string of comma separated values. | ||
148 | + | ||
149 | + Args: | ||
150 | + feature_names: string containing comma separated list of feature names | ||
151 | + feature_sizes: string containing comma separated list of feature sizes | ||
152 | + | ||
153 | + Returns: | ||
154 | + List of the feature names and list of the dimensionality of each feature. | ||
155 | + Elements in the first/second list are strings/integers. | ||
156 | + """ | ||
157 | + list_of_feature_names = [ | ||
158 | + feature_names.strip() for feature_names in feature_names.split(",") | ||
159 | + ] | ||
160 | + list_of_feature_sizes = [ | ||
161 | + int(feature_sizes) for feature_sizes in feature_sizes.split(",") | ||
162 | + ] | ||
163 | + if len(list_of_feature_names) != len(list_of_feature_sizes): | ||
164 | + logging.error("length of the feature names (=" + | ||
165 | + str(len(list_of_feature_names)) + ") != length of feature " | ||
166 | + "sizes (=" + str(len(list_of_feature_sizes)) + ")") | ||
167 | + | ||
168 | + return list_of_feature_names, list_of_feature_sizes | ||
169 | + | ||
170 | + | ||
171 | +def clip_gradient_norms(gradients_to_variables, max_norm): | ||
172 | + """Clips the gradients by the given value. | ||
173 | + | ||
174 | + Args: | ||
175 | + gradients_to_variables: A list of gradient to variable pairs (tuples). | ||
176 | + max_norm: the maximum norm value. | ||
177 | + | ||
178 | + Returns: | ||
179 | + A list of clipped gradient to variable pairs. | ||
180 | + """ | ||
181 | + clipped_grads_and_vars = [] | ||
182 | + for grad, var in gradients_to_variables: | ||
183 | + if grad is not None: | ||
184 | + if isinstance(grad, tf.IndexedSlices): | ||
185 | + tmp = tf.clip_by_norm(grad.values, max_norm) | ||
186 | + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) | ||
187 | + else: | ||
188 | + grad = tf.clip_by_norm(grad, max_norm) | ||
189 | + clipped_grads_and_vars.append((grad, var)) | ||
190 | + return clipped_grads_and_vars | ||
191 | + | ||
192 | + | ||
193 | +def combine_gradients(tower_grads): | ||
194 | + """Calculate the combined gradient for each shared variable across all towers. | ||
195 | + | ||
196 | + Note that this function provides a synchronization point across all towers. | ||
197 | + | ||
198 | + Args: | ||
199 | + tower_grads: List of lists of (gradient, variable) tuples. The outer list is | ||
200 | + over individual gradients. The inner list is over the gradient calculation | ||
201 | + for each tower. | ||
202 | + | ||
203 | + Returns: | ||
204 | + List of pairs of (gradient, variable) where the gradient has been summed | ||
205 | + across all towers. | ||
206 | + """ | ||
207 | + filtered_grads = [ | ||
208 | + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads | ||
209 | + ] | ||
210 | + final_grads = [] | ||
211 | + for i in xrange(len(filtered_grads[0])): | ||
212 | + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] | ||
213 | + grad = tf.stack([x[0] for x in grads], 0) | ||
214 | + grad = tf.reduce_sum(grad, 0) | ||
215 | + final_grads.append(( | ||
216 | + grad, | ||
217 | + filtered_grads[0][i][1], | ||
218 | + )) | ||
219 | + | ||
220 | + return final_grads |
web/backend/yt8m/video_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains model definitions.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import models | ||
18 | +import tensorflow as tf | ||
19 | +import utils | ||
20 | + | ||
21 | +from tensorflow import flags | ||
22 | +import tensorflow.contrib.slim as slim | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer( | ||
26 | + "moe_num_mixtures", 2, | ||
27 | + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") | ||
28 | + | ||
29 | + | ||
30 | +class LogisticModel(models.BaseModel): | ||
31 | + """Logistic model with L2 regularization.""" | ||
32 | + | ||
33 | + def create_model(self, | ||
34 | + model_input, | ||
35 | + vocab_size, | ||
36 | + l2_penalty=1e-8, | ||
37 | + **unused_params): | ||
38 | + """Creates a logistic model. | ||
39 | + | ||
40 | + Args: | ||
41 | + model_input: 'batch' x 'num_features' matrix of input features. | ||
42 | + vocab_size: The number of classes in the dataset. | ||
43 | + | ||
44 | + Returns: | ||
45 | + A dictionary with a tensor containing the probability predictions of the | ||
46 | + model in the 'predictions' key. The dimensions of the tensor are | ||
47 | + batch_size x num_classes. | ||
48 | + """ | ||
49 | + output = slim.fully_connected( | ||
50 | + model_input, | ||
51 | + vocab_size, | ||
52 | + activation_fn=tf.nn.sigmoid, | ||
53 | + weights_regularizer=slim.l2_regularizer(l2_penalty)) | ||
54 | + return {"predictions": output} | ||
55 | + | ||
56 | + | ||
57 | +class MoeModel(models.BaseModel): | ||
58 | + """A softmax over a mixture of logistic models (with L2 regularization).""" | ||
59 | + | ||
60 | + def create_model(self, | ||
61 | + model_input, | ||
62 | + vocab_size, | ||
63 | + num_mixtures=None, | ||
64 | + l2_penalty=1e-8, | ||
65 | + **unused_params): | ||
66 | + """Creates a Mixture of (Logistic) Experts model. | ||
67 | + | ||
68 | + The model consists of a per-class softmax distribution over a | ||
69 | + configurable number of logistic classifiers. One of the classifiers in the | ||
70 | + mixture is not trained, and always predicts 0. | ||
71 | + | ||
72 | + Args: | ||
73 | + model_input: 'batch_size' x 'num_features' matrix of input features. | ||
74 | + vocab_size: The number of classes in the dataset. | ||
75 | + num_mixtures: The number of mixtures (excluding a dummy 'expert' that | ||
76 | + always predicts the non-existence of an entity). | ||
77 | + l2_penalty: How much to penalize the squared magnitudes of parameter | ||
78 | + values. | ||
79 | + | ||
80 | + Returns: | ||
81 | + A dictionary with a tensor containing the probability predictions of the | ||
82 | + model in the 'predictions' key. The dimensions of the tensor are | ||
83 | + batch_size x num_classes. | ||
84 | + """ | ||
85 | + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures | ||
86 | + | ||
87 | + gate_activations = slim.fully_connected( | ||
88 | + model_input, | ||
89 | + vocab_size * (num_mixtures + 1), | ||
90 | + activation_fn=None, | ||
91 | + biases_initializer=None, | ||
92 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
93 | + scope="gates") | ||
94 | + expert_activations = slim.fully_connected( | ||
95 | + model_input, | ||
96 | + vocab_size * num_mixtures, | ||
97 | + activation_fn=None, | ||
98 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
99 | + scope="experts") | ||
100 | + | ||
101 | + gating_distribution = tf.nn.softmax( | ||
102 | + tf.reshape( | ||
103 | + gate_activations, | ||
104 | + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) | ||
105 | + expert_distribution = tf.nn.sigmoid( | ||
106 | + tf.reshape(expert_activations, | ||
107 | + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures | ||
108 | + | ||
109 | + final_probabilities_by_class_and_batch = tf.reduce_sum( | ||
110 | + gating_distribution[:, :num_mixtures] * expert_distribution, 1) | ||
111 | + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, | ||
112 | + [-1, vocab_size]) | ||
113 | + return {"predictions": final_probabilities} |
web/backend/yt8m/vocabulary.csv
0 → 100644
This diff could not be displayed because it is too large.
web/frontend/.browserslistrc
0 → 100644
web/frontend/.editorconfig
0 → 100644
web/frontend/.eslintrc.js
0 → 100644
1 | +module.exports = { | ||
2 | + root: true, | ||
3 | + env: { | ||
4 | + browser: true, | ||
5 | + es6: true, | ||
6 | + node: true, | ||
7 | + }, | ||
8 | + extends: [ | ||
9 | + 'plugin:vue/essential', | ||
10 | + '@vue/airbnb', | ||
11 | + ], | ||
12 | + parserOptions: { | ||
13 | + parser: 'babel-eslint', | ||
14 | + }, | ||
15 | + rules: { | ||
16 | + 'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off', | ||
17 | + 'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off', | ||
18 | + 'linebreak-style': ['error', 'unix'], | ||
19 | + 'max-len': ['error', { code: 200 }], | ||
20 | + }, | ||
21 | +}; |
web/frontend/.gitignore
0 → 100644
web/frontend/README.md
0 → 100644
1 | +# front | ||
2 | + | ||
3 | +## Project setup | ||
4 | +``` | ||
5 | +yarn install | ||
6 | +``` | ||
7 | + | ||
8 | +### Compiles and hot-reloads for development | ||
9 | +``` | ||
10 | +yarn serve | ||
11 | +``` | ||
12 | + | ||
13 | +### Compiles and minifies for production | ||
14 | +``` | ||
15 | +yarn build | ||
16 | +``` | ||
17 | + | ||
18 | +### Lints and fixes files | ||
19 | +``` | ||
20 | +yarn lint | ||
21 | +``` | ||
22 | + | ||
23 | +### Customize configuration | ||
24 | +See [Configuration Reference](https://cli.vuejs.org/config/). |
web/frontend/babel.config.js
0 → 100644
web/frontend/package-lock.json
0 → 100644
This diff could not be displayed because it is too large.
web/frontend/package.json
0 → 100644
1 | +{ | ||
2 | + "name": "front", | ||
3 | + "version": "0.1.0", | ||
4 | + "private": true, | ||
5 | + "scripts": { | ||
6 | + "serve": "vue-cli-service serve", | ||
7 | + "build": "vue-cli-service build", | ||
8 | + "lint": "vue-cli-service lint" | ||
9 | + }, | ||
10 | + "dependencies": { | ||
11 | + "@mdi/font": "^3.6.95", | ||
12 | + "axios": "^0.19.2", | ||
13 | + "core-js": "^3.6.5", | ||
14 | + "moment": "^2.24.0", | ||
15 | + "roboto-fontface": "*", | ||
16 | + "vue": "^2.6.11", | ||
17 | + "vue-router": "^3.1.6", | ||
18 | + "vuetify": "^2.2.11", | ||
19 | + "vuex": "^3.1.3" | ||
20 | + }, | ||
21 | + "devDependencies": { | ||
22 | + "@vue/cli-plugin-babel": "~4.3.0", | ||
23 | + "@vue/cli-plugin-eslint": "~4.3.0", | ||
24 | + "@vue/cli-plugin-router": "~4.3.0", | ||
25 | + "@vue/cli-plugin-vuex": "~4.3.0", | ||
26 | + "@vue/cli-service": "~4.3.0", | ||
27 | + "@vue/eslint-config-airbnb": "^5.0.2", | ||
28 | + "babel-eslint": "^10.1.0", | ||
29 | + "eslint": "^6.7.2", | ||
30 | + "eslint-plugin-import": "^2.20.2", | ||
31 | + "eslint-plugin-vue": "^6.2.2", | ||
32 | + "node-sass": "^4.12.0", | ||
33 | + "sass": "^1.19.0", | ||
34 | + "sass-loader": "^8.0.2", | ||
35 | + "vue-cli-plugin-vuetify": "~2.0.5", | ||
36 | + "vue-template-compiler": "^2.6.11", | ||
37 | + "vuetify-loader": "^1.3.0" | ||
38 | + } | ||
39 | +} |
web/frontend/public/favicon.ico
0 → 100644
No preview for this file type
web/frontend/public/index.html
0 → 100644
1 | +<!DOCTYPE html> | ||
2 | +<html lang="en"> | ||
3 | + <head> | ||
4 | + <meta charset="utf-8"> | ||
5 | + <meta http-equiv="X-UA-Compatible" content="IE=edge"> | ||
6 | + <meta name="viewport" content="width=device-width,initial-scale=1.0"> | ||
7 | + <link rel="icon" href="<%= BASE_URL %>favicon.ico"> | ||
8 | + <title>Profit-Hunter</title> | ||
9 | + </head> | ||
10 | + <body> | ||
11 | + <noscript> | ||
12 | + <strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong> | ||
13 | + </noscript> | ||
14 | + <div id="app"></div> | ||
15 | + <!-- built files will be auto injected --> | ||
16 | + </body> | ||
17 | +</html> |
web/frontend/src/App.vue
0 → 100644
1 | +<template> | ||
2 | + <v-app> | ||
3 | + <v-app-bar app color="#ffffff" elevation="1" hide-on-scroll> | ||
4 | + <v-icon size="35" class="mr-1" color="grey700">mdi-youtube</v-icon> | ||
5 | + <div style="color: #343a40; font-size: 20px; font-weight: 500;">Youtube Auto Tagger</div> | ||
6 | + <v-spacer></v-spacer> | ||
7 | + <v-tooltip bottom> | ||
8 | + <template v-slot:activator="{ on }"> | ||
9 | + <v-btn icon v-on="on" color="grey700" @click="clickInfo=true"> | ||
10 | + <v-icon size="30">mdi-information</v-icon> | ||
11 | + </v-btn> | ||
12 | + </template> | ||
13 | + <span>Service Info</span> | ||
14 | + </v-tooltip> | ||
15 | + </v-app-bar> | ||
16 | + <v-dialog v-model="clickInfo" max-width="600" class="pa-2"> | ||
17 | + <v-card elevation="0" outlined class="pa-4"> | ||
18 | + <v-row justify="center" class="mx-0 mb-8 mt-2"> | ||
19 | + <v-icon color="primary">mdi-power-on</v-icon> | ||
20 | + <div | ||
21 | + style="text-align: center; font-size: 22px; font-weight: 400; color: #343a40;" | ||
22 | + >Information of This Service</div> | ||
23 | + <v-icon color="primary">mdi-power-on</v-icon> | ||
24 | + </v-row> | ||
25 | + <v-btn color="primary" class="mt-2" block>Description of service</v-btn> | ||
26 | + <v-btn color="primary" class="my-2" block>Term of this service</v-btn> | ||
27 | + <v-btn color="primary" block>Used Opensource</v-btn> | ||
28 | + </v-card> | ||
29 | + </v-dialog> | ||
30 | + <v-content> | ||
31 | + <router-view /> | ||
32 | + </v-content> | ||
33 | + <v-footer> | ||
34 | + <v-row justify="center"> | ||
35 | + <v-avatar size="25" tile style="border-radius: 4px"> | ||
36 | + <v-img src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/99f7d5c73e506d2c5c0072a21f362181/logo.69342704.png"></v-img> | ||
37 | + </v-avatar> | ||
38 | + <a href="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1"> | ||
39 | + <div | ||
40 | + style="margin-left: 4px; font-size: 16px; color: #5a5a5a; font-weight: 400" | ||
41 | + >Profit-Hunter</div> | ||
42 | + </a> | ||
43 | + </v-row> | ||
44 | + </v-footer> | ||
45 | + </v-app> | ||
46 | +</template> | ||
47 | +<script> | ||
48 | +export default { | ||
49 | + name: 'App', | ||
50 | + data() { | ||
51 | + return { | ||
52 | + clickInfo: false, | ||
53 | + }; | ||
54 | + }, | ||
55 | +}; | ||
56 | +</script> | ||
57 | +<style lang="scss"> | ||
58 | +a:hover { | ||
59 | + text-decoration: none; | ||
60 | +} | ||
61 | +a:link { | ||
62 | + text-decoration: none; | ||
63 | +} | ||
64 | +a:visited { | ||
65 | + text-decoration: none; | ||
66 | +} | ||
67 | +</style> |
web/frontend/src/components/description.vue
0 → 100644
web/frontend/src/main.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import App from './App.vue'; | ||
3 | +import router from './router'; | ||
4 | +import store from './store'; | ||
5 | +import vuetify from './plugins/vuetify'; | ||
6 | +import 'roboto-fontface/css/roboto/roboto-fontface.css'; | ||
7 | +import '@mdi/font/css/materialdesignicons.css'; | ||
8 | + | ||
9 | +Vue.config.productionTip = false; | ||
10 | + | ||
11 | +new Vue({ | ||
12 | + router, | ||
13 | + store, | ||
14 | + vuetify, | ||
15 | + render: (h) => h(App), | ||
16 | +}).$mount('#app'); |
web/frontend/src/plugins/vuetify.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import Vuetify from 'vuetify/lib'; | ||
3 | + | ||
4 | +Vue.use(Vuetify); | ||
5 | + | ||
6 | +export default new Vuetify({ | ||
7 | + theme: { | ||
8 | + themes: { | ||
9 | + light: { | ||
10 | + primary: '#343a40', | ||
11 | + secondary: '#506980', | ||
12 | + accent: '#505B80', | ||
13 | + error: '#FF5252', | ||
14 | + info: '#2196F3', | ||
15 | + blue: '#173f5f', | ||
16 | + lightblue: '#72b1e4', | ||
17 | + success: '#2779bd', | ||
18 | + warning: '#12283a', | ||
19 | + grey300: '#eceeef', | ||
20 | + grey500: '#aaaaaa', | ||
21 | + grey700: '#5a5a5a', | ||
22 | + grey900: '#212529', | ||
23 | + }, | ||
24 | + dark: { | ||
25 | + primary: '#343a40', | ||
26 | + secondary: '#506980', | ||
27 | + accent: '#505B80', | ||
28 | + error: '#FF5252', | ||
29 | + info: '#2196F3', | ||
30 | + blue: '#173f5f', | ||
31 | + lightblue: '#72b1e4', | ||
32 | + success: '#2779bd', | ||
33 | + warning: '#12283a', | ||
34 | + grey300: '#eceeef', | ||
35 | + grey500: '#aaaaaa', | ||
36 | + grey700: '#5a5a5a', | ||
37 | + grey900: '#212529', | ||
38 | + }, | ||
39 | + }, | ||
40 | + }, | ||
41 | +}); |
web/frontend/src/router/index.js
0 → 100644
1 | +import Vue from 'vue'; | ||
2 | +import VueRouter from 'vue-router'; | ||
3 | +import axios from 'axios'; | ||
4 | +import Home from '../views/Home.vue'; | ||
5 | + | ||
6 | +Vue.prototype.$axios = axios; | ||
7 | +const apiRootPath = process.env.NODE_ENV !== 'production' | ||
8 | + ? 'http://localhost:8000/api/' | ||
9 | + : '/api/'; | ||
10 | +Vue.prototype.$apiRootPath = apiRootPath; | ||
11 | +axios.defaults.baseURL = apiRootPath; | ||
12 | + | ||
13 | +Vue.use(VueRouter); | ||
14 | + | ||
15 | +const routes = [ | ||
16 | + { | ||
17 | + path: '/', | ||
18 | + name: 'Home', | ||
19 | + component: Home, | ||
20 | + }, | ||
21 | +]; | ||
22 | + | ||
23 | +const router = new VueRouter({ | ||
24 | + mode: 'history', | ||
25 | + base: process.env.BASE_URL, | ||
26 | + routes, | ||
27 | +}); | ||
28 | + | ||
29 | +export default router; |
web/frontend/src/store/index.js
0 → 100644
web/frontend/src/views/Home.vue
0 → 100644
1 | +<template> | ||
2 | + <v-sheet> | ||
3 | + <v-overlay v-model="loadingProcess"> | ||
4 | + <v-progress-circular :size="120" width="10" color="primary" indeterminate></v-progress-circular> | ||
5 | + <div | ||
6 | + style="color: #ffffff; font-size: 22px; margin-top: 20px; margin-left: -40px;" | ||
7 | + >Analyzing Your Video...</div> | ||
8 | + </v-overlay> | ||
9 | + <v-layout justify-center> | ||
10 | + <v-flex xs12 sm8 md6 lg4> | ||
11 | + <v-row justify="center" class="mx-0 mt-12"> | ||
12 | + <div style="font-size: 34px; font-weight: 500; color: #343a40;">WELCOME</div> | ||
13 | + </v-row> | ||
14 | + <v-card elevation="0"> | ||
15 | + <!-- 데스크톱 화면 설명 요약 --> | ||
16 | + <v-row justify="center" class="mx-0 mt-12" v-if="$vuetify.breakpoint.mdAndUp"> | ||
17 | + <v-flex md7 class="mt-1"> | ||
18 | + <div | ||
19 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
20 | + >This is Video auto tagging Service</div> | ||
21 | + <div | ||
22 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
23 | + >Designed for Youtube Videos</div> | ||
24 | + <div | ||
25 | + style="font-size: 20px; font-weight: 300; color: #888;" | ||
26 | + >It takes few minutes to analyze your Video!</div> | ||
27 | + </v-flex> | ||
28 | + <v-flex md5> | ||
29 | + <v-card width="240" elevation="0" class="ml-5"> | ||
30 | + <v-img | ||
31 | + width="240" | ||
32 | + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png" | ||
33 | + ></v-img> | ||
34 | + </v-card> | ||
35 | + </v-flex> | ||
36 | + </v-row> | ||
37 | + | ||
38 | + <!-- 모바일 화면 설명 요약 --> | ||
39 | + <v-card elevation="0" class="mt-8" v-else> | ||
40 | + <div | ||
41 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
42 | + >This is Video auto tagging Service</div> | ||
43 | + <div | ||
44 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
45 | + >Designed for Youtube Videos</div> | ||
46 | + <div | ||
47 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
48 | + >It takes few minutes to analyze your Video!</div> | ||
49 | + <v-img | ||
50 | + style="margin: auto; margin-top: 20px" | ||
51 | + width="180" | ||
52 | + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png" | ||
53 | + ></v-img> | ||
54 | + </v-card> | ||
55 | + | ||
56 | + <!-- Set Threshold --> | ||
57 | + <div | ||
58 | + class="mt-10" | ||
59 | + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;" | ||
60 | + >How To start this service</div> | ||
61 | + <div | ||
62 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center; margin-bottom: 15px" | ||
63 | + > | ||
64 | + <div>Set up Threshold of</div> | ||
65 | + <div>Recommended Youtube link</div> | ||
66 | + </div> | ||
67 | + <v-row style="max-width: 300px; margin: auto"> | ||
68 | + <v-slider v-model="threshold" :thumb-size="20" thumb-label="always" :min="2" :max="15"></v-slider> | ||
69 | + </v-row> | ||
70 | + | ||
71 | + <!-- Upload Video --> | ||
72 | + <div | ||
73 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
74 | + >Then, Just Upload your Video</div> | ||
75 | + <v-row justify="center" class="mx-0 mt-2"> | ||
76 | + <v-card | ||
77 | + max-width="500" | ||
78 | + outlined | ||
79 | + height="120" | ||
80 | + class="pa-9" | ||
81 | + @dragover.prevent | ||
82 | + @dragenter.prevent | ||
83 | + @drop.prevent="onDrop" | ||
84 | + > | ||
85 | + <v-btn | ||
86 | + style="text-transform: none" | ||
87 | + @click="clickUploadButton" | ||
88 | + text | ||
89 | + large | ||
90 | + color="primary" | ||
91 | + >CLICK or DRAG & DROP</v-btn> | ||
92 | + <input ref="fileInput" style="display: none" type="file" @change="onFileChange" /> | ||
93 | + </v-card> | ||
94 | + </v-row> | ||
95 | + | ||
96 | + <!-- 결과 화면 --> | ||
97 | + <div | ||
98 | + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;" | ||
99 | + class="mt-10" | ||
100 | + >The Results of Analyzed Video</div> | ||
101 | + <v-card outlined class="pa-2 mx-5 mt-6" elevation="0" min-height="67"> | ||
102 | + <div | ||
103 | + style="margin-left: 5px; margin-top: -18px; background-color: #fff; width: 110px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500" | ||
104 | + >Generated Tags</div> | ||
105 | + <v-chip-group column> | ||
106 | + <v-chip color="secondary" v-for="(tag, index) in generatedTag" :key="index">{{ tag[0] }} : {{tag[1]}}</v-chip> | ||
107 | + </v-chip-group> | ||
108 | + </v-card> | ||
109 | + <v-card outlined class="pa-3 mx-5 mt-8" elevation="0" min-height="67"> | ||
110 | + <div | ||
111 | + style="margin-left: 5px; margin-top: -22px; margin-bottom: 5px; background-color: #fff; width: 140px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500" | ||
112 | + >Related Youtube Link</div> | ||
113 | + <v-flex style="margin-bottom: 2px" v-for="(url) in YoutubeUrl" :key="url.id"> | ||
114 | + <div> | ||
115 | + <a style="color: #343a40; font-size: 16px; font-weight: 500" :href="url">{{url}}</a> | ||
116 | + </div> | ||
117 | + </v-flex> | ||
118 | + </v-card> | ||
119 | + <div | ||
120 | + class="mt-3" | ||
121 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
122 | + >If the Video is analyzed successfully,</div> | ||
123 | + <div | ||
124 | + class="mb-5" | ||
125 | + style="font-size: 20px; font-weight: 300; color: #888; text-align: center" | ||
126 | + >Result Show up in each of Boxes!</div> | ||
127 | + </v-card> | ||
128 | + | ||
129 | + </v-flex> | ||
130 | + </v-layout> | ||
131 | + </v-sheet> | ||
132 | +</template> | ||
133 | +<script> | ||
134 | +export default { | ||
135 | + name: 'Home', | ||
136 | + data() { | ||
137 | + return { | ||
138 | + videoFile: '', | ||
139 | + YoutubeUrl: [], | ||
140 | + generatedTag: [], | ||
141 | + threshold: 5, | ||
142 | + successDialog: false, | ||
143 | + errorDialog: false, | ||
144 | + loadingProcess: false, | ||
145 | + }; | ||
146 | + }, | ||
147 | + created() { | ||
148 | + // this.YoutubeUrl = []; | ||
149 | + // this.generatedTag = []; | ||
150 | + }, | ||
151 | + methods: { | ||
152 | + loadVideoInfo() {}, | ||
153 | + uploadVideo(files) { | ||
154 | + this.loadingProcess = true; | ||
155 | + const formData = new FormData(); | ||
156 | + formData.append('file', files[0]); | ||
157 | + formData.append('threshold', this.threshold); | ||
158 | + console.log(files[0]); | ||
159 | + this.$axios | ||
160 | + .post('/upload', formData, { | ||
161 | + headers: { 'Content-Type': 'multipart/form-data' }, | ||
162 | + }) | ||
163 | + .then((r) => { | ||
164 | + const tag = r.data.tag_result; | ||
165 | + const url = r.data.video_result; | ||
166 | + url.forEach((element) => { | ||
167 | + this.YoutubeUrl.push(element.video_url); | ||
168 | + }); | ||
169 | + this.generatedTag = tag; | ||
170 | + this.loadingProcess = false; | ||
171 | + this.successDialog = true; | ||
172 | + console.log(tag, url); | ||
173 | + }) | ||
174 | + .catch((e) => { | ||
175 | + this.errorDialog = true; | ||
176 | + console.log(e.message); | ||
177 | + }); | ||
178 | + }, | ||
179 | + onDrop(event) { | ||
180 | + this.uploadVideo(event.dataTransfer.files); | ||
181 | + }, | ||
182 | + clickUploadButton() { | ||
183 | + this.$refs.fileInput.click(); | ||
184 | + }, | ||
185 | + onFileChange(event) { | ||
186 | + this.uploadVideo(event.target.files); | ||
187 | + }, | ||
188 | + }, | ||
189 | +}; | ||
190 | +</script> |
web/frontend/vue.config.js
0 → 100644
web/frontend/yarn.lock
0 → 100644
This diff could not be displayed because it is too large.
-
Please register or login to post a comment