Showing
2 changed files
with
23 additions
and
14 deletions
... | @@ -10,17 +10,15 @@ import src.video_util as videoutil | ... | @@ -10,17 +10,15 @@ import src.video_util as videoutil |
10 | import json | 10 | import json |
11 | import urllib3 | 11 | import urllib3 |
12 | 12 | ||
13 | -# Erase logs | ||
14 | logging.disable(logging.WARNING) | 13 | logging.disable(logging.WARNING) |
15 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | 14 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
16 | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | 15 | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
17 | 16 | ||
18 | -# Define model paths. | 17 | +# Old model |
19 | -MODEL_PATH = "./new_model/inference_model/segment_inference_model" | 18 | +MODEL_PATH = "./model/inference_model/segment_inference_model" |
20 | -# TAG_VECTOR_MODEL_PATH = "./new_model/twitter100_tag_vectors.gz" | 19 | +TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" |
21 | -TAG_VECTOR_MODEL_PATH = "glove-wiki-gigaword-100" | 20 | +VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" |
22 | -VIDEO_VECTOR_MODEL_PATH = "./new_model/gigaword100_video_vectors.model" | 21 | +VIDEO_TAGS_PATH = "./statics/kaggle_solution_40k.csv" |
23 | -VIDEO_TAGS_PATH = "./statics/new_kaggle_solution_40k.csv" | ||
24 | 22 | ||
25 | # Define static file paths. | 23 | # Define static file paths. |
26 | SEGMENT_LABEL_PATH = "./statics/segment_label_ids.csv" | 24 | SEGMENT_LABEL_PATH = "./statics/segment_label_ids.csv" |
... | @@ -28,11 +26,10 @@ VOCAB_PATH = "./statics/vocabulary.csv" | ... | @@ -28,11 +26,10 @@ VOCAB_PATH = "./statics/vocabulary.csv" |
28 | 26 | ||
29 | # Define parameters. | 27 | # Define parameters. |
30 | TAG_TOP_K = 5 | 28 | TAG_TOP_K = 5 |
31 | -VIDEO_TOP_K = 5 | 29 | +VIDEO_TOP_K = 10 |
32 | 30 | ||
33 | # Target featuremap. | 31 | # Target featuremap. |
34 | -FEATUREMAP_PATH = "./featuremaps/toy-3-features.pb" | 32 | +FEATUREMAP_PATH = "./featuremaps/concert-1-features.pb" |
35 | - | ||
36 | 33 | ||
37 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): | 34 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): |
38 | """Get segment-level inputs from frame-level features.""" | 35 | """Get segment-level inputs from frame-level features.""" |
... | @@ -235,5 +232,17 @@ def inference_pb(file_path, threshold): | ... | @@ -235,5 +232,17 @@ def inference_pb(file_path, threshold): |
235 | 232 | ||
236 | 233 | ||
237 | if __name__ == '__main__': | 234 | if __name__ == '__main__': |
238 | - result = inference_pb(FEATUREMAP_PATH, VIDEO_TOP_K) | 235 | + result = inference_pb(FEATUREMAP_PATH, 5) |
239 | - print(json.dumps(result, sort_keys=True, indent=2)) | 236 | + print("=============== Old Model ===============") |
237 | + print(result["tag_result"]) | ||
238 | + print(json.dumps(result["video_result"], sort_keys=True, indent=2)) | ||
239 | + | ||
240 | + # New model | ||
241 | + MODEL_PATH = "./new_model/inference_model/segment_inference_model" | ||
242 | + TAG_VECTOR_MODEL_PATH = "./new_model/tag_vectors.model" | ||
243 | + VIDEO_VECTOR_MODEL_PATH = "./new_model/video_vectors.model" | ||
244 | + VIDEO_TAGS_PATH = "./statics/new_kaggle_solution_40k.csv" | ||
245 | + result = inference_pb(FEATUREMAP_PATH, 5) | ||
246 | + print("=============== New Model ===============") | ||
247 | + print(result["tag_result"]) | ||
248 | + print(json.dumps(result["video_result"], sort_keys=True, indent=2)) | ... | ... |
... | @@ -4,9 +4,9 @@ import numpy as np | ... | @@ -4,9 +4,9 @@ import numpy as np |
4 | 4 | ||
5 | 5 | ||
6 | def recommend_videos(tags, tag_model_path, video_model_path, top_k): | 6 | def recommend_videos(tags, tag_model_path, video_model_path, top_k): |
7 | - # tag_vectors = Word2Vec.load(tag_model_path).wv | 7 | + tag_vectors = Word2Vec.load(tag_model_path).wv |
8 | # tag_vectors = KeyedVectors.load_word2vec_format(tag_model_path, binary=True) | 8 | # tag_vectors = KeyedVectors.load_word2vec_format(tag_model_path, binary=True) |
9 | - tag_vectors = api.load(tag_model_path) | 9 | + # tag_vectors = api.load(tag_model_path) |
10 | video_vectors = Word2Vec().wv.load(video_model_path) | 10 | video_vectors = Word2Vec().wv.load(video_model_path) |
11 | error_tags = [] | 11 | error_tags = [] |
12 | 12 | ... | ... |
-
Please register or login to post a comment