윤영빈

top-k label ouput

...@@ -34,366 +34,392 @@ from tensorflow import logging ...@@ -34,366 +34,392 @@ from tensorflow import logging
34 from tensorflow.python.lib.io import file_io 34 from tensorflow.python.lib.io import file_io
35 import utils 35 import utils
36 from collections import Counter 36 from collections import Counter
37 +import operator
37 38
38 FLAGS = flags.FLAGS 39 FLAGS = flags.FLAGS
39 40
40 if __name__ == "__main__": 41 if __name__ == "__main__":
41 - # Input 42 + # Input
42 - flags.DEFINE_string( 43 + flags.DEFINE_string(
43 - "train_dir", "", "The directory to load the model files from. We assume " 44 + "train_dir", "", "The directory to load the model files from. We assume "
44 - "that you have already run eval.py onto this, such that " 45 + "that you have already run eval.py onto this, such that "
45 - "inference_model.* files already exist.") 46 + "inference_model.* files already exist.")
46 - flags.DEFINE_string( 47 + flags.DEFINE_string(
47 - "input_data_pattern", "", 48 + "input_data_pattern", "",
48 - "File glob defining the evaluation dataset in tensorflow.SequenceExample " 49 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
49 - "format. The SequenceExamples are expected to have an 'rgb' byte array " 50 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
50 - "sequence feature as well as a 'labels' int64 context feature.") 51 + "sequence feature as well as a 'labels' int64 context feature.")
51 - flags.DEFINE_string( 52 + flags.DEFINE_string(
52 - "input_model_tgz", "", 53 + "input_model_tgz", "",
53 - "If given, must be path to a .tgz file that was written " 54 + "If given, must be path to a .tgz file that was written "
54 - "by this binary using flag --output_model_tgz. In this " 55 + "by this binary using flag --output_model_tgz. In this "
55 - "case, the .tgz file will be untarred to " 56 + "case, the .tgz file will be untarred to "
56 - "--untar_model_dir and the model will be used for " 57 + "--untar_model_dir and the model will be used for "
57 - "inference.") 58 + "inference.")
58 - flags.DEFINE_string( 59 + flags.DEFINE_string(
59 - "untar_model_dir", "/tmp/yt8m-model", 60 + "untar_model_dir", "/tmp/yt8m-model",
60 - "If --input_model_tgz is given, then this directory will " 61 + "If --input_model_tgz is given, then this directory will "
61 - "be created and the contents of the .tgz file will be " 62 + "be created and the contents of the .tgz file will be "
62 - "untarred here.") 63 + "untarred here.")
63 - flags.DEFINE_bool( 64 + flags.DEFINE_bool(
64 - "segment_labels", False, 65 + "segment_labels", False,
65 - "If set, then --input_data_pattern must be frame-level features (but with" 66 + "If set, then --input_data_pattern must be frame-level features (but with"
66 - " segment_labels). Otherwise, --input_data_pattern must be aggregated " 67 + " segment_labels). Otherwise, --input_data_pattern must be aggregated "
67 - "video-level features. The model must also be set appropriately (i.e. to " 68 + "video-level features. The model must also be set appropriately (i.e. to "
68 - "read 3D batches VS 4D batches.") 69 + "read 3D batches VS 4D batches.")
69 - flags.DEFINE_integer("segment_max_pred", 100000, 70 + flags.DEFINE_integer("segment_max_pred", 100000,
70 - "Limit total number of segment outputs per entity.") 71 + "Limit total number of segment outputs per entity.")
71 - flags.DEFINE_string( 72 + flags.DEFINE_string(
72 - "segment_label_ids_file", 73 + "segment_label_ids_file",
73 - "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", 74 + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv",
74 - "The file that contains the segment label ids.") 75 + "The file that contains the segment label ids.")
75 - 76 +
76 - # Output 77 + # Output
77 - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") 78 + flags.DEFINE_string("output_file", "", "The file to save the predictions to.")
78 - flags.DEFINE_string( 79 + flags.DEFINE_string(
79 - "output_model_tgz", "", 80 + "output_model_tgz", "",
80 - "If given, should be a filename with a .tgz extension, " 81 + "If given, should be a filename with a .tgz extension, "
81 - "the model graph and checkpoint will be bundled in this " 82 + "the model graph and checkpoint will be bundled in this "
82 - "gzip tar. This file can be uploaded to Kaggle for the " 83 + "gzip tar. This file can be uploaded to Kaggle for the "
83 - "top 10 participants.") 84 + "top 10 participants.")
84 - flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.") 85 + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
85 - 86 +
86 - # Other flags. 87 + # Other flags.
87 - flags.DEFINE_integer("batch_size", 512, 88 + flags.DEFINE_integer("batch_size", 512,
88 - "How many examples to process per batch.") 89 + "How many examples to process per batch.")
89 - flags.DEFINE_integer("num_readers", 1, 90 + flags.DEFINE_integer("num_readers", 1,
90 - "How many threads to use for reading input files.") 91 + "How many threads to use for reading input files.")
91 92
92 93
93 def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): 94 def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None):
94 - """Create an information line the submission file.""" 95 + """Create an information line the submission file."""
95 - batch_size = len(video_ids) 96 + batch_size = len(video_ids)
96 - for video_index in range(batch_size): 97 + for video_index in range(batch_size):
97 - video_prediction = predictions[video_index] 98 + video_prediction = predictions[video_index]
98 - if whitelisted_cls_mask is not None: 99 + if whitelisted_cls_mask is not None:
99 - # Whitelist classes. 100 + # Whitelist classes.
100 - video_prediction *= whitelisted_cls_mask 101 + video_prediction *= whitelisted_cls_mask
101 - top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] 102 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
102 - line = [(class_index, predictions[video_index][class_index]) 103 + line = [(class_index, predictions[video_index][class_index])
103 - for class_index in top_indices] 104 + for class_index in top_indices]
104 - line = sorted(line, key=lambda p: -p[1]) 105 + line = sorted(line, key=lambda p: -p[1])
105 - yield (video_ids[video_index] + "," + 106 + yield (video_ids[video_index] + "," +
106 - " ".join("%i %g" % (label, score) for (label, score) in line) + 107 + " ".join("%i %g" % (label, score) for (label, score) in line) +
107 - "\n").encode("utf8") 108 + "\n").encode("utf8")
108 109
109 110
110 def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): 111 def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
111 - """Creates the section of the graph which reads the input data. 112 + """Creates the section of the graph which reads the input data.
112 - 113 +
113 - Args: 114 + Args:
114 - reader: A class which parses the input data. 115 + reader: A class which parses the input data.
115 - data_pattern: A 'glob' style path to the data files. 116 + data_pattern: A 'glob' style path to the data files.
116 - batch_size: How many examples to process at a time. 117 + batch_size: How many examples to process at a time.
117 - num_readers: How many I/O threads to use. 118 + num_readers: How many I/O threads to use.
118 - 119 +
119 - Returns: 120 + Returns:
120 - A tuple containing the features tensor, labels tensor, and optionally a 121 + A tuple containing the features tensor, labels tensor, and optionally a
121 - tensor containing the number of frames per video. The exact dimensions 122 + tensor containing the number of frames per video. The exact dimensions
122 - depend on the reader being used. 123 + depend on the reader being used.
123 - 124 +
124 - Raises: 125 + Raises:
125 - IOError: If no files matching the given pattern were found. 126 + IOError: If no files matching the given pattern were found.
126 - """ 127 + """
127 - with tf.name_scope("input"): 128 + with tf.name_scope("input"):
128 - files = gfile.Glob(data_pattern) 129 + files = gfile.Glob(data_pattern)
129 - if not files: 130 + if not files:
130 - raise IOError("Unable to find input files. data_pattern='" + 131 + raise IOError("Unable to find input files. data_pattern='" +
131 - data_pattern + "'") 132 + data_pattern + "'")
132 - logging.info("number of input files: " + str(len(files))) 133 + logging.info("number of input files: " + str(len(files)))
133 - filename_queue = tf.train.string_input_producer(files, 134 + filename_queue = tf.train.string_input_producer(files,
134 - num_epochs=1, 135 + num_epochs=1,
135 - shuffle=False) 136 + shuffle=False)
136 - examples_and_labels = [ 137 + examples_and_labels = [
137 - reader.prepare_reader(filename_queue) for _ in range(num_readers) 138 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
138 - ] 139 + ]
139 - 140 +
140 - input_data_dict = (tf.train.batch_join(examples_and_labels, 141 + input_data_dict = (tf.train.batch_join(examples_and_labels,
141 - batch_size=batch_size, 142 + batch_size=batch_size,
142 - allow_smaller_final_batch=True, 143 + allow_smaller_final_batch=True,
143 - enqueue_many=True)) 144 + enqueue_many=True))
144 - video_id_batch = input_data_dict["video_ids"] 145 + video_id_batch = input_data_dict["video_ids"]
145 - video_batch = input_data_dict["video_matrix"] 146 + video_batch = input_data_dict["video_matrix"]
146 - num_frames_batch = input_data_dict["num_frames"] 147 + num_frames_batch = input_data_dict["num_frames"]
147 - return video_id_batch, video_batch, num_frames_batch 148 + return video_id_batch, video_batch, num_frames_batch
148 149
149 150
150 def get_segments(batch_video_mtx, batch_num_frames, segment_size): 151 def get_segments(batch_video_mtx, batch_num_frames, segment_size):
151 - """Get segment-level inputs from frame-level features.""" 152 + """Get segment-level inputs from frame-level features."""
152 - video_batch_size = batch_video_mtx.shape[0] 153 + video_batch_size = batch_video_mtx.shape[0]
153 - max_frame = batch_video_mtx.shape[1] 154 + max_frame = batch_video_mtx.shape[1]
154 - feature_dim = batch_video_mtx.shape[-1] 155 + feature_dim = batch_video_mtx.shape[-1]
155 - padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size 156 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
156 - padded_segment_sizes *= segment_size 157 + padded_segment_sizes *= segment_size
157 - segment_mask = ( 158 + segment_mask = (
158 - 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) 159 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
159 - 160 +
160 - # Segment bags. 161 + # Segment bags.
161 - frame_bags = batch_video_mtx.reshape((-1, feature_dim)) 162 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
162 - segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( 163 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
163 - (-1, segment_size, feature_dim)) 164 + (-1, segment_size, feature_dim))
164 - 165 +
165 - # Segment num frames. 166 + # Segment num frames.
166 - segment_start_times = np.arange(0, max_frame, segment_size) 167 + segment_start_times = np.arange(0, max_frame, segment_size)
167 - num_segments = batch_num_frames[:, np.newaxis] - segment_start_times 168 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
168 - num_segment_bags = num_segments.reshape((-1)) 169 + num_segment_bags = num_segments.reshape((-1))
169 - valid_segment_mask = num_segment_bags > 0 170 + valid_segment_mask = num_segment_bags > 0
170 - segment_num_frames = num_segment_bags[valid_segment_mask] 171 + segment_num_frames = num_segment_bags[valid_segment_mask]
171 - segment_num_frames[segment_num_frames > segment_size] = segment_size 172 + segment_num_frames[segment_num_frames > segment_size] = segment_size
172 - 173 +
173 - max_segment_num = (max_frame + segment_size - 1) // segment_size 174 + max_segment_num = (max_frame + segment_size - 1) // segment_size
174 - video_idxs = np.tile( 175 + video_idxs = np.tile(
175 - np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) 176 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
176 - segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) 177 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
177 - idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) 178 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
178 - video_segment_ids = idx_bags[valid_segment_mask] 179 + video_segment_ids = idx_bags[valid_segment_mask]
179 - 180 +
180 - return { 181 + return {
181 - "video_batch": segment_frames, 182 + "video_batch": segment_frames,
182 - "num_frames_batch": segment_num_frames, 183 + "num_frames_batch": segment_num_frames,
183 - "video_segment_ids": video_segment_ids 184 + "video_segment_ids": video_segment_ids
184 - } 185 + }
185 186
186 187
187 def inference(reader, train_dir, data_pattern, out_file_location, batch_size, 188 def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
188 top_k): 189 top_k):
189 - """Inference function.""" 190 + """Inference function."""
190 - with tf.Session(config=tf.ConfigProto( 191 + with tf.Session(config=tf.ConfigProto(
191 - allow_soft_placement=True)) as sess, gfile.Open(out_file_location, 192 + allow_soft_placement=True)) as sess, gfile.Open(out_file_location,
192 - "w+") as out_file: 193 + "w+") as out_file:
193 - video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( 194 + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
194 - reader, data_pattern, batch_size) 195 + reader, data_pattern, batch_size)
195 - inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" 196 + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model"
196 - checkpoint_file = os.path.join(train_dir, "inference_model", 197 + checkpoint_file = os.path.join(train_dir, "inference_model",
197 - inference_model_name) 198 + inference_model_name)
198 - if not gfile.Exists(checkpoint_file + ".meta"): 199 + if not gfile.Exists(checkpoint_file + ".meta"):
199 - raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) 200 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
200 - meta_graph_location = checkpoint_file + ".meta" 201 + meta_graph_location = checkpoint_file + ".meta"
201 - logging.info("loading meta-graph: " + meta_graph_location) 202 + logging.info("loading meta-graph: " + meta_graph_location)
202 - 203 +
203 - if FLAGS.output_model_tgz: 204 + if FLAGS.output_model_tgz:
204 - with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: 205 + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
205 - for model_file in glob.glob(checkpoint_file + ".*"): 206 + for model_file in glob.glob(checkpoint_file + ".*"):
206 - tar.add(model_file, arcname=os.path.basename(model_file)) 207 + tar.add(model_file, arcname=os.path.basename(model_file))
207 - tar.add(os.path.join(train_dir, "model_flags.json"), 208 + tar.add(os.path.join(train_dir, "model_flags.json"),
208 - arcname="model_flags.json") 209 + arcname="model_flags.json")
209 - print("Tarred model onto " + FLAGS.output_model_tgz) 210 + print("Tarred model onto " + FLAGS.output_model_tgz)
210 - with tf.device("/cpu:0"): 211 + with tf.device("/cpu:0"):
211 - saver = tf.train.import_meta_graph(meta_graph_location, 212 + saver = tf.train.import_meta_graph(meta_graph_location,
212 - clear_devices=True) 213 + clear_devices=True)
213 - logging.info("restoring variables from " + checkpoint_file) 214 + logging.info("restoring variables from " + checkpoint_file)
214 - saver.restore(sess, checkpoint_file) 215 + saver.restore(sess, checkpoint_file)
215 - input_tensor = tf.get_collection("input_batch_raw")[0] 216 + input_tensor = tf.get_collection("input_batch_raw")[0]
216 - num_frames_tensor = tf.get_collection("num_frames")[0] 217 + num_frames_tensor = tf.get_collection("num_frames")[0]
217 - predictions_tensor = tf.get_collection("predictions")[0] 218 + predictions_tensor = tf.get_collection("predictions")[0]
218 - 219 +
219 - # Workaround for num_epochs issue. 220 + # Workaround for num_epochs issue.
220 - def set_up_init_ops(variables): 221 + def set_up_init_ops(variables):
221 - init_op_list = [] 222 + init_op_list = []
222 - for variable in list(variables): 223 + for variable in list(variables):
223 - if "train_input" in variable.name: 224 + if "train_input" in variable.name:
224 - init_op_list.append(tf.assign(variable, 1)) 225 + init_op_list.append(tf.assign(variable, 1))
225 - variables.remove(variable) 226 + variables.remove(variable)
226 - init_op_list.append(tf.variables_initializer(variables)) 227 + init_op_list.append(tf.variables_initializer(variables))
227 - return init_op_list 228 + return init_op_list
228 - 229 +
229 - sess.run( 230 + sess.run(
230 - set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) 231 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
231 - 232 +
232 - coord = tf.train.Coordinator() 233 + coord = tf.train.Coordinator()
233 - threads = tf.train.start_queue_runners(sess=sess, coord=coord) 234 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
234 - num_examples_processed = 0 235 + num_examples_processed = 0
235 - start_time = time.time() 236 + start_time = time.time()
236 - whitelisted_cls_mask = None 237 + whitelisted_cls_mask = None
238 + if FLAGS.segment_labels:
239 + final_out_file = out_file
240 + out_file = tempfile.NamedTemporaryFile()
241 + logging.info(
242 + "Segment temp prediction output will be written to temp file: %s",
243 + out_file.name)
244 + if FLAGS.segment_label_ids_file:
245 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
246 + dtype=np.float32)
247 + segment_label_ids_file = FLAGS.segment_label_ids_file
248 + if segment_label_ids_file.startswith("http"):
249 + logging.info("Retrieving segment ID whitelist files from %s...",
250 + segment_label_ids_file)
251 + segment_label_ids_file, _ = urllib.request.urlretrieve(
252 + segment_label_ids_file)
253 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
254 + for line in fobj:
255 + try:
256 + cls_id = int(line)
257 + whitelisted_cls_mask[cls_id] = 1.
258 + except ValueError:
259 + # Simply skip the non-integer line.
260 + continue
261 +
262 + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
263 +
264 + #=========================================
265 + #open vocab csv file and store to dictionary
266 + #=========================================
267 + voca_dict = {}
268 + vocabs = open("./vocabulary.csv", 'r')
269 + while True:
270 + line = vocabs.readline()
271 + if not line: break
272 + vocab_dict_item = line.split(",")
273 + if vocab_dict_item[0] != "Index":
274 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
275 + vocabs.close()
276 + try:
277 + while not coord.should_stop():
278 + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
279 + [video_id_batch, video_batch, num_frames_batch])
237 if FLAGS.segment_labels: 280 if FLAGS.segment_labels:
238 - final_out_file = out_file 281 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
239 - out_file = tempfile.NamedTemporaryFile() 282 + video_segment_ids = results["video_segment_ids"]
240 - logging.info( 283 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
241 - "Segment temp prediction output will be written to temp file: %s", 284 + video_id_batch_val = np.array([
242 - out_file.name) 285 + "%s:%d" % (x.decode("utf8"), y)
243 - if FLAGS.segment_label_ids_file: 286 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
244 - whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), 287 + ])
245 - dtype=np.float32) 288 + video_batch_val = results["video_batch"]
246 - segment_label_ids_file = FLAGS.segment_label_ids_file 289 + num_frames_batch_val = results["num_frames_batch"]
247 - if segment_label_ids_file.startswith("http"): 290 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
248 - logging.info("Retrieving segment ID whitelist files from %s...", 291 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
249 - segment_label_ids_file) 292 + "with correct segment_labels settings.")
250 - segment_label_ids_file, _ = urllib.request.urlretrieve( 293 +
251 - segment_label_ids_file) 294 + predictions_val, = sess.run([predictions_tensor],
252 - with tf.io.gfile.GFile(segment_label_ids_file) as fobj: 295 + feed_dict={
253 - for line in fobj: 296 + input_tensor: video_batch_val,
254 - try: 297 + num_frames_tensor: num_frames_batch_val
255 - cls_id = int(line) 298 + })
256 - whitelisted_cls_mask[cls_id] = 1. 299 + now = time.time()
257 - except ValueError: 300 + num_examples_processed += len(video_batch_val)
258 - # Simply skip the non-integer line. 301 + elapsed_time = now - start_time
259 - continue 302 + logging.info("num examples processed: " + str(num_examples_processed) +
260 - 303 + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) +
261 - out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) 304 + " examples/sec: %.2f" %
262 - 305 + (num_examples_processed / elapsed_time))
263 - try: 306 + for line in format_lines(video_id_batch_val, predictions_val, top_k,
264 - while not coord.should_stop(): 307 + whitelisted_cls_mask):
265 - video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( 308 + out_file.write(line)
266 - [video_id_batch, video_batch, num_frames_batch]) 309 + out_file.flush()
267 - if FLAGS.segment_labels: 310 +
268 - results = get_segments(video_batch_val, num_frames_batch_val, 5) 311 + except tf.errors.OutOfRangeError:
269 - video_segment_ids = results["video_segment_ids"] 312 + logging.info("Done with inference. The output file was written to " +
270 - video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] 313 + out_file.name)
271 - video_id_batch_val = np.array([ 314 + finally:
272 - "%s:%d" % (x.decode("utf8"), y) 315 + coord.request_stop()
273 - for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) 316 +
274 - ]) 317 + if FLAGS.segment_labels:
275 - video_batch_val = results["video_batch"] 318 + # Re-read the file and do heap sort.
276 - num_frames_batch_val = results["num_frames_batch"] 319 + # Create multiple heaps.
277 - if input_tensor.get_shape()[1] != video_batch_val.shape[1]: 320 + logging.info("Post-processing segment predictions...")
278 - raise ValueError("max_frames mismatch. Please re-run the eval.py " 321 + segment_id_list = []
279 - "with correct segment_labels settings.") 322 + segment_classes = []
280 - 323 + cls_result_arr = []
281 - predictions_val, = sess.run([predictions_tensor], 324 + cls_score_dict = {}
282 - feed_dict={ 325 + out_file.seek(0, 0)
283 - input_tensor: video_batch_val, 326 + old_seg_name = '0000'
284 - num_frames_tensor: num_frames_batch_val 327 + for line in out_file:
285 - }) 328 + segment_id, preds = line.decode("utf8").split(",")
286 - now = time.time() 329 + if segment_id == "VideoId":
287 - num_examples_processed += len(video_batch_val) 330 + # Skip the headline.
288 - elapsed_time = now - start_time 331 + continue
289 - logging.info("num examples processed: " + str(num_examples_processed) + 332 +
290 - " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + 333 + preds = preds.split(" ")
291 - " examples/sec: %.2f" % 334 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
292 - (num_examples_processed / elapsed_time)) 335 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
293 - for line in format_lines(video_id_batch_val, predictions_val, top_k, 336 + #=======================================
294 - whitelisted_cls_mask): 337 + segment_id = str(segment_id.split(":")[0])
295 - out_file.write(line) 338 + if segment_id not in segment_id_list:
296 - out_file.flush() 339 + segment_id_list.append(str(segment_id))
297 - 340 + segment_classes.append("")
298 - except tf.errors.OutOfRangeError: 341 +
299 - logging.info("Done with inference. The output file was written to " + 342 + index = segment_id_list.index(segment_id)
300 - out_file.name) 343 +
301 - finally: 344 + if old_seg_name != segment_id:
302 - coord.request_stop() 345 + cls_score_dict[segment_id] = {}
303 - 346 + old_seg_name = segment_id
304 - if FLAGS.segment_labels: 347 +
305 - # Re-read the file and do heap sort. 348 + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids:
306 - # Create multiple heaps. 349 + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment
307 - logging.info("Post-processing segment predictions...") 350 + if pred_cls_ids[classes] in cls_score_dict[segment_id]:
308 - segment_id_list = [] 351 + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes]
309 - segment_classes = [] 352 + else:
310 - cls_result_arr = [] 353 + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes]
311 - out_file.seek(0, 0) 354 +
312 - for line in out_file: 355 + for segs,item in zip(segment_id_list,segment_classes):
313 - segment_id, preds = line.decode("utf8").split(",") 356 + print('====== R E C O R D ======')
314 - if segment_id == "VideoId": 357 + cls_arr = item.split(" ")[:-1]
315 - # Skip the headline. 358 +
316 - continue 359 + cls_arr = list(map(int,cls_arr))
317 - 360 + cls_arr = sorted(cls_arr) #클래스별로 정렬
318 - preds = preds.split(" ") 361 +
319 - pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] 362 + result_string = ""
320 - # ======================================= 363 +
321 - segment_id = str(segment_id.split(":")[0]) 364 + temp = cls_score_dict[segs]
322 - if segment_id not in segment_id_list: 365 + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬
323 - segment_id_list.append(str(segment_id)) 366 + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
324 - segment_classes.append("") 367 + #for item in temp:
325 - 368 + for itemIndex in range(0, top_k):
326 - index = segment_id_list.index(segment_id) 369 + result_string = result_string + str(voca_dict[str(temp[itemIndex][0])]) + ":" + format(temp[itemIndex][1]/demoninator,".3f") + ","
327 - for classes in pred_cls_ids: 370 +
328 - segment_classes[index] = str(segment_classes[index]) + str( 371 + cls_result_arr.append(result_string[:-1])
329 - classes) + " " # append classes from new segment 372 + logging.info(segs + " : " + result_string[:-1])
330 - 373 + #=======================================
331 - for segs, item in zip(segment_id_list, segment_classes): 374 + final_out_file.write("vid_id,seg_classes\n")
332 - print('====== R E C O R D ======') 375 + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
333 - cls_arr = item.split(" ")[:-1] 376 + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies)))
334 - 377 + final_out_file.close()
335 - cls_arr = list(map(int, cls_arr)) 378 +
336 - cls_arr = sorted(cls_arr) 379 + out_file.close()
337 - 380 +
338 - result_string = "" 381 + coord.join(threads)
339 - 382 + sess.close()
340 - temp = Counter(cls_arr)
341 - for item in temp:
342 - result_string = result_string + str(item) + ":" + str(temp[item]) + ","
343 -
344 - cls_result_arr.append(result_string[:-1])
345 - logging.info(segs + " : " + result_string[:-1])
346 - # =======================================
347 - final_out_file.write("vid_id,seg_classes\n")
348 - for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
349 - final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies)))
350 - final_out_file.close()
351 -
352 - out_file.close()
353 -
354 - coord.join(threads)
355 - sess.close()
356 -
357 383
358 def main(unused_argv): 384 def main(unused_argv):
359 - logging.set_verbosity(tf.logging.INFO) 385 + logging.set_verbosity(tf.logging.INFO)
360 - if FLAGS.input_model_tgz: 386 + if FLAGS.input_model_tgz:
361 - if FLAGS.train_dir: 387 + if FLAGS.train_dir:
362 - raise ValueError("You cannot supply --train_dir if supplying " 388 + raise ValueError("You cannot supply --train_dir if supplying "
363 - "--input_model_tgz") 389 + "--input_model_tgz")
364 - # Untar. 390 + # Untar.
365 - if not os.path.exists(FLAGS.untar_model_dir): 391 + if not os.path.exists(FLAGS.untar_model_dir):
366 - os.makedirs(FLAGS.untar_model_dir) 392 + os.makedirs(FLAGS.untar_model_dir)
367 - tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) 393 + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
368 - FLAGS.train_dir = FLAGS.untar_model_dir 394 + FLAGS.train_dir = FLAGS.untar_model_dir
369 - 395 +
370 - flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") 396 + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
371 - if not file_io.file_exists(flags_dict_file): 397 + if not file_io.file_exists(flags_dict_file):
372 - raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) 398 + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
373 - flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) 399 + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read())
374 - 400 +
375 - # convert feature_names and feature_sizes to lists of values 401 + # convert feature_names and feature_sizes to lists of values
376 - feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( 402 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
377 - flags_dict["feature_names"], flags_dict["feature_sizes"]) 403 + flags_dict["feature_names"], flags_dict["feature_sizes"])
378 - 404 +
379 - if flags_dict["frame_features"]: 405 + if flags_dict["frame_features"]:
380 - reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, 406 + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
381 - feature_sizes=feature_sizes) 407 + feature_sizes=feature_sizes)
382 - else: 408 + else:
383 - reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, 409 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
384 - feature_sizes=feature_sizes) 410 + feature_sizes=feature_sizes)
385 - 411 +
386 - if not FLAGS.output_file: 412 + if not FLAGS.output_file:
387 - raise ValueError("'output_file' was not specified. " 413 + raise ValueError("'output_file' was not specified. "
388 - "Unable to continue with inference.") 414 + "Unable to continue with inference.")
389 - 415 +
390 - if not FLAGS.input_data_pattern: 416 + if not FLAGS.input_data_pattern:
391 - raise ValueError("'input_data_pattern' was not specified. " 417 + raise ValueError("'input_data_pattern' was not specified. "
392 - "Unable to continue with inference.") 418 + "Unable to continue with inference.")
393 - 419 +
394 - inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, 420 + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
395 - FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) 421 + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
396 422
397 423
398 if __name__ == "__main__": 424 if __name__ == "__main__":
399 - app.run() 425 + app.run()
......