Showing
1 changed file
with
366 additions
and
340 deletions
... | @@ -34,366 +34,392 @@ from tensorflow import logging | ... | @@ -34,366 +34,392 @@ from tensorflow import logging |
34 | from tensorflow.python.lib.io import file_io | 34 | from tensorflow.python.lib.io import file_io |
35 | import utils | 35 | import utils |
36 | from collections import Counter | 36 | from collections import Counter |
37 | +import operator | ||
37 | 38 | ||
38 | FLAGS = flags.FLAGS | 39 | FLAGS = flags.FLAGS |
39 | 40 | ||
40 | if __name__ == "__main__": | 41 | if __name__ == "__main__": |
41 | - # Input | 42 | + # Input |
42 | - flags.DEFINE_string( | 43 | + flags.DEFINE_string( |
43 | - "train_dir", "", "The directory to load the model files from. We assume " | 44 | + "train_dir", "", "The directory to load the model files from. We assume " |
44 | - "that you have already run eval.py onto this, such that " | 45 | + "that you have already run eval.py onto this, such that " |
45 | - "inference_model.* files already exist.") | 46 | + "inference_model.* files already exist.") |
46 | - flags.DEFINE_string( | 47 | + flags.DEFINE_string( |
47 | - "input_data_pattern", "", | 48 | + "input_data_pattern", "", |
48 | - "File glob defining the evaluation dataset in tensorflow.SequenceExample " | 49 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " |
49 | - "format. The SequenceExamples are expected to have an 'rgb' byte array " | 50 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " |
50 | - "sequence feature as well as a 'labels' int64 context feature.") | 51 | + "sequence feature as well as a 'labels' int64 context feature.") |
51 | - flags.DEFINE_string( | 52 | + flags.DEFINE_string( |
52 | - "input_model_tgz", "", | 53 | + "input_model_tgz", "", |
53 | - "If given, must be path to a .tgz file that was written " | 54 | + "If given, must be path to a .tgz file that was written " |
54 | - "by this binary using flag --output_model_tgz. In this " | 55 | + "by this binary using flag --output_model_tgz. In this " |
55 | - "case, the .tgz file will be untarred to " | 56 | + "case, the .tgz file will be untarred to " |
56 | - "--untar_model_dir and the model will be used for " | 57 | + "--untar_model_dir and the model will be used for " |
57 | - "inference.") | 58 | + "inference.") |
58 | - flags.DEFINE_string( | 59 | + flags.DEFINE_string( |
59 | - "untar_model_dir", "/tmp/yt8m-model", | 60 | + "untar_model_dir", "/tmp/yt8m-model", |
60 | - "If --input_model_tgz is given, then this directory will " | 61 | + "If --input_model_tgz is given, then this directory will " |
61 | - "be created and the contents of the .tgz file will be " | 62 | + "be created and the contents of the .tgz file will be " |
62 | - "untarred here.") | 63 | + "untarred here.") |
63 | - flags.DEFINE_bool( | 64 | + flags.DEFINE_bool( |
64 | - "segment_labels", False, | 65 | + "segment_labels", False, |
65 | - "If set, then --input_data_pattern must be frame-level features (but with" | 66 | + "If set, then --input_data_pattern must be frame-level features (but with" |
66 | - " segment_labels). Otherwise, --input_data_pattern must be aggregated " | 67 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " |
67 | - "video-level features. The model must also be set appropriately (i.e. to " | 68 | + "video-level features. The model must also be set appropriately (i.e. to " |
68 | - "read 3D batches VS 4D batches.") | 69 | + "read 3D batches VS 4D batches.") |
69 | - flags.DEFINE_integer("segment_max_pred", 100000, | 70 | + flags.DEFINE_integer("segment_max_pred", 100000, |
70 | - "Limit total number of segment outputs per entity.") | 71 | + "Limit total number of segment outputs per entity.") |
71 | - flags.DEFINE_string( | 72 | + flags.DEFINE_string( |
72 | - "segment_label_ids_file", | 73 | + "segment_label_ids_file", |
73 | - "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | 74 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", |
74 | - "The file that contains the segment label ids.") | 75 | + "The file that contains the segment label ids.") |
75 | - | 76 | + |
76 | - # Output | 77 | + # Output |
77 | - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | 78 | + flags.DEFINE_string("output_file", "", "The file to save the predictions to.") |
78 | - flags.DEFINE_string( | 79 | + flags.DEFINE_string( |
79 | - "output_model_tgz", "", | 80 | + "output_model_tgz", "", |
80 | - "If given, should be a filename with a .tgz extension, " | 81 | + "If given, should be a filename with a .tgz extension, " |
81 | - "the model graph and checkpoint will be bundled in this " | 82 | + "the model graph and checkpoint will be bundled in this " |
82 | - "gzip tar. This file can be uploaded to Kaggle for the " | 83 | + "gzip tar. This file can be uploaded to Kaggle for the " |
83 | - "top 10 participants.") | 84 | + "top 10 participants.") |
84 | - flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.") | 85 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
85 | - | 86 | + |
86 | - # Other flags. | 87 | + # Other flags. |
87 | - flags.DEFINE_integer("batch_size", 512, | 88 | + flags.DEFINE_integer("batch_size", 512, |
88 | - "How many examples to process per batch.") | 89 | + "How many examples to process per batch.") |
89 | - flags.DEFINE_integer("num_readers", 1, | 90 | + flags.DEFINE_integer("num_readers", 1, |
90 | - "How many threads to use for reading input files.") | 91 | + "How many threads to use for reading input files.") |
91 | 92 | ||
92 | 93 | ||
93 | def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | 94 | def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): |
94 | - """Create an information line the submission file.""" | 95 | + """Create an information line the submission file.""" |
95 | - batch_size = len(video_ids) | 96 | + batch_size = len(video_ids) |
96 | - for video_index in range(batch_size): | 97 | + for video_index in range(batch_size): |
97 | - video_prediction = predictions[video_index] | 98 | + video_prediction = predictions[video_index] |
98 | - if whitelisted_cls_mask is not None: | 99 | + if whitelisted_cls_mask is not None: |
99 | - # Whitelist classes. | 100 | + # Whitelist classes. |
100 | - video_prediction *= whitelisted_cls_mask | 101 | + video_prediction *= whitelisted_cls_mask |
101 | - top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | 102 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] |
102 | - line = [(class_index, predictions[video_index][class_index]) | 103 | + line = [(class_index, predictions[video_index][class_index]) |
103 | - for class_index in top_indices] | 104 | + for class_index in top_indices] |
104 | - line = sorted(line, key=lambda p: -p[1]) | 105 | + line = sorted(line, key=lambda p: -p[1]) |
105 | - yield (video_ids[video_index] + "," + | 106 | + yield (video_ids[video_index] + "," + |
106 | - " ".join("%i %g" % (label, score) for (label, score) in line) + | 107 | + " ".join("%i %g" % (label, score) for (label, score) in line) + |
107 | - "\n").encode("utf8") | 108 | + "\n").encode("utf8") |
108 | 109 | ||
109 | 110 | ||
110 | def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | 111 | def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): |
111 | - """Creates the section of the graph which reads the input data. | 112 | + """Creates the section of the graph which reads the input data. |
112 | - | 113 | + |
113 | - Args: | 114 | + Args: |
114 | - reader: A class which parses the input data. | 115 | + reader: A class which parses the input data. |
115 | - data_pattern: A 'glob' style path to the data files. | 116 | + data_pattern: A 'glob' style path to the data files. |
116 | - batch_size: How many examples to process at a time. | 117 | + batch_size: How many examples to process at a time. |
117 | - num_readers: How many I/O threads to use. | 118 | + num_readers: How many I/O threads to use. |
118 | - | 119 | + |
119 | - Returns: | 120 | + Returns: |
120 | - A tuple containing the features tensor, labels tensor, and optionally a | 121 | + A tuple containing the features tensor, labels tensor, and optionally a |
121 | - tensor containing the number of frames per video. The exact dimensions | 122 | + tensor containing the number of frames per video. The exact dimensions |
122 | - depend on the reader being used. | 123 | + depend on the reader being used. |
123 | - | 124 | + |
124 | - Raises: | 125 | + Raises: |
125 | - IOError: If no files matching the given pattern were found. | 126 | + IOError: If no files matching the given pattern were found. |
126 | - """ | 127 | + """ |
127 | - with tf.name_scope("input"): | 128 | + with tf.name_scope("input"): |
128 | - files = gfile.Glob(data_pattern) | 129 | + files = gfile.Glob(data_pattern) |
129 | - if not files: | 130 | + if not files: |
130 | - raise IOError("Unable to find input files. data_pattern='" + | 131 | + raise IOError("Unable to find input files. data_pattern='" + |
131 | - data_pattern + "'") | 132 | + data_pattern + "'") |
132 | - logging.info("number of input files: " + str(len(files))) | 133 | + logging.info("number of input files: " + str(len(files))) |
133 | - filename_queue = tf.train.string_input_producer(files, | 134 | + filename_queue = tf.train.string_input_producer(files, |
134 | - num_epochs=1, | 135 | + num_epochs=1, |
135 | - shuffle=False) | 136 | + shuffle=False) |
136 | - examples_and_labels = [ | 137 | + examples_and_labels = [ |
137 | - reader.prepare_reader(filename_queue) for _ in range(num_readers) | 138 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) |
138 | - ] | 139 | + ] |
139 | - | 140 | + |
140 | - input_data_dict = (tf.train.batch_join(examples_and_labels, | 141 | + input_data_dict = (tf.train.batch_join(examples_and_labels, |
141 | - batch_size=batch_size, | 142 | + batch_size=batch_size, |
142 | - allow_smaller_final_batch=True, | 143 | + allow_smaller_final_batch=True, |
143 | - enqueue_many=True)) | 144 | + enqueue_many=True)) |
144 | - video_id_batch = input_data_dict["video_ids"] | 145 | + video_id_batch = input_data_dict["video_ids"] |
145 | - video_batch = input_data_dict["video_matrix"] | 146 | + video_batch = input_data_dict["video_matrix"] |
146 | - num_frames_batch = input_data_dict["num_frames"] | 147 | + num_frames_batch = input_data_dict["num_frames"] |
147 | - return video_id_batch, video_batch, num_frames_batch | 148 | + return video_id_batch, video_batch, num_frames_batch |
148 | 149 | ||
149 | 150 | ||
150 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): | 151 | def get_segments(batch_video_mtx, batch_num_frames, segment_size): |
151 | - """Get segment-level inputs from frame-level features.""" | 152 | + """Get segment-level inputs from frame-level features.""" |
152 | - video_batch_size = batch_video_mtx.shape[0] | 153 | + video_batch_size = batch_video_mtx.shape[0] |
153 | - max_frame = batch_video_mtx.shape[1] | 154 | + max_frame = batch_video_mtx.shape[1] |
154 | - feature_dim = batch_video_mtx.shape[-1] | 155 | + feature_dim = batch_video_mtx.shape[-1] |
155 | - padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | 156 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size |
156 | - padded_segment_sizes *= segment_size | 157 | + padded_segment_sizes *= segment_size |
157 | - segment_mask = ( | 158 | + segment_mask = ( |
158 | - 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | 159 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) |
159 | - | 160 | + |
160 | - # Segment bags. | 161 | + # Segment bags. |
161 | - frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | 162 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) |
162 | - segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | 163 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( |
163 | - (-1, segment_size, feature_dim)) | 164 | + (-1, segment_size, feature_dim)) |
164 | - | 165 | + |
165 | - # Segment num frames. | 166 | + # Segment num frames. |
166 | - segment_start_times = np.arange(0, max_frame, segment_size) | 167 | + segment_start_times = np.arange(0, max_frame, segment_size) |
167 | - num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | 168 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times |
168 | - num_segment_bags = num_segments.reshape((-1)) | 169 | + num_segment_bags = num_segments.reshape((-1)) |
169 | - valid_segment_mask = num_segment_bags > 0 | 170 | + valid_segment_mask = num_segment_bags > 0 |
170 | - segment_num_frames = num_segment_bags[valid_segment_mask] | 171 | + segment_num_frames = num_segment_bags[valid_segment_mask] |
171 | - segment_num_frames[segment_num_frames > segment_size] = segment_size | 172 | + segment_num_frames[segment_num_frames > segment_size] = segment_size |
172 | - | 173 | + |
173 | - max_segment_num = (max_frame + segment_size - 1) // segment_size | 174 | + max_segment_num = (max_frame + segment_size - 1) // segment_size |
174 | - video_idxs = np.tile( | 175 | + video_idxs = np.tile( |
175 | - np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | 176 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) |
176 | - segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | 177 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) |
177 | - idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | 178 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) |
178 | - video_segment_ids = idx_bags[valid_segment_mask] | 179 | + video_segment_ids = idx_bags[valid_segment_mask] |
179 | - | 180 | + |
180 | - return { | 181 | + return { |
181 | - "video_batch": segment_frames, | 182 | + "video_batch": segment_frames, |
182 | - "num_frames_batch": segment_num_frames, | 183 | + "num_frames_batch": segment_num_frames, |
183 | - "video_segment_ids": video_segment_ids | 184 | + "video_segment_ids": video_segment_ids |
184 | - } | 185 | + } |
185 | 186 | ||
186 | 187 | ||
187 | def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | 188 | def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
188 | top_k): | 189 | top_k): |
189 | - """Inference function.""" | 190 | + """Inference function.""" |
190 | - with tf.Session(config=tf.ConfigProto( | 191 | + with tf.Session(config=tf.ConfigProto( |
191 | - allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | 192 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, |
192 | - "w+") as out_file: | 193 | + "w+") as out_file: |
193 | - video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | 194 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( |
194 | - reader, data_pattern, batch_size) | 195 | + reader, data_pattern, batch_size) |
195 | - inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | 196 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" |
196 | - checkpoint_file = os.path.join(train_dir, "inference_model", | 197 | + checkpoint_file = os.path.join(train_dir, "inference_model", |
197 | - inference_model_name) | 198 | + inference_model_name) |
198 | - if not gfile.Exists(checkpoint_file + ".meta"): | 199 | + if not gfile.Exists(checkpoint_file + ".meta"): |
199 | - raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | 200 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) |
200 | - meta_graph_location = checkpoint_file + ".meta" | 201 | + meta_graph_location = checkpoint_file + ".meta" |
201 | - logging.info("loading meta-graph: " + meta_graph_location) | 202 | + logging.info("loading meta-graph: " + meta_graph_location) |
202 | - | 203 | + |
203 | - if FLAGS.output_model_tgz: | 204 | + if FLAGS.output_model_tgz: |
204 | - with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | 205 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: |
205 | - for model_file in glob.glob(checkpoint_file + ".*"): | 206 | + for model_file in glob.glob(checkpoint_file + ".*"): |
206 | - tar.add(model_file, arcname=os.path.basename(model_file)) | 207 | + tar.add(model_file, arcname=os.path.basename(model_file)) |
207 | - tar.add(os.path.join(train_dir, "model_flags.json"), | 208 | + tar.add(os.path.join(train_dir, "model_flags.json"), |
208 | - arcname="model_flags.json") | 209 | + arcname="model_flags.json") |
209 | - print("Tarred model onto " + FLAGS.output_model_tgz) | 210 | + print("Tarred model onto " + FLAGS.output_model_tgz) |
210 | - with tf.device("/cpu:0"): | 211 | + with tf.device("/cpu:0"): |
211 | - saver = tf.train.import_meta_graph(meta_graph_location, | 212 | + saver = tf.train.import_meta_graph(meta_graph_location, |
212 | - clear_devices=True) | 213 | + clear_devices=True) |
213 | - logging.info("restoring variables from " + checkpoint_file) | 214 | + logging.info("restoring variables from " + checkpoint_file) |
214 | - saver.restore(sess, checkpoint_file) | 215 | + saver.restore(sess, checkpoint_file) |
215 | - input_tensor = tf.get_collection("input_batch_raw")[0] | 216 | + input_tensor = tf.get_collection("input_batch_raw")[0] |
216 | - num_frames_tensor = tf.get_collection("num_frames")[0] | 217 | + num_frames_tensor = tf.get_collection("num_frames")[0] |
217 | - predictions_tensor = tf.get_collection("predictions")[0] | 218 | + predictions_tensor = tf.get_collection("predictions")[0] |
218 | - | 219 | + |
219 | - # Workaround for num_epochs issue. | 220 | + # Workaround for num_epochs issue. |
220 | - def set_up_init_ops(variables): | 221 | + def set_up_init_ops(variables): |
221 | - init_op_list = [] | 222 | + init_op_list = [] |
222 | - for variable in list(variables): | 223 | + for variable in list(variables): |
223 | - if "train_input" in variable.name: | 224 | + if "train_input" in variable.name: |
224 | - init_op_list.append(tf.assign(variable, 1)) | 225 | + init_op_list.append(tf.assign(variable, 1)) |
225 | - variables.remove(variable) | 226 | + variables.remove(variable) |
226 | - init_op_list.append(tf.variables_initializer(variables)) | 227 | + init_op_list.append(tf.variables_initializer(variables)) |
227 | - return init_op_list | 228 | + return init_op_list |
228 | - | 229 | + |
229 | - sess.run( | 230 | + sess.run( |
230 | - set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | 231 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) |
231 | - | 232 | + |
232 | - coord = tf.train.Coordinator() | 233 | + coord = tf.train.Coordinator() |
233 | - threads = tf.train.start_queue_runners(sess=sess, coord=coord) | 234 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
234 | - num_examples_processed = 0 | 235 | + num_examples_processed = 0 |
235 | - start_time = time.time() | 236 | + start_time = time.time() |
236 | - whitelisted_cls_mask = None | 237 | + whitelisted_cls_mask = None |
238 | + if FLAGS.segment_labels: | ||
239 | + final_out_file = out_file | ||
240 | + out_file = tempfile.NamedTemporaryFile() | ||
241 | + logging.info( | ||
242 | + "Segment temp prediction output will be written to temp file: %s", | ||
243 | + out_file.name) | ||
244 | + if FLAGS.segment_label_ids_file: | ||
245 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
246 | + dtype=np.float32) | ||
247 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
248 | + if segment_label_ids_file.startswith("http"): | ||
249 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
250 | + segment_label_ids_file) | ||
251 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
252 | + segment_label_ids_file) | ||
253 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
254 | + for line in fobj: | ||
255 | + try: | ||
256 | + cls_id = int(line) | ||
257 | + whitelisted_cls_mask[cls_id] = 1. | ||
258 | + except ValueError: | ||
259 | + # Simply skip the non-integer line. | ||
260 | + continue | ||
261 | + | ||
262 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
263 | + | ||
264 | + #========================================= | ||
265 | + #open vocab csv file and store to dictionary | ||
266 | + #========================================= | ||
267 | + voca_dict = {} | ||
268 | + vocabs = open("./vocabulary.csv", 'r') | ||
269 | + while True: | ||
270 | + line = vocabs.readline() | ||
271 | + if not line: break | ||
272 | + vocab_dict_item = line.split(",") | ||
273 | + if vocab_dict_item[0] != "Index": | ||
274 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
275 | + vocabs.close() | ||
276 | + try: | ||
277 | + while not coord.should_stop(): | ||
278 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
279 | + [video_id_batch, video_batch, num_frames_batch]) | ||
237 | if FLAGS.segment_labels: | 280 | if FLAGS.segment_labels: |
238 | - final_out_file = out_file | 281 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) |
239 | - out_file = tempfile.NamedTemporaryFile() | 282 | + video_segment_ids = results["video_segment_ids"] |
240 | - logging.info( | 283 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] |
241 | - "Segment temp prediction output will be written to temp file: %s", | 284 | + video_id_batch_val = np.array([ |
242 | - out_file.name) | 285 | + "%s:%d" % (x.decode("utf8"), y) |
243 | - if FLAGS.segment_label_ids_file: | 286 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) |
244 | - whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | 287 | + ]) |
245 | - dtype=np.float32) | 288 | + video_batch_val = results["video_batch"] |
246 | - segment_label_ids_file = FLAGS.segment_label_ids_file | 289 | + num_frames_batch_val = results["num_frames_batch"] |
247 | - if segment_label_ids_file.startswith("http"): | 290 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: |
248 | - logging.info("Retrieving segment ID whitelist files from %s...", | 291 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " |
249 | - segment_label_ids_file) | 292 | + "with correct segment_labels settings.") |
250 | - segment_label_ids_file, _ = urllib.request.urlretrieve( | 293 | + |
251 | - segment_label_ids_file) | 294 | + predictions_val, = sess.run([predictions_tensor], |
252 | - with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | 295 | + feed_dict={ |
253 | - for line in fobj: | 296 | + input_tensor: video_batch_val, |
254 | - try: | 297 | + num_frames_tensor: num_frames_batch_val |
255 | - cls_id = int(line) | 298 | + }) |
256 | - whitelisted_cls_mask[cls_id] = 1. | 299 | + now = time.time() |
257 | - except ValueError: | 300 | + num_examples_processed += len(video_batch_val) |
258 | - # Simply skip the non-integer line. | 301 | + elapsed_time = now - start_time |
259 | - continue | 302 | + logging.info("num examples processed: " + str(num_examples_processed) + |
260 | - | 303 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + |
261 | - out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | 304 | + " examples/sec: %.2f" % |
262 | - | 305 | + (num_examples_processed / elapsed_time)) |
263 | - try: | 306 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, |
264 | - while not coord.should_stop(): | 307 | + whitelisted_cls_mask): |
265 | - video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | 308 | + out_file.write(line) |
266 | - [video_id_batch, video_batch, num_frames_batch]) | 309 | + out_file.flush() |
267 | - if FLAGS.segment_labels: | 310 | + |
268 | - results = get_segments(video_batch_val, num_frames_batch_val, 5) | 311 | + except tf.errors.OutOfRangeError: |
269 | - video_segment_ids = results["video_segment_ids"] | 312 | + logging.info("Done with inference. The output file was written to " + |
270 | - video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | 313 | + out_file.name) |
271 | - video_id_batch_val = np.array([ | 314 | + finally: |
272 | - "%s:%d" % (x.decode("utf8"), y) | 315 | + coord.request_stop() |
273 | - for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | 316 | + |
274 | - ]) | 317 | + if FLAGS.segment_labels: |
275 | - video_batch_val = results["video_batch"] | 318 | + # Re-read the file and do heap sort. |
276 | - num_frames_batch_val = results["num_frames_batch"] | 319 | + # Create multiple heaps. |
277 | - if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | 320 | + logging.info("Post-processing segment predictions...") |
278 | - raise ValueError("max_frames mismatch. Please re-run the eval.py " | 321 | + segment_id_list = [] |
279 | - "with correct segment_labels settings.") | 322 | + segment_classes = [] |
280 | - | 323 | + cls_result_arr = [] |
281 | - predictions_val, = sess.run([predictions_tensor], | 324 | + cls_score_dict = {} |
282 | - feed_dict={ | 325 | + out_file.seek(0, 0) |
283 | - input_tensor: video_batch_val, | 326 | + old_seg_name = '0000' |
284 | - num_frames_tensor: num_frames_batch_val | 327 | + for line in out_file: |
285 | - }) | 328 | + segment_id, preds = line.decode("utf8").split(",") |
286 | - now = time.time() | 329 | + if segment_id == "VideoId": |
287 | - num_examples_processed += len(video_batch_val) | 330 | + # Skip the headline. |
288 | - elapsed_time = now - start_time | 331 | + continue |
289 | - logging.info("num examples processed: " + str(num_examples_processed) + | 332 | + |
290 | - " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | 333 | + preds = preds.split(" ") |
291 | - " examples/sec: %.2f" % | 334 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] |
292 | - (num_examples_processed / elapsed_time)) | 335 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] |
293 | - for line in format_lines(video_id_batch_val, predictions_val, top_k, | 336 | + #======================================= |
294 | - whitelisted_cls_mask): | 337 | + segment_id = str(segment_id.split(":")[0]) |
295 | - out_file.write(line) | 338 | + if segment_id not in segment_id_list: |
296 | - out_file.flush() | 339 | + segment_id_list.append(str(segment_id)) |
297 | - | 340 | + segment_classes.append("") |
298 | - except tf.errors.OutOfRangeError: | 341 | + |
299 | - logging.info("Done with inference. The output file was written to " + | 342 | + index = segment_id_list.index(segment_id) |
300 | - out_file.name) | 343 | + |
301 | - finally: | 344 | + if old_seg_name != segment_id: |
302 | - coord.request_stop() | 345 | + cls_score_dict[segment_id] = {} |
303 | - | 346 | + old_seg_name = segment_id |
304 | - if FLAGS.segment_labels: | 347 | + |
305 | - # Re-read the file and do heap sort. | 348 | + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids: |
306 | - # Create multiple heaps. | 349 | + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment |
307 | - logging.info("Post-processing segment predictions...") | 350 | + if pred_cls_ids[classes] in cls_score_dict[segment_id]: |
308 | - segment_id_list = [] | 351 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes] |
309 | - segment_classes = [] | 352 | + else: |
310 | - cls_result_arr = [] | 353 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes] |
311 | - out_file.seek(0, 0) | 354 | + |
312 | - for line in out_file: | 355 | + for segs,item in zip(segment_id_list,segment_classes): |
313 | - segment_id, preds = line.decode("utf8").split(",") | 356 | + print('====== R E C O R D ======') |
314 | - if segment_id == "VideoId": | 357 | + cls_arr = item.split(" ")[:-1] |
315 | - # Skip the headline. | 358 | + |
316 | - continue | 359 | + cls_arr = list(map(int,cls_arr)) |
317 | - | 360 | + cls_arr = sorted(cls_arr) #클래스별로 정렬 |
318 | - preds = preds.split(" ") | 361 | + |
319 | - pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | 362 | + result_string = "" |
320 | - # ======================================= | 363 | + |
321 | - segment_id = str(segment_id.split(":")[0]) | 364 | + temp = cls_score_dict[segs] |
322 | - if segment_id not in segment_id_list: | 365 | + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 |
323 | - segment_id_list.append(str(segment_id)) | 366 | + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) |
324 | - segment_classes.append("") | 367 | + #for item in temp: |
325 | - | 368 | + for itemIndex in range(0, top_k): |
326 | - index = segment_id_list.index(segment_id) | 369 | + result_string = result_string + str(voca_dict[str(temp[itemIndex][0])]) + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," |
327 | - for classes in pred_cls_ids: | 370 | + |
328 | - segment_classes[index] = str(segment_classes[index]) + str( | 371 | + cls_result_arr.append(result_string[:-1]) |
329 | - classes) + " " # append classes from new segment | 372 | + logging.info(segs + " : " + result_string[:-1]) |
330 | - | 373 | + #======================================= |
331 | - for segs, item in zip(segment_id_list, segment_classes): | 374 | + final_out_file.write("vid_id,seg_classes\n") |
332 | - print('====== R E C O R D ======') | 375 | + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): |
333 | - cls_arr = item.split(" ")[:-1] | 376 | + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) |
334 | - | 377 | + final_out_file.close() |
335 | - cls_arr = list(map(int, cls_arr)) | 378 | + |
336 | - cls_arr = sorted(cls_arr) | 379 | + out_file.close() |
337 | - | 380 | + |
338 | - result_string = "" | 381 | + coord.join(threads) |
339 | - | 382 | + sess.close() |
340 | - temp = Counter(cls_arr) | ||
341 | - for item in temp: | ||
342 | - result_string = result_string + str(item) + ":" + str(temp[item]) + "," | ||
343 | - | ||
344 | - cls_result_arr.append(result_string[:-1]) | ||
345 | - logging.info(segs + " : " + result_string[:-1]) | ||
346 | - # ======================================= | ||
347 | - final_out_file.write("vid_id,seg_classes\n") | ||
348 | - for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | ||
349 | - final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies))) | ||
350 | - final_out_file.close() | ||
351 | - | ||
352 | - out_file.close() | ||
353 | - | ||
354 | - coord.join(threads) | ||
355 | - sess.close() | ||
356 | - | ||
357 | 383 | ||
358 | def main(unused_argv): | 384 | def main(unused_argv): |
359 | - logging.set_verbosity(tf.logging.INFO) | 385 | + logging.set_verbosity(tf.logging.INFO) |
360 | - if FLAGS.input_model_tgz: | 386 | + if FLAGS.input_model_tgz: |
361 | - if FLAGS.train_dir: | 387 | + if FLAGS.train_dir: |
362 | - raise ValueError("You cannot supply --train_dir if supplying " | 388 | + raise ValueError("You cannot supply --train_dir if supplying " |
363 | - "--input_model_tgz") | 389 | + "--input_model_tgz") |
364 | - # Untar. | 390 | + # Untar. |
365 | - if not os.path.exists(FLAGS.untar_model_dir): | 391 | + if not os.path.exists(FLAGS.untar_model_dir): |
366 | - os.makedirs(FLAGS.untar_model_dir) | 392 | + os.makedirs(FLAGS.untar_model_dir) |
367 | - tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | 393 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) |
368 | - FLAGS.train_dir = FLAGS.untar_model_dir | 394 | + FLAGS.train_dir = FLAGS.untar_model_dir |
369 | - | 395 | + |
370 | - flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | 396 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") |
371 | - if not file_io.file_exists(flags_dict_file): | 397 | + if not file_io.file_exists(flags_dict_file): |
372 | - raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | 398 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) |
373 | - flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | 399 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) |
374 | - | 400 | + |
375 | - # convert feature_names and feature_sizes to lists of values | 401 | + # convert feature_names and feature_sizes to lists of values |
376 | - feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | 402 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( |
377 | - flags_dict["feature_names"], flags_dict["feature_sizes"]) | 403 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) |
378 | - | 404 | + |
379 | - if flags_dict["frame_features"]: | 405 | + if flags_dict["frame_features"]: |
380 | - reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | 406 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, |
381 | - feature_sizes=feature_sizes) | 407 | + feature_sizes=feature_sizes) |
382 | - else: | 408 | + else: |
383 | - reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | 409 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, |
384 | - feature_sizes=feature_sizes) | 410 | + feature_sizes=feature_sizes) |
385 | - | 411 | + |
386 | - if not FLAGS.output_file: | 412 | + if not FLAGS.output_file: |
387 | - raise ValueError("'output_file' was not specified. " | 413 | + raise ValueError("'output_file' was not specified. " |
388 | - "Unable to continue with inference.") | 414 | + "Unable to continue with inference.") |
389 | - | 415 | + |
390 | - if not FLAGS.input_data_pattern: | 416 | + if not FLAGS.input_data_pattern: |
391 | - raise ValueError("'input_data_pattern' was not specified. " | 417 | + raise ValueError("'input_data_pattern' was not specified. " |
392 | - "Unable to continue with inference.") | 418 | + "Unable to continue with inference.") |
393 | - | 419 | + |
394 | - inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | 420 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, |
395 | - FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | 421 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) |
396 | 422 | ||
397 | 423 | ||
398 | if __name__ == "__main__": | 424 | if __name__ == "__main__": |
399 | - app.run() | 425 | + app.run() | ... | ... |
-
Please register or login to post a comment