Showing
1 changed file
with
24 additions
and
10 deletions
1 | +import logging | ||
2 | +import os | ||
1 | import numpy as np | 3 | import numpy as np |
2 | import tensorflow as tf | 4 | import tensorflow as tf |
3 | -from tensorflow import logging | ||
4 | from tensorflow import gfile | 5 | from tensorflow import gfile |
5 | import operator | 6 | import operator |
6 | import src.pb_util as pbutil | 7 | import src.pb_util as pbutil |
7 | import src.video_recommender as recommender | 8 | import src.video_recommender as recommender |
8 | import src.video_util as videoutil | 9 | import src.video_util as videoutil |
10 | +import json | ||
11 | +import urllib3 | ||
12 | + | ||
13 | +logging.disable(logging.WARNING) | ||
14 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
15 | +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | ||
9 | 16 | ||
10 | # Old model | 17 | # Old model |
11 | MODEL_PATH = "./model/inference_model/segment_inference_model" | 18 | MODEL_PATH = "./model/inference_model/segment_inference_model" |
... | @@ -13,12 +20,6 @@ TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" | ... | @@ -13,12 +20,6 @@ TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" |
13 | VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" | 20 | VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" |
14 | VIDEO_TAGS_PATH = "./statics/kaggle_solution_40k.csv" | 21 | VIDEO_TAGS_PATH = "./statics/kaggle_solution_40k.csv" |
15 | 22 | ||
16 | -# New model | ||
17 | -# MODEL_PATH = "./new_model/inference_model/segment_inference_model" | ||
18 | -# TAG_VECTOR_MODEL_PATH = "./new_model/tag_vectors.model" | ||
19 | -# VIDEO_VECTOR_MODEL_PATH = "./new_model/video_vectors.model" | ||
20 | -# VIDEO_TAGS_PATH = "./statics/new_kaggle_solution_40k.csv" | ||
21 | - | ||
22 | # Define static file paths. | 23 | # Define static file paths. |
23 | SEGMENT_LABEL_PATH = "./statics/segment_label_ids.csv" | 24 | SEGMENT_LABEL_PATH = "./statics/segment_label_ids.csv" |
24 | VOCAB_PATH = "./statics/vocabulary.csv" | 25 | VOCAB_PATH = "./statics/vocabulary.csv" |
... | @@ -27,7 +28,6 @@ VOCAB_PATH = "./statics/vocabulary.csv" | ... | @@ -27,7 +28,6 @@ VOCAB_PATH = "./statics/vocabulary.csv" |
27 | TAG_TOP_K = 5 | 28 | TAG_TOP_K = 5 |
28 | VIDEO_TOP_K = 10 | 29 | VIDEO_TOP_K = 10 |
29 | 30 | ||
30 | - | ||
31 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): | 31 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): |
32 | """Get segment-level inputs from frame-level features.""" | 32 | """Get segment-level inputs from frame-level features.""" |
33 | video_batch_size = batch_video_mtx.shape[0] | 33 | video_batch_size = batch_video_mtx.shape[0] |
... | @@ -95,7 +95,9 @@ def normalize_tag(tag): | ... | @@ -95,7 +95,9 @@ def normalize_tag(tag): |
95 | def inference_pb(file_path, threshold): | 95 | def inference_pb(file_path, threshold): |
96 | VIDEO_TOP_K = int(threshold) | 96 | VIDEO_TOP_K = int(threshold) |
97 | inference_result = {} | 97 | inference_result = {} |
98 | - with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | 98 | + |
99 | + graph = tf.Graph() | ||
100 | + with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
99 | 101 | ||
100 | # 0. Import SequenceExample type target from pb. | 102 | # 0. Import SequenceExample type target from pb. |
101 | target_video = pbutil.convert_pb(file_path) | 103 | target_video = pbutil.convert_pb(file_path) |
... | @@ -222,10 +224,22 @@ def inference_pb(file_path, threshold): | ... | @@ -222,10 +224,22 @@ def inference_pb(file_path, threshold): |
222 | # 6. Dispose instances. | 224 | # 6. Dispose instances. |
223 | sess.close() | 225 | sess.close() |
224 | 226 | ||
227 | + tf.reset_default_graph() | ||
225 | return inference_result | 228 | return inference_result |
226 | 229 | ||
227 | 230 | ||
228 | if __name__ == '__main__': | 231 | if __name__ == '__main__': |
229 | filepath = "./featuremaps/features.pb" | 232 | filepath = "./featuremaps/features.pb" |
230 | result = inference_pb(filepath, 5) | 233 | result = inference_pb(filepath, 5) |
231 | - print(result) | 234 | + print("=============== Old Model ===============") |
235 | + print(result["tag_result"]) | ||
236 | + print(json.dumps(result["video_result"], sort_keys=True, indent=2)) | ||
237 | + | ||
238 | + # New model | ||
239 | + MODEL_PATH = "./new_model/inference_model/segment_inference_model" | ||
240 | + TAG_VECTOR_MODEL_PATH = "./new_model/tag_vectors.model" | ||
241 | + VIDEO_VECTOR_MODEL_PATH = "./new_model/video_vectors.model" | ||
242 | + VIDEO_TAGS_PATH = "./statics/new_kaggle_solution_40k.csv" | ||
243 | + result = inference_pb(filepath, 5) | ||
244 | + print("=============== New Model ===============") | ||
245 | + print(json.dumps(result, sort_keys=True, indent=2)) | ... | ... |
-
Please register or login to post a comment