윤영빈

top-k label ouput

......@@ -34,6 +34,7 @@ from tensorflow import logging
from tensorflow.python.lib.io import file_io
import utils
from collections import Counter
import operator
FLAGS = flags.FLAGS
......@@ -81,7 +82,7 @@ if __name__ == "__main__":
"the model graph and checkpoint will be bundled in this "
"gzip tar. This file can be uploaded to Kaggle for the "
"top 10 participants.")
flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.")
flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
# Other flags.
flags.DEFINE_integer("batch_size", 512,
......@@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
#=========================================
#open vocab csv file and store to dictionary
#=========================================
voca_dict = {}
vocabs = open("./vocabulary.csv", 'r')
while True:
line = vocabs.readline()
if not line: break
vocab_dict_item = line.split(",")
if vocab_dict_item[0] != "Index":
voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
vocabs.close()
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
......@@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
segment_id_list = []
segment_classes = []
cls_result_arr = []
cls_score_dict = {}
out_file.seek(0, 0)
old_seg_name = '0000'
for line in out_file:
segment_id, preds = line.decode("utf8").split(",")
if segment_id == "VideoId":
......@@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
preds = preds.split(" ")
pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
# =======================================
pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
#=======================================
segment_id = str(segment_id.split(":")[0])
if segment_id not in segment_id_list:
segment_id_list.append(str(segment_id))
segment_classes.append("")
index = segment_id_list.index(segment_id)
for classes in pred_cls_ids:
segment_classes[index] = str(segment_classes[index]) + str(
classes) + " " # append classes from new segment
for segs, item in zip(segment_id_list, segment_classes):
if old_seg_name != segment_id:
cls_score_dict[segment_id] = {}
old_seg_name = segment_id
for classes in range(0,len(pred_cls_ids)):#pred_cls_ids:
segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment
if pred_cls_ids[classes] in cls_score_dict[segment_id]:
cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes]
else:
cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes]
for segs,item in zip(segment_id_list,segment_classes):
print('====== R E C O R D ======')
cls_arr = item.split(" ")[:-1]
cls_arr = list(map(int, cls_arr))
cls_arr = sorted(cls_arr)
cls_arr = list(map(int,cls_arr))
cls_arr = sorted(cls_arr) #클래스별로 정렬
result_string = ""
temp = Counter(cls_arr)
for item in temp:
result_string = result_string + str(item) + ":" + str(temp[item]) + ","
temp = cls_score_dict[segs]
temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬
demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
#for item in temp:
for itemIndex in range(0, top_k):
result_string = result_string + str(voca_dict[str(temp[itemIndex][0])]) + ":" + format(temp[itemIndex][1]/demoninator,".3f") + ","
cls_result_arr.append(result_string[:-1])
logging.info(segs + " : " + result_string[:-1])
# =======================================
#=======================================
final_out_file.write("vid_id,seg_classes\n")
for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies)))
final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies)))
final_out_file.close()
out_file.close()
......@@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
coord.join(threads)
sess.close()
def main(unused_argv):
logging.set_verbosity(tf.logging.INFO)
if FLAGS.input_model_tgz:
......