Yoonjunhyeon

model과 백엔드 연결

...@@ -17,6 +17,9 @@ import subprocess ...@@ -17,6 +17,9 @@ import subprocess
17 import shlex 17 import shlex
18 import json 18 import json
19 # Create your views here. 19 # Create your views here.
20 +import sys
21 +sys.path.insert(0, "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria")
22 +import inference_pb
20 23
21 def with_ffprobe(filename): 24 def with_ffprobe(filename):
22 25
...@@ -75,11 +78,14 @@ class VideoFileUploadView(APIView): ...@@ -75,11 +78,14 @@ class VideoFileUploadView(APIView):
75 # 동영상 길이 출력 78 # 동영상 길이 출력
76 runTime = with_ffprobe('/'+file_serializer.data['file_save_name']) 79 runTime = with_ffprobe('/'+file_serializer.data['file_save_name'])
77 print(runTime) 80 print(runTime)
78 - 81 + print(threshold)
79 process = subprocess.Popen(['./runMediaPipe.sh %s %s' %(file_serializer.data['file_save_name'],runTime,)], shell = True) 82 process = subprocess.Popen(['./runMediaPipe.sh %s %s' %(file_serializer.data['file_save_name'],runTime,)], shell = True)
80 process.wait() 83 process.wait()
81 84
82 - return Response(True, status=status.HTTP_201_CREATED) 85 +
86 + result = inference_pb.inference_pb('/tmp/mediapipe/features.pb', threshold)
87 +
88 + return Response(result, status=status.HTTP_201_CREATED)
83 else: 89 else:
84 return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) 90 return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
85 91
......
...@@ -3,16 +3,17 @@ import tensorflow as tf ...@@ -3,16 +3,17 @@ import tensorflow as tf
3 from tensorflow import logging 3 from tensorflow import logging
4 from tensorflow import gfile 4 from tensorflow import gfile
5 import operator 5 import operator
6 -import esot3ria.pb_util as pbutil 6 +import pb_util as pbutil
7 -import esot3ria.video_recommender as recommender 7 +import video_recommender as recommender
8 -import esot3ria.video_util as videoutil 8 +import video_util as videoutil
9 9
10 # Define file paths. 10 # Define file paths.
11 -MODEL_PATH = "/Users/esot3ria/PycharmProjects/yt8m/models/frame/refined_model/inference_model/segment_inference_model" 11 +MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model"
12 -VOCAB_PATH = "../vocabulary.csv" 12 +VOCAB_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/vocabulary.csv"
13 -VIDEO_TAGS_PATH = "./kaggle_solution_40k.csv" 13 +VIDEO_TAGS_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/kaggle_solution_40k.csv"
14 -TAG_VECTOR_MODEL_PATH = "./tag_vectors.model" 14 +TAG_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/tag_vectors.model"
15 -VIDEO_VECTOR_MODEL_PATH = "./video_vectors.model" 15 +VIDEO_VECTOR_MODEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/esot3ria/video_vectors.model"
16 +SEGMENT_LABEL_PATH = "/home/jun/documents/univ/PKH_Project1/web/backend/yt8m/segment_label_ids.csv"
16 17
17 # Define parameters. 18 # Define parameters.
18 TAG_TOP_K = 5 19 TAG_TOP_K = 5
...@@ -83,7 +84,8 @@ def normalize_tag(tag): ...@@ -83,7 +84,8 @@ def normalize_tag(tag):
83 return tag 84 return tag
84 85
85 86
86 -def inference_pb(file_path): 87 +def inference_pb(file_path, threshold):
88 + VIDEO_TOP_K = int(threshold)
87 inference_result = {} 89 inference_result = {}
88 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 90 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
89 91
...@@ -135,8 +137,7 @@ def inference_pb(file_path): ...@@ -135,8 +137,7 @@ def inference_pb(file_path):
135 137
136 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), 138 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
137 dtype=np.float32) 139 dtype=np.float32)
138 - segment_label_ids_file = '../segment_label_ids.csv' 140 + with tf.io.gfile.GFile(SEGMENT_LABEL_PATH) as fobj:
139 - with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
140 for line in fobj: 141 for line in fobj:
141 try: 142 try:
142 cls_id = int(line) 143 cls_id = int(line)
...@@ -217,6 +218,6 @@ def inference_pb(file_path): ...@@ -217,6 +218,6 @@ def inference_pb(file_path):
217 218
218 219
219 if __name__ == '__main__': 220 if __name__ == '__main__':
220 - filepath = "../featuremaps/features(yorusika).pb" 221 + filepath = "/tmp/mediapipe/features.pb"
221 result = inference_pb(filepath) 222 result = inference_pb(filepath)
222 print(result) 223 print(result)
......
1 +model_checkpoint_path: "/root/volume/youtube-8m/saved_model/inference_model/segment_inference_model"
2 +all_model_checkpoint_paths: "/root/volume/youtube-8m/saved_model/inference_model/segment_inference_model"
1 +{"model": "FrameLevelLogisticModel", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"}
...\ No newline at end of file ...\ No newline at end of file