Showing
5 changed files
with
62 additions
and
1 deletions
esot3ria/export_model_mediapipe.py
0 → 100644
1 | +# Lint as: python3 | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from tensorflow import app | ||
5 | +from tensorflow import flags | ||
6 | + | ||
7 | +FLAGS = flags.FLAGS | ||
8 | + | ||
9 | + | ||
10 | +def main(unused_argv): | ||
11 | + # Get the input tensor names to be replaced. | ||
12 | + tf.reset_default_graph() | ||
13 | + meta_graph_location = FLAGS.checkpoint_file + ".meta" | ||
14 | + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
15 | + | ||
16 | + input_tensor_name = tf.get_collection("input_batch_raw")[0].name | ||
17 | + num_frames_tensor_name = tf.get_collection("num_frames")[0].name | ||
18 | + | ||
19 | + # Create output graph. | ||
20 | + saver = tf.train.Saver() | ||
21 | + tf.reset_default_graph() | ||
22 | + | ||
23 | + input_feature_placeholder = tf.placeholder( | ||
24 | + tf.float32, shape=(None, None, 1152)) | ||
25 | + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) | ||
26 | + | ||
27 | + saver = tf.train.import_meta_graph( | ||
28 | + meta_graph_location, | ||
29 | + input_map={ | ||
30 | + input_tensor_name: input_feature_placeholder, | ||
31 | + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) | ||
32 | + }, | ||
33 | + clear_devices=True) | ||
34 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
35 | + | ||
36 | + with tf.Session() as sess: | ||
37 | + print("restoring variables from " + FLAGS.checkpoint_file) | ||
38 | + saver.restore(sess, FLAGS.checkpoint_file) | ||
39 | + tf.saved_model.simple_save( | ||
40 | + sess, | ||
41 | + FLAGS.output_dir, | ||
42 | + inputs={'rgb_and_audio': input_feature_placeholder, | ||
43 | + 'num_frames': num_frames_placeholder}, | ||
44 | + outputs={'predictions': predictions_tensor}) | ||
45 | + | ||
46 | + # Try running inference. | ||
47 | + predictions = sess.run( | ||
48 | + [predictions_tensor], | ||
49 | + feed_dict={ | ||
50 | + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), | ||
51 | + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) | ||
52 | + print('Test inference:', predictions) | ||
53 | + | ||
54 | + print('Model saved to ', FLAGS.output_dir) | ||
55 | + | ||
56 | + | ||
57 | +if __name__ == '__main__': | ||
58 | + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') | ||
59 | + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') | ||
60 | + app.run(main) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -8,7 +8,8 @@ import esot3ria.video_recommender as recommender | ... | @@ -8,7 +8,8 @@ import esot3ria.video_recommender as recommender |
8 | import esot3ria.video_util as videoutil | 8 | import esot3ria.video_util as videoutil |
9 | 9 | ||
10 | # Define model paths. | 10 | # Define model paths. |
11 | -MODEL_PATH = "./model/inference_model/segment_inference_model" | 11 | +# MODEL_PATH = "./model/inference_model/segment_inference_model" |
12 | +MODEL_PATH = "./pretrained_model/variables/variables" | ||
12 | TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" | 13 | TAG_VECTOR_MODEL_PATH = "./model/tag_vectors.model" |
13 | VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" | 14 | VIDEO_VECTOR_MODEL_PATH = "./model/video_vectors.model" |
14 | 15 | ... | ... |
esot3ria/pretrained_model/saved_model.pb
0 → 100644
No preview for this file type
This file is too large to display.
No preview for this file type
-
Please register or login to post a comment