Showing
1 changed file
with
448 additions
and
0 deletions
esot3ria/inference_kaggle_solution.py
0 → 100644
1 | +# Copyright 2017 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for generating predictions over a set of videos.""" | ||
15 | + | ||
16 | +from __future__ import print_function | ||
17 | + | ||
18 | +import glob | ||
19 | +import heapq | ||
20 | +import json | ||
21 | +import os | ||
22 | +import tarfile | ||
23 | +import tempfile | ||
24 | +import time | ||
25 | +import numpy as np | ||
26 | + | ||
27 | +import readers | ||
28 | +from six.moves import urllib | ||
29 | +import tensorflow as tf | ||
30 | +from tensorflow import app | ||
31 | +from tensorflow import flags | ||
32 | +from tensorflow import gfile | ||
33 | +from tensorflow import logging | ||
34 | +from tensorflow.python.lib.io import file_io | ||
35 | +import utils | ||
36 | +from collections import Counter | ||
37 | +import operator | ||
38 | + | ||
39 | +FLAGS = flags.FLAGS | ||
40 | + | ||
41 | +if __name__ == "__main__": | ||
42 | + # Input | ||
43 | + flags.DEFINE_string( | ||
44 | + "train_dir", "./new_model", "The directory to load the model files from. We assume " | ||
45 | + "that you have already run eval.py onto this, such that " | ||
46 | + "inference_model.* files already exist.") | ||
47 | + flags.DEFINE_string( | ||
48 | + "input_data_pattern", "/Volumes/HDD/develop/yt8m/3/frame/test/test*.tfrecord", | ||
49 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
50 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
51 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
52 | + flags.DEFINE_string( | ||
53 | + "input_model_tgz", "", | ||
54 | + "If given, must be path to a .tgz file that was written " | ||
55 | + "by this binary using flag --output_model_tgz. In this " | ||
56 | + "case, the .tgz file will be untarred to " | ||
57 | + "--untar_model_dir and the model will be used for " | ||
58 | + "inference.") | ||
59 | + flags.DEFINE_string( | ||
60 | + "untar_model_dir", "/tmp/yt8m-model", | ||
61 | + "If --input_model_tgz is given, then this directory will " | ||
62 | + "be created and the contents of the .tgz file will be " | ||
63 | + "untarred here.") | ||
64 | + flags.DEFINE_bool( | ||
65 | + "segment_labels", True, | ||
66 | + "If set, then --input_data_pattern must be frame-level features (but with" | ||
67 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " | ||
68 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
69 | + "read 3D batches VS 4D batches.") | ||
70 | + flags.DEFINE_integer("segment_max_pred", 100000, | ||
71 | + "Limit total number of segment outputs per entity.") | ||
72 | + flags.DEFINE_string( | ||
73 | + "segment_label_ids_file", | ||
74 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | ||
75 | + "The file that contains the segment label ids.") | ||
76 | + | ||
77 | + # Output | ||
78 | + flags.DEFINE_string("output_file", "/Users/esot3ria/kaggle_solution.csv", "The file to save the predictions to.") | ||
79 | + flags.DEFINE_string( | ||
80 | + "output_model_tgz", "", | ||
81 | + "If given, should be a filename with a .tgz extension, " | ||
82 | + "the model graph and checkpoint will be bundled in this " | ||
83 | + "gzip tar. This file can be uploaded to Kaggle for the " | ||
84 | + "top 10 participants.") | ||
85 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") | ||
86 | + | ||
87 | + # Other flags. | ||
88 | + flags.DEFINE_integer("batch_size", 512, | ||
89 | + "How many examples to process per batch.") | ||
90 | + flags.DEFINE_integer("num_readers", 1, | ||
91 | + "How many threads to use for reading input files.") | ||
92 | + | ||
93 | + | ||
94 | +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
95 | + """Create an information line the submission file.""" | ||
96 | + batch_size = len(video_ids) | ||
97 | + for video_index in range(batch_size): | ||
98 | + video_prediction = predictions[video_index] | ||
99 | + if whitelisted_cls_mask is not None: | ||
100 | + # Whitelist classes. | ||
101 | + video_prediction *= whitelisted_cls_mask | ||
102 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
103 | + line = [(class_index, predictions[video_index][class_index]) | ||
104 | + for class_index in top_indices] | ||
105 | + line = sorted(line, key=lambda p: -p[1]) | ||
106 | + yield (video_ids[video_index] + "," + | ||
107 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
108 | + "\n").encode("utf8") | ||
109 | + | ||
110 | + | ||
111 | +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | ||
112 | + """Creates the section of the graph which reads the input data. | ||
113 | + | ||
114 | + Args: | ||
115 | + reader: A class which parses the input data. | ||
116 | + data_pattern: A 'glob' style path to the data files. | ||
117 | + batch_size: How many examples to process at a time. | ||
118 | + num_readers: How many I/O threads to use. | ||
119 | + | ||
120 | + Returns: | ||
121 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
122 | + tensor containing the number of frames per video. The exact dimensions | ||
123 | + depend on the reader being used. | ||
124 | + | ||
125 | + Raises: | ||
126 | + IOError: If no files matching the given pattern were found. | ||
127 | + """ | ||
128 | + with tf.name_scope("input"): | ||
129 | + files = gfile.Glob(data_pattern) | ||
130 | + if not files: | ||
131 | + raise IOError("Unable to find input files. data_pattern='" + | ||
132 | + data_pattern + "'") | ||
133 | + logging.info("number of input files: " + str(len(files))) | ||
134 | + filename_queue = tf.train.string_input_producer(files, | ||
135 | + num_epochs=1, | ||
136 | + shuffle=False) | ||
137 | + examples_and_labels = [ | ||
138 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
139 | + ] | ||
140 | + | ||
141 | + input_data_dict = (tf.train.batch_join(examples_and_labels, | ||
142 | + batch_size=batch_size, | ||
143 | + allow_smaller_final_batch=True, | ||
144 | + enqueue_many=True)) | ||
145 | + video_id_batch = input_data_dict["video_ids"] | ||
146 | + video_batch = input_data_dict["video_matrix"] | ||
147 | + num_frames_batch = input_data_dict["num_frames"] | ||
148 | + return video_id_batch, video_batch, num_frames_batch | ||
149 | + | ||
150 | + | ||
151 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
152 | + """Get segment-level inputs from frame-level features.""" | ||
153 | + video_batch_size = batch_video_mtx.shape[0] | ||
154 | + max_frame = batch_video_mtx.shape[1] | ||
155 | + feature_dim = batch_video_mtx.shape[-1] | ||
156 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
157 | + padded_segment_sizes *= segment_size | ||
158 | + segment_mask = ( | ||
159 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
160 | + | ||
161 | + # Segment bags. | ||
162 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
163 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
164 | + (-1, segment_size, feature_dim)) | ||
165 | + | ||
166 | + # Segment num frames. | ||
167 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
168 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
169 | + num_segment_bags = num_segments.reshape((-1)) | ||
170 | + valid_segment_mask = num_segment_bags > 0 | ||
171 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
172 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
173 | + | ||
174 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
175 | + video_idxs = np.tile( | ||
176 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
177 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
178 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
179 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
180 | + | ||
181 | + return { | ||
182 | + "video_batch": segment_frames, | ||
183 | + "num_frames_batch": segment_num_frames, | ||
184 | + "video_segment_ids": video_segment_ids | ||
185 | + } | ||
186 | + | ||
187 | + | ||
188 | +def normalize_tag(tag): | ||
189 | + if isinstance(tag, str): | ||
190 | + new_tag = tag.lower().replace('[^a-zA-Z]', ' ') | ||
191 | + if new_tag.find(" (") != -1: | ||
192 | + new_tag = new_tag[:new_tag.find(" (")] | ||
193 | + new_tag = new_tag.replace(" ", "-") | ||
194 | + return new_tag | ||
195 | + else: | ||
196 | + return tag | ||
197 | + | ||
198 | + | ||
199 | +def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
200 | + top_k): | ||
201 | + """Inference function.""" | ||
202 | + with tf.Session(config=tf.ConfigProto( | ||
203 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
204 | + "w+") as out_file: | ||
205 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
206 | + reader, data_pattern, batch_size) | ||
207 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
208 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
209 | + inference_model_name) | ||
210 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
211 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
212 | + meta_graph_location = checkpoint_file + ".meta" | ||
213 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
214 | + | ||
215 | + if FLAGS.output_model_tgz: | ||
216 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
217 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
218 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
219 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
220 | + arcname="model_flags.json") | ||
221 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
222 | + with tf.device("/cpu:0"): | ||
223 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
224 | + clear_devices=True) | ||
225 | + logging.info("restoring variables from " + checkpoint_file) | ||
226 | + saver.restore(sess, checkpoint_file) | ||
227 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
228 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
229 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
230 | + | ||
231 | + # Workaround for num_epochs issue. | ||
232 | + def set_up_init_ops(variables): | ||
233 | + init_op_list = [] | ||
234 | + for variable in list(variables): | ||
235 | + if "train_input" in variable.name: | ||
236 | + init_op_list.append(tf.assign(variable, 1)) | ||
237 | + variables.remove(variable) | ||
238 | + init_op_list.append(tf.variables_initializer(variables)) | ||
239 | + return init_op_list | ||
240 | + | ||
241 | + sess.run( | ||
242 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
243 | + | ||
244 | + coord = tf.train.Coordinator() | ||
245 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
246 | + num_examples_processed = 0 | ||
247 | + start_time = time.time() | ||
248 | + whitelisted_cls_mask = None | ||
249 | + if FLAGS.segment_labels: | ||
250 | + final_out_file = out_file | ||
251 | + out_file = tempfile.NamedTemporaryFile() | ||
252 | + logging.info( | ||
253 | + "Segment temp prediction output will be written to temp file: %s", | ||
254 | + out_file.name) | ||
255 | + if FLAGS.segment_label_ids_file: | ||
256 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
257 | + dtype=np.float32) | ||
258 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
259 | + if segment_label_ids_file.startswith("http"): | ||
260 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
261 | + segment_label_ids_file) | ||
262 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
263 | + segment_label_ids_file) | ||
264 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
265 | + for line in fobj: | ||
266 | + try: | ||
267 | + cls_id = int(line) | ||
268 | + whitelisted_cls_mask[cls_id] = 1. | ||
269 | + except ValueError: | ||
270 | + # Simply skip the non-integer line. | ||
271 | + continue | ||
272 | + | ||
273 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
274 | + | ||
275 | + # ========================================= | ||
276 | + # open vocab csv file and store to dictionary | ||
277 | + # ========================================= | ||
278 | + voca_dict = {} | ||
279 | + vocabs = open("./statics/vocabulary.csv", 'r') | ||
280 | + while True: | ||
281 | + line = vocabs.readline() | ||
282 | + if not line: break | ||
283 | + vocab_dict_item = line.split(",") | ||
284 | + if vocab_dict_item[0] != "Index": | ||
285 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
286 | + vocabs.close() | ||
287 | + try: | ||
288 | + while not coord.should_stop(): | ||
289 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
290 | + [video_id_batch, video_batch, num_frames_batch]) | ||
291 | + if FLAGS.segment_labels: | ||
292 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
293 | + video_segment_ids = results["video_segment_ids"] | ||
294 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
295 | + video_id_batch_val = np.array([ | ||
296 | + "%s:%d" % (x.decode("utf8"), y) | ||
297 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
298 | + ]) | ||
299 | + video_batch_val = results["video_batch"] | ||
300 | + num_frames_batch_val = results["num_frames_batch"] | ||
301 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
302 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
303 | + "with correct segment_labels settings.") | ||
304 | + | ||
305 | + predictions_val, = sess.run([predictions_tensor], | ||
306 | + feed_dict={ | ||
307 | + input_tensor: video_batch_val, | ||
308 | + num_frames_tensor: num_frames_batch_val | ||
309 | + }) | ||
310 | + now = time.time() | ||
311 | + num_examples_processed += len(video_batch_val) | ||
312 | + elapsed_time = now - start_time | ||
313 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
314 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
315 | + " examples/sec: %.2f" % | ||
316 | + (num_examples_processed / elapsed_time)) | ||
317 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
318 | + whitelisted_cls_mask): | ||
319 | + out_file.write(line) | ||
320 | + out_file.flush() | ||
321 | + | ||
322 | + except tf.errors.OutOfRangeError: | ||
323 | + logging.info("Done with inference. The output file was written to " + | ||
324 | + out_file.name) | ||
325 | + finally: | ||
326 | + coord.request_stop() | ||
327 | + | ||
328 | + if FLAGS.segment_labels: | ||
329 | + # Re-read the file and do heap sort. | ||
330 | + # Create multiple heaps. | ||
331 | + logging.info("Post-processing segment predictions...") | ||
332 | + segment_id_list = [] | ||
333 | + segment_classes = [] | ||
334 | + cls_result_arr = [] | ||
335 | + cls_score_dict = {} | ||
336 | + out_file.seek(0, 0) | ||
337 | + old_seg_name = '0000' | ||
338 | + counter = 0 | ||
339 | + for line in out_file: | ||
340 | + counter += 1 | ||
341 | + if counter / 5000 == 0: | ||
342 | + print(counter, " processed") | ||
343 | + segment_id, preds = line.decode("utf8").split(",") | ||
344 | + if segment_id == "VideoId": | ||
345 | + # Skip the headline. | ||
346 | + continue | ||
347 | + | ||
348 | + preds = preds.split(" ") | ||
349 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
350 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
351 | + # ======================================= | ||
352 | + segment_id = str(segment_id.split(":")[0]) | ||
353 | + if segment_id not in segment_id_list: | ||
354 | + segment_id_list.append(str(segment_id)) | ||
355 | + segment_classes.append("") | ||
356 | + | ||
357 | + index = segment_id_list.index(segment_id) | ||
358 | + | ||
359 | + if old_seg_name != segment_id: | ||
360 | + cls_score_dict[segment_id] = {} | ||
361 | + old_seg_name = segment_id | ||
362 | + | ||
363 | + for classes in range(0, len(pred_cls_ids)): # pred_cls_ids: | ||
364 | + segment_classes[index] = str(segment_classes[index]) + str( | ||
365 | + pred_cls_ids[classes]) + " " # append classes from new segment | ||
366 | + if pred_cls_ids[classes] in cls_score_dict[segment_id]: | ||
367 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][ | ||
368 | + pred_cls_ids[classes]] + \ | ||
369 | + pred_cls_scores[classes] | ||
370 | + else: | ||
371 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes] | ||
372 | + | ||
373 | + for segs, item in zip(segment_id_list, segment_classes): | ||
374 | + # print('====== R E C O R D ======') | ||
375 | + cls_arr = item.split(" ")[:-1] | ||
376 | + | ||
377 | + cls_arr = list(map(int, cls_arr)) | ||
378 | + cls_arr = sorted(cls_arr) # 클래스별로 정렬 | ||
379 | + | ||
380 | + result_string = "" | ||
381 | + | ||
382 | + temp = cls_score_dict[segs] | ||
383 | + temp = sorted(temp.items(), key=operator.itemgetter(1), reverse=True) # 밸류값 기준으로 정렬 | ||
384 | + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) | ||
385 | + # for item in temp: | ||
386 | + for itemIndex in range(0, top_k): | ||
387 | + # Normalize tag name | ||
388 | + segment_tag = str(voca_dict[str(temp[itemIndex][0])]) | ||
389 | + normalized_tag = normalize_tag(segment_tag) | ||
390 | + result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1] / demoninator, | ||
391 | + ".3f") + "," | ||
392 | + | ||
393 | + cls_result_arr.append(result_string[:-1]) | ||
394 | + logging.info(segs + " : " + result_string[:-1]) | ||
395 | + # ======================================= | ||
396 | + final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") | ||
397 | + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | ||
398 | + final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies))) | ||
399 | + final_out_file.close() | ||
400 | + | ||
401 | + out_file.close() | ||
402 | + | ||
403 | + coord.join(threads) | ||
404 | + sess.close() | ||
405 | + | ||
406 | + | ||
407 | +def main(unused_argv): | ||
408 | + logging.set_verbosity(tf.logging.INFO) | ||
409 | + if FLAGS.input_model_tgz: | ||
410 | + if FLAGS.train_dir: | ||
411 | + raise ValueError("You cannot supply --train_dir if supplying " | ||
412 | + "--input_model_tgz") | ||
413 | + # Untar. | ||
414 | + if not os.path.exists(FLAGS.untar_model_dir): | ||
415 | + os.makedirs(FLAGS.untar_model_dir) | ||
416 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | ||
417 | + FLAGS.train_dir = FLAGS.untar_model_dir | ||
418 | + | ||
419 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
420 | + if not file_io.file_exists(flags_dict_file): | ||
421 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ||
422 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | ||
423 | + | ||
424 | + # convert feature_names and feature_sizes to lists of values | ||
425 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
426 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
427 | + | ||
428 | + if flags_dict["frame_features"]: | ||
429 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
430 | + feature_sizes=feature_sizes) | ||
431 | + else: | ||
432 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
433 | + feature_sizes=feature_sizes) | ||
434 | + | ||
435 | + if not FLAGS.output_file: | ||
436 | + raise ValueError("'output_file' was not specified. " | ||
437 | + "Unable to continue with inference.") | ||
438 | + | ||
439 | + if not FLAGS.input_data_pattern: | ||
440 | + raise ValueError("'input_data_pattern' was not specified. " | ||
441 | + "Unable to continue with inference.") | ||
442 | + | ||
443 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | ||
444 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | ||
445 | + | ||
446 | + | ||
447 | +if __name__ == "__main__": | ||
448 | + app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment