이현규

Import Youtube-8M pretrained model

1 +# Lint as: python3
2 +import numpy as np
3 +import tensorflow as tf
4 +from tensorflow import app
5 +from tensorflow import flags
6 +
7 +FLAGS = flags.FLAGS
8 +
9 +
10 +def main(unused_argv):
11 + # Get the input tensor names to be replaced.
12 + tf.reset_default_graph()
13 + meta_graph_location = FLAGS.checkpoint_file + ".meta"
14 + tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
15 +
16 + input_tensor_name = tf.get_collection("input_batch_raw")[0].name
17 + num_frames_tensor_name = tf.get_collection("num_frames")[0].name
18 +
19 + # Create output graph.
20 + saver = tf.train.Saver()
21 + tf.reset_default_graph()
22 +
23 + input_feature_placeholder = tf.placeholder(
24 + tf.float32, shape=(None, None, 1152))
25 + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))
26 +
27 + saver = tf.train.import_meta_graph(
28 + meta_graph_location,
29 + input_map={
30 + input_tensor_name: input_feature_placeholder,
31 + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
32 + },
33 + clear_devices=True)
34 + predictions_tensor = tf.get_collection("predictions")[0]
35 +
36 + with tf.Session() as sess:
37 + print("restoring variables from " + FLAGS.checkpoint_file)
38 + saver.restore(sess, FLAGS.checkpoint_file)
39 + tf.saved_model.simple_save(
40 + sess,
41 + FLAGS.output_dir,
42 + inputs={'rgb_and_audio': input_feature_placeholder,
43 + 'num_frames': num_frames_placeholder},
44 + outputs={'predictions': predictions_tensor})
45 +
46 + # Try running inference.
47 + predictions = sess.run(
48 + [predictions_tensor],
49 + feed_dict={
50 + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
51 + num_frames_placeholder: np.array([[7]], dtype=np.int32)})
52 + print('Test inference:', predictions)
53 +
54 + print('Model saved to ', FLAGS.output_dir)
55 +
56 +
57 +if __name__ == '__main__':
58 + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
59 + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
60 + app.run(main)
...\ No newline at end of file ...\ No newline at end of file
...@@ -8,7 +8,8 @@ import esot3ria.video_recommender as recommender ...@@ -8,7 +8,8 @@ import esot3ria.video_recommender as recommender
8 import esot3ria.video_util as videoutil 8 import esot3ria.video_util as videoutil
9 9
10 # Define model paths. 10 # Define model paths.
11 -MODEL_PATH = "./model/inference_model/segment_inference_model" 11 +# MODEL_PATH = "./model/inference_model/segment_inference_model"
12 +MODEL_PATH = "./pretrained_model/variables/variables"
12 TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" 13 TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model"
13 VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" 14 VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model"
14 15
......
No preview for this file type