이현규

Add kaggle solution maker

1 +# Copyright 2017 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Binary for generating predictions over a set of videos."""
15 +
16 +from __future__ import print_function
17 +
18 +import glob
19 +import heapq
20 +import json
21 +import os
22 +import tarfile
23 +import tempfile
24 +import time
25 +import numpy as np
26 +
27 +import readers
28 +from six.moves import urllib
29 +import tensorflow as tf
30 +from tensorflow import app
31 +from tensorflow import flags
32 +from tensorflow import gfile
33 +from tensorflow import logging
34 +from tensorflow.python.lib.io import file_io
35 +import utils
36 +from collections import Counter
37 +import operator
38 +
39 +FLAGS = flags.FLAGS
40 +
41 +if __name__ == "__main__":
42 + # Input
43 + flags.DEFINE_string(
44 + "train_dir", "./new_model", "The directory to load the model files from. We assume "
45 + "that you have already run eval.py onto this, such that "
46 + "inference_model.* files already exist.")
47 + flags.DEFINE_string(
48 + "input_data_pattern", "/Volumes/HDD/develop/yt8m/3/frame/test/test*.tfrecord",
49 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
50 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
51 + "sequence feature as well as a 'labels' int64 context feature.")
52 + flags.DEFINE_string(
53 + "input_model_tgz", "",
54 + "If given, must be path to a .tgz file that was written "
55 + "by this binary using flag --output_model_tgz. In this "
56 + "case, the .tgz file will be untarred to "
57 + "--untar_model_dir and the model will be used for "
58 + "inference.")
59 + flags.DEFINE_string(
60 + "untar_model_dir", "/tmp/yt8m-model",
61 + "If --input_model_tgz is given, then this directory will "
62 + "be created and the contents of the .tgz file will be "
63 + "untarred here.")
64 + flags.DEFINE_bool(
65 + "segment_labels", True,
66 + "If set, then --input_data_pattern must be frame-level features (but with"
67 + " segment_labels). Otherwise, --input_data_pattern must be aggregated "
68 + "video-level features. The model must also be set appropriately (i.e. to "
69 + "read 3D batches VS 4D batches.")
70 + flags.DEFINE_integer("segment_max_pred", 100000,
71 + "Limit total number of segment outputs per entity.")
72 + flags.DEFINE_string(
73 + "segment_label_ids_file",
74 + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv",
75 + "The file that contains the segment label ids.")
76 +
77 + # Output
78 + flags.DEFINE_string("output_file", "/Users/esot3ria/kaggle_solution.csv", "The file to save the predictions to.")
79 + flags.DEFINE_string(
80 + "output_model_tgz", "",
81 + "If given, should be a filename with a .tgz extension, "
82 + "the model graph and checkpoint will be bundled in this "
83 + "gzip tar. This file can be uploaded to Kaggle for the "
84 + "top 10 participants.")
85 + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
86 +
87 + # Other flags.
88 + flags.DEFINE_integer("batch_size", 512,
89 + "How many examples to process per batch.")
90 + flags.DEFINE_integer("num_readers", 1,
91 + "How many threads to use for reading input files.")
92 +
93 +
94 +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None):
95 + """Create an information line the submission file."""
96 + batch_size = len(video_ids)
97 + for video_index in range(batch_size):
98 + video_prediction = predictions[video_index]
99 + if whitelisted_cls_mask is not None:
100 + # Whitelist classes.
101 + video_prediction *= whitelisted_cls_mask
102 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
103 + line = [(class_index, predictions[video_index][class_index])
104 + for class_index in top_indices]
105 + line = sorted(line, key=lambda p: -p[1])
106 + yield (video_ids[video_index] + "," +
107 + " ".join("%i %g" % (label, score) for (label, score) in line) +
108 + "\n").encode("utf8")
109 +
110 +
111 +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
112 + """Creates the section of the graph which reads the input data.
113 +
114 + Args:
115 + reader: A class which parses the input data.
116 + data_pattern: A 'glob' style path to the data files.
117 + batch_size: How many examples to process at a time.
118 + num_readers: How many I/O threads to use.
119 +
120 + Returns:
121 + A tuple containing the features tensor, labels tensor, and optionally a
122 + tensor containing the number of frames per video. The exact dimensions
123 + depend on the reader being used.
124 +
125 + Raises:
126 + IOError: If no files matching the given pattern were found.
127 + """
128 + with tf.name_scope("input"):
129 + files = gfile.Glob(data_pattern)
130 + if not files:
131 + raise IOError("Unable to find input files. data_pattern='" +
132 + data_pattern + "'")
133 + logging.info("number of input files: " + str(len(files)))
134 + filename_queue = tf.train.string_input_producer(files,
135 + num_epochs=1,
136 + shuffle=False)
137 + examples_and_labels = [
138 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
139 + ]
140 +
141 + input_data_dict = (tf.train.batch_join(examples_and_labels,
142 + batch_size=batch_size,
143 + allow_smaller_final_batch=True,
144 + enqueue_many=True))
145 + video_id_batch = input_data_dict["video_ids"]
146 + video_batch = input_data_dict["video_matrix"]
147 + num_frames_batch = input_data_dict["num_frames"]
148 + return video_id_batch, video_batch, num_frames_batch
149 +
150 +
151 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
152 + """Get segment-level inputs from frame-level features."""
153 + video_batch_size = batch_video_mtx.shape[0]
154 + max_frame = batch_video_mtx.shape[1]
155 + feature_dim = batch_video_mtx.shape[-1]
156 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
157 + padded_segment_sizes *= segment_size
158 + segment_mask = (
159 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
160 +
161 + # Segment bags.
162 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
163 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
164 + (-1, segment_size, feature_dim))
165 +
166 + # Segment num frames.
167 + segment_start_times = np.arange(0, max_frame, segment_size)
168 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
169 + num_segment_bags = num_segments.reshape((-1))
170 + valid_segment_mask = num_segment_bags > 0
171 + segment_num_frames = num_segment_bags[valid_segment_mask]
172 + segment_num_frames[segment_num_frames > segment_size] = segment_size
173 +
174 + max_segment_num = (max_frame + segment_size - 1) // segment_size
175 + video_idxs = np.tile(
176 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
177 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
178 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
179 + video_segment_ids = idx_bags[valid_segment_mask]
180 +
181 + return {
182 + "video_batch": segment_frames,
183 + "num_frames_batch": segment_num_frames,
184 + "video_segment_ids": video_segment_ids
185 + }
186 +
187 +
188 +def normalize_tag(tag):
189 + if isinstance(tag, str):
190 + new_tag = tag.lower().replace('[^a-zA-Z]', ' ')
191 + if new_tag.find(" (") != -1:
192 + new_tag = new_tag[:new_tag.find(" (")]
193 + new_tag = new_tag.replace(" ", "-")
194 + return new_tag
195 + else:
196 + return tag
197 +
198 +
199 +def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
200 + top_k):
201 + """Inference function."""
202 + with tf.Session(config=tf.ConfigProto(
203 + allow_soft_placement=True)) as sess, gfile.Open(out_file_location,
204 + "w+") as out_file:
205 + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
206 + reader, data_pattern, batch_size)
207 + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model"
208 + checkpoint_file = os.path.join(train_dir, "inference_model",
209 + inference_model_name)
210 + if not gfile.Exists(checkpoint_file + ".meta"):
211 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
212 + meta_graph_location = checkpoint_file + ".meta"
213 + logging.info("loading meta-graph: " + meta_graph_location)
214 +
215 + if FLAGS.output_model_tgz:
216 + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
217 + for model_file in glob.glob(checkpoint_file + ".*"):
218 + tar.add(model_file, arcname=os.path.basename(model_file))
219 + tar.add(os.path.join(train_dir, "model_flags.json"),
220 + arcname="model_flags.json")
221 + print("Tarred model onto " + FLAGS.output_model_tgz)
222 + with tf.device("/cpu:0"):
223 + saver = tf.train.import_meta_graph(meta_graph_location,
224 + clear_devices=True)
225 + logging.info("restoring variables from " + checkpoint_file)
226 + saver.restore(sess, checkpoint_file)
227 + input_tensor = tf.get_collection("input_batch_raw")[0]
228 + num_frames_tensor = tf.get_collection("num_frames")[0]
229 + predictions_tensor = tf.get_collection("predictions")[0]
230 +
231 + # Workaround for num_epochs issue.
232 + def set_up_init_ops(variables):
233 + init_op_list = []
234 + for variable in list(variables):
235 + if "train_input" in variable.name:
236 + init_op_list.append(tf.assign(variable, 1))
237 + variables.remove(variable)
238 + init_op_list.append(tf.variables_initializer(variables))
239 + return init_op_list
240 +
241 + sess.run(
242 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
243 +
244 + coord = tf.train.Coordinator()
245 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
246 + num_examples_processed = 0
247 + start_time = time.time()
248 + whitelisted_cls_mask = None
249 + if FLAGS.segment_labels:
250 + final_out_file = out_file
251 + out_file = tempfile.NamedTemporaryFile()
252 + logging.info(
253 + "Segment temp prediction output will be written to temp file: %s",
254 + out_file.name)
255 + if FLAGS.segment_label_ids_file:
256 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
257 + dtype=np.float32)
258 + segment_label_ids_file = FLAGS.segment_label_ids_file
259 + if segment_label_ids_file.startswith("http"):
260 + logging.info("Retrieving segment ID whitelist files from %s...",
261 + segment_label_ids_file)
262 + segment_label_ids_file, _ = urllib.request.urlretrieve(
263 + segment_label_ids_file)
264 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
265 + for line in fobj:
266 + try:
267 + cls_id = int(line)
268 + whitelisted_cls_mask[cls_id] = 1.
269 + except ValueError:
270 + # Simply skip the non-integer line.
271 + continue
272 +
273 + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
274 +
275 + # =========================================
276 + # open vocab csv file and store to dictionary
277 + # =========================================
278 + voca_dict = {}
279 + vocabs = open("./statics/vocabulary.csv", 'r')
280 + while True:
281 + line = vocabs.readline()
282 + if not line: break
283 + vocab_dict_item = line.split(",")
284 + if vocab_dict_item[0] != "Index":
285 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
286 + vocabs.close()
287 + try:
288 + while not coord.should_stop():
289 + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
290 + [video_id_batch, video_batch, num_frames_batch])
291 + if FLAGS.segment_labels:
292 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
293 + video_segment_ids = results["video_segment_ids"]
294 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
295 + video_id_batch_val = np.array([
296 + "%s:%d" % (x.decode("utf8"), y)
297 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
298 + ])
299 + video_batch_val = results["video_batch"]
300 + num_frames_batch_val = results["num_frames_batch"]
301 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
302 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
303 + "with correct segment_labels settings.")
304 +
305 + predictions_val, = sess.run([predictions_tensor],
306 + feed_dict={
307 + input_tensor: video_batch_val,
308 + num_frames_tensor: num_frames_batch_val
309 + })
310 + now = time.time()
311 + num_examples_processed += len(video_batch_val)
312 + elapsed_time = now - start_time
313 + logging.info("num examples processed: " + str(num_examples_processed) +
314 + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) +
315 + " examples/sec: %.2f" %
316 + (num_examples_processed / elapsed_time))
317 + for line in format_lines(video_id_batch_val, predictions_val, top_k,
318 + whitelisted_cls_mask):
319 + out_file.write(line)
320 + out_file.flush()
321 +
322 + except tf.errors.OutOfRangeError:
323 + logging.info("Done with inference. The output file was written to " +
324 + out_file.name)
325 + finally:
326 + coord.request_stop()
327 +
328 + if FLAGS.segment_labels:
329 + # Re-read the file and do heap sort.
330 + # Create multiple heaps.
331 + logging.info("Post-processing segment predictions...")
332 + segment_id_list = []
333 + segment_classes = []
334 + cls_result_arr = []
335 + cls_score_dict = {}
336 + out_file.seek(0, 0)
337 + old_seg_name = '0000'
338 + counter = 0
339 + for line in out_file:
340 + counter += 1
341 + if counter / 5000 == 0:
342 + print(counter, " processed")
343 + segment_id, preds = line.decode("utf8").split(",")
344 + if segment_id == "VideoId":
345 + # Skip the headline.
346 + continue
347 +
348 + preds = preds.split(" ")
349 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
350 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
351 + # =======================================
352 + segment_id = str(segment_id.split(":")[0])
353 + if segment_id not in segment_id_list:
354 + segment_id_list.append(str(segment_id))
355 + segment_classes.append("")
356 +
357 + index = segment_id_list.index(segment_id)
358 +
359 + if old_seg_name != segment_id:
360 + cls_score_dict[segment_id] = {}
361 + old_seg_name = segment_id
362 +
363 + for classes in range(0, len(pred_cls_ids)): # pred_cls_ids:
364 + segment_classes[index] = str(segment_classes[index]) + str(
365 + pred_cls_ids[classes]) + " " # append classes from new segment
366 + if pred_cls_ids[classes] in cls_score_dict[segment_id]:
367 + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][
368 + pred_cls_ids[classes]] + \
369 + pred_cls_scores[classes]
370 + else:
371 + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes]
372 +
373 + for segs, item in zip(segment_id_list, segment_classes):
374 + # print('====== R E C O R D ======')
375 + cls_arr = item.split(" ")[:-1]
376 +
377 + cls_arr = list(map(int, cls_arr))
378 + cls_arr = sorted(cls_arr) # 클래스별로 정렬
379 +
380 + result_string = ""
381 +
382 + temp = cls_score_dict[segs]
383 + temp = sorted(temp.items(), key=operator.itemgetter(1), reverse=True) # 밸류값 기준으로 정렬
384 + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
385 + # for item in temp:
386 + for itemIndex in range(0, top_k):
387 + # Normalize tag name
388 + segment_tag = str(voca_dict[str(temp[itemIndex][0])])
389 + normalized_tag = normalize_tag(segment_tag)
390 + result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1] / demoninator,
391 + ".3f") + ","
392 +
393 + cls_result_arr.append(result_string[:-1])
394 + logging.info(segs + " : " + result_string[:-1])
395 + # =======================================
396 + final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n")
397 + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
398 + final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies)))
399 + final_out_file.close()
400 +
401 + out_file.close()
402 +
403 + coord.join(threads)
404 + sess.close()
405 +
406 +
407 +def main(unused_argv):
408 + logging.set_verbosity(tf.logging.INFO)
409 + if FLAGS.input_model_tgz:
410 + if FLAGS.train_dir:
411 + raise ValueError("You cannot supply --train_dir if supplying "
412 + "--input_model_tgz")
413 + # Untar.
414 + if not os.path.exists(FLAGS.untar_model_dir):
415 + os.makedirs(FLAGS.untar_model_dir)
416 + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
417 + FLAGS.train_dir = FLAGS.untar_model_dir
418 +
419 + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
420 + if not file_io.file_exists(flags_dict_file):
421 + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
422 + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read())
423 +
424 + # convert feature_names and feature_sizes to lists of values
425 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
426 + flags_dict["feature_names"], flags_dict["feature_sizes"])
427 +
428 + if flags_dict["frame_features"]:
429 + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
430 + feature_sizes=feature_sizes)
431 + else:
432 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
433 + feature_sizes=feature_sizes)
434 +
435 + if not FLAGS.output_file:
436 + raise ValueError("'output_file' was not specified. "
437 + "Unable to continue with inference.")
438 +
439 + if not FLAGS.input_data_pattern:
440 + raise ValueError("'input_data_pattern' was not specified. "
441 + "Unable to continue with inference.")
442 +
443 + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
444 + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
445 +
446 +
447 +if __name__ == "__main__":
448 + app.run()
...\ No newline at end of file ...\ No newline at end of file