Showing
1 changed file
with
40 additions
and
14 deletions
... | @@ -34,6 +34,7 @@ from tensorflow import logging | ... | @@ -34,6 +34,7 @@ from tensorflow import logging |
34 | from tensorflow.python.lib.io import file_io | 34 | from tensorflow.python.lib.io import file_io |
35 | import utils | 35 | import utils |
36 | from collections import Counter | 36 | from collections import Counter |
37 | +import operator | ||
37 | 38 | ||
38 | FLAGS = flags.FLAGS | 39 | FLAGS = flags.FLAGS |
39 | 40 | ||
... | @@ -81,7 +82,7 @@ if __name__ == "__main__": | ... | @@ -81,7 +82,7 @@ if __name__ == "__main__": |
81 | "the model graph and checkpoint will be bundled in this " | 82 | "the model graph and checkpoint will be bundled in this " |
82 | "gzip tar. This file can be uploaded to Kaggle for the " | 83 | "gzip tar. This file can be uploaded to Kaggle for the " |
83 | "top 10 participants.") | 84 | "top 10 participants.") |
84 | - flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.") | 85 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
85 | 86 | ||
86 | # Other flags. | 87 | # Other flags. |
87 | flags.DEFINE_integer("batch_size", 512, | 88 | flags.DEFINE_integer("batch_size", 512, |
... | @@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
260 | 261 | ||
261 | out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | 262 | out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) |
262 | 263 | ||
264 | + #========================================= | ||
265 | + #open vocab csv file and store to dictionary | ||
266 | + #========================================= | ||
267 | + voca_dict = {} | ||
268 | + vocabs = open("./vocabulary.csv", 'r') | ||
269 | + while True: | ||
270 | + line = vocabs.readline() | ||
271 | + if not line: break | ||
272 | + vocab_dict_item = line.split(",") | ||
273 | + if vocab_dict_item[0] != "Index": | ||
274 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
275 | + vocabs.close() | ||
263 | try: | 276 | try: |
264 | while not coord.should_stop(): | 277 | while not coord.should_stop(): |
265 | video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | 278 | video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( |
... | @@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
308 | segment_id_list = [] | 321 | segment_id_list = [] |
309 | segment_classes = [] | 322 | segment_classes = [] |
310 | cls_result_arr = [] | 323 | cls_result_arr = [] |
324 | + cls_score_dict = {} | ||
311 | out_file.seek(0, 0) | 325 | out_file.seek(0, 0) |
326 | + old_seg_name = '0000' | ||
312 | for line in out_file: | 327 | for line in out_file: |
313 | segment_id, preds = line.decode("utf8").split(",") | 328 | segment_id, preds = line.decode("utf8").split(",") |
314 | if segment_id == "VideoId": | 329 | if segment_id == "VideoId": |
... | @@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
317 | 332 | ||
318 | preds = preds.split(" ") | 333 | preds = preds.split(" ") |
319 | pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | 334 | pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] |
320 | - # ======================================= | 335 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] |
336 | + #======================================= | ||
321 | segment_id = str(segment_id.split(":")[0]) | 337 | segment_id = str(segment_id.split(":")[0]) |
322 | if segment_id not in segment_id_list: | 338 | if segment_id not in segment_id_list: |
323 | segment_id_list.append(str(segment_id)) | 339 | segment_id_list.append(str(segment_id)) |
324 | segment_classes.append("") | 340 | segment_classes.append("") |
325 | 341 | ||
326 | index = segment_id_list.index(segment_id) | 342 | index = segment_id_list.index(segment_id) |
327 | - for classes in pred_cls_ids: | ||
328 | - segment_classes[index] = str(segment_classes[index]) + str( | ||
329 | - classes) + " " # append classes from new segment | ||
330 | 343 | ||
331 | - for segs, item in zip(segment_id_list, segment_classes): | 344 | + if old_seg_name != segment_id: |
345 | + cls_score_dict[segment_id] = {} | ||
346 | + old_seg_name = segment_id | ||
347 | + | ||
348 | + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids: | ||
349 | + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment | ||
350 | + if pred_cls_ids[classes] in cls_score_dict[segment_id]: | ||
351 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes] | ||
352 | + else: | ||
353 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes] | ||
354 | + | ||
355 | + for segs,item in zip(segment_id_list,segment_classes): | ||
332 | print('====== R E C O R D ======') | 356 | print('====== R E C O R D ======') |
333 | cls_arr = item.split(" ")[:-1] | 357 | cls_arr = item.split(" ")[:-1] |
334 | 358 | ||
335 | - cls_arr = list(map(int, cls_arr)) | 359 | + cls_arr = list(map(int,cls_arr)) |
336 | - cls_arr = sorted(cls_arr) | 360 | + cls_arr = sorted(cls_arr) #클래스별로 정렬 |
337 | 361 | ||
338 | result_string = "" | 362 | result_string = "" |
339 | 363 | ||
340 | - temp = Counter(cls_arr) | 364 | + temp = cls_score_dict[segs] |
341 | - for item in temp: | 365 | + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 |
342 | - result_string = result_string + str(item) + ":" + str(temp[item]) + "," | 366 | + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) |
367 | + #for item in temp: | ||
368 | + for itemIndex in range(0, top_k): | ||
369 | + result_string = result_string + str(voca_dict[str(temp[itemIndex][0])]) + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," | ||
343 | 370 | ||
344 | cls_result_arr.append(result_string[:-1]) | 371 | cls_result_arr.append(result_string[:-1]) |
345 | logging.info(segs + " : " + result_string[:-1]) | 372 | logging.info(segs + " : " + result_string[:-1]) |
346 | - # ======================================= | 373 | + #======================================= |
347 | final_out_file.write("vid_id,seg_classes\n") | 374 | final_out_file.write("vid_id,seg_classes\n") |
348 | for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | 375 | for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): |
349 | - final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies))) | 376 | + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) |
350 | final_out_file.close() | 377 | final_out_file.close() |
351 | 378 | ||
352 | out_file.close() | 379 | out_file.close() |
... | @@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
354 | coord.join(threads) | 381 | coord.join(threads) |
355 | sess.close() | 382 | sess.close() |
356 | 383 | ||
357 | - | ||
358 | def main(unused_argv): | 384 | def main(unused_argv): |
359 | logging.set_verbosity(tf.logging.INFO) | 385 | logging.set_verbosity(tf.logging.INFO) |
360 | if FLAGS.input_model_tgz: | 386 | if FLAGS.input_model_tgz: | ... | ... |
-
Please register or login to post a comment