이현규

Create model generator

...@@ -58,12 +58,12 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): ...@@ -58,12 +58,12 @@ def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None):
58 "\n").encode("utf8") 58 "\n").encode("utf8")
59 59
60 60
61 -def inference_pb(filename): 61 +def inference_pb(file_path, model_path):
62 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 62 with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
63 63
64 # 200527 Esot3riA 64 # 200527 Esot3riA
65 # 0. Import SequenceExample type target from pb. 65 # 0. Import SequenceExample type target from pb.
66 - target_video = pbutil.convert_pb(filename) 66 + target_video = pbutil.convert_pb(file_path)
67 67
68 # 1. Load video features from pb. 68 # 1. Load video features from pb.
69 video_id_batch_val = np.array([b'video']) 69 video_id_batch_val = np.array([b'video'])
...@@ -83,18 +83,15 @@ def inference_pb(filename): ...@@ -83,18 +83,15 @@ def inference_pb(filename):
83 # 200527 Esot3riA End 83 # 200527 Esot3riA End
84 84
85 # Restore checkpoint and meta-graph file 85 # Restore checkpoint and meta-graph file
86 - checkpoint_file = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \ 86 + if not gfile.Exists(model_path + ".meta"):
87 - '/sample_model/inference_model/segment_inference_model' 87 + raise IOError("Cannot find %s. Did you run eval.py?" % model_path)
88 - if not gfile.Exists(checkpoint_file + ".meta"): 88 + meta_graph_location = model_path + ".meta"
89 - raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
90 - meta_graph_location = checkpoint_file + ".meta"
91 logging.info("loading meta-graph: " + meta_graph_location) 89 logging.info("loading meta-graph: " + meta_graph_location)
92 90
93 with tf.device("/cpu:0"): 91 with tf.device("/cpu:0"):
94 - saver = tf.train.import_meta_graph(meta_graph_location, 92 + saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
95 - clear_devices=True) 93 + logging.info("restoring variables from " + model_path)
96 - logging.info("restoring variables from " + checkpoint_file) 94 + saver.restore(sess, model_path)
97 - saver.restore(sess, checkpoint_file)
98 input_tensor = tf.get_collection("input_batch_raw")[0] 95 input_tensor = tf.get_collection("input_batch_raw")[0]
99 num_frames_tensor = tf.get_collection("num_frames")[0] 96 num_frames_tensor = tf.get_collection("num_frames")[0]
100 predictions_tensor = tf.get_collection("predictions")[0] 97 predictions_tensor = tf.get_collection("predictions")[0]
...@@ -150,10 +147,18 @@ def inference_pb(filename): ...@@ -150,10 +147,18 @@ def inference_pb(filename):
150 logging.info("profit :D") 147 logging.info("profit :D")
151 148
152 # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask) 149 # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask)
150 + # 결과값
151 + # 1. Tag 목록들(5개) + 각 Tag의 유사도(dict format)
152 + # 2. 연관된 영상들의 링크 => 모델에서 연관영상 찾아서, 유저 인풋(Threshold) 받아서 (20%~80%) 연관영상 + 연관도 5개 출력.
153 +
154 +
153 155
154 156
155 if __name__ == '__main__': 157 if __name__ == '__main__':
156 logging.set_verbosity(tf.logging.INFO) 158 logging.set_verbosity(tf.logging.INFO)
157 159
158 - filename = 'features.pb' 160 + file_path = 'features.pb'
159 - inference_pb(filename) 161 + model_path = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \
162 + '/sample_model/inference_model/segment_inference_model'
163 +
164 + inference_pb(file_path, model_path)
......
This file is too large to display.
This diff could not be displayed because it is too large.
1 +import nltk
2 +import gensim
3 +import pandas as pd
4 +
5 +# Load files.
6 +nltk.download('stopwords')
7 +vocab = pd.read_csv('vocabulary.csv')
8 +
9 +# Lower corpus and Remove () from name.
10 +vocab['WikiDescription'] = vocab['WikiDescription'].str.lower().str.replace('[^a-zA-Z]', ' ')
11 +vocab['Name'] = vocab['Name'].str.lower()
12 +for i in range(vocab['Name'].__len__()):
13 + name = vocab['Name'][i]
14 + if isinstance(name, str) and name.find(" (") != -1:
15 + vocab['Name'][i] = name[:name.find(" (")]
16 +
17 +# Combine separated names.(mobile phone -> mobile_phone)
18 +for name in vocab['Name']:
19 + if isinstance(name, str) and name.find(" ") != -1:
20 + combined_name = name.replace(" ", "-")
21 + for i in range(vocab['WikiDescription'].__len__()):
22 + if isinstance(vocab['WikiDescription'][i], str):
23 + vocab['WikiDescription'][i] = vocab['WikiDescription'][i].replace(name, combined_name)
24 +
25 +
26 +# Remove stopwords from corpus.
27 +stop_re = '\\b'+'\\b|\\b'.join(nltk.corpus.stopwords.words('english'))+'\\b'
28 +vocab['WikiDescription'] = vocab['WikiDescription'].str.replace(stop_re, '')
29 +vocab['WikiDescription'] = vocab['WikiDescription'].str.split()
30 +
31 +# Tokenize corpus.
32 +tokenlist = [x for x in vocab['WikiDescription'] if str(x) != 'nan']
33 +phrases = gensim.models.phrases.Phrases(tokenlist)
34 +phraser = gensim.models.phrases.Phraser(phrases)
35 +vocab_phrased = phraser[tokenlist]
36 +
37 +# Vectorize tags.
38 +w2v = gensim.models.word2vec.Word2Vec(sentences=tokenlist, workers=2, min_count=1)
39 +w2v.save('tags_word2vec.model')
40 +
41 +word_vectors = w2v.wv
42 +vocabs = word_vectors.vocab.keys()
43 +word_vectors_list = [word_vectors[v] for v in vocabs]
...\ No newline at end of file ...\ No newline at end of file