윤영빈

copied original project

Showing 94 changed files with 4619 additions and 0 deletions
1 +# Created by https://www.gitignore.io/api/django
2 +# Edit at https://www.gitignore.io/?templates=django
3 +
4 +### Django ###
5 +*.log
6 +*.pot
7 +*.pyc
8 +__pycache__/
9 +local_settings.py
10 +db.sqlite3
11 +db.sqlite3-journal
12 +media
13 +static/
14 +/static
15 +/backend/env
16 +env/
17 +/env
18 +
19 +# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/
20 +# in your Git repository. Update and uncomment the following line accordingly.
21 +# <django-project-name>/staticfiles/
22 +
23 +### Django.Python Stack ###
24 +# Byte-compiled / optimized / DLL files
25 +*.py[cod]
26 +*$py.class
27 +
28 +# C extensions
29 +*.so
30 +
31 +# Distribution / packaging
32 +.Python
33 +build/
34 +develop-eggs/
35 +dist/
36 +downloads/
37 +eggs/
38 +.eggs/
39 +lib/
40 +lib64/
41 +parts/
42 +sdist/
43 +var/
44 +wheels/
45 +pip-wheel-metadata/
46 +share/python-wheels/
47 +*.egg-info/
48 +.installed.cfg
49 +*.egg
50 +MANIFEST
51 +
52 +# PyInstaller
53 +# Usually these files are written by a python script from a template
54 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
55 +*.manifest
56 +*.spec
57 +
58 +# Installer logs
59 +pip-log.txt
60 +pip-delete-this-directory.txt
61 +
62 +# Unit test / coverage reports
63 +htmlcov/
64 +.tox/
65 +.nox/
66 +.coverage
67 +.coverage.*
68 +.cache
69 +nosetests.xml
70 +coverage.xml
71 +*.cover
72 +.hypothesis/
73 +.pytest_cache/
74 +
75 +# Translations
76 +*.mo
77 +
78 +# Scrapy stuff:
79 +.scrapy
80 +
81 +# Sphinx documentation
82 +docs/_build/
83 +
84 +# PyBuilder
85 +target/
86 +
87 +# pyenv
88 +.python-version
89 +
90 +# pipenv
91 +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 +# However, in case of collaboration, if having platform-specific dependencies or dependencies
93 +# having no cross-platform support, pipenv may install dependencies that don't work, or not
94 +# install all needed dependencies.
95 +#Pipfile.lock
96 +
97 +# celery beat schedule file
98 +celerybeat-schedule
99 +
100 +# SageMath parsed files
101 +*.sage.py
102 +
103 +# Spyder project settings
104 +.spyderproject
105 +.spyproject
106 +
107 +# Rope project settings
108 +.ropeproject
109 +
110 +# Mr Developer
111 +.mr.developer.cfg
112 +.project
113 +.pydevproject
114 +
115 +# mkdocs documentation
116 +/site
117 +
118 +# mypy
119 +.mypy_cache/
120 +.dmypy.json
121 +dmypy.json
122 +
123 +# Pyre type checker
124 +.pyre/
125 +
126 +# End of https://www.gitignore.io/api/django
127 +
128 +.DS_Store
129 +node_modules
130 +/dist
131 +
132 +# local env files
133 +.env.local
134 +.env.*.local
135 +
136 +# Log files
137 +npm-debug.log*
138 +yarn-debug.log*
139 +yarn-error.log*
140 +
141 +# Editor directories and files
142 +.idea
143 +.vscode
144 +*.suo
145 +*.ntvs*
146 +*.njsproj
147 +*.sln
148 +*.sw?
149 +
1 +import os
2 +from django.conf.global_settings import FILE_UPLOAD_MAX_MEMORY_SIZE
3 +import datetime
4 +import uuid
5 +
6 +# 각 media 파일에 대한 URL Prefix
7 +MEDIA_URL = '/media/' # 항상 / 로 끝나도록 설정
8 +# MEDIA_URL = 'http://static.myservice.com/media/' 다른 서버로 media 파일 복사시
9 +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10 +
11 +MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
12 +
13 +# 파일 업로드 사이즈 100M ( 100 * 1024 * 1024 )
14 +#FILE_UPLOAD_MAX_MEMORY_SIZE = 104857600
15 +
16 +# 실제 파일을 저장할 경로 및 파일 명 생성
17 +# 폴더는 일별로 생성됨
18 +def file_upload_path( filename):
19 + ext = filename.split('.')[-1]
20 + d = datetime.datetime.now()
21 + filepath = d.strftime('%Y-%m-%d')
22 + suffix = d.strftime("%Y%m%d%H%M%S")
23 + filename = "%s.%s"%(suffix, ext)
24 + return os.path.join( MEDIA_ROOT , filepath, filename)
25 +
26 +# DB 필드에서 호출
27 +def file_upload_path_for_db( intance, filename):
28 + return file_upload_path(filename)
...\ No newline at end of file ...\ No newline at end of file
1 +from django.contrib import admin
2 +
3 +# Register your models here.
1 +from django.apps import AppConfig
2 +
3 +
4 +class ApiConfig(AppConfig):
5 + name = 'api'
1 +# Generated by Django 3.0.5 on 2020-05-02 12:19
2 +
3 +import api
4 +from django.db import migrations, models
5 +
6 +
7 +class Migration(migrations.Migration):
8 +
9 + initial = True
10 +
11 + dependencies = [
12 + ]
13 +
14 + operations = [
15 + migrations.CreateModel(
16 + name='Video',
17 + fields=[
18 + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
19 + ('videourl', models.CharField(blank=True, max_length=1000)),
20 + ('title', models.CharField(max_length=200)),
21 + ('tags', models.CharField(max_length=500)),
22 + ],
23 + ),
24 + migrations.CreateModel(
25 + name='VideoFile',
26 + fields=[
27 + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
28 + ('file_save_name', models.FileField(upload_to=api.file_upload_path_for_db)),
29 + ('file_origin_name', models.CharField(max_length=100)),
30 + ('file_path', models.CharField(max_length=100)),
31 + ],
32 + ),
33 + ]
1 +# Generated by Django 3.0.5 on 2020-05-31 05:39
2 +
3 +from django.db import migrations, models
4 +
5 +
6 +class Migration(migrations.Migration):
7 +
8 + dependencies = [
9 + ('api', '0001_initial'),
10 + ]
11 +
12 + operations = [
13 + migrations.AddField(
14 + model_name='video',
15 + name='threshold',
16 + field=models.CharField(default=20, max_length=20),
17 + preserve_default=False,
18 + ),
19 + ]
1 +from django.db import models
2 +from api import file_upload_path_for_db
3 +# Create your models here.
4 +
5 +
6 +class Video(models.Model):
7 + videourl = models.CharField(max_length=1000, blank=True)
8 + title = models.CharField(max_length=200)
9 + threshold = models.CharField(max_length=20)
10 + tags = models.CharField(max_length=500)
11 +
12 +
13 +class VideoFile(models.Model):
14 + file_save_name = models.FileField(upload_to=file_upload_path_for_db, blank=False, null=False)
15 + # 파일의 원래 이름
16 + file_origin_name = models.CharField(max_length=100)
17 + # 파일 저장 경로
18 + file_path = models.CharField(max_length=100)
19 +
20 + def __str__(self):
21 + return self.file.name
1 +from rest_framework import serializers
2 +from api.models import Video, VideoFile
3 +
4 +
5 +class VideoSerializer(serializers.ModelSerializer):
6 +
7 + class Meta:
8 + model = Video
9 + fields = '__all__'
10 +
11 +
12 +class VideoFileSerializer(serializers.ModelSerializer):
13 +
14 + class Meta:
15 + model = VideoFile
16 + fields = '__all__'
1 +from django.test import TestCase
2 +
3 +# Create your tests here.
1 +from django.urls import path, include
2 +from django.conf.urls import url
3 +from api.views import VideoFileUploadView, VideoFileList, FileListView
4 +from . import views
5 +from rest_framework.routers import DefaultRouter
6 +from django.views.generic import TemplateView
7 +
8 +router = DefaultRouter()
9 +router.register('db/videofile', views.VideoFileViewSet)
10 +router.register('db/video', views.VideoViewSet)
11 +
12 +urlpatterns = [
13 + # FBV
14 + path('api/upload', VideoFileUploadView.as_view(), name="file-upload"),
15 + path('api/upload/<int:pk>/', VideoFileList.as_view(), name="file-list"),
16 + path('api/file', FileListView.as_view(), name="file"),
17 + url(r'^(?P<path>.*)$', TemplateView.as_view(template_name='index.html')),
18 + # path('api/upload', views.VideoFile_Upload),
19 + path('', include(router.urls)),
20 +]
1 +from rest_framework import status
2 +from rest_framework.views import APIView
3 +from rest_framework.response import Response
4 +from rest_framework.parsers import MultiPartParser, FormParser
5 +from rest_framework.response import Response
6 +from rest_framework import viewsets
7 +import os
8 +from django.http.request import QueryDict
9 +from django.http import Http404
10 +from django.http import JsonResponse
11 +from django.shortcuts import get_object_or_404, render
12 +from api.models import Video, VideoFile
13 +from api.serializers import VideoSerializer, VideoFileSerializer
14 +from api import file_upload_path
15 +
16 +import subprocess
17 +import shlex
18 +import json
19 +# Create your views here.
20 +import sys
21 +sys.path.insert(0, "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria")
22 +import inference_pb
23 +
24 +def with_ffprobe(filename):
25 +
26 + result = subprocess.check_output(
27 + f'ffprobe -v quiet -show_streams -select_streams v:0 -of json "{filename}"',
28 + shell=True).decode()
29 + fields = json.loads(result)['streams'][0]
30 + duration = int(float(fields['duration']))
31 + return duration
32 +
33 +def index(request):
34 + return render(request, template_name='index.html')
35 +
36 +class VideoViewSet(viewsets.ModelViewSet):
37 + queryset = Video.objects.all()
38 + serializer_class = VideoSerializer
39 +
40 +class VideoFileViewSet(viewsets.ModelViewSet):
41 + queryset = VideoFile.objects.all()
42 + serializer_class = VideoFileSerializer
43 +
44 +
45 +class VideoFileUploadView(APIView):
46 + parser_classes = (MultiPartParser, FormParser)
47 +
48 + def get(self, request, format=None):
49 + videoFiles = VideoFile.objects.all()
50 + serializer = VideoFileSerializer(videoFiles, many=True)
51 + return Response(serializer.data)
52 +
53 + def post(self, req, *args, **kwargs):
54 + # 동영상 길이
55 + runTime = 0
56 + # 요청된 데이터를 꺼냄( QueryDict)
57 + new_data = req.data.dict()
58 +
59 + # 요청된 파일 객체
60 + file_name = req.data['file']
61 + threshold = req.data['threshold']
62 +
63 + # 저장될 파일의 풀path를 생성
64 + new_file_full_name = file_upload_path(file_name.name)
65 + # 새롭게 생성된 파일의 경로
66 + file_path = '-'.join(new_file_full_name.split('-')[0:-1])
67 +
68 + new_data['file_path'] = file_path
69 + new_data['file_origin_name'] = req.data['file'].name
70 + new_data['file_save_name'] = req.data['file']
71 +
72 + new_query_dict = QueryDict('', mutable=True)
73 + new_query_dict.update(new_data)
74 + file_serializer = VideoFileSerializer(data = new_query_dict)
75 +
76 + if file_serializer.is_valid():
77 + file_serializer.save()
78 + # 동영상 길이 출력
79 + runTime = with_ffprobe('/'+file_serializer.data['file_save_name'])
80 + print(runTime)
81 + print(threshold)
82 + process = subprocess.Popen(['./runMediaPipe.sh %s %s' %(file_serializer.data['file_save_name'],runTime,)], shell = True)
83 + process.wait()
84 +
85 +
86 + result = inference_pb.inference_pb('/tmp/mediapipe/features.pb', threshold)
87 +
88 + return Response(result, status=status.HTTP_201_CREATED)
89 + else:
90 + return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
91 +
92 +
93 +class VideoFileList(APIView):
94 +
95 + def get_object(self, pk):
96 + try:
97 + return VideoFile.objects.get(pk=pk)
98 + except VideoFile.DoesNotExist:
99 + raise Http404
100 +
101 + def get(self, request, pk, format=None):
102 + video = self.get_object(pk)
103 + serializer = VideoFileSerializer(video)
104 + return Response(serializer.data)
105 +
106 + def put(self, request, pk, format=None):
107 + video = self.get_object(pk)
108 + serializer = VideoFileSerializer(video, data=request.data)
109 + if serializer.is_valid():
110 + serializer.save()
111 + return Response(serializer.data)
112 + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
113 +
114 + def delete(self, request, pk, format=None):
115 + video = self.get_object(pk)
116 + video.delete()
117 + return Response(status=status.HTTP_204_NO_CONTENT)
118 +
119 +
120 +class FileListView(APIView):
121 + def get(self, request):
122 + data = {
123 + "search": '',
124 + "limit": 10,
125 + "skip": 0,
126 + "order": "time",
127 + "fileList": [
128 + {
129 + "name": "1.png",
130 + "created": "2020-04-30",
131 + "size": 10234,
132 + "isFolder": False,
133 + "deletedDate": "",
134 + },
135 + {
136 + "name": "2.png",
137 + "created": "2020-04-30",
138 + "size": 3145,
139 + "isFolder": False,
140 + "deletedDate": "",
141 + },
142 + {
143 + "name": "3.png",
144 + "created": "2020-05-01",
145 + "size": 5653,
146 + "isFolder": False,
147 + "deletedDate": "",
148 + },
149 + ]
150 + }
151 + return Response(data)
152 + def post(self, request, format=None):
153 + data = {
154 + "isSuccess": True,
155 + "File": {
156 + "name": "test.jpg",
157 + "created": "2020-05-02",
158 + "deletedDate": "",
159 + "size": 2312,
160 + "isFolder": False
161 + }
162 + }
163 + return Response(data)
...\ No newline at end of file ...\ No newline at end of file
1 +"""
2 +ASGI config for backend project.
3 +
4 +It exposes the ASGI callable as a module-level variable named ``application``.
5 +
6 +For more information on this file, see
7 +https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
8 +"""
9 +
10 +import os
11 +
12 +from django.core.asgi import get_asgi_application
13 +
14 +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
15 +
16 +application = get_asgi_application()
1 +"""
2 +Django settings for backend project.
3 +
4 +Generated by 'django-admin startproject' using Django 3.0.5.
5 +
6 +For more information on this file, see
7 +https://docs.djangoproject.com/en/3.0/topics/settings/
8 +
9 +For the full list of settings and their values, see
10 +https://docs.djangoproject.com/en/3.0/ref/settings/
11 +"""
12 +
13 +import os
14 +
15 +# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
16 +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17 +
18 +
19 +# Quick-start development settings - unsuitable for production
20 +# See https://docs.djangoproject.com/en/3.0/howto/deployment/checklist/
21 +
22 +# SECURITY WARNING: keep the secret key used in production secret!
23 +SECRET_KEY = '7e^4!u6019jww&=-!mu%r$hz6jy#=i+i9@9m_44+ga^#%7#e0l'
24 +
25 +# SECURITY WARNING: don't run with debug turned on in production!
26 +DEBUG = True
27 +
28 +ALLOWED_HOSTS = ['*']
29 +
30 +
31 +# Static File DIR Settings
32 +STATIC_URL = '/static/'
33 +STATIC_ROOT = os.path.join(BASE_DIR, 'static')
34 +
35 +
36 +FRONTEND_DIR = os.path.join(os.path.abspath('../'), 'frontend')
37 +STATICFILES_DIRS = [
38 + os.path.join(FRONTEND_DIR, 'dist/'),
39 +]
40 +# Application definition
41 +
42 +REST_FRAMEWORK = {
43 + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
44 + 'PAGE_SIZE': 10
45 +}
46 +
47 +INSTALLED_APPS = [
48 + 'django.contrib.admin',
49 + 'django.contrib.auth',
50 + 'django.contrib.contenttypes',
51 + 'django.contrib.sessions',
52 + 'django.contrib.messages',
53 + 'django.contrib.staticfiles',
54 + 'rest_framework',
55 + 'corsheaders',
56 + 'api'
57 +]
58 +
59 +MIDDLEWARE = [
60 + 'django.middleware.security.SecurityMiddleware',
61 + 'django.contrib.sessions.middleware.SessionMiddleware',
62 + 'django.middleware.common.CommonMiddleware',
63 + # 'django.middleware.csrf.CsrfViewMiddleware',
64 + 'django.contrib.auth.middleware.AuthenticationMiddleware',
65 + 'django.contrib.messages.middleware.MessageMiddleware',
66 + 'django.middleware.clickjacking.XFrameOptionsMiddleware',
67 + 'corsheaders.middleware.CorsMiddleware',
68 +]
69 +
70 +ROOT_URLCONF = 'backend.urls'
71 +
72 +TEMPLATES = [
73 + {
74 + 'BACKEND': 'django.template.backends.django.DjangoTemplates',
75 + 'DIRS': [
76 + os.path.join(BASE_DIR, 'templates'),
77 + ],
78 + 'APP_DIRS': True,
79 + 'OPTIONS': {
80 + 'context_processors': [
81 + 'django.template.context_processors.debug',
82 + 'django.template.context_processors.request',
83 + 'django.contrib.auth.context_processors.auth',
84 + 'django.contrib.messages.context_processors.messages',
85 + ],
86 + },
87 + },
88 +]
89 +
90 +WSGI_APPLICATION = 'backend.wsgi.application'
91 +
92 +
93 +# Database
94 +# https://docs.djangoproject.com/en/3.0/ref/settings/#databases
95 +
96 +DATABASES = {
97 + 'default': {
98 + 'ENGINE': 'django.db.backends.sqlite3',
99 + 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
100 + }
101 +}
102 +
103 +
104 +# Password validation
105 +# https://docs.djangoproject.com/en/3.0/ref/settings/#auth-password-validators
106 +
107 +AUTH_PASSWORD_VALIDATORS = [
108 + {
109 + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
110 + },
111 + {
112 + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
113 + },
114 + {
115 + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
116 + },
117 + {
118 + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
119 + },
120 +]
121 +
122 +
123 +# Internationalization
124 +# https://docs.djangoproject.com/en/3.0/topics/i18n/
125 +
126 +LANGUAGE_CODE = 'en-us'
127 +
128 +TIME_ZONE = 'UTC'
129 +
130 +USE_I18N = True
131 +
132 +USE_L10N = True
133 +
134 +USE_TZ = True
135 +
136 +
137 +# CORS settings
138 +
139 +CORS_ORIGIN_ALLOW_ALL = True
140 +CORS_ALLOW_CREDENTIALS = True
141 +CORS_ORIGIN_WHITELIST = [
142 + "http://127.0.0.1:12233",
143 + "http://localhost:8080",
144 + "http://127.0.0.1:9000"
145 +]
146 +CORS_URLS_REGEX = r'^/api/.*$'
147 +CORS_ALLOW_METHODS = (
148 + 'DELETE',
149 + 'GET',
150 + 'OPTIONS',
151 + 'PATCH',
152 + 'POST',
153 + 'PUT',
154 +)
155 +
156 +CORS_ALLOW_HEADERS = (
157 + 'accept',
158 + 'accept-encoding',
159 + 'authorization',
160 + 'content-type',
161 + 'dnt',
162 + 'origin',
163 + 'user-agent',
164 + 'x-csrftoken',
165 + 'x-requested-with',
166 +)
1 +"""backend URL Configuration
2 +
3 +The `urlpatterns` list routes URLs to views. For more information please see:
4 + https://docs.djangoproject.com/en/3.0/topics/http/urls/
5 +Examples:
6 +Function views
7 + 1. Add an import: from my_app import views
8 + 2. Add a URL to urlpatterns: path('', views.home, name='home')
9 +Class-based views
10 + 1. Add an import: from other_app.views import Home
11 + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
12 +Including another URLconf
13 + 1. Import the include() function: from django.urls import include, path
14 + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
15 +"""
16 +from django.contrib import admin
17 +from django.urls import path
18 +from django.conf.urls import url, include
19 +from django.conf import settings
20 +from django.conf.urls.static import static
21 +
22 +urlpatterns = [
23 + path('', include('api.urls')),
24 + path('admin/', admin.site.urls),
25 +]
...\ No newline at end of file ...\ No newline at end of file
1 +"""
2 +WSGI config for backend project.
3 +
4 +It exposes the WSGI callable as a module-level variable named ``application``.
5 +
6 +For more information on this file, see
7 +https://docs.djangoproject.com/en/3.0/howto/deployment/wsgi/
8 +"""
9 +
10 +import os
11 +
12 +from django.core.wsgi import get_wsgi_application
13 +
14 +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
15 +
16 +application = get_wsgi_application()
1 +import tensorflow as tf
2 +import glob, os
3 +import numpy
4 +
5 +def _make_bytes(int_array):
6 + if bytes == str: # Python2
7 + return ''.join(map(chr, int_array))
8 + else:
9 + return bytes(int_array)
10 +
11 +
12 +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0):
13 + """Quantizes float32 `features` into string."""
14 + assert features.dtype == 'float32'
15 + assert len(features.shape) == 1 # 1-D array
16 + features = numpy.clip(features, min_quantized_value, max_quantized_value)
17 + quantize_range = max_quantized_value - min_quantized_value
18 + features = (features - min_quantized_value) * (255.0 / quantize_range)
19 + features = [int(round(f)) for f in features]
20 +
21 + return _make_bytes(features)
22 +
23 +
24 +# for parse feature.pb
25 +
26 +contexts = {
27 + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
28 + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32),
29 + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
30 + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32),
31 + 'clip/data_path': tf.io.FixedLenFeature([], tf.string),
32 + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64),
33 + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64)
34 +}
35 +
36 +features = {
37 + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
38 + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64),
39 + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
40 + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64)
41 +
42 +}
43 +
44 +
45 +def parse_exmp(serial_exmp):
46 + _, sequence_parsed = tf.io.parse_single_sequence_example(
47 + serialized=serial_exmp,
48 + context_features=contexts,
49 + sequence_features=features)
50 +
51 + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0]
52 +
53 + audio = sequence_parsed['AUDIO/feature/floats'].values
54 + rgb = sequence_parsed['RGB/feature/floats'].values
55 +
56 + # print(audio.values)
57 + # print(type(audio.values))
58 +
59 + # audio is 128 8bit, rgb is 1024 8bit for every second
60 + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)]
61 + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)]
62 +
63 + byte_audio = []
64 + byte_rgb = []
65 +
66 + for seg in audio_slices:
67 + audio_seg = quantize(seg)
68 + byte_audio.append(audio_seg)
69 +
70 + for seg in rgb_slices:
71 + rgb_seg = quantize(seg)
72 + byte_rgb.append(rgb_seg)
73 +
74 + return byte_audio, byte_rgb
75 +
76 +
77 +def make_exmp(id, labels, audio, rgb):
78 + audio_features = []
79 + rgb_features = []
80 +
81 + for embedding in audio:
82 + embedding_feature = tf.train.Feature(
83 + bytes_list=tf.train.BytesList(value=[embedding]))
84 + audio_features.append(embedding_feature)
85 +
86 + for embedding in rgb:
87 + embedding_feature = tf.train.Feature(
88 + bytes_list=tf.train.BytesList(value=[embedding]))
89 + rgb_features.append(embedding_feature)
90 +
91 + # for construct yt8m data
92 + seq_exmp = tf.train.SequenceExample(
93 + context=tf.train.Features(
94 + feature={
95 + 'id': tf.train.Feature(bytes_list=tf.train.BytesList(
96 + value=[id.encode('utf-8')])),
97 + 'labels': tf.train.Feature(int64_list=tf.train.Int64List(
98 + value=[labels]))
99 + }),
100 + feature_lists=tf.train.FeatureLists(
101 + feature_list={
102 + 'audio': tf.train.FeatureList(
103 + feature=audio_features
104 + ),
105 + 'rgb': tf.train.FeatureList(
106 + feature=rgb_features
107 + )
108 + })
109 + )
110 + serialized = seq_exmp.SerializeToString()
111 + return serialized
112 +
113 +
114 +if __name__ == '__main__':
115 + filename = '/tmp/mediapipe/features.pb'
116 +
117 + sequence_example = open(filename, 'rb').read()
118 +
119 + audio, rgb = parse_exmp(sequence_example)
120 +
121 + id = 'test_001'
122 +
123 + labels = 1
124 +
125 + tmp_example = make_exmp(id, labels, audio, rgb)
126 +
127 + decoded = tf.train.SequenceExample.FromString(tmp_example)
128 + print(decoded)
129 +
130 + # then you can write tmp_example to tfrecord files
...\ No newline at end of file ...\ No newline at end of file
1 +#!/usr/bin/env python
2 +"""Django's command-line utility for administrative tasks."""
3 +import os
4 +import sys
5 +
6 +
7 +def main():
8 + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
9 + try:
10 + from django.core.management import execute_from_command_line
11 + except ImportError as exc:
12 + raise ImportError(
13 + "Couldn't import Django. Are you sure it's installed and "
14 + "available on your PYTHONPATH environment variable? Did you "
15 + "forget to activate a virtual environment?"
16 + ) from exc
17 + execute_from_command_line(sys.argv)
18 +
19 +
20 +if __name__ == '__main__':
21 + main()
1 +Django==3.0.5
2 +django-cors-headers==3.2.1
3 +djangorestframework==3.11.0
4 +tensorflow==1.15
5 +pandas==1.0.4
6 +gensim==3.8.3
...\ No newline at end of file ...\ No newline at end of file
1 +#!/bin/bash
2 +. env/bin/activate
3 +python manage.py migrate
4 +python manage.py runserver 0.0.0.0:8000
...\ No newline at end of file ...\ No newline at end of file
1 +#!/bin/bash
2 +cd ../../../mediapipe
3 +. venv/bin/activate
4 +
5 +/usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \
6 +alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel'
7 +
8 +python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
9 + --path_to_input_video=/$1 \
10 + --clip_end_time_sec=$2
11 +
12 +GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
13 + --calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
14 + --input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.pb \
15 + --output_side_packets=output_sequence_example=/tmp/mediapipe/features.pb
...\ No newline at end of file ...\ No newline at end of file
1 +{% load static %}
2 +<!doctype html>
3 +<html lang="en">
4 +
5 +<head>
6 + <meta charset=utf-8>
7 + <meta http-equiv=X-UA-Compatible content="IE=edge">
8 + <meta name=viewport content="width=device-width,initial-scale=1">
9 + <link rel=icon href="{% static 'favicon.ico' %}">
10 + <title>ThrowBox</title>
11 + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=preload as=style>
12 + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=preload as=style>
13 + <link href="{% static '/js/app.1e0612ff.js' %}" rel=preload as=script>
14 + <link href="{% static '/js/chunk-vendors.951af5fd.js' %}" rel=preload as=script>
15 + <link href="{% static '/css/chunk-vendors.e4bdc0d1.css' %}" rel=stylesheet>
16 + <link href="{% static '/css/app.1dc1d4aa.css' %}" rel=stylesheet>
17 +</head>
18 +
19 +<body>
20 + <noscript><strong>We're sorry but ThrowBox doesn't work properly without JavaScript enabled. Please enable it to
21 + continue.</strong></noscript>
22 + <div id=app>
23 + </div>
24 + <script src="{% static '/js/chunk-vendors.951af5fd.js' %}"></script>
25 + <script src="{% static '/js/app.1e0612ff.js' %}"></script>
26 +</body>
27 +
28 +</html>
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +read = open("static/index.html", 'r')
3 +write = open("templates/index.html", 'w')
4 +
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate or keep track of the interpolated average precision.
15 +
16 +It provides an interface for calculating interpolated average precision for an
17 +entire list or the top-n ranked items. For the definition of the
18 +(non-)interpolated average precision:
19 +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
20 +
21 +Example usages:
22 +1) Use it as a static function call to directly calculate average precision for
23 +a short ranked list in the memory.
24 +
25 +```
26 +import random
27 +
28 +p = np.array([random.random() for _ in xrange(10)])
29 +a = np.array([random.choice([0, 1]) for _ in xrange(10)])
30 +
31 +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
32 +```
33 +
34 +2) Use it as an object for long ranked list that cannot be stored in memory or
35 +the case where partial predictions can be observed at a time (Tensorflow
36 +predictions). In this case, we first call the function accumulate many times
37 +to process parts of the ranked list. After processing all the parts, we call
38 +peek_interpolated_ap_at_n.
39 +```
40 +p1 = np.array([random.random() for _ in xrange(5)])
41 +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
42 +p2 = np.array([random.random() for _ in xrange(5)])
43 +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
44 +
45 +# interpolated average precision at 10 using 1000 break points
46 +calculator = average_precision_calculator.AveragePrecisionCalculator(10)
47 +calculator.accumulate(p1, a1)
48 +calculator.accumulate(p2, a2)
49 +ap3 = calculator.peek_ap_at_n()
50 +```
51 +"""
52 +
53 +import heapq
54 +import random
55 +import numbers
56 +
57 +import numpy
58 +
59 +
60 +class AveragePrecisionCalculator(object):
61 + """Calculate the average precision and average precision at n."""
62 +
63 + def __init__(self, top_n=None):
64 + """Construct an AveragePrecisionCalculator to calculate average precision.
65 +
66 + This class is used to calculate the average precision for a single label.
67 +
68 + Args:
69 + top_n: A positive Integer specifying the average precision at n, or None
70 + to use all provided data points.
71 +
72 + Raises:
73 + ValueError: An error occurred when the top_n is not a positive integer.
74 + """
75 + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
76 + raise ValueError("top_n must be a positive integer or None.")
77 +
78 + self._top_n = top_n # average precision at n
79 + self._total_positives = 0 # total number of positives have seen
80 + self._heap = [] # max heap of (prediction, actual)
81 +
82 + @property
83 + def heap_size(self):
84 + """Gets the heap size maintained in the class."""
85 + return len(self._heap)
86 +
87 + @property
88 + def num_accumulated_positives(self):
89 + """Gets the number of positive samples that have been accumulated."""
90 + return self._total_positives
91 +
92 + def accumulate(self, predictions, actuals, num_positives=None):
93 + """Accumulate the predictions and their ground truth labels.
94 +
95 + After the function call, we may call peek_ap_at_n to actually calculate
96 + the average precision.
97 + Note predictions and actuals must have the same shape.
98 +
99 + Args:
100 + predictions: a list storing the prediction scores.
101 + actuals: a list storing the ground truth labels. Any value larger than 0
102 + will be treated as positives, otherwise as negatives. num_positives = If
103 + the 'predictions' and 'actuals' inputs aren't complete, then it's
104 + possible some true positives were missed in them. In that case, you can
105 + provide 'num_positives' in order to accurately track recall.
106 +
107 + Raises:
108 + ValueError: An error occurred when the format of the input is not the
109 + numpy 1-D array or the shape of predictions and actuals does not match.
110 + """
111 + if len(predictions) != len(actuals):
112 + raise ValueError("the shape of predictions and actuals does not match.")
113 +
114 + if num_positives is not None:
115 + if not isinstance(num_positives, numbers.Number) or num_positives < 0:
116 + raise ValueError(
117 + "'num_positives' was provided but it was a negative number.")
118 +
119 + if num_positives is not None:
120 + self._total_positives += num_positives
121 + else:
122 + self._total_positives += numpy.size(
123 + numpy.where(numpy.array(actuals) > 1e-5))
124 + topk = self._top_n
125 + heap = self._heap
126 +
127 + for i in range(numpy.size(predictions)):
128 + if topk is None or len(heap) < topk:
129 + heapq.heappush(heap, (predictions[i], actuals[i]))
130 + else:
131 + if predictions[i] > heap[0][0]: # heap[0] is the smallest
132 + heapq.heappop(heap)
133 + heapq.heappush(heap, (predictions[i], actuals[i]))
134 +
135 + def clear(self):
136 + """Clear the accumulated predictions."""
137 + self._heap = []
138 + self._total_positives = 0
139 +
140 + def peek_ap_at_n(self):
141 + """Peek the non-interpolated average precision at n.
142 +
143 + Returns:
144 + The non-interpolated average precision at n (default 0).
145 + If n is larger than the length of the ranked list,
146 + the average precision will be returned.
147 + """
148 + if self.heap_size <= 0:
149 + return 0
150 + predlists = numpy.array(list(zip(*self._heap)))
151 +
152 + ap = self.ap_at_n(predlists[0],
153 + predlists[1],
154 + n=self._top_n,
155 + total_num_positives=self._total_positives)
156 + return ap
157 +
158 + @staticmethod
159 + def ap(predictions, actuals):
160 + """Calculate the non-interpolated average precision.
161 +
162 + Args:
163 + predictions: a numpy 1-D array storing the sparse prediction scores.
164 + actuals: a numpy 1-D array storing the ground truth labels. Any value
165 + larger than 0 will be treated as positives, otherwise as negatives.
166 +
167 + Returns:
168 + The non-interpolated average precision at n.
169 + If n is larger than the length of the ranked list,
170 + the average precision will be returned.
171 +
172 + Raises:
173 + ValueError: An error occurred when the format of the input is not the
174 + numpy 1-D array or the shape of predictions and actuals does not match.
175 + """
176 + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
177 +
178 + @staticmethod
179 + def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
180 + """Calculate the non-interpolated average precision.
181 +
182 + Args:
183 + predictions: a numpy 1-D array storing the sparse prediction scores.
184 + actuals: a numpy 1-D array storing the ground truth labels. Any value
185 + larger than 0 will be treated as positives, otherwise as negatives.
186 + n: the top n items to be considered in ap@n.
187 + total_num_positives : (optionally) you can specify the number of total
188 + positive in the list. If specified, it will be used in calculation.
189 +
190 + Returns:
191 + The non-interpolated average precision at n.
192 + If n is larger than the length of the ranked list,
193 + the average precision will be returned.
194 +
195 + Raises:
196 + ValueError: An error occurred when
197 + 1) the format of the input is not the numpy 1-D array;
198 + 2) the shape of predictions and actuals does not match;
199 + 3) the input n is not a positive integer.
200 + """
201 + if len(predictions) != len(actuals):
202 + raise ValueError("the shape of predictions and actuals does not match.")
203 +
204 + if n is not None:
205 + if not isinstance(n, int) or n <= 0:
206 + raise ValueError("n must be 'None' or a positive integer."
207 + " It was '%s'." % n)
208 +
209 + ap = 0.0
210 +
211 + predictions = numpy.array(predictions)
212 + actuals = numpy.array(actuals)
213 +
214 + # add a shuffler to avoid overestimating the ap
215 + predictions, actuals = AveragePrecisionCalculator._shuffle(
216 + predictions, actuals)
217 + sortidx = sorted(range(len(predictions)),
218 + key=lambda k: predictions[k],
219 + reverse=True)
220 +
221 + if total_num_positives is None:
222 + numpos = numpy.size(numpy.where(actuals > 0))
223 + else:
224 + numpos = total_num_positives
225 +
226 + if numpos == 0:
227 + return 0
228 +
229 + if n is not None:
230 + numpos = min(numpos, n)
231 + delta_recall = 1.0 / numpos
232 + poscount = 0.0
233 +
234 + # calculate the ap
235 + r = len(sortidx)
236 + if n is not None:
237 + r = min(r, n)
238 + for i in range(r):
239 + if actuals[sortidx[i]] > 0:
240 + poscount += 1
241 + ap += poscount / (i + 1) * delta_recall
242 + return ap
243 +
244 + @staticmethod
245 + def _shuffle(predictions, actuals):
246 + random.seed(0)
247 + suffidx = random.sample(range(len(predictions)), len(predictions))
248 + predictions = predictions[suffidx]
249 + actuals = actuals[suffidx]
250 + return predictions, actuals
251 +
252 + @staticmethod
253 + def _zero_one_normalize(predictions, epsilon=1e-7):
254 + """Normalize the predictions to the range between 0.0 and 1.0.
255 +
256 + For some predictions like SVM predictions, we need to normalize them before
257 + calculate the interpolated average precision. The normalization will not
258 + change the rank in the original list and thus won't change the average
259 + precision.
260 +
261 + Args:
262 + predictions: a numpy 1-D array storing the sparse prediction scores.
263 + epsilon: a small constant to avoid denominator being zero.
264 +
265 + Returns:
266 + The normalized prediction.
267 + """
268 + denominator = numpy.max(predictions) - numpy.min(predictions)
269 + ret = (predictions - numpy.min(predictions)) / numpy.max(
270 + denominator, epsilon)
271 + return ret
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utility to convert the output of batch prediction into a CSV submission.
15 +
16 +It converts the JSON files created by the command
17 +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission.
18 +"""
19 +
20 +import json
21 +import tensorflow as tf
22 +
23 +from builtins import range
24 +from tensorflow import app
25 +from tensorflow import flags
26 +from tensorflow import gfile
27 +from tensorflow import logging
28 +
29 +FLAGS = flags.FLAGS
30 +
31 +if __name__ == "__main__":
32 +
33 + flags.DEFINE_string(
34 + "json_prediction_files_pattern", None,
35 + "Pattern specifying the list of JSON files that the command "
36 + "'gcloud beta ml jobs submit prediction' outputs. These files are "
37 + "located in the output path of the prediction command and are prefixed "
38 + "with 'prediction.results'.")
39 + flags.DEFINE_string(
40 + "csv_output_file", None,
41 + "The file to save the predictions converted to the CSV format.")
42 +
43 +
44 +def get_csv_header():
45 + return "VideoId,LabelConfidencePairs\n"
46 +
47 +
48 +def to_csv_row(json_data):
49 +
50 + video_id = json_data["video_id"]
51 +
52 + class_indexes = json_data["class_indexes"]
53 + predictions = json_data["predictions"]
54 +
55 + if isinstance(video_id, list):
56 + video_id = video_id[0]
57 + class_indexes = class_indexes[0]
58 + predictions = predictions[0]
59 +
60 + if len(class_indexes) != len(predictions):
61 + raise ValueError(
62 + "The number of indexes (%s) and predictions (%s) must be equal." %
63 + (len(class_indexes), len(predictions)))
64 +
65 + return (video_id.decode("utf-8") + "," +
66 + " ".join("%i %f" % (class_indexes[i], predictions[i])
67 + for i in range(len(class_indexes))) + "\n")
68 +
69 +
70 +def main(unused_argv):
71 + logging.set_verbosity(tf.logging.INFO)
72 +
73 + if not FLAGS.json_prediction_files_pattern:
74 + raise ValueError(
75 + "The flag --json_prediction_files_pattern must be specified.")
76 +
77 + if not FLAGS.csv_output_file:
78 + raise ValueError("The flag --csv_output_file must be specified.")
79 +
80 + logging.info("Looking for prediction files with pattern: %s",
81 + FLAGS.json_prediction_files_pattern)
82 +
83 + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
84 + logging.info("Found files: %s", file_paths)
85 +
86 + logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
87 + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
88 + output_file.write(get_csv_header())
89 +
90 + for file_path in file_paths:
91 + logging.info("processing file: %s", file_path)
92 +
93 + with gfile.Open(file_path) as input_file:
94 +
95 + for line in input_file:
96 + json_data = json.loads(line)
97 + output_file.write(to_csv_row(json_data))
98 +
99 + output_file.flush()
100 + logging.info("done")
101 +
102 +
103 +if __name__ == "__main__":
104 + app.run()
1 +import numpy as np
2 +import tensorflow as tf
3 +from tensorflow import logging
4 +from tensorflow import gfile
5 +import operator
6 +import pb_util as pbutil
7 +import video_recommender as recommender
8 +import video_util as videoutil
9 +
10 +# Define file paths.
11 +MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model"
12 +VOCAB_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/vocabulary.csv"
13 +VIDEO_TAGS_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/kaggle_solution_40k.csv"
14 +TAG_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/tag_vectors.model"
15 +VIDEO_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/video_vectors.model"
16 +SEGMENT_LABEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/segment_label_ids.csv"
17 +
18 +# Define parameters.
19 +TAG_TOP_K = 5
20 +VIDEO_TOP_K = 10
21 +
22 +
23 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
24 + """Get segment-level inputs from frame-level features."""
25 + video_batch_size = batch_video_mtx.shape[0]
26 + max_frame = batch_video_mtx.shape[1]
27 + feature_dim = batch_video_mtx.shape[-1]
28 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
29 + padded_segment_sizes *= segment_size
30 + segment_mask = (
31 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
32 +
33 + # Segment bags.
34 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
35 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
36 + (-1, segment_size, feature_dim))
37 +
38 + # Segment num frames.
39 + segment_start_times = np.arange(0, max_frame, segment_size)
40 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
41 + num_segment_bags = num_segments.reshape((-1))
42 + valid_segment_mask = num_segment_bags > 0
43 + segment_num_frames = num_segment_bags[valid_segment_mask]
44 + segment_num_frames[segment_num_frames > segment_size] = segment_size
45 +
46 + max_segment_num = (max_frame + segment_size - 1) // segment_size
47 + video_idxs = np.tile(
48 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
49 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
50 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
51 + video_segment_ids = idx_bags[valid_segment_mask]
52 +
53 + return {
54 + "video_batch": segment_frames,
55 + "num_frames_batch": segment_num_frames,
56 + "video_segment_ids": video_segment_ids
57 + }
58 +
59 +
60 +def format_predictions(video_ids, predictions, top_k, whitelisted_cls_mask=None):
61 + batch_size = len(video_ids)
62 + for video_index in range(batch_size):
63 + video_prediction = predictions[video_index]
64 + if whitelisted_cls_mask is not None:
65 + # Whitelist classes.
66 + video_prediction *= whitelisted_cls_mask
67 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
68 + line = [(class_index, predictions[video_index][class_index])
69 + for class_index in top_indices]
70 + line = sorted(line, key=lambda p: -p[1])
71 + yield (video_ids[video_index] + "," +
72 + " ".join("%i %g" % (label, score) for (label, score) in line) +
73 + "\n").encode("utf8")
74 +
75 +
76 +def normalize_tag(tag):
77 + if isinstance(tag, str):
78 + new_tag = tag.lower().replace('[^a-zA-Z]', ' ')
79 + if new_tag.find(" (") != -1:
80 + new_tag = new_tag[:new_tag.find(" (")]
81 + new_tag = new_tag.replace(" ", "-")
82 + return new_tag
83 + else:
84 + return tag
85 +
86 +
87 +def inference_pb(file_path, threshold):
88 + VIDEO_TOP_K = int(threshold)
89 + inference_result = {}
90 + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
91 +
92 + # 0. Import SequenceExample type target from pb.
93 + target_video = pbutil.convert_pb(file_path)
94 +
95 + # 1. Load video features from pb.
96 + video_id_batch_val = np.array([b'video'])
97 + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature)
98 + # Restrict frame size to 300
99 + if n_frames > 300:
100 + n_frames = 300
101 + video_batch_val = np.zeros((300, 1152))
102 + for i in range(n_frames):
103 + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0]
104 + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval())
105 + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0]
106 + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval())
107 + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0)
108 + video_batch_val = np.array([video_batch_val])
109 + num_frames_batch_val = np.array([n_frames])
110 +
111 + # Restore checkpoint and meta-graph file.
112 + if not gfile.Exists(MODEL_PATH + ".meta"):
113 + raise IOError("Cannot find %s. Did you run eval.py?" % MODEL_PATH)
114 + meta_graph_location = MODEL_PATH + ".meta"
115 + logging.info("loading meta-graph: " + meta_graph_location)
116 +
117 + with tf.device("/cpu:0"):
118 + saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
119 + logging.info("restoring variables from " + MODEL_PATH)
120 + saver.restore(sess, MODEL_PATH)
121 + input_tensor = tf.get_collection("input_batch_raw")[0]
122 + num_frames_tensor = tf.get_collection("num_frames")[0]
123 + predictions_tensor = tf.get_collection("predictions")[0]
124 +
125 + # Workaround for num_epochs issue.
126 + def set_up_init_ops(variables):
127 + init_op_list = []
128 + for variable in list(variables):
129 + if "train_input" in variable.name:
130 + init_op_list.append(tf.assign(variable, 1))
131 + variables.remove(variable)
132 + init_op_list.append(tf.variables_initializer(variables))
133 + return init_op_list
134 +
135 + sess.run(
136 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
137 +
138 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
139 + dtype=np.float32)
140 + with tf.io.gfile.GFile(SEGMENT_LABEL_PATH) as fobj:
141 + for line in fobj:
142 + try:
143 + cls_id = int(line)
144 + whitelisted_cls_mask[cls_id] = 1.
145 + except ValueError:
146 + # Simply skip the non-integer line.
147 + continue
148 +
149 + # 2. Make segment features.
150 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
151 + video_segment_ids = results["video_segment_ids"]
152 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
153 + video_id_batch_val = np.array([
154 + "%s:%d" % (x.decode("utf8"), y)
155 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
156 + ])
157 + video_batch_val = results["video_batch"]
158 + num_frames_batch_val = results["num_frames_batch"]
159 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
160 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
161 + "with correct segment_labels settings.")
162 +
163 + predictions_val, = sess.run([predictions_tensor],
164 + feed_dict={
165 + input_tensor: video_batch_val,
166 + num_frames_tensor: num_frames_batch_val
167 + })
168 +
169 + # 3. Make vocabularies.
170 + voca_dict = {}
171 + vocabs = open(VOCAB_PATH, 'r')
172 + while True:
173 + line = vocabs.readline()
174 + if not line: break
175 + vocab_dict_item = line.split(",")
176 + if vocab_dict_item[0] != "Index":
177 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
178 + vocabs.close()
179 +
180 + # 4. Make combined scores.
181 + combined_scores = {}
182 + for line in format_predictions(video_id_batch_val, predictions_val, TAG_TOP_K, whitelisted_cls_mask):
183 + segment_id, preds = line.decode("utf8").split(",")
184 + preds = preds.split(" ")
185 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
186 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
187 + for i in range(len(pred_cls_ids)):
188 + if pred_cls_ids[i] in combined_scores:
189 + combined_scores[pred_cls_ids[i]] += pred_cls_scores[i]
190 + else:
191 + combined_scores[pred_cls_ids[i]] = pred_cls_scores[i]
192 +
193 + combined_scores = sorted(combined_scores.items(), key=operator.itemgetter(1), reverse=True)
194 + demoninator = float(combined_scores[0][1] + combined_scores[1][1]
195 + + combined_scores[2][1] + combined_scores[3][1] + combined_scores[4][1])
196 +
197 + tag_result = []
198 + for itemIndex in range(TAG_TOP_K):
199 + segment_tag = str(voca_dict[str(combined_scores[itemIndex][0])])
200 + normalized_tag = normalize_tag(segment_tag)
201 + tag_percentage = format(combined_scores[itemIndex][1] / demoninator, ".3f")
202 + tag_result.append((normalized_tag, tag_percentage))
203 +
204 + # 5. Create recommend videos info, Combine results.
205 + recommend_video_ids = recommender.recommend_videos(tag_result, TAG_VECTOR_MODEL_PATH,
206 + VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K)
207 + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids]
208 +
209 + inference_result = {
210 + "tag_result": tag_result,
211 + "video_result": video_result
212 + }
213 +
214 + # 6. Dispose instances.
215 + sess.close()
216 +
217 + return inference_result
218 +
219 +
220 +if __name__ == '__main__':
221 + filepath = "/tmp/mediapipe/features.pb"
222 + result = inference_pb(filepath)
223 + print(result)
This diff could not be displayed because it is too large.
1 +model_checkpoint_path: "/root/volume/youtube-8m/saved_model/inference_model/segment_inference_model"
2 +all_model_checkpoint_paths: "/root/volume/youtube-8m/saved_model/inference_model/segment_inference_model"
1 +{"model": "FrameLevelLogisticModel", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"}
...\ No newline at end of file ...\ No newline at end of file
1 +import tensorflow as tf
2 +import numpy
3 +
4 +
5 +def _make_bytes(int_array):
6 + if bytes == str: # Python2
7 + return ''.join(map(chr, int_array))
8 + else:
9 + return bytes(int_array)
10 +
11 +
12 +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0):
13 + """Quantizes float32 `features` into string."""
14 + assert features.dtype == 'float32'
15 + assert len(features.shape) == 1 # 1-D array
16 + features = numpy.clip(features, min_quantized_value, max_quantized_value)
17 + quantize_range = max_quantized_value - min_quantized_value
18 + features = (features - min_quantized_value) * (255.0 / quantize_range)
19 + features = [int(round(f)) for f in features]
20 +
21 + return _make_bytes(features)
22 +
23 +
24 +# for parse feature.pb
25 +
26 +contexts = {
27 + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
28 + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32),
29 + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
30 + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32),
31 + 'clip/data_path': tf.io.FixedLenFeature([], tf.string),
32 + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64),
33 + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64)
34 +}
35 +
36 +features = {
37 + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
38 + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64),
39 + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
40 + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64)
41 +
42 +}
43 +
44 +
45 +def parse_exmp(serial_exmp):
46 + _, sequence_parsed = tf.io.parse_single_sequence_example(
47 + serialized=serial_exmp,
48 + context_features=contexts,
49 + sequence_features=features)
50 +
51 + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0]
52 +
53 + audio = sequence_parsed['AUDIO/feature/floats'].values
54 + rgb = sequence_parsed['RGB/feature/floats'].values
55 +
56 + # print(audio.values)
57 + # print(type(audio.values))
58 +
59 + # audio is 128 8bit, rgb is 1024 8bit for every second
60 + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)]
61 + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)]
62 +
63 + byte_audio = []
64 + byte_rgb = []
65 +
66 + for seg in audio_slices:
67 + # audio_seg = quantize(seg)
68 + audio_seg = _make_bytes(seg)
69 + byte_audio.append(audio_seg)
70 +
71 + for seg in rgb_slices:
72 + # rgb_seg = quantize(seg)
73 + rgb_seg = _make_bytes(seg)
74 + byte_rgb.append(rgb_seg)
75 +
76 + return byte_audio, byte_rgb
77 +
78 +
79 +def make_exmp(id, audio, rgb):
80 + audio_features = []
81 + rgb_features = []
82 +
83 + for embedding in audio:
84 + embedding_feature = tf.train.Feature(
85 + bytes_list=tf.train.BytesList(value=[embedding]))
86 + audio_features.append(embedding_feature)
87 +
88 + for embedding in rgb:
89 + embedding_feature = tf.train.Feature(
90 + bytes_list=tf.train.BytesList(value=[embedding]))
91 + rgb_features.append(embedding_feature)
92 +
93 + # for construct yt8m data
94 + seq_exmp = tf.train.SequenceExample(
95 + context=tf.train.Features(
96 + feature={
97 + 'id': tf.train.Feature(bytes_list=tf.train.BytesList(
98 + value=[id.encode('utf-8')]))
99 + }),
100 + feature_lists=tf.train.FeatureLists(
101 + feature_list={
102 + 'audio': tf.train.FeatureList(
103 + feature=audio_features
104 + ),
105 + 'rgb': tf.train.FeatureList(
106 + feature=rgb_features
107 + )
108 + })
109 + )
110 + serialized = seq_exmp.SerializeToString()
111 + return serialized
112 +
113 +
114 +def convert_pb(filename):
115 + sequence_example = open(filename, 'rb').read()
116 +
117 + audio, rgb = parse_exmp(sequence_example)
118 + tmp_example = make_exmp('video', audio, rgb)
119 +
120 + decoded = tf.train.SequenceExample.FromString(tmp_example)
121 + return decoded
1 +import tensorflow as tf
2 +import numpy as np
3 +
4 +frame_lvl_record = "test0000.tfrecord"
5 +
6 +feat_rgb = []
7 +feat_audio = []
8 +
9 +for example in tf.python_io.tf_record_iterator(frame_lvl_record):
10 + tf_seq_example = tf.train.SequenceExample.FromString(example)
11 + test = tf_seq_example.SerializeToString()
12 + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature)
13 + sess = tf.InteractiveSession()
14 + rgb_frame = []
15 + audio_frame = []
16 + # iterate through frames
17 + for i in range(n_frames):
18 + rgb_frame.append(tf.cast(tf.decode_raw(
19 + tf_seq_example.feature_lists.feature_list['rgb']
20 + .feature[i].bytes_list.value[0], tf.uint8)
21 + , tf.float32).eval())
22 + audio_frame.append(tf.cast(tf.decode_raw(
23 + tf_seq_example.feature_lists.feature_list['audio']
24 + .feature[i].bytes_list.value[0], tf.uint8)
25 + , tf.float32).eval())
26 +
27 + sess.close()
28 +
29 + feat_audio.append(audio_frame)
30 + feat_rgb.append(rgb_frame)
31 + break
32 +
33 +print('The first video has %d frames' %len(feat_rgb[0]))
...\ No newline at end of file ...\ No newline at end of file
1 +import nltk
2 +import gensim
3 +import pandas as pd
4 +
5 +# Load files.
6 +nltk.download('stopwords')
7 +vocab = pd.read_csv('../vocabulary.csv')
8 +
9 +# Lower corpus and Remove () from name.
10 +vocab['WikiDescription'] = vocab['WikiDescription'].str.lower().str.replace('[^a-zA-Z0-9]', ' ')
11 +for i in range(vocab['Name'].__len__()):
12 + name = vocab['Name'][i]
13 + if isinstance(name, str) and name.find(" (") != -1:
14 + vocab['Name'][i] = name[:name.find(" (")]
15 +vocab['Name'] = vocab['Name'].str.lower()
16 +
17 +# Combine separated names.(mobile phone -> mobile-phone)
18 +for name in vocab['Name']:
19 + if isinstance(name, str) and name.find(" ") != -1:
20 + combined_name = name.replace(" ", "-")
21 + for i in range(vocab['WikiDescription'].__len__()):
22 + if isinstance(vocab['WikiDescription'][i], str):
23 + vocab['WikiDescription'][i] = vocab['WikiDescription'][i].replace(name, combined_name)
24 +
25 +
26 +# Remove stopwords from corpus.
27 +stop_re = '\\b'+'\\b|\\b'.join(nltk.corpus.stopwords.words('english'))+'\\b'
28 +vocab['WikiDescription'] = vocab['WikiDescription'].str.replace(stop_re, '')
29 +vocab['WikiDescription'] = vocab['WikiDescription'].str.split()
30 +
31 +# Tokenize corpus.
32 +tokenlist = [x for x in vocab['WikiDescription'] if str(x) != 'nan']
33 +phrases = gensim.models.phrases.Phrases(tokenlist)
34 +phraser = gensim.models.phrases.Phraser(phrases)
35 +vocab_phrased = phraser[tokenlist]
36 +
37 +# Vectorize tags.
38 +w2v = gensim.models.word2vec.Word2Vec(sentences=tokenlist, min_count=1)
39 +w2v.save('tag_vectors.model')
40 +
41 +# word_vectors = w2v.wv
42 +# vocabs = word_vectors.vocab.keys()
43 +# word_vectors_list = [word_vectors[v] for v in vocabs]
This file is too large to display.
1 +from gensim.models import Word2Vec
2 +import numpy as np
3 +
4 +def recommend_videos(tags, tag_model_path, video_model_path, top_k):
5 + tag_vectors = Word2Vec.load(tag_model_path).wv
6 + video_vectors = Word2Vec().wv.load(video_model_path)
7 + error_tags = []
8 +
9 + video_vector = np.zeros(100)
10 + for (tag, weight) in tags:
11 + if tag in tag_vectors.vocab:
12 + video_vector = video_vector + (tag_vectors[tag] * float(weight))
13 + else:
14 + # Pass if tag is unknown
15 + if tag not in error_tags:
16 + error_tags.append(tag)
17 +
18 + similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)]
19 + return similar_ids
1 +import requests
2 +import pandas as pd
3 +
4 +base_URL = 'https://data.yt8m.org/2/j/i/'
5 +youtube_url = 'https://www.youtube.com/watch?v='
6 +
7 +
8 +def getURL(vid_id):
9 + URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js'
10 + response = requests.get(URL, verify = False)
11 + if response.status_code == 200:
12 + return youtube_url + response.text[10:-3]
13 +
14 +
15 +def getVideoInfo(vid_id, video_tags_path, top_k):
16 + video_url = getURL(vid_id)
17 +
18 + entire_video_tags = pd.read_csv(video_tags_path)
19 + video_tags_info = entire_video_tags.loc[entire_video_tags["vid_id"] == vid_id]
20 + video_tags = []
21 + for i in range(1, top_k + 1):
22 + video_tag_tuple = video_tags_info["segment" + str(i)].values[0] # ex: "mobile-phone:0.361"
23 + video_tags.append(video_tag_tuple.split(":")[0])
24 +
25 + return {
26 + "video_url": video_url,
27 + "video_tags": video_tags
28 + }
1 +import pandas as pd
2 +import numpy as np
3 +from gensim.models import Word2Vec
4 +
5 +BATCH_SIZE = 1000
6 +
7 +
8 +def vectorization_video():
9 + print('[0.1 0.2]')
10 +
11 +
12 +if __name__ == '__main__':
13 + tag_vectors = Word2Vec.load("tag_vectors.model").wv
14 + video_vectors = Word2Vec().wv # Empty model
15 +
16 + # Load video recommendation tags.
17 + video_tags = pd.read_csv('kaggle_solution_40k.csv')
18 +
19 + # Define batch variables.
20 + batch_video_ids = []
21 + batch_video_vectors = []
22 + error_tags = []
23 +
24 + for i, row in video_tags.iterrows():
25 + video_id = row[0]
26 + video_vector = np.zeros(100)
27 + for segment_index in range(1, 6):
28 + tag, weight = row[segment_index].split(":")
29 + if tag in tag_vectors.vocab:
30 + video_vector = video_vector + (tag_vectors[tag] * float(weight))
31 + else:
32 + # Pass if tag is unknown
33 + if tag not in error_tags:
34 + error_tags.append(tag)
35 +
36 + batch_video_ids.append(video_id)
37 + batch_video_vectors.append(video_vector)
38 + # Add video vectors.
39 + if (i+1) % BATCH_SIZE == 0:
40 + video_vectors.add(batch_video_ids, batch_video_vectors)
41 + batch_video_ids = []
42 + batch_video_vectors = []
43 + print("Video vectors created: ", i+1)
44 +
45 + # Add rest of video vectors.
46 + video_vectors.add(batch_video_ids, batch_video_vectors)
47 + print("error tags: ")
48 + print(error_tags)
49 +
50 + video_vectors.save("video_vectors.model")
51 +
52 + # Usage
53 + # video_vectors = Word2Vec().wv.load("video_vectors.model")
54 + # video_vectors.most_similar("XwFj", topn=5)
This file is too large to display.
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides functions to help with evaluating models."""
15 +import average_precision_calculator as ap_calculator
16 +import mean_average_precision_calculator as map_calculator
17 +import numpy
18 +from tensorflow.python.platform import gfile
19 +
20 +
21 +def flatten(l):
22 + """Merges a list of lists into a single list. """
23 + return [item for sublist in l for item in sublist]
24 +
25 +
26 +def calculate_hit_at_one(predictions, actuals):
27 + """Performs a local (numpy) calculation of the hit at one.
28 +
29 + Args:
30 + predictions: Matrix containing the outputs of the model. Dimensions are
31 + 'batch' x 'num_classes'.
32 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
33 + 'num_classes'.
34 +
35 + Returns:
36 + float: The average hit at one across the entire batch.
37 + """
38 + top_prediction = numpy.argmax(predictions, 1)
39 + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
40 + return numpy.average(hits)
41 +
42 +
43 +def calculate_precision_at_equal_recall_rate(predictions, actuals):
44 + """Performs a local (numpy) calculation of the PERR.
45 +
46 + Args:
47 + predictions: Matrix containing the outputs of the model. Dimensions are
48 + 'batch' x 'num_classes'.
49 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
50 + 'num_classes'.
51 +
52 + Returns:
53 + float: The average precision at equal recall rate across the entire batch.
54 + """
55 + aggregated_precision = 0.0
56 + num_videos = actuals.shape[0]
57 + for row in numpy.arange(num_videos):
58 + num_labels = int(numpy.sum(actuals[row]))
59 + top_indices = numpy.argpartition(predictions[row],
60 + -num_labels)[-num_labels:]
61 + item_precision = 0.0
62 + for label_index in top_indices:
63 + if predictions[row][label_index] > 0:
64 + item_precision += actuals[row][label_index]
65 + item_precision /= top_indices.size
66 + aggregated_precision += item_precision
67 + aggregated_precision /= num_videos
68 + return aggregated_precision
69 +
70 +
71 +def calculate_gap(predictions, actuals, top_k=20):
72 + """Performs a local (numpy) calculation of the global average precision.
73 +
74 + Only the top_k predictions are taken for each of the videos.
75 +
76 + Args:
77 + predictions: Matrix containing the outputs of the model. Dimensions are
78 + 'batch' x 'num_classes'.
79 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
80 + 'num_classes'.
81 + top_k: How many predictions to use per video.
82 +
83 + Returns:
84 + float: The global average precision.
85 + """
86 + gap_calculator = ap_calculator.AveragePrecisionCalculator()
87 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
88 + predictions, actuals, top_k)
89 + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels),
90 + sum(num_positives))
91 + return gap_calculator.peek_ap_at_n()
92 +
93 +
94 +def top_k_by_class(predictions, labels, k=20):
95 + """Extracts the top k predictions for each video, sorted by class.
96 +
97 + Args:
98 + predictions: A numpy matrix containing the outputs of the model. Dimensions
99 + are 'batch' x 'num_classes'.
100 + k: the top k non-zero entries to preserve in each prediction.
101 +
102 + Returns:
103 + A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
104 + are lists of lists of floats. 'true_positives' is a list of scalars. The
105 + length of the lists are equal to the number of classes. The entries in the
106 + predictions variable are probability predictions, and
107 + the corresponding entries in the labels variable are the ground truth for
108 + those predictions. The entries in 'true_positives' are the number of true
109 + positives for each class in the ground truth.
110 +
111 + Raises:
112 + ValueError: An error occurred when the k is not a positive integer.
113 + """
114 + if k <= 0:
115 + raise ValueError("k must be a positive integer.")
116 + k = min(k, predictions.shape[1])
117 + num_classes = predictions.shape[1]
118 + prediction_triplets = []
119 + for video_index in range(predictions.shape[0]):
120 + prediction_triplets.extend(
121 + top_k_triplets(predictions[video_index], labels[video_index], k))
122 + out_predictions = [[] for _ in range(num_classes)]
123 + out_labels = [[] for _ in range(num_classes)]
124 + for triplet in prediction_triplets:
125 + out_predictions[triplet[0]].append(triplet[1])
126 + out_labels[triplet[0]].append(triplet[2])
127 + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
128 +
129 + return out_predictions, out_labels, out_true_positives
130 +
131 +
132 +def top_k_triplets(predictions, labels, k=20):
133 + """Get the top_k for a 1-d numpy array.
134 +
135 + Returns a sparse list of tuples in
136 + (prediction, class) format
137 + """
138 + m = len(predictions)
139 + k = min(k, m)
140 + indices = numpy.argpartition(predictions, -k)[-k:]
141 + return [(index, predictions[index], labels[index]) for index in indices]
142 +
143 +
144 +class EvaluationMetrics(object):
145 + """A class to store the evaluation metrics."""
146 +
147 + def __init__(self, num_class, top_k, top_n):
148 + """Construct an EvaluationMetrics object to store the evaluation metrics.
149 +
150 + Args:
151 + num_class: A positive integer specifying the number of classes.
152 + top_k: A positive integer specifying how many predictions are considered
153 + per video.
154 + top_n: A positive Integer specifying the average precision at n, or None
155 + to use all provided data points.
156 +
157 + Raises:
158 + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
159 + not be constructed.
160 + """
161 + self.sum_hit_at_one = 0.0
162 + self.sum_perr = 0.0
163 + self.sum_loss = 0.0
164 + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
165 + num_class, top_n=top_n)
166 + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
167 + self.top_k = top_k
168 + self.num_examples = 0
169 +
170 + def accumulate(self, predictions, labels, loss):
171 + """Accumulate the metrics calculated locally for this mini-batch.
172 +
173 + Args:
174 + predictions: A numpy matrix containing the outputs of the model.
175 + Dimensions are 'batch' x 'num_classes'.
176 + labels: A numpy matrix containing the ground truth labels. Dimensions are
177 + 'batch' x 'num_classes'.
178 + loss: A numpy array containing the loss for each sample.
179 +
180 + Returns:
181 + dictionary: A dictionary storing the metrics for the mini-batch.
182 +
183 + Raises:
184 + ValueError: An error occurred when the shape of predictions and actuals
185 + does not match.
186 + """
187 + batch_size = labels.shape[0]
188 + mean_hit_at_one = calculate_hit_at_one(predictions, labels)
189 + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels)
190 + mean_loss = numpy.mean(loss)
191 +
192 + # Take the top 20 predictions.
193 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
194 + predictions, labels, self.top_k)
195 + self.map_calculator.accumulate(sparse_predictions, sparse_labels,
196 + num_positives)
197 + self.global_ap_calculator.accumulate(flatten(sparse_predictions),
198 + flatten(sparse_labels),
199 + sum(num_positives))
200 +
201 + self.num_examples += batch_size
202 + self.sum_hit_at_one += mean_hit_at_one * batch_size
203 + self.sum_perr += mean_perr * batch_size
204 + self.sum_loss += mean_loss * batch_size
205 +
206 + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss}
207 +
208 + def get(self):
209 + """Calculate the evaluation metrics for the whole epoch.
210 +
211 + Raises:
212 + ValueError: If no examples were accumulated.
213 +
214 + Returns:
215 + dictionary: a dictionary storing the evaluation metrics for the epoch. The
216 + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
217 + aps (default nan).
218 + """
219 + if self.num_examples <= 0:
220 + raise ValueError("total_sample must be positive.")
221 + avg_hit_at_one = self.sum_hit_at_one / self.num_examples
222 + avg_perr = self.sum_perr / self.num_examples
223 + avg_loss = self.sum_loss / self.num_examples
224 +
225 + aps = self.map_calculator.peek_map_at_n()
226 + gap = self.global_ap_calculator.peek_ap_at_n()
227 +
228 + epoch_info_dict = {
229 + "avg_hit_at_one": avg_hit_at_one,
230 + "avg_perr": avg_perr,
231 + "avg_loss": avg_loss,
232 + "aps": aps,
233 + "gap": gap
234 + }
235 + return epoch_info_dict
236 +
237 + def clear(self):
238 + """Clear the evaluation metrics and reset the EvaluationMetrics object."""
239 + self.sum_hit_at_one = 0.0
240 + self.sum_perr = 0.0
241 + self.sum_loss = 0.0
242 + self.map_calculator.clear()
243 + self.global_ap_calculator.clear()
244 + self.num_examples = 0
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utilities to export a model for batch prediction."""
15 +
16 +import tensorflow as tf
17 +import tensorflow.contrib.slim as slim
18 +
19 +from tensorflow.python.saved_model import builder as saved_model_builder
20 +from tensorflow.python.saved_model import signature_constants
21 +from tensorflow.python.saved_model import signature_def_utils
22 +from tensorflow.python.saved_model import tag_constants
23 +from tensorflow.python.saved_model import utils as saved_model_utils
24 +
25 +_TOP_PREDICTIONS_IN_OUTPUT = 20
26 +
27 +
28 +class ModelExporter(object):
29 +
30 + def __init__(self, frame_features, model, reader):
31 + self.frame_features = frame_features
32 + self.model = model
33 + self.reader = reader
34 +
35 + with tf.Graph().as_default() as graph:
36 + self.inputs, self.outputs = self.build_inputs_and_outputs()
37 + self.graph = graph
38 + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
39 +
40 + def export_model(self, model_dir, global_step_val, last_checkpoint):
41 + """Exports the model so that it can used for batch predictions."""
42 +
43 + with self.graph.as_default():
44 + with tf.Session() as session:
45 + session.run(tf.global_variables_initializer())
46 + self.saver.restore(session, last_checkpoint)
47 +
48 + signature = signature_def_utils.build_signature_def(
49 + inputs=self.inputs,
50 + outputs=self.outputs,
51 + method_name=signature_constants.PREDICT_METHOD_NAME)
52 +
53 + signature_map = {
54 + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
55 + }
56 +
57 + model_builder = saved_model_builder.SavedModelBuilder(model_dir)
58 + model_builder.add_meta_graph_and_variables(
59 + session,
60 + tags=[tag_constants.SERVING],
61 + signature_def_map=signature_map,
62 + clear_devices=True)
63 + model_builder.save()
64 +
65 + def build_inputs_and_outputs(self):
66 + if self.frame_features:
67 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
68 +
69 + fn = lambda x: self.build_prediction_graph(x)
70 + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn(
71 + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32)))
72 +
73 + else:
74 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
75 +
76 + video_id_output, top_indices_output, top_predictions_output = (
77 + self.build_prediction_graph(serialized_examples))
78 +
79 + inputs = {
80 + "example_bytes":
81 + saved_model_utils.build_tensor_info(serialized_examples)
82 + }
83 +
84 + outputs = {
85 + "video_id":
86 + saved_model_utils.build_tensor_info(video_id_output),
87 + "class_indexes":
88 + saved_model_utils.build_tensor_info(top_indices_output),
89 + "predictions":
90 + saved_model_utils.build_tensor_info(top_predictions_output)
91 + }
92 +
93 + return inputs, outputs
94 +
95 + def build_prediction_graph(self, serialized_examples):
96 + input_data_dict = (
97 + self.reader.prepare_serialized_examples(serialized_examples))
98 + video_id = input_data_dict["video_ids"]
99 + model_input_raw = input_data_dict["video_matrix"]
100 + labels_batch = input_data_dict["labels"]
101 + num_frames = input_data_dict["num_frames"]
102 +
103 + feature_dim = len(model_input_raw.get_shape()) - 1
104 + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
105 +
106 + with tf.variable_scope("tower"):
107 + result = self.model.create_model(model_input,
108 + num_frames=num_frames,
109 + vocab_size=self.reader.num_classes,
110 + labels=labels_batch,
111 + is_training=False)
112 +
113 + for variable in slim.get_model_variables():
114 + tf.summary.histogram(variable.op.name, variable)
115 +
116 + predictions = result["predictions"]
117 +
118 + top_predictions, top_indices = tf.nn.top_k(predictions,
119 + _TOP_PREDICTIONS_IN_OUTPUT)
120 + return video_id, top_indices, top_predictions
1 +# Lint as: python3
2 +import numpy as np
3 +import tensorflow as tf
4 +from tensorflow import app
5 +from tensorflow import flags
6 +
7 +FLAGS = flags.FLAGS
8 +
9 +
10 +def main(unused_argv):
11 + # Get the input tensor names to be replaced.
12 + tf.reset_default_graph()
13 + meta_graph_location = FLAGS.checkpoint_file + ".meta"
14 + tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
15 +
16 + input_tensor_name = tf.get_collection("input_batch_raw")[0].name
17 + num_frames_tensor_name = tf.get_collection("num_frames")[0].name
18 +
19 + # Create output graph.
20 + saver = tf.train.Saver()
21 + tf.reset_default_graph()
22 +
23 + input_feature_placeholder = tf.placeholder(
24 + tf.float32, shape=(None, None, 1152))
25 + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))
26 +
27 + saver = tf.train.import_meta_graph(
28 + meta_graph_location,
29 + input_map={
30 + input_tensor_name: input_feature_placeholder,
31 + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
32 + },
33 + clear_devices=True)
34 + predictions_tensor = tf.get_collection("predictions")[0]
35 +
36 + with tf.Session() as sess:
37 + print("restoring variables from " + FLAGS.checkpoint_file)
38 + saver.restore(sess, FLAGS.checkpoint_file)
39 + tf.saved_model.simple_save(
40 + sess,
41 + FLAGS.output_dir,
42 + inputs={'rgb_and_audio': input_feature_placeholder,
43 + 'num_frames': num_frames_placeholder},
44 + outputs={'predictions': predictions_tensor})
45 +
46 + # Try running inference.
47 + predictions = sess.run(
48 + [predictions_tensor],
49 + feed_dict={
50 + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
51 + num_frames_placeholder: np.array([[7]], dtype=np.int32)})
52 + print('Test inference:', predictions)
53 +
54 + print('Model saved to ', FLAGS.output_dir)
55 +
56 +
57 +if __name__ == '__main__':
58 + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
59 + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
60 + app.run(main)
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides definitions for non-regularized training or test losses."""
15 +
16 +import tensorflow as tf
17 +
18 +
19 +class BaseLoss(object):
20 + """Inherit from this class when implementing new losses."""
21 +
22 + def calculate_loss(self, unused_predictions, unused_labels, **unused_params):
23 + """Calculates the average loss of the examples in a mini-batch.
24 +
25 + Args:
26 + unused_predictions: a 2-d tensor storing the prediction scores, in which
27 + each row represents a sample in the mini-batch and each column
28 + represents a class.
29 + unused_labels: a 2-d tensor storing the labels, which has the same shape
30 + as the unused_predictions. The labels must be in the range of 0 and 1.
31 + unused_params: loss specific parameters.
32 +
33 + Returns:
34 + A scalar loss tensor.
35 + """
36 + raise NotImplementedError()
37 +
38 +
39 +class CrossEntropyLoss(BaseLoss):
40 + """Calculate the cross entropy loss between the predictions and labels."""
41 +
42 + def calculate_loss(self,
43 + predictions,
44 + labels,
45 + label_weights=None,
46 + **unused_params):
47 + with tf.name_scope("loss_xent"):
48 + epsilon = 1e-5
49 + float_labels = tf.cast(labels, tf.float32)
50 + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + (
51 + 1 - float_labels) * tf.math.log(1 - predictions + epsilon)
52 + cross_entropy_loss = tf.negative(cross_entropy_loss)
53 + if label_weights is not None:
54 + cross_entropy_loss *= label_weights
55 + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
56 +
57 +
58 +class HingeLoss(BaseLoss):
59 + """Calculate the hinge loss between the predictions and labels.
60 +
61 + Note the subgradient is used in the backpropagation, and thus the optimization
62 + may converge slower. The predictions trained by the hinge loss are between -1
63 + and +1.
64 + """
65 +
66 + def calculate_loss(self, predictions, labels, b=1.0, **unused_params):
67 + with tf.name_scope("loss_hinge"):
68 + float_labels = tf.cast(labels, tf.float32)
69 + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32)
70 + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32)
71 + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones)
72 + hinge_loss = tf.maximum(
73 + all_zeros,
74 + tf.scalar_mul(b, all_ones) - sign_labels * predictions)
75 + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1))
76 +
77 +
78 +class SoftmaxLoss(BaseLoss):
79 + """Calculate the softmax loss between the predictions and labels.
80 +
81 + The function calculates the loss in the following way: first we feed the
82 + predictions to the softmax activation function and then we calculate
83 + the minus linear dot product between the logged softmax activations and the
84 + normalized ground truth label.
85 +
86 + It is an extension to the one-hot label. It allows for more than one positive
87 + labels for each sample.
88 + """
89 +
90 + def calculate_loss(self, predictions, labels, **unused_params):
91 + with tf.name_scope("loss_softmax"):
92 + epsilon = 10e-8
93 + float_labels = tf.cast(labels, tf.float32)
94 + # l1 normalization (labels are no less than 0)
95 + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True),
96 + epsilon)
97 + norm_float_labels = tf.div(float_labels, label_rowsum)
98 + softmax_outputs = tf.nn.softmax(predictions)
99 + softmax_loss = tf.negative(
100 + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)),
101 + 1))
102 + return tf.reduce_mean(softmax_loss)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate the mean average precision.
15 +
16 +It provides an interface for calculating mean average precision
17 +for an entire list or the top-n ranked items.
18 +
19 +Example usages:
20 +We first call the function accumulate many times to process parts of the ranked
21 +list. After processing all the parts, we call peek_map_at_n
22 +to calculate the mean average precision.
23 +
24 +```
25 +import random
26 +
27 +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
28 +a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
29 + for _ in xrange(1000)])
30 +
31 +# mean average precision for 50 classes.
32 +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
33 + num_class=50)
34 +calculator.accumulate(p, a)
35 +aps = calculator.peek_map_at_n()
36 +```
37 +"""
38 +
39 +import average_precision_calculator
40 +
41 +
42 +class MeanAveragePrecisionCalculator(object):
43 + """This class is to calculate mean average precision."""
44 +
45 + def __init__(self, num_class, filter_empty_classes=True, top_n=None):
46 + """Construct a calculator to calculate the (macro) average precision.
47 +
48 + Args:
49 + num_class: A positive Integer specifying the number of classes.
50 + filter_empty_classes: whether to filter classes without any positives.
51 + top_n: A positive Integer specifying the average precision at n, or None
52 + to use all provided data points.
53 +
54 + Raises:
55 + ValueError: An error occurred when num_class is not a positive integer;
56 + or the top_n_array is not a list of positive integers.
57 + """
58 + if not isinstance(num_class, int) or num_class <= 1:
59 + raise ValueError("num_class must be a positive integer.")
60 +
61 + self._ap_calculators = [] # member of AveragePrecisionCalculator
62 + self._num_class = num_class # total number of classes
63 + self._filter_empty_classes = filter_empty_classes
64 + for _ in range(num_class):
65 + self._ap_calculators.append(
66 + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n))
67 +
68 + def accumulate(self, predictions, actuals, num_positives=None):
69 + """Accumulate the predictions and their ground truth labels.
70 +
71 + Args:
72 + predictions: A list of lists storing the prediction scores. The outer
73 + dimension corresponds to classes.
74 + actuals: A list of lists storing the ground truth labels. The dimensions
75 + should correspond to the predictions input. Any value larger than 0 will
76 + be treated as positives, otherwise as negatives.
77 + num_positives: If provided, it is a list of numbers representing the
78 + number of true positives for each class. If not provided, the number of
79 + true positives will be inferred from the 'actuals' array.
80 +
81 + Raises:
82 + ValueError: An error occurred when the shape of predictions and actuals
83 + does not match.
84 + """
85 + if not num_positives:
86 + num_positives = [None for i in range(self._num_class)]
87 +
88 + calculators = self._ap_calculators
89 + for i in range(self._num_class):
90 + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])
91 +
92 + def clear(self):
93 + for calculator in self._ap_calculators:
94 + calculator.clear()
95 +
96 + def is_empty(self):
97 + return ([calculator.heap_size for calculator in self._ap_calculators
98 + ] == [0 for _ in range(self._num_class)])
99 +
100 + def peek_map_at_n(self):
101 + """Peek the non-interpolated mean average precision at n.
102 +
103 + Returns:
104 + An array of non-interpolated average precision at n (default 0) for each
105 + class.
106 + """
107 + aps = []
108 + for i in range(self._num_class):
109 + if (not self._filter_empty_classes or
110 + self._ap_calculators[i].num_accumulated_positives > 0):
111 + ap = self._ap_calculators[i].peek_ap_at_n()
112 + aps.append(ap)
113 + return aps
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for model construction."""
15 +import numpy
16 +import tensorflow as tf
17 +from tensorflow import logging
18 +from tensorflow import flags
19 +import tensorflow.contrib.slim as slim
20 +
21 +
22 +def SampleRandomSequence(model_input, num_frames, num_samples):
23 + """Samples a random sequence of frames of size num_samples.
24 +
25 + Args:
26 + model_input: A tensor of size batch_size x max_frames x feature_size
27 + num_frames: A tensor of size batch_size x 1
28 + num_samples: A scalar
29 +
30 + Returns:
31 + `model_input`: A tensor of size batch_size x num_samples x feature_size
32 + """
33 +
34 + batch_size = tf.shape(model_input)[0]
35 + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0),
36 + [batch_size, 1])
37 + max_start_frame_index = tf.maximum(num_frames - num_samples, 0)
38 + start_frame_index = tf.cast(
39 + tf.multiply(tf.random_uniform([batch_size, 1]),
40 + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32)
41 + frame_index = tf.minimum(start_frame_index + frame_index_offset,
42 + tf.cast(num_frames - 1, tf.int32))
43 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
44 + [1, num_samples])
45 + index = tf.stack([batch_index, frame_index], 2)
46 + return tf.gather_nd(model_input, index)
47 +
48 +
49 +def SampleRandomFrames(model_input, num_frames, num_samples):
50 + """Samples a random set of frames of size num_samples.
51 +
52 + Args:
53 + model_input: A tensor of size batch_size x max_frames x feature_size
54 + num_frames: A tensor of size batch_size x 1
55 + num_samples: A scalar
56 +
57 + Returns:
58 + `model_input`: A tensor of size batch_size x num_samples x feature_size
59 + """
60 + batch_size = tf.shape(model_input)[0]
61 + frame_index = tf.cast(
62 + tf.multiply(tf.random_uniform([batch_size, num_samples]),
63 + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])),
64 + tf.int32)
65 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
66 + [1, num_samples])
67 + index = tf.stack([batch_index, frame_index], 2)
68 + return tf.gather_nd(model_input, index)
69 +
70 +
71 +def FramePooling(frames, method, **unused_params):
72 + """Pools over the frames of a video.
73 +
74 + Args:
75 + frames: A tensor with shape [batch_size, num_frames, feature_size].
76 + method: "average", "max", "attention", or "none".
77 +
78 + Returns:
79 + A tensor with shape [batch_size, feature_size] for average, max, or
80 + attention pooling. A tensor with shape [batch_size*num_frames, feature_size]
81 + for none pooling.
82 +
83 + Raises:
84 + ValueError: if method is other than "average", "max", "attention", or
85 + "none".
86 + """
87 + if method == "average":
88 + return tf.reduce_mean(frames, 1)
89 + elif method == "max":
90 + return tf.reduce_max(frames, 1)
91 + elif method == "none":
92 + feature_size = frames.shape_as_list()[2]
93 + return tf.reshape(frames, [-1, feature_size])
94 + else:
95 + raise ValueError("Unrecognized pooling method: %s" % method)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains the base class for models."""
15 +
16 +
17 +class BaseModel(object):
18 + """Inherit from this class when implementing new models."""
19 +
20 + def create_model(self, unused_model_input, **unused_params):
21 + raise NotImplementedError()
This diff is collapsed. Click to expand it.
1 +"""Eval mAP@N metric from inference file."""
2 +
3 +from __future__ import absolute_import
4 +from __future__ import division
5 +from __future__ import print_function
6 +
7 +from absl import app
8 +from absl import flags
9 +
10 +import mean_average_precision_calculator as map_calculator
11 +import numpy as np
12 +import tensorflow as tf
13 +
14 +flags.DEFINE_string(
15 + "eval_data_pattern", "",
16 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
17 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
18 + "sequence feature as well as a 'labels' int64 context feature.")
19 +flags.DEFINE_string(
20 + "label_cache", "",
21 + "The path for the label cache file. Leave blank for not to cache.")
22 +flags.DEFINE_string("submission_file", "",
23 + "The segment submission file generated by inference.py.")
24 +flags.DEFINE_integer(
25 + "top_n", 0,
26 + "The cap per-class predictions by a maximum of N. Use 0 for not capping.")
27 +
28 +FLAGS = flags.FLAGS
29 +
30 +
31 +class Labels(object):
32 + """Contains the class to hold label objects.
33 +
34 + This class can serialize and de-serialize the groundtruths.
35 + The ground truth is in a mapping from (segment_id, class_id) -> label_score.
36 + """
37 +
38 + def __init__(self, labels):
39 + """__init__ method."""
40 + self._labels = labels
41 +
42 + @property
43 + def labels(self):
44 + """Return the ground truth mapping. See class docstring for details."""
45 + return self._labels
46 +
47 + def to_file(self, file_name):
48 + """Materialize the GT mapping to file."""
49 + with tf.gfile.Open(file_name, "w") as fobj:
50 + for k, v in self._labels.items():
51 + seg_id, label = k
52 + line = "%s,%s,%s\n" % (seg_id, label, v)
53 + fobj.write(line)
54 +
55 + @classmethod
56 + def from_file(cls, file_name):
57 + """Read the GT mapping from cached file."""
58 + labels = {}
59 + with tf.gfile.Open(file_name) as fobj:
60 + for line in fobj:
61 + line = line.strip().strip("\n")
62 + seg_id, label, score = line.split(",")
63 + labels[(seg_id, int(label))] = float(score)
64 + return cls(labels)
65 +
66 +
67 +def read_labels(data_pattern, cache_path=""):
68 + """Read labels from TFRecords.
69 +
70 + Args:
71 + data_pattern: the data pattern to the TFRecords.
72 + cache_path: the cache path for the label file.
73 +
74 + Returns:
75 + a Labels object.
76 + """
77 + if cache_path:
78 + if tf.gfile.Exists(cache_path):
79 + tf.logging.info("Reading cached labels from %s..." % cache_path)
80 + return Labels.from_file(cache_path)
81 + tf.enable_eager_execution()
82 + data_paths = tf.gfile.Glob(data_pattern)
83 + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50)
84 + context_features = {
85 + "id": tf.FixedLenFeature([], tf.string),
86 + "segment_labels": tf.VarLenFeature(tf.int64),
87 + "segment_start_times": tf.VarLenFeature(tf.int64),
88 + "segment_scores": tf.VarLenFeature(tf.float32)
89 + }
90 +
91 + def _parse_se_func(sequence_example):
92 + return tf.parse_single_sequence_example(sequence_example,
93 + context_features=context_features)
94 +
95 + ds = ds.map(_parse_se_func)
96 + rated_labels = {}
97 + tf.logging.info("Reading labels from TFRecords...")
98 + last_batch = 0
99 + batch_size = 5000
100 + for cxt_feature_val, _ in ds:
101 + video_id = cxt_feature_val["id"].numpy()
102 + segment_labels = cxt_feature_val["segment_labels"].values.numpy()
103 + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy()
104 + segment_scores = cxt_feature_val["segment_scores"].values.numpy()
105 + for label, start_time, score in zip(segment_labels, segment_start_times,
106 + segment_scores):
107 + rated_labels[("%s:%d" % (video_id, start_time), label)] = score
108 + batch_id = len(rated_labels) // batch_size
109 + if batch_id != last_batch:
110 + tf.logging.info("%d examples processed.", len(rated_labels))
111 + last_batch = batch_id
112 + tf.logging.info("Finish reading labels from TFRecords...")
113 + labels_obj = Labels(rated_labels)
114 + if cache_path:
115 + tf.logging.info("Caching labels to %s..." % cache_path)
116 + labels_obj.to_file(cache_path)
117 + return labels_obj
118 +
119 +
120 +def read_segment_predictions(file_path, labels, top_n=None):
121 + """Read segement predictions.
122 +
123 + Args:
124 + file_path: the submission file path.
125 + labels: a Labels object containing the eval labels.
126 + top_n: the per-class class capping.
127 +
128 + Returns:
129 + a segment prediction list for each classes.
130 + """
131 + cls_preds = {} # A label_id to pred list mapping.
132 + with tf.gfile.Open(file_path) as fobj:
133 + tf.logging.info("Reading predictions from %s..." % file_path)
134 + for line in fobj:
135 + label_id, pred_ids_val = line.split(",")
136 + pred_ids = pred_ids_val.split(" ")
137 + if top_n:
138 + pred_ids = pred_ids[:top_n]
139 + pred_ids = [
140 + pred_id for pred_id in pred_ids
141 + if (pred_id, int(label_id)) in labels.labels
142 + ]
143 + cls_preds[int(label_id)] = pred_ids
144 + if len(cls_preds) % 50 == 0:
145 + tf.logging.info("Processed %d classes..." % len(cls_preds))
146 + tf.logging.info("Finish reading predictions.")
147 + return cls_preds
148 +
149 +
150 +def main(unused_argv):
151 + """Entry function of the script."""
152 + if not FLAGS.submission_file:
153 + raise ValueError("You must input submission file.")
154 + eval_labels = read_labels(FLAGS.eval_data_pattern,
155 + cache_path=FLAGS.label_cache)
156 + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels))
157 + positive_counter = {}
158 + for k, v in eval_labels.labels.items():
159 + _, label_id = k
160 + if v > 0:
161 + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1
162 +
163 + seg_preds = read_segment_predictions(FLAGS.submission_file,
164 + eval_labels,
165 + top_n=FLAGS.top_n)
166 + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds))
167 + seg_labels = []
168 + seg_scored_preds = []
169 + num_positives = []
170 + for label_id in sorted(seg_preds):
171 + class_preds = seg_preds[label_id]
172 + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds]
173 + seg_labels.append(seg_label)
174 + seg_scored_pred = []
175 + if class_preds:
176 + seg_scored_pred = [
177 + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1)
178 + ]
179 + seg_scored_preds.append(seg_scored_pred)
180 + num_positives.append(positive_counter[label_id])
181 + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives)
182 + map_at_n = np.mean(map_cal.peek_map_at_n())
183 + tf.logging.info("Num classes: %d | mAP@%d: %.6f" %
184 + (len(seg_preds), FLAGS.top_n, map_at_n))
185 +
186 +
187 +if __name__ == "__main__":
188 + app.run(main)
1 +Index
2 +3
3 +7
4 +8
5 +11
6 +12
7 +17
8 +18
9 +19
10 +21
11 +22
12 +23
13 +28
14 +31
15 +30
16 +32
17 +33
18 +34
19 +41
20 +43
21 +45
22 +46
23 +48
24 +53
25 +54
26 +52
27 +55
28 +58
29 +59
30 +60
31 +61
32 +65
33 +68
34 +73
35 +71
36 +74
37 +75
38 +76
39 +77
40 +80
41 +83
42 +90
43 +88
44 +89
45 +92
46 +95
47 +100
48 +101
49 +99
50 +104
51 +105
52 +109
53 +113
54 +112
55 +115
56 +116
57 +118
58 +120
59 +121
60 +123
61 +125
62 +127
63 +131
64 +128
65 +129
66 +130
67 +137
68 +141
69 +143
70 +145
71 +148
72 +152
73 +151
74 +156
75 +155
76 +158
77 +160
78 +164
79 +163
80 +169
81 +170
82 +172
83 +171
84 +173
85 +174
86 +175
87 +176
88 +178
89 +182
90 +184
91 +186
92 +188
93 +187
94 +192
95 +191
96 +190
97 +194
98 +197
99 +196
100 +198
101 +201
102 +202
103 +200
104 +199
105 +205
106 +204
107 +209
108 +207
109 +206
110 +210
111 +213
112 +214
113 +220
114 +218
115 +217
116 +226
117 +227
118 +231
119 +232
120 +229
121 +233
122 +235
123 +237
124 +244
125 +240
126 +249
127 +246
128 +248
129 +239
130 +250
131 +245
132 +255
133 +253
134 +256
135 +261
136 +259
137 +263
138 +262
139 +266
140 +267
141 +268
142 +269
143 +271
144 +276
145 +273
146 +277
147 +274
148 +278
149 +279
150 +280
151 +288
152 +291
153 +295
154 +294
155 +293
156 +297
157 +296
158 +300
159 +299
160 +303
161 +302
162 +304
163 +305
164 +313
165 +307
166 +311
167 +310
168 +312
169 +316
170 +318
171 +321
172 +322
173 +331
174 +333
175 +329
176 +330
177 +334
178 +343
179 +349
180 +340
181 +344
182 +348
183 +358
184 +347
185 +359
186 +355
187 +361
188 +360
189 +364
190 +365
191 +368
192 +369
193 +366
194 +370
195 +374
196 +380
197 +373
198 +385
199 +384
200 +388
201 +389
202 +382
203 +393
204 +381
205 +390
206 +394
207 +399
208 +397
209 +396
210 +402
211 +400
212 +398
213 +401
214 +405
215 +406
216 +410
217 +408
218 +416
219 +415
220 +419
221 +422
222 +414
223 +421
224 +424
225 +429
226 +418
227 +427
228 +434
229 +428
230 +435
231 +430
232 +441
233 +439
234 +437
235 +443
236 +440
237 +442
238 +445
239 +446
240 +448
241 +454
242 +444
243 +453
244 +455
245 +451
246 +452
247 +458
248 +460
249 +465
250 +457
251 +463
252 +462
253 +461
254 +464
255 +469
256 +468
257 +472
258 +473
259 +471
260 +475
261 +474
262 +477
263 +485
264 +491
265 +488
266 +482
267 +490
268 +496
269 +494
270 +483
271 +495
272 +493
273 +507
274 +501
275 +499
276 +503
277 +498
278 +514
279 +504
280 +502
281 +506
282 +508
283 +511
284 +527
285 +526
286 +532
287 +513
288 +519
289 +525
290 +518
291 +528
292 +522
293 +523
294 +535
295 +539
296 +540
297 +533
298 +521
299 +541
300 +547
301 +550
302 +544
303 +549
304 +551
305 +554
306 +543
307 +548
308 +557
309 +560
310 +552
311 +559
312 +563
313 +565
314 +567
315 +555
316 +576
317 +568
318 +564
319 +573
320 +581
321 +580
322 +572
323 +571
324 +584
325 +590
326 +585
327 +587
328 +588
329 +592
330 +598
331 +597
332 +599
333 +603
334 +600
335 +604
336 +605
337 +614
338 +602
339 +610
340 +608
341 +611
342 +612
343 +613
344 +617
345 +620
346 +607
347 +624
348 +627
349 +625
350 +631
351 +629
352 +638
353 +632
354 +634
355 +644
356 +641
357 +642
358 +646
359 +652
360 +647
361 +637
362 +661
363 +635
364 +658
365 +648
366 +663
367 +668
368 +664
369 +656
370 +666
371 +671
372 +683
373 +675
374 +669
375 +676
376 +667
377 +691
378 +685
379 +673
380 +688
381 +702
382 +684
383 +679
384 +694
385 +686
386 +689
387 +680
388 +693
389 +703
390 +697
391 +698
392 +692
393 +705
394 +706
395 +712
396 +711
397 +709
398 +710
399 +726
400 +713
401 +721
402 +720
403 +715
404 +717
405 +730
406 +728
407 +723
408 +716
409 +722
410 +718
411 +732
412 +724
413 +736
414 +725
415 +742
416 +727
417 +735
418 +740
419 +748
420 +738
421 +746
422 +751
423 +749
424 +752
425 +754
426 +760
427 +763
428 +756
429 +758
430 +766
431 +764
432 +757
433 +780
434 +767
435 +769
436 +771
437 +786
438 +785
439 +781
440 +787
441 +778
442 +783
443 +792
444 +791
445 +795
446 +788
447 +805
448 +802
449 +801
450 +793
451 +796
452 +804
453 +803
454 +797
455 +814
456 +813
457 +789
458 +808
459 +818
460 +816
461 +817
462 +811
463 +820
464 +826
465 +829
466 +824
467 +821
468 +825
469 +822
470 +835
471 +833
472 +843
473 +823
474 +827
475 +830
476 +832
477 +837
478 +852
479 +844
480 +841
481 +812
482 +847
483 +862
484 +869
485 +860
486 +838
487 +870
488 +846
489 +858
490 +854
491 +880
492 +876
493 +857
494 +859
495 +877
496 +871
497 +855
498 +875
499 +861
500 +867
501 +892
502 +898
503 +888
504 +884
505 +887
506 +891
507 +906
508 +900
509 +878
510 +885
511 +883
512 +901
513 +903
514 +907
515 +930
516 +897
517 +914
518 +917
519 +910
520 +905
521 +909
522 +933
523 +932
524 +922
525 +913
526 +923
527 +931
528 +911
529 +937
530 +918
531 +955
532 +915
533 +944
534 +952
535 +945
536 +948
537 +946
538 +970
539 +974
540 +958
541 +925
542 +979
543 +942
544 +965
545 +975
546 +950
547 +982
548 +940
549 +973
550 +962
551 +972
552 +957
553 +984
554 +983
555 +964
556 +1007
557 +971
558 +981
559 +954
560 +993
561 +991
562 +996
563 +1005
564 +1015
565 +1009
566 +995
567 +986
568 +1000
569 +985
570 +980
571 +1016
572 +1011
573 +999
574 +1002
575 +994
576 +1013
577 +1010
578 +992
579 +1008
580 +1036
581 +1025
582 +1012
583 +990
584 +1037
585 +1040
586 +1031
587 +1019
588 +1052
589 +1001
590 +1055
591 +1032
592 +1069
593 +1058
594 +1014
595 +1023
596 +1030
597 +1061
598 +1035
599 +1034
600 +1053
601 +1045
602 +1046
603 +1067
604 +1060
605 +1049
606 +1056
607 +1074
608 +1066
609 +1044
610 +1038
611 +1073
612 +1077
613 +1068
614 +1057
615 +1072
616 +1104
617 +1083
618 +1089
619 +1087
620 +1099
621 +1076
622 +1086
623 +1098
624 +1094
625 +1095
626 +1096
627 +1101
628 +1107
629 +1105
630 +1117
631 +1093
632 +1106
633 +1122
634 +1119
635 +1103
636 +1128
637 +1120
638 +1126
639 +1102
640 +1115
641 +1124
642 +1123
643 +1131
644 +1136
645 +1144
646 +1121
647 +1137
648 +1132
649 +1133
650 +1157
651 +1134
652 +1143
653 +1159
654 +1164
655 +1155
656 +1142
657 +1150
658 +1148
659 +1161
660 +1165
661 +1147
662 +1162
663 +1152
664 +1174
665 +1160
666 +1166
667 +1190
668 +1175
669 +1167
670 +1156
671 +1180
672 +1171
673 +1179
674 +1172
675 +1186
676 +1188
677 +1201
678 +1177
679 +1208
680 +1183
681 +1189
682 +1192
683 +1209
684 +1214
685 +1197
686 +1168
687 +1202
688 +1205
689 +1203
690 +1199
691 +1219
692 +1217
693 +1187
694 +1206
695 +1210
696 +1241
697 +1221
698 +1218
699 +1223
700 +1236
701 +1212
702 +1237
703 +1195
704 +1216
705 +1247
706 +1234
707 +1240
708 +1257
709 +1224
710 +1243
711 +1259
712 +1242
713 +1282
714 +1222
715 +1254
716 +1227
717 +1235
718 +1269
719 +1258
720 +1290
721 +1275
722 +1262
723 +1252
724 +1248
725 +1272
726 +1246
727 +1225
728 +1245
729 +1277
730 +1298
731 +1288
732 +1271
733 +1265
734 +1286
735 +1260
736 +1266
737 +1296
738 +1280
739 +1285
740 +1293
741 +1276
742 +1287
743 +1289
744 +1261
745 +1264
746 +1295
747 +1291
748 +1283
749 +1311
750 +1303
751 +1330
752 +1315
753 +1300
754 +1333
755 +1307
756 +1325
757 +1334
758 +1316
759 +1314
760 +1317
761 +1310
762 +1329
763 +1324
764 +1339
765 +1346
766 +1342
767 +1352
768 +1321
769 +1376
770 +1366
771 +1308
772 +1345
773 +1348
774 +1386
775 +1383
776 +1372
777 +1367
778 +1400
779 +1382
780 +1375
781 +1392
782 +1380
783 +1371
784 +1393
785 +1389
786 +1353
787 +1387
788 +1374
789 +1379
790 +1381
791 +1359
792 +1360
793 +1396
794 +1399
795 +1365
796 +1424
797 +1373
798 +1411
799 +1401
800 +1397
801 +1395
802 +1412
803 +1394
804 +1368
805 +1423
806 +1391
807 +1435
808 +1409
809 +1443
810 +1402
811 +1425
812 +1415
813 +1421
814 +1426
815 +1433
816 +1420
817 +1452
818 +1436
819 +1430
820 +1408
821 +1458
822 +1429
823 +1453
824 +1454
825 +1447
826 +1472
827 +1486
828 +1468
829 +1461
830 +1467
831 +1484
832 +1457
833 +1444
834 +1450
835 +1451
836 +1459
837 +1462
838 +1449
839 +1476
840 +1470
841 +1471
842 +1498
843 +1488
844 +1442
845 +1480
846 +1456
847 +1466
848 +1505
849 +1517
850 +1464
851 +1503
852 +1490
853 +1519
854 +1481
855 +1493
856 +1463
857 +1532
858 +1487
859 +1501
860 +1500
861 +1495
862 +1509
863 +1535
864 +1506
865 +1521
866 +1580
867 +1540
868 +1502
869 +1520
870 +1496
871 +1569
872 +1515
873 +1489
874 +1507
875 +1527
876 +1545
877 +1560
878 +1510
879 +1514
880 +1526
881 +1594
882 +1511
883 +1572
884 +1548
885 +1584
886 +1556
887 +1588
888 +1628
889 +1555
890 +1568
891 +1550
892 +1622
893 +1563
894 +1603
895 +1616
896 +1576
897 +1549
898 +1537
899 +1593
900 +1618
901 +1645
902 +1624
903 +1617
904 +1634
905 +1595
906 +1597
907 +1590
908 +1632
909 +1575
910 +1559
911 +1625
912 +1615
913 +1591
914 +1630
915 +1608
916 +1621
917 +1589
918 +1646
919 +1643
920 +1652
921 +1627
922 +1611
923 +1626
924 +1613
925 +1639
926 +1655
927 +1620
928 +1602
929 +1651
930 +1653
931 +1669
932 +1638
933 +1696
934 +1649
935 +1675
936 +1660
937 +1683
938 +1666
939 +1671
940 +1703
941 +1716
942 +1637
943 +1672
944 +1676
945 +1692
946 +1711
947 +1680
948 +1641
949 +1688
950 +1708
951 +1704
952 +1690
953 +1674
954 +1718
955 +1699
956 +1723
957 +1756
958 +1700
959 +1662
960 +1715
961 +1657
962 +1733
963 +1728
964 +1670
965 +1712
966 +1685
967 +1724
968 +1735
969 +1714
970 +1730
971 +1747
972 +1656
973 +1737
974 +1705
975 +1693
976 +1713
977 +1689
978 +1753
979 +1739
980 +1721
981 +1725
982 +1749
983 +1732
984 +1743
985 +1731
986 +1767
987 +1738
988 +1831
989 +1771
990 +1726
991 +1746
992 +1776
993 +1775
994 +1799
995 +1774
996 +1780
997 +1781
998 +1769
999 +1805
1000 +1788
1001 +1801
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for training and evaluating."""
15 +
16 +import numpy
17 +import tensorflow as tf
18 +from tensorflow import logging
19 +
20 +try:
21 + xrange # Python 2
22 +except NameError:
23 + xrange = range # Python 3
24 +
25 +
26 +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
27 + """Dequantize the feature from the byte format to the float format.
28 +
29 + Args:
30 + feat_vector: the input 1-d vector.
31 + max_quantized_value: the maximum of the quantized value.
32 + min_quantized_value: the minimum of the quantized value.
33 +
34 + Returns:
35 + A float vector which has the same shape as feat_vector.
36 + """
37 + assert max_quantized_value > min_quantized_value
38 + quantized_range = max_quantized_value - min_quantized_value
39 + scalar = quantized_range / 255.0
40 + bias = (quantized_range / 512.0) + min_quantized_value
41 + return feat_vector * scalar + bias
42 +
43 +
44 +def MakeSummary(name, value):
45 + """Creates a tf.Summary proto with the given name and value."""
46 + summary = tf.Summary()
47 + val = summary.value.add()
48 + val.tag = str(name)
49 + val.simple_value = float(value)
50 + return summary
51 +
52 +
53 +def AddGlobalStepSummary(summary_writer,
54 + global_step_val,
55 + global_step_info_dict,
56 + summary_scope="Eval"):
57 + """Add the global_step summary to the Tensorboard.
58 +
59 + Args:
60 + summary_writer: Tensorflow summary_writer.
61 + global_step_val: a int value of the global step.
62 + global_step_info_dict: a dictionary of the evaluation metrics calculated for
63 + a mini-batch.
64 + summary_scope: Train or Eval.
65 +
66 + Returns:
67 + A string of this global_step summary
68 + """
69 + this_hit_at_one = global_step_info_dict["hit_at_one"]
70 + this_perr = global_step_info_dict["perr"]
71 + this_loss = global_step_info_dict["loss"]
72 + examples_per_second = global_step_info_dict.get("examples_per_second", -1)
73 +
74 + summary_writer.add_summary(
75 + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one),
76 + global_step_val)
77 + summary_writer.add_summary(
78 + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr),
79 + global_step_val)
80 + summary_writer.add_summary(
81 + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss),
82 + global_step_val)
83 +
84 + if examples_per_second != -1:
85 + summary_writer.add_summary(
86 + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second",
87 + examples_per_second), global_step_val)
88 +
89 + summary_writer.flush()
90 + info = (
91 + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch "
92 + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format(
93 + global_step_val, this_hit_at_one, this_perr, this_loss,
94 + examples_per_second)
95 + return info
96 +
97 +
98 +def AddEpochSummary(summary_writer,
99 + global_step_val,
100 + epoch_info_dict,
101 + summary_scope="Eval"):
102 + """Add the epoch summary to the Tensorboard.
103 +
104 + Args:
105 + summary_writer: Tensorflow summary_writer.
106 + global_step_val: a int value of the global step.
107 + epoch_info_dict: a dictionary of the evaluation metrics calculated for the
108 + whole epoch.
109 + summary_scope: Train or Eval.
110 +
111 + Returns:
112 + A string of this global_step summary
113 + """
114 + epoch_id = epoch_info_dict["epoch_id"]
115 + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"]
116 + avg_perr = epoch_info_dict["avg_perr"]
117 + avg_loss = epoch_info_dict["avg_loss"]
118 + aps = epoch_info_dict["aps"]
119 + gap = epoch_info_dict["gap"]
120 + mean_ap = numpy.mean(aps)
121 +
122 + summary_writer.add_summary(
123 + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
124 + global_step_val)
125 + summary_writer.add_summary(
126 + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr),
127 + global_step_val)
128 + summary_writer.add_summary(
129 + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss),
130 + global_step_val)
131 + summary_writer.add_summary(
132 + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val)
133 + summary_writer.add_summary(
134 + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val)
135 + summary_writer.flush()
136 +
137 + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
138 + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}"
139 + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss,
140 + len(aps))
141 + return info
142 +
143 +
144 +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
145 + """Extract the list of feature names and the dimensionality of each feature
146 +
147 + from string of comma separated values.
148 +
149 + Args:
150 + feature_names: string containing comma separated list of feature names
151 + feature_sizes: string containing comma separated list of feature sizes
152 +
153 + Returns:
154 + List of the feature names and list of the dimensionality of each feature.
155 + Elements in the first/second list are strings/integers.
156 + """
157 + list_of_feature_names = [
158 + feature_names.strip() for feature_names in feature_names.split(",")
159 + ]
160 + list_of_feature_sizes = [
161 + int(feature_sizes) for feature_sizes in feature_sizes.split(",")
162 + ]
163 + if len(list_of_feature_names) != len(list_of_feature_sizes):
164 + logging.error("length of the feature names (=" +
165 + str(len(list_of_feature_names)) + ") != length of feature "
166 + "sizes (=" + str(len(list_of_feature_sizes)) + ")")
167 +
168 + return list_of_feature_names, list_of_feature_sizes
169 +
170 +
171 +def clip_gradient_norms(gradients_to_variables, max_norm):
172 + """Clips the gradients by the given value.
173 +
174 + Args:
175 + gradients_to_variables: A list of gradient to variable pairs (tuples).
176 + max_norm: the maximum norm value.
177 +
178 + Returns:
179 + A list of clipped gradient to variable pairs.
180 + """
181 + clipped_grads_and_vars = []
182 + for grad, var in gradients_to_variables:
183 + if grad is not None:
184 + if isinstance(grad, tf.IndexedSlices):
185 + tmp = tf.clip_by_norm(grad.values, max_norm)
186 + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
187 + else:
188 + grad = tf.clip_by_norm(grad, max_norm)
189 + clipped_grads_and_vars.append((grad, var))
190 + return clipped_grads_and_vars
191 +
192 +
193 +def combine_gradients(tower_grads):
194 + """Calculate the combined gradient for each shared variable across all towers.
195 +
196 + Note that this function provides a synchronization point across all towers.
197 +
198 + Args:
199 + tower_grads: List of lists of (gradient, variable) tuples. The outer list is
200 + over individual gradients. The inner list is over the gradient calculation
201 + for each tower.
202 +
203 + Returns:
204 + List of pairs of (gradient, variable) where the gradient has been summed
205 + across all towers.
206 + """
207 + filtered_grads = [
208 + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads
209 + ]
210 + final_grads = []
211 + for i in xrange(len(filtered_grads[0])):
212 + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
213 + grad = tf.stack([x[0] for x in grads], 0)
214 + grad = tf.reduce_sum(grad, 0)
215 + final_grads.append((
216 + grad,
217 + filtered_grads[0][i][1],
218 + ))
219 +
220 + return final_grads
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains model definitions."""
15 +import math
16 +
17 +import models
18 +import tensorflow as tf
19 +import utils
20 +
21 +from tensorflow import flags
22 +import tensorflow.contrib.slim as slim
23 +
24 +FLAGS = flags.FLAGS
25 +flags.DEFINE_integer(
26 + "moe_num_mixtures", 2,
27 + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.")
28 +
29 +
30 +class LogisticModel(models.BaseModel):
31 + """Logistic model with L2 regularization."""
32 +
33 + def create_model(self,
34 + model_input,
35 + vocab_size,
36 + l2_penalty=1e-8,
37 + **unused_params):
38 + """Creates a logistic model.
39 +
40 + Args:
41 + model_input: 'batch' x 'num_features' matrix of input features.
42 + vocab_size: The number of classes in the dataset.
43 +
44 + Returns:
45 + A dictionary with a tensor containing the probability predictions of the
46 + model in the 'predictions' key. The dimensions of the tensor are
47 + batch_size x num_classes.
48 + """
49 + output = slim.fully_connected(
50 + model_input,
51 + vocab_size,
52 + activation_fn=tf.nn.sigmoid,
53 + weights_regularizer=slim.l2_regularizer(l2_penalty))
54 + return {"predictions": output}
55 +
56 +
57 +class MoeModel(models.BaseModel):
58 + """A softmax over a mixture of logistic models (with L2 regularization)."""
59 +
60 + def create_model(self,
61 + model_input,
62 + vocab_size,
63 + num_mixtures=None,
64 + l2_penalty=1e-8,
65 + **unused_params):
66 + """Creates a Mixture of (Logistic) Experts model.
67 +
68 + The model consists of a per-class softmax distribution over a
69 + configurable number of logistic classifiers. One of the classifiers in the
70 + mixture is not trained, and always predicts 0.
71 +
72 + Args:
73 + model_input: 'batch_size' x 'num_features' matrix of input features.
74 + vocab_size: The number of classes in the dataset.
75 + num_mixtures: The number of mixtures (excluding a dummy 'expert' that
76 + always predicts the non-existence of an entity).
77 + l2_penalty: How much to penalize the squared magnitudes of parameter
78 + values.
79 +
80 + Returns:
81 + A dictionary with a tensor containing the probability predictions of the
82 + model in the 'predictions' key. The dimensions of the tensor are
83 + batch_size x num_classes.
84 + """
85 + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures
86 +
87 + gate_activations = slim.fully_connected(
88 + model_input,
89 + vocab_size * (num_mixtures + 1),
90 + activation_fn=None,
91 + biases_initializer=None,
92 + weights_regularizer=slim.l2_regularizer(l2_penalty),
93 + scope="gates")
94 + expert_activations = slim.fully_connected(
95 + model_input,
96 + vocab_size * num_mixtures,
97 + activation_fn=None,
98 + weights_regularizer=slim.l2_regularizer(l2_penalty),
99 + scope="experts")
100 +
101 + gating_distribution = tf.nn.softmax(
102 + tf.reshape(
103 + gate_activations,
104 + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1)
105 + expert_distribution = tf.nn.sigmoid(
106 + tf.reshape(expert_activations,
107 + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures
108 +
109 + final_probabilities_by_class_and_batch = tf.reduce_sum(
110 + gating_distribution[:, :num_mixtures] * expert_distribution, 1)
111 + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
112 + [-1, vocab_size])
113 + return {"predictions": final_probabilities}
This diff could not be displayed because it is too large.
1 +> 1%
2 +last 2 versions
3 +not dead
1 +[*.{js,jsx,ts,tsx,vue}]
2 +indent_style = space
3 +indent_size = 2
4 +end_of_line = lf
5 +trim_trailing_whitespace = true
6 +insert_final_newline = true
7 +max_line_length = 200
1 +module.exports = {
2 + root: true,
3 + env: {
4 + browser: true,
5 + es6: true,
6 + node: true,
7 + },
8 + extends: [
9 + 'plugin:vue/essential',
10 + '@vue/airbnb',
11 + ],
12 + parserOptions: {
13 + parser: 'babel-eslint',
14 + },
15 + rules: {
16 + 'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
17 + 'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
18 + 'linebreak-style': ['error', 'unix'],
19 + 'max-len': ['error', { code: 200 }],
20 + },
21 +};
1 +.DS_Store
2 +node_modules
3 +/dist
4 +
5 +# local env files
6 +.env.local
7 +.env.*.local
8 +
9 +# Log files
10 +npm-debug.log*
11 +yarn-debug.log*
12 +yarn-error.log*
13 +
14 +# Editor directories and files
15 +.idea
16 +.vscode
17 +*.suo
18 +*.ntvs*
19 +*.njsproj
20 +*.sln
21 +*.sw?
1 +# front
2 +
3 +## Project setup
4 +```
5 +yarn install
6 +```
7 +
8 +### Compiles and hot-reloads for development
9 +```
10 +yarn serve
11 +```
12 +
13 +### Compiles and minifies for production
14 +```
15 +yarn build
16 +```
17 +
18 +### Lints and fixes files
19 +```
20 +yarn lint
21 +```
22 +
23 +### Customize configuration
24 +See [Configuration Reference](https://cli.vuejs.org/config/).
1 +module.exports = {
2 + presets: [
3 + '@vue/cli-plugin-babel/preset',
4 + ],
5 +};
This diff could not be displayed because it is too large.
1 +{
2 + "name": "front",
3 + "version": "0.1.0",
4 + "private": true,
5 + "scripts": {
6 + "serve": "vue-cli-service serve",
7 + "build": "vue-cli-service build",
8 + "lint": "vue-cli-service lint"
9 + },
10 + "dependencies": {
11 + "@mdi/font": "^3.6.95",
12 + "axios": "^0.19.2",
13 + "core-js": "^3.6.5",
14 + "moment": "^2.24.0",
15 + "roboto-fontface": "*",
16 + "vue": "^2.6.11",
17 + "vue-router": "^3.1.6",
18 + "vuetify": "^2.2.11",
19 + "vuex": "^3.1.3"
20 + },
21 + "devDependencies": {
22 + "@vue/cli-plugin-babel": "~4.3.0",
23 + "@vue/cli-plugin-eslint": "~4.3.0",
24 + "@vue/cli-plugin-router": "~4.3.0",
25 + "@vue/cli-plugin-vuex": "~4.3.0",
26 + "@vue/cli-service": "~4.3.0",
27 + "@vue/eslint-config-airbnb": "^5.0.2",
28 + "babel-eslint": "^10.1.0",
29 + "eslint": "^6.7.2",
30 + "eslint-plugin-import": "^2.20.2",
31 + "eslint-plugin-vue": "^6.2.2",
32 + "node-sass": "^4.12.0",
33 + "sass": "^1.19.0",
34 + "sass-loader": "^8.0.2",
35 + "vue-cli-plugin-vuetify": "~2.0.5",
36 + "vue-template-compiler": "^2.6.11",
37 + "vuetify-loader": "^1.3.0"
38 + }
39 +}
No preview for this file type
1 +<!DOCTYPE html>
2 +<html lang="en">
3 + <head>
4 + <meta charset="utf-8">
5 + <meta http-equiv="X-UA-Compatible" content="IE=edge">
6 + <meta name="viewport" content="width=device-width,initial-scale=1.0">
7 + <link rel="icon" href="<%= BASE_URL %>favicon.ico">
8 + <title>Profit-Hunter</title>
9 + </head>
10 + <body>
11 + <noscript>
12 + <strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
13 + </noscript>
14 + <div id="app"></div>
15 + <!-- built files will be auto injected -->
16 + </body>
17 +</html>
1 +<template>
2 + <v-app>
3 + <v-app-bar app color="#ffffff" elevation="1" hide-on-scroll>
4 + <v-icon size="35" class="mr-1" color="grey700">mdi-youtube</v-icon>
5 + <div style="color: #343a40; font-size: 20px; font-weight: 500;">Youtube Auto Tagger</div>
6 + <v-spacer></v-spacer>
7 + <v-tooltip bottom>
8 + <template v-slot:activator="{ on }">
9 + <v-btn icon v-on="on" color="grey700" @click="clickInfo=true">
10 + <v-icon size="30">mdi-information</v-icon>
11 + </v-btn>
12 + </template>
13 + <span>Service Info</span>
14 + </v-tooltip>
15 + </v-app-bar>
16 + <v-dialog v-model="clickInfo" max-width="600" class="pa-2">
17 + <v-card elevation="0" outlined class="pa-4">
18 + <v-row justify="center" class="mx-0 mb-8 mt-2">
19 + <v-icon color="primary">mdi-power-on</v-icon>
20 + <div
21 + style="text-align: center; font-size: 22px; font-weight: 400; color: #343a40;"
22 + >Information of This Service</div>
23 + <v-icon color="primary">mdi-power-on</v-icon>
24 + </v-row>
25 + <v-btn color="primary" class="mt-2" block>Description of service</v-btn>
26 + <v-btn color="primary" class="my-2" block>Term of this service</v-btn>
27 + <v-btn color="primary" block>Used Opensource</v-btn>
28 + </v-card>
29 + </v-dialog>
30 + <v-content>
31 + <router-view />
32 + </v-content>
33 + <v-footer>
34 + <v-row justify="center">
35 + <v-avatar size="25" tile style="border-radius: 4px">
36 + <v-img src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/99f7d5c73e506d2c5c0072a21f362181/logo.69342704.png"></v-img>
37 + </v-avatar>
38 + <a href="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1">
39 + <div
40 + style="margin-left: 4px; font-size: 16px; color: #5a5a5a; font-weight: 400"
41 + >Profit-Hunter</div>
42 + </a>
43 + </v-row>
44 + </v-footer>
45 + </v-app>
46 +</template>
47 +<script>
48 +export default {
49 + name: 'App',
50 + data() {
51 + return {
52 + clickInfo: false,
53 + };
54 + },
55 +};
56 +</script>
57 +<style lang="scss">
58 +a:hover {
59 + text-decoration: none;
60 +}
61 +a:link {
62 + text-decoration: none;
63 +}
64 +a:visited {
65 + text-decoration: none;
66 +}
67 +</style>
1 +<template>
2 + <div></div>
3 +</template>
4 +
5 +<script>
6 +export default {
7 +
8 +};
9 +</script>
10 +
11 +<style>
12 +
13 +</style>
1 +<template>
2 + <div></div>
3 +</template>
4 +
5 +<script>
6 +export default {
7 +
8 +};
9 +</script>
10 +
11 +<style>
12 +
13 +</style>
1 +<template>
2 + <div></div>
3 +</template>
4 +
5 +<script>
6 +export default {
7 +
8 +};
9 +</script>
10 +
11 +<style>
12 +
13 +</style>
1 +import Vue from 'vue';
2 +import App from './App.vue';
3 +import router from './router';
4 +import store from './store';
5 +import vuetify from './plugins/vuetify';
6 +import 'roboto-fontface/css/roboto/roboto-fontface.css';
7 +import '@mdi/font/css/materialdesignicons.css';
8 +
9 +Vue.config.productionTip = false;
10 +
11 +new Vue({
12 + router,
13 + store,
14 + vuetify,
15 + render: (h) => h(App),
16 +}).$mount('#app');
1 +import Vue from 'vue';
2 +import Vuetify from 'vuetify/lib';
3 +
4 +Vue.use(Vuetify);
5 +
6 +export default new Vuetify({
7 + theme: {
8 + themes: {
9 + light: {
10 + primary: '#343a40',
11 + secondary: '#506980',
12 + accent: '#505B80',
13 + error: '#FF5252',
14 + info: '#2196F3',
15 + blue: '#173f5f',
16 + lightblue: '#72b1e4',
17 + success: '#2779bd',
18 + warning: '#12283a',
19 + grey300: '#eceeef',
20 + grey500: '#aaaaaa',
21 + grey700: '#5a5a5a',
22 + grey900: '#212529',
23 + },
24 + dark: {
25 + primary: '#343a40',
26 + secondary: '#506980',
27 + accent: '#505B80',
28 + error: '#FF5252',
29 + info: '#2196F3',
30 + blue: '#173f5f',
31 + lightblue: '#72b1e4',
32 + success: '#2779bd',
33 + warning: '#12283a',
34 + grey300: '#eceeef',
35 + grey500: '#aaaaaa',
36 + grey700: '#5a5a5a',
37 + grey900: '#212529',
38 + },
39 + },
40 + },
41 +});
1 +import Vue from 'vue';
2 +import VueRouter from 'vue-router';
3 +import axios from 'axios';
4 +import Home from '../views/Home.vue';
5 +
6 +Vue.prototype.$axios = axios;
7 +const apiRootPath = process.env.NODE_ENV !== 'production'
8 + ? 'http://localhost:8000/api/'
9 + : '/api/';
10 +Vue.prototype.$apiRootPath = apiRootPath;
11 +axios.defaults.baseURL = apiRootPath;
12 +
13 +Vue.use(VueRouter);
14 +
15 +const routes = [
16 + {
17 + path: '/',
18 + name: 'Home',
19 + component: Home,
20 + },
21 +];
22 +
23 +const router = new VueRouter({
24 + mode: 'history',
25 + base: process.env.BASE_URL,
26 + routes,
27 +});
28 +
29 +export default router;
1 +import Vue from 'vue';
2 +import Vuex from 'vuex';
3 +
4 +Vue.use(Vuex);
5 +
6 +export default new Vuex.Store({
7 + state: {
8 + },
9 + mutations: {
10 + },
11 + actions: {
12 + },
13 + modules: {
14 + },
15 +});
1 +<template>
2 + <v-sheet>
3 + <v-overlay v-model="loadingProcess">
4 + <v-progress-circular :size="120" width="10" color="primary" indeterminate></v-progress-circular>
5 + <div
6 + style="color: #ffffff; font-size: 22px; margin-top: 20px; margin-left: -40px;"
7 + >Analyzing Your Video...</div>
8 + </v-overlay>
9 + <v-layout justify-center>
10 + <v-flex xs12 sm8 md6 lg4>
11 + <v-row justify="center" class="mx-0 mt-12">
12 + <div style="font-size: 34px; font-weight: 500; color: #343a40;">WELCOME</div>
13 + </v-row>
14 + <v-card elevation="0">
15 + <!-- 데스크톱 화면 설명 요약 -->
16 + <v-row justify="center" class="mx-0 mt-12" v-if="$vuetify.breakpoint.mdAndUp">
17 + <v-flex md7 class="mt-1">
18 + <div
19 + style="font-size: 20px; font-weight: 300; color: #888;"
20 + >This is Video auto tagging Service</div>
21 + <div
22 + style="font-size: 20px; font-weight: 300; color: #888;"
23 + >Designed for Youtube Videos</div>
24 + <div
25 + style="font-size: 20px; font-weight: 300; color: #888;"
26 + >It takes few minutes to analyze your Video!</div>
27 + </v-flex>
28 + <v-flex md5>
29 + <v-card width="240" elevation="0" class="ml-5">
30 + <v-img
31 + width="240"
32 + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png"
33 + ></v-img>
34 + </v-card>
35 + </v-flex>
36 + </v-row>
37 +
38 + <!-- 모바일 화면 설명 요약 -->
39 + <v-card elevation="0" class="mt-8" v-else>
40 + <div
41 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
42 + >This is Video auto tagging Service</div>
43 + <div
44 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
45 + >Designed for Youtube Videos</div>
46 + <div
47 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
48 + >It takes few minutes to analyze your Video!</div>
49 + <v-img
50 + style="margin: auto; margin-top: 20px"
51 + width="180"
52 + src="http://khuhub.khu.ac.kr/2020-1-capstone-design1/PKH_Project1/uploads/b70e4a173c2b7d5fa6ab73d48582dd6e/youtubelogoBlack.326653df.png"
53 + ></v-img>
54 + </v-card>
55 +
56 + <!-- Set Threshold -->
57 + <div
58 + class="mt-10"
59 + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;"
60 + >How To start this service</div>
61 + <div
62 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center; margin-bottom: 15px"
63 + >
64 + <div>Set up Threshold of</div>
65 + <div>Recommended Youtube link</div>
66 + </div>
67 + <v-row style="max-width: 300px; margin: auto">
68 + <v-slider v-model="threshold" :thumb-size="20" thumb-label="always" :min="2" :max="15"></v-slider>
69 + </v-row>
70 +
71 + <!-- Upload Video -->
72 + <div
73 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
74 + >Then, Just Upload your Video</div>
75 + <v-row justify="center" class="mx-0 mt-2">
76 + <v-card
77 + max-width="500"
78 + outlined
79 + height="120"
80 + class="pa-9"
81 + @dragover.prevent
82 + @dragenter.prevent
83 + @drop.prevent="onDrop"
84 + >
85 + <v-btn
86 + style="text-transform: none"
87 + @click="clickUploadButton"
88 + text
89 + large
90 + color="primary"
91 + >CLICK or DRAG & DROP</v-btn>
92 + <input ref="fileInput" style="display: none" type="file" @change="onFileChange" />
93 + </v-card>
94 + </v-row>
95 +
96 + <!-- 결과 화면 -->
97 + <div
98 + style="font-size: 24px; text-align: center; font-weight: 400; color: #5a5a5a;"
99 + class="mt-10"
100 + >The Results of Analyzed Video</div>
101 + <v-card outlined class="pa-2 mx-5 mt-6" elevation="0" min-height="67">
102 + <div
103 + style="margin-left: 5px; margin-top: -18px; background-color: #fff; width: 110px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500"
104 + >Generated Tags</div>
105 + <v-chip-group column>
106 + <v-chip color="secondary" v-for="(tag, index) in generatedTag" :key="index">{{ tag[0] }} : {{tag[1]}}</v-chip>
107 + </v-chip-group>
108 + </v-card>
109 + <v-card outlined class="pa-3 mx-5 mt-8" elevation="0" min-height="67">
110 + <div
111 + style="margin-left: 5px; margin-top: -22px; margin-bottom: 5px; background-color: #fff; width: 140px; text-align: center;font-size: 14px; color: #5a5a5a; font-weight: 500"
112 + >Related Youtube Link</div>
113 + <v-flex style="margin-bottom: 2px" v-for="(url) in YoutubeUrl" :key="url.id">
114 + <div>
115 + <a style="color: #343a40; font-size: 16px; font-weight: 500" :href="url">{{url}}</a>
116 + </div>
117 + </v-flex>
118 + </v-card>
119 + <div
120 + class="mt-3"
121 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
122 + >If the Video is analyzed successfully,</div>
123 + <div
124 + class="mb-5"
125 + style="font-size: 20px; font-weight: 300; color: #888; text-align: center"
126 + >Result Show up in each of Boxes!</div>
127 + </v-card>
128 +
129 + </v-flex>
130 + </v-layout>
131 + </v-sheet>
132 +</template>
133 +<script>
134 +export default {
135 + name: 'Home',
136 + data() {
137 + return {
138 + videoFile: '',
139 + YoutubeUrl: [],
140 + generatedTag: [],
141 + threshold: 5,
142 + successDialog: false,
143 + errorDialog: false,
144 + loadingProcess: false,
145 + };
146 + },
147 + created() {
148 + // this.YoutubeUrl = [];
149 + // this.generatedTag = [];
150 + },
151 + methods: {
152 + loadVideoInfo() {},
153 + uploadVideo(files) {
154 + this.loadingProcess = true;
155 + const formData = new FormData();
156 + formData.append('file', files[0]);
157 + formData.append('threshold', this.threshold);
158 + console.log(files[0]);
159 + this.$axios
160 + .post('/upload', formData, {
161 + headers: { 'Content-Type': 'multipart/form-data' },
162 + })
163 + .then((r) => {
164 + const tag = r.data.tag_result;
165 + const url = r.data.video_result;
166 + url.forEach((element) => {
167 + this.YoutubeUrl.push(element.video_url);
168 + });
169 + this.generatedTag = tag;
170 + this.loadingProcess = false;
171 + this.successDialog = true;
172 + console.log(tag, url);
173 + })
174 + .catch((e) => {
175 + this.errorDialog = true;
176 + console.log(e.message);
177 + });
178 + },
179 + onDrop(event) {
180 + this.uploadVideo(event.dataTransfer.files);
181 + },
182 + clickUploadButton() {
183 + this.$refs.fileInput.click();
184 + },
185 + onFileChange(event) {
186 + this.uploadVideo(event.target.files);
187 + },
188 + },
189 +};
190 +</script>
1 +module.exports = {
2 + transpileDependencies: [
3 + 'vuetify',
4 + ],
5 +};
This diff could not be displayed because it is too large.