Showing
5 changed files
with
125 additions
and
32 deletions
No preview for this file type
... | @@ -2,7 +2,22 @@ import numpy as np | ... | @@ -2,7 +2,22 @@ import numpy as np |
2 | import tensorflow as tf | 2 | import tensorflow as tf |
3 | from tensorflow import logging | 3 | from tensorflow import logging |
4 | from tensorflow import gfile | 4 | from tensorflow import gfile |
5 | -import esot3ria.pbutil as pbutil | 5 | +import operator |
6 | +import esot3ria.pb_util as pbutil | ||
7 | +import esot3ria.video_recommender as recommender | ||
8 | +import esot3ria.video_util as videoutil | ||
9 | + | ||
10 | +# Define file paths. | ||
11 | +MODEL_PATH = "/Users/esot3ria/PycharmProjects/yt8m/models/frame/" \ | ||
12 | + "refined_model/inference_model/segment_inference_model" | ||
13 | +VOCAB_PATH = "../vocabulary.csv" | ||
14 | +VIDEO_TAGS_PATH = "./kaggle_solution_40k.csv" | ||
15 | +TAG_VECTOR_MODEL_PATH = "./tag_vectors.model" | ||
16 | +VIDEO_VECTOR_MODEL_PATH = "./video_vectors.model" | ||
17 | + | ||
18 | +# Define parameters. | ||
19 | +TAG_TOP_K = 5 | ||
20 | +VIDEO_TOP_K = 10 | ||
6 | 21 | ||
7 | 22 | ||
8 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): | 23 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): |
... | @@ -42,7 +57,7 @@ def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ... | @@ -42,7 +57,7 @@ def get_segments(batch_video_mtx, batch_num_frames, segment_size): |
42 | } | 57 | } |
43 | 58 | ||
44 | 59 | ||
45 | -def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): | 60 | +def format_predictions(video_ids, predictions, top_k, whitelisted_cls_mask=None): |
46 | batch_size = len(video_ids) | 61 | batch_size = len(video_ids) |
47 | for video_index in range(batch_size): | 62 | for video_index in range(batch_size): |
48 | video_prediction = predictions[video_index] | 63 | video_prediction = predictions[video_index] |
... | @@ -53,15 +68,26 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ... | @@ -53,15 +68,26 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): |
53 | line = [(class_index, predictions[video_index][class_index]) | 68 | line = [(class_index, predictions[video_index][class_index]) |
54 | for class_index in top_indices] | 69 | for class_index in top_indices] |
55 | line = sorted(line, key=lambda p: -p[1]) | 70 | line = sorted(line, key=lambda p: -p[1]) |
56 | - return (video_ids[video_index] + "," + | 71 | + yield (video_ids[video_index] + "," + |
57 | " ".join("%i %g" % (label, score) for (label, score) in line) + | 72 | " ".join("%i %g" % (label, score) for (label, score) in line) + |
58 | "\n").encode("utf8") | 73 | "\n").encode("utf8") |
59 | 74 | ||
60 | 75 | ||
61 | -def inference_pb(file_path, model_path): | 76 | +def normalize_tag(tag): |
77 | + if isinstance(tag, str): | ||
78 | + new_tag = tag.lower().replace('[^a-zA-Z]', ' ') | ||
79 | + if new_tag.find(" (") != -1: | ||
80 | + new_tag = new_tag[:new_tag.find(" (")] | ||
81 | + new_tag = new_tag.replace(" ", "-") | ||
82 | + return new_tag | ||
83 | + else: | ||
84 | + return tag | ||
85 | + | ||
86 | + | ||
87 | +def inference_pb(file_path): | ||
88 | + inference_result = {} | ||
62 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | 89 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: |
63 | 90 | ||
64 | - # 200527 Esot3riA | ||
65 | # 0. Import SequenceExample type target from pb. | 91 | # 0. Import SequenceExample type target from pb. |
66 | target_video = pbutil.convert_pb(file_path) | 92 | target_video = pbutil.convert_pb(file_path) |
67 | 93 | ||
... | @@ -80,18 +106,17 @@ def inference_pb(file_path, model_path): | ... | @@ -80,18 +106,17 @@ def inference_pb(file_path, model_path): |
80 | video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) | 106 | video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) |
81 | video_batch_val = np.array([video_batch_val]) | 107 | video_batch_val = np.array([video_batch_val]) |
82 | num_frames_batch_val = np.array([n_frames]) | 108 | num_frames_batch_val = np.array([n_frames]) |
83 | - # 200527 Esot3riA End | ||
84 | 109 | ||
85 | - # Restore checkpoint and meta-graph file | 110 | + # Restore checkpoint and meta-graph file. |
86 | - if not gfile.Exists(model_path + ".meta"): | 111 | + if not gfile.Exists(MODEL_PATH + ".meta"): |
87 | - raise IOError("Cannot find %s. Did you run eval.py?" % model_path) | 112 | + raise IOError("Cannot find %s. Did you run eval.py?" % MODEL_PATH) |
88 | - meta_graph_location = model_path + ".meta" | 113 | + meta_graph_location = MODEL_PATH + ".meta" |
89 | logging.info("loading meta-graph: " + meta_graph_location) | 114 | logging.info("loading meta-graph: " + meta_graph_location) |
90 | 115 | ||
91 | with tf.device("/cpu:0"): | 116 | with tf.device("/cpu:0"): |
92 | saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | 117 | saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) |
93 | - logging.info("restoring variables from " + model_path) | 118 | + logging.info("restoring variables from " + MODEL_PATH) |
94 | - saver.restore(sess, model_path) | 119 | + saver.restore(sess, MODEL_PATH) |
95 | input_tensor = tf.get_collection("input_batch_raw")[0] | 120 | input_tensor = tf.get_collection("input_batch_raw")[0] |
96 | num_frames_tensor = tf.get_collection("num_frames")[0] | 121 | num_frames_tensor = tf.get_collection("num_frames")[0] |
97 | predictions_tensor = tf.get_collection("predictions")[0] | 122 | predictions_tensor = tf.get_collection("predictions")[0] |
... | @@ -109,8 +134,6 @@ def inference_pb(file_path, model_path): | ... | @@ -109,8 +134,6 @@ def inference_pb(file_path, model_path): |
109 | sess.run( | 134 | sess.run( |
110 | set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | 135 | set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) |
111 | 136 | ||
112 | - coord = tf.train.Coordinator() | ||
113 | - threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
114 | whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | 137 | whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), |
115 | dtype=np.float32) | 138 | dtype=np.float32) |
116 | segment_label_ids_file = '../segment_label_ids.csv' | 139 | segment_label_ids_file = '../segment_label_ids.csv' |
... | @@ -123,7 +146,6 @@ def inference_pb(file_path, model_path): | ... | @@ -123,7 +146,6 @@ def inference_pb(file_path, model_path): |
123 | # Simply skip the non-integer line. | 146 | # Simply skip the non-integer line. |
124 | continue | 147 | continue |
125 | 148 | ||
126 | - # 200527 Esot3riA | ||
127 | # 2. Make segment features. | 149 | # 2. Make segment features. |
128 | results = get_segments(video_batch_val, num_frames_batch_val, 5) | 150 | results = get_segments(video_batch_val, num_frames_batch_val, 5) |
129 | video_segment_ids = results["video_segment_ids"] | 151 | video_segment_ids = results["video_segment_ids"] |
... | @@ -143,22 +165,59 @@ def inference_pb(file_path, model_path): | ... | @@ -143,22 +165,59 @@ def inference_pb(file_path, model_path): |
143 | input_tensor: video_batch_val, | 165 | input_tensor: video_batch_val, |
144 | num_frames_tensor: num_frames_batch_val | 166 | num_frames_tensor: num_frames_batch_val |
145 | }) | 167 | }) |
146 | - logging.info(predictions_val) | ||
147 | - logging.info("profit :D") | ||
148 | - | ||
149 | - # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask) | ||
150 | - # 결과값 | ||
151 | - # 1. Tag 목록들(5개) + 각 Tag의 유사도(dict format) | ||
152 | - # 2. 연관된 영상들의 링크 => 모델에서 연관영상 찾아서, 유저 인풋(Threshold) 받아서 (20%~80%) 연관영상 + 연관도 5개 출력. | ||
153 | - | ||
154 | 168 | ||
169 | + # 3. Make vocabularies. | ||
170 | + voca_dict = {} | ||
171 | + vocabs = open(VOCAB_PATH, 'r') | ||
172 | + while True: | ||
173 | + line = vocabs.readline() | ||
174 | + if not line: break | ||
175 | + vocab_dict_item = line.split(",") | ||
176 | + if vocab_dict_item[0] != "Index": | ||
177 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
178 | + vocabs.close() | ||
179 | + | ||
180 | + # 4. Make combined scores. | ||
181 | + combined_scores = {} | ||
182 | + for line in format_predictions(video_id_batch_val, predictions_val, TAG_TOP_K, whitelisted_cls_mask): | ||
183 | + segment_id, preds = line.decode("utf8").split(",") | ||
184 | + preds = preds.split(" ") | ||
185 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
186 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
187 | + for i in range(len(pred_cls_ids)): | ||
188 | + if pred_cls_ids[i] in combined_scores: | ||
189 | + combined_scores[pred_cls_ids[i]] += pred_cls_scores[i] | ||
190 | + else: | ||
191 | + combined_scores[pred_cls_ids[i]] = pred_cls_scores[i] | ||
192 | + | ||
193 | + combined_scores = sorted(combined_scores.items(), key=operator.itemgetter(1), reverse=True) | ||
194 | + demoninator = float(combined_scores[0][1] + combined_scores[1][1] | ||
195 | + + combined_scores[2][1] + combined_scores[3][1] + combined_scores[4][1]) | ||
196 | + | ||
197 | + tag_result = [] | ||
198 | + for itemIndex in range(TAG_TOP_K): | ||
199 | + segment_tag = str(voca_dict[str(combined_scores[itemIndex][0])]) | ||
200 | + normalized_tag = normalize_tag(segment_tag) | ||
201 | + tag_percentage = format(combined_scores[itemIndex][1] / demoninator, ".3f") | ||
202 | + tag_result.append((normalized_tag, tag_percentage)) | ||
203 | + | ||
204 | + # 5. Create recommend videos info, Combine results. | ||
205 | + recommend_video_ids = recommender.recommend_videos(tag_result, TAG_VECTOR_MODEL_PATH, | ||
206 | + VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K) | ||
207 | + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids] | ||
208 | + | ||
209 | + inference_result = { | ||
210 | + "tag_result": tag_result, | ||
211 | + "video_result": video_result | ||
212 | + } | ||
213 | + | ||
214 | + # 6. Dispose instances. | ||
215 | + sess.close() | ||
216 | + | ||
217 | + return inference_result | ||
155 | 218 | ||
156 | 219 | ||
157 | if __name__ == '__main__': | 220 | if __name__ == '__main__': |
158 | - logging.set_verbosity(tf.logging.INFO) | 221 | + filepath = "features.pb" |
159 | - | 222 | + result = inference_pb(filepath) |
160 | - file_path = '/tmp/mediapipe/features.pb' | 223 | + print(result) |
161 | - model_path = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \ | ||
162 | - '/sample_model/inference_model/segment_inference_model' | ||
163 | - | ||
164 | - inference_pb(file_path, model_path) | ... | ... |
1 | +from gensim.models import Word2Vec | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +def recommend_videos(tags, tag_model_path, video_model_path, top_k): | ||
5 | + tag_vectors = Word2Vec.load(tag_model_path).wv | ||
6 | + video_vectors = Word2Vec().wv.load(video_model_path) | ||
7 | + error_tags = [] | ||
8 | + | ||
9 | + video_vector = np.zeros(100) | ||
10 | + for (tag, weight) in tags: | ||
11 | + if tag in tag_vectors.vocab: | ||
12 | + video_vector = video_vector + (tag_vectors[tag] * float(weight)) | ||
13 | + else: | ||
14 | + # Pass if tag is unknown | ||
15 | + if tag not in error_tags: | ||
16 | + error_tags.append(tag) | ||
17 | + | ||
18 | + similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)] | ||
19 | + return similar_ids |
1 | import requests | 1 | import requests |
2 | +import pandas as pd | ||
2 | 3 | ||
3 | base_URL = 'https://data.yt8m.org/2/j/i/' | 4 | base_URL = 'https://data.yt8m.org/2/j/i/' |
4 | youtube_url = 'https://www.youtube.com/watch?v=' | 5 | youtube_url = 'https://www.youtube.com/watch?v=' |
5 | 6 | ||
7 | + | ||
6 | def getURL(vid_id): | 8 | def getURL(vid_id): |
7 | URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js' | 9 | URL = base_URL + vid_id[:-2] + '/' + vid_id + '.js' |
8 | response = requests.get(URL, verify = False) | 10 | response = requests.get(URL, verify = False) |
9 | if response.status_code == 200: | 11 | if response.status_code == 200: |
10 | return youtube_url + response.text[10:-3] | 12 | return youtube_url + response.text[10:-3] |
11 | - | ||
12 | 13 | ||
13 | -# example usage : getURL('nXSc'); | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
14 | + | ||
15 | +def getVideoInfo(vid_id, video_tags_path, top_k): | ||
16 | + video_url = getURL(vid_id) | ||
17 | + | ||
18 | + entire_video_tags = pd.read_csv(video_tags_path) | ||
19 | + video_tags_info = entire_video_tags.loc[entire_video_tags["vid_id"] == vid_id] | ||
20 | + video_tags = [] | ||
21 | + for i in range(1, top_k + 1): | ||
22 | + video_tag_tuple = video_tags_info["segment" + str(i)].values[0] # ex: "mobile-phone:0.361" | ||
23 | + video_tags.append(video_tag_tuple.split(":")[0]) | ||
24 | + | ||
25 | + return { | ||
26 | + "video_url": video_url, | ||
27 | + "video_tags": video_tags | ||
28 | + } | ... | ... |
-
Please register or login to post a comment