이현규

Make inference_pb.py, Complete making ML recommend module

...@@ -2,7 +2,22 @@ import numpy as np ...@@ -2,7 +2,22 @@ import numpy as np
2 import tensorflow as tf 2 import tensorflow as tf
3 from tensorflow import logging 3 from tensorflow import logging
4 from tensorflow import gfile 4 from tensorflow import gfile
5 -import esot3ria.pbutil as pbutil 5 +import operator
6 +import esot3ria.pb_util as pbutil
7 +import esot3ria.video_recommender as recommender
8 +import esot3ria.video_util as videoutil
9 +
10 +# Define file paths.
11 +MODEL_PATH = "/Users/esot3ria/PycharmProjects/yt8m/models/frame/" \
12 + "refined_model/inference_model/segment_inference_model"
13 +VOCAB_PATH = "../vocabulary.csv"
14 +VIDEO_TAGS_PATH = "./kaggle_solution_40k.csv"
15 +TAG_VECTOR_MODEL_PATH = "./tag_vectors.model"
16 +VIDEO_VECTOR_MODEL_PATH = "./video_vectors.model"
17 +
18 +# Define parameters.
19 +TAG_TOP_K = 5
20 +VIDEO_TOP_K = 10
6 21
7 22
8 def get_segments(batch_video_mtx, batch_num_frames, segment_size): 23 def get_segments(batch_video_mtx, batch_num_frames, segment_size):
...@@ -42,7 +57,7 @@ def get_segments(batch_video_mtx, batch_num_frames, segment_size): ...@@ -42,7 +57,7 @@ def get_segments(batch_video_mtx, batch_num_frames, segment_size):
42 } 57 }
43 58
44 59
45 -def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): 60 +def format_predictions(video_ids, predictions, top_k, whitelisted_cls_mask=None):
46 batch_size = len(video_ids) 61 batch_size = len(video_ids)
47 for video_index in range(batch_size): 62 for video_index in range(batch_size):
48 video_prediction = predictions[video_index] 63 video_prediction = predictions[video_index]
...@@ -53,15 +68,26 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): ...@@ -53,15 +68,26 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None):
53 line = [(class_index, predictions[video_index][class_index]) 68 line = [(class_index, predictions[video_index][class_index])
54 for class_index in top_indices] 69 for class_index in top_indices]
55 line = sorted(line, key=lambda p: -p[1]) 70 line = sorted(line, key=lambda p: -p[1])
56 - return (video_ids[video_index] + "," + 71 + yield (video_ids[video_index] + "," +
57 " ".join("%i %g" % (label, score) for (label, score) in line) + 72 " ".join("%i %g" % (label, score) for (label, score) in line) +
58 "\n").encode("utf8") 73 "\n").encode("utf8")
59 74
60 75
61 -def inference_pb(file_path, model_path): 76 +def normalize_tag(tag):
77 + if isinstance(tag, str):
78 + new_tag = tag.lower().replace('[^a-zA-Z]', ' ')
79 + if new_tag.find(" (") != -1:
80 + new_tag = new_tag[:new_tag.find(" (")]
81 + new_tag = new_tag.replace(" ", "-")
82 + return new_tag
83 + else:
84 + return tag
85 +
86 +
87 +def inference_pb(file_path):
88 + inference_result = {}
62 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 89 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
63 90
64 - # 200527 Esot3riA
65 # 0. Import SequenceExample type target from pb. 91 # 0. Import SequenceExample type target from pb.
66 target_video = pbutil.convert_pb(file_path) 92 target_video = pbutil.convert_pb(file_path)
67 93
...@@ -80,18 +106,17 @@ def inference_pb(file_path, model_path): ...@@ -80,18 +106,17 @@ def inference_pb(file_path, model_path):
80 video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) 106 video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0)
81 video_batch_val = np.array([video_batch_val]) 107 video_batch_val = np.array([video_batch_val])
82 num_frames_batch_val = np.array([n_frames]) 108 num_frames_batch_val = np.array([n_frames])
83 - # 200527 Esot3riA End
84 109
85 - # Restore checkpoint and meta-graph file 110 + # Restore checkpoint and meta-graph file.
86 - if not gfile.Exists(model_path + ".meta"): 111 + if not gfile.Exists(MODEL_PATH + ".meta"):
87 - raise IOError("Cannot find %s. Did you run eval.py?" % model_path) 112 + raise IOError("Cannot find %s. Did you run eval.py?" % MODEL_PATH)
88 - meta_graph_location = model_path + ".meta" 113 + meta_graph_location = MODEL_PATH + ".meta"
89 logging.info("loading meta-graph: " + meta_graph_location) 114 logging.info("loading meta-graph: " + meta_graph_location)
90 115
91 with tf.device("/cpu:0"): 116 with tf.device("/cpu:0"):
92 saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) 117 saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
93 - logging.info("restoring variables from " + model_path) 118 + logging.info("restoring variables from " + MODEL_PATH)
94 - saver.restore(sess, model_path) 119 + saver.restore(sess, MODEL_PATH)
95 input_tensor = tf.get_collection("input_batch_raw")[0] 120 input_tensor = tf.get_collection("input_batch_raw")[0]
96 num_frames_tensor = tf.get_collection("num_frames")[0] 121 num_frames_tensor = tf.get_collection("num_frames")[0]
97 predictions_tensor = tf.get_collection("predictions")[0] 122 predictions_tensor = tf.get_collection("predictions")[0]
...@@ -109,8 +134,6 @@ def inference_pb(file_path, model_path): ...@@ -109,8 +134,6 @@ def inference_pb(file_path, model_path):
109 sess.run( 134 sess.run(
110 set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) 135 set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
111 136
112 - coord = tf.train.Coordinator()
113 - threads = tf.train.start_queue_runners(sess=sess, coord=coord)
114 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), 137 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
115 dtype=np.float32) 138 dtype=np.float32)
116 segment_label_ids_file = '../segment_label_ids.csv' 139 segment_label_ids_file = '../segment_label_ids.csv'
...@@ -123,7 +146,6 @@ def inference_pb(file_path, model_path): ...@@ -123,7 +146,6 @@ def inference_pb(file_path, model_path):
123 # Simply skip the non-integer line. 146 # Simply skip the non-integer line.
124 continue 147 continue
125 148
126 - # 200527 Esot3riA
127 # 2. Make segment features. 149 # 2. Make segment features.
128 results = get_segments(video_batch_val, num_frames_batch_val, 5) 150 results = get_segments(video_batch_val, num_frames_batch_val, 5)
129 video_segment_ids = results["video_segment_ids"] 151 video_segment_ids = results["video_segment_ids"]
...@@ -143,22 +165,59 @@ def inference_pb(file_path, model_path): ...@@ -143,22 +165,59 @@ def inference_pb(file_path, model_path):
143 input_tensor: video_batch_val, 165 input_tensor: video_batch_val,
144 num_frames_tensor: num_frames_batch_val 166 num_frames_tensor: num_frames_batch_val
145 }) 167 })
146 - logging.info(predictions_val)
147 - logging.info("profit :D")
148 -
149 - # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask)
150 - # 결과값
151 - # 1. Tag 목록들(5개) + 각 Tag의 유사도(dict format)
152 - # 2. 연관된 영상들의 링크 => 모델에서 연관영상 찾아서, 유저 인풋(Threshold) 받아서 (20%~80%) 연관영상 + 연관도 5개 출력.
153 -
154 168
169 + # 3. Make vocabularies.
170 + voca_dict = {}
171 + vocabs = open(VOCAB_PATH, 'r')
172 + while True:
173 + line = vocabs.readline()
174 + if not line: break
175 + vocab_dict_item = line.split(",")
176 + if vocab_dict_item[0] != "Index":
177 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
178 + vocabs.close()
179 +
180 + # 4. Make combined scores.
181 + combined_scores = {}
182 + for line in format_predictions(video_id_batch_val, predictions_val, TAG_TOP_K, whitelisted_cls_mask):
183 + segment_id, preds = line.decode("utf8").split(",")
184 + preds = preds.split(" ")
185 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
186 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
187 + for i in range(len(pred_cls_ids)):
188 + if pred_cls_ids[i] in combined_scores:
189 + combined_scores[pred_cls_ids[i]] += pred_cls_scores[i]
190 + else:
191 + combined_scores[pred_cls_ids[i]] = pred_cls_scores[i]
192 +
193 + combined_scores = sorted(combined_scores.items(), key=operator.itemgetter(1), reverse=True)
194 + demoninator = float(combined_scores[0][1] + combined_scores[1][1]
195 + + combined_scores[2][1] + combined_scores[3][1] + combined_scores[4][1])
196 +
197 + tag_result = []
198 + for itemIndex in range(TAG_TOP_K):
199 + segment_tag = str(voca_dict[str(combined_scores[itemIndex][0])])
200 + normalized_tag = normalize_tag(segment_tag)
201 + tag_percentage = format(combined_scores[itemIndex][1] / demoninator, ".3f")
202 + tag_result.append((normalized_tag, tag_percentage))
203 +
204 + # 5. Create recommend videos info, Combine results.
205 + recommend_video_ids = recommender.recommend_videos(tag_result, TAG_VECTOR_MODEL_PATH,
206 + VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K)
207 + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids]
208 +
209 + inference_result = {
210 + "tag_result": tag_result,
211 + "video_result": video_result
212 + }
213 +
214 + # 6. Dispose instances.
215 + sess.close()
216 +
217 + return inference_result
155 218
156 219
157 if __name__ == '__main__': 220 if __name__ == '__main__':
158 - logging.set_verbosity(tf.logging.INFO) 221 + filepath = "features.pb"
159 - 222 + result = inference_pb(filepath)
160 - file_path = '/tmp/mediapipe/features.pb' 223 + print(result)
161 - model_path = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \
162 - '/sample_model/inference_model/segment_inference_model'
163 -
164 - inference_pb(file_path, model_path)
......
1 +from gensim.models import Word2Vec
2 +import numpy as np
3 +
4 +def recommend_videos(tags, tag_model_path, video_model_path, top_k):
5 + tag_vectors = Word2Vec.load(tag_model_path).wv
6 + video_vectors = Word2Vec().wv.load(video_model_path)
7 + error_tags = []
8 +
9 + video_vector = np.zeros(100)
10 + for (tag, weight) in tags:
11 + if tag in tag_vectors.vocab:
12 + video_vector = video_vector + (tag_vectors[tag] * float(weight))
13 + else:
14 + # Pass if tag is unknown
15 + if tag not in error_tags:
16 + error_tags.append(tag)
17 +
18 + similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)]
19 + return similar_ids
1 import requests 1 import requests
2 +import pandas as pd
2 3
3 base_URL = 'https://data.yt8m.org/2/j/i/' 4 base_URL = 'https://data.yt8m.org/2/j/i/'
4 youtube_url = 'https://www.youtube.com/watch?v=' 5 youtube_url = 'https://www.youtube.com/watch?v='
5 6
7 +
6 def getURL(vid_id): 8 def getURL(vid_id):
7 URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js' 9 URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js'
8 response = requests.get(URL, verify = False) 10 response = requests.get(URL, verify = False)
9 if response.status_code == 200: 11 if response.status_code == 200:
10 return youtube_url + response.text[10:-3] 12 return youtube_url + response.text[10:-3]
11 -
12 13
13 -# example usage : getURL('nXSc');
...\ No newline at end of file ...\ No newline at end of file
14 +
15 +def getVideoInfo(vid_id, video_tags_path, top_k):
16 + video_url = getURL(vid_id)
17 +
18 + entire_video_tags = pd.read_csv(video_tags_path)
19 + video_tags_info = entire_video_tags.loc[entire_video_tags["vid_id"] == vid_id]
20 + video_tags = []
21 + for i in range(1, top_k + 1):
22 + video_tag_tuple = video_tags_info["segment" + str(i)].values[0] # ex: "mobile-phone:0.361"
23 + video_tags.append(video_tag_tuple.split(":")[0])
24 +
25 + return {
26 + "video_url": video_url,
27 + "video_tags": video_tags
28 + }
......