윤영빈

top-k label ouput

...@@ -34,6 +34,7 @@ from tensorflow import logging ...@@ -34,6 +34,7 @@ from tensorflow import logging
34 from tensorflow.python.lib.io import file_io 34 from tensorflow.python.lib.io import file_io
35 import utils 35 import utils
36 from collections import Counter 36 from collections import Counter
37 +import operator
37 38
38 FLAGS = flags.FLAGS 39 FLAGS = flags.FLAGS
39 40
...@@ -81,7 +82,7 @@ if __name__ == "__main__": ...@@ -81,7 +82,7 @@ if __name__ == "__main__":
81 "the model graph and checkpoint will be bundled in this " 82 "the model graph and checkpoint will be bundled in this "
82 "gzip tar. This file can be uploaded to Kaggle for the " 83 "gzip tar. This file can be uploaded to Kaggle for the "
83 "top 10 participants.") 84 "top 10 participants.")
84 - flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.") 85 + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
85 86
86 # Other flags. 87 # Other flags.
87 flags.DEFINE_integer("batch_size", 512, 88 flags.DEFINE_integer("batch_size", 512,
...@@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
260 261
261 out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) 262 out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
262 263
264 + #=========================================
265 + #open vocab csv file and store to dictionary
266 + #=========================================
267 + voca_dict = {}
268 + vocabs = open("./vocabulary.csv", 'r')
269 + while True:
270 + line = vocabs.readline()
271 + if not line: break
272 + vocab_dict_item = line.split(",")
273 + if vocab_dict_item[0] != "Index":
274 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
275 + vocabs.close()
263 try: 276 try:
264 while not coord.should_stop(): 277 while not coord.should_stop():
265 video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( 278 video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
...@@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
308 segment_id_list = [] 321 segment_id_list = []
309 segment_classes = [] 322 segment_classes = []
310 cls_result_arr = [] 323 cls_result_arr = []
324 + cls_score_dict = {}
311 out_file.seek(0, 0) 325 out_file.seek(0, 0)
326 + old_seg_name = '0000'
312 for line in out_file: 327 for line in out_file:
313 segment_id, preds = line.decode("utf8").split(",") 328 segment_id, preds = line.decode("utf8").split(",")
314 if segment_id == "VideoId": 329 if segment_id == "VideoId":
...@@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
317 332
318 preds = preds.split(" ") 333 preds = preds.split(" ")
319 pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] 334 pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
320 - # ======================================= 335 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
336 + #=======================================
321 segment_id = str(segment_id.split(":")[0]) 337 segment_id = str(segment_id.split(":")[0])
322 if segment_id not in segment_id_list: 338 if segment_id not in segment_id_list:
323 segment_id_list.append(str(segment_id)) 339 segment_id_list.append(str(segment_id))
324 segment_classes.append("") 340 segment_classes.append("")
325 341
326 index = segment_id_list.index(segment_id) 342 index = segment_id_list.index(segment_id)
327 - for classes in pred_cls_ids:
328 - segment_classes[index] = str(segment_classes[index]) + str(
329 - classes) + " " # append classes from new segment
330 343
331 - for segs, item in zip(segment_id_list, segment_classes): 344 + if old_seg_name != segment_id:
345 + cls_score_dict[segment_id] = {}
346 + old_seg_name = segment_id
347 +
348 + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids:
349 + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment
350 + if pred_cls_ids[classes] in cls_score_dict[segment_id]:
351 + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes]
352 + else:
353 + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes]
354 +
355 + for segs,item in zip(segment_id_list,segment_classes):
332 print('====== R E C O R D ======') 356 print('====== R E C O R D ======')
333 cls_arr = item.split(" ")[:-1] 357 cls_arr = item.split(" ")[:-1]
334 358
335 - cls_arr = list(map(int, cls_arr)) 359 + cls_arr = list(map(int,cls_arr))
336 - cls_arr = sorted(cls_arr) 360 + cls_arr = sorted(cls_arr) #클래스별로 정렬
337 361
338 result_string = "" 362 result_string = ""
339 363
340 - temp = Counter(cls_arr) 364 + temp = cls_score_dict[segs]
341 - for item in temp: 365 + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬
342 - result_string = result_string + str(item) + ":" + str(temp[item]) + "," 366 + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
367 + #for item in temp:
368 + for itemIndex in range(0, top_k):
369 + result_string = result_string + str(voca_dict[str(temp[itemIndex][0])]) + ":" + format(temp[itemIndex][1]/demoninator,".3f") + ","
343 370
344 cls_result_arr.append(result_string[:-1]) 371 cls_result_arr.append(result_string[:-1])
345 logging.info(segs + " : " + result_string[:-1]) 372 logging.info(segs + " : " + result_string[:-1])
346 - # ======================================= 373 + #=======================================
347 final_out_file.write("vid_id,seg_classes\n") 374 final_out_file.write("vid_id,seg_classes\n")
348 for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): 375 for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
349 - final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies))) 376 + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies)))
350 final_out_file.close() 377 final_out_file.close()
351 378
352 out_file.close() 379 out_file.close()
...@@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
354 coord.join(threads) 381 coord.join(threads)
355 sess.close() 382 sess.close()
356 383
357 -
358 def main(unused_argv): 384 def main(unused_argv):
359 logging.set_verbosity(tf.logging.INFO) 385 logging.set_verbosity(tf.logging.INFO)
360 if FLAGS.input_model_tgz: 386 if FLAGS.input_model_tgz:
......