윤영빈

almost done

...@@ -12,8 +12,10 @@ import video_util as videoutil ...@@ -12,8 +12,10 @@ import video_util as videoutil
12 MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model" 12 MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model"
13 VOCAB_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv" 13 VOCAB_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv"
14 VIDEO_TAGS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/segment_tags.csv" 14 VIDEO_TAGS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/segment_tags.csv"
15 +VIDEO_IDS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/videoIds.csv"
15 TAG_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/tag_vectors.model" 16 TAG_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/tag_vectors.model"
16 VIDEO_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/video_vectors.model" 17 VIDEO_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/video_vectors.model"
18 +VIDEO_ID_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/videoId_vectors.model"
17 SEGMENT_LABEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/segment_label_ids.csv" 19 SEGMENT_LABEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/segment_label_ids.csv"
18 20
19 # Define parameters. 21 # Define parameters.
...@@ -228,8 +230,8 @@ def inference_pb(file_path, threshold): ...@@ -228,8 +230,8 @@ def inference_pb(file_path, threshold):
228 230
229 # 5. Create recommend videos info, Combine results. 231 # 5. Create recommend videos info, Combine results.
230 recommend_video_ids = recommender.recommend_videos(tag_result, inputVideoTagResults, TAG_VECTOR_MODEL_PATH, 232 recommend_video_ids = recommender.recommend_videos(tag_result, inputVideoTagResults, TAG_VECTOR_MODEL_PATH,
231 - VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K) 233 + VIDEO_VECTOR_MODEL_PATH, VIDEO_ID_MODEL_PATH, VIDEO_TOP_K)
232 - video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids] 234 + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K,False) for ids in recommend_video_ids]
233 235
234 inference_result = { 236 inference_result = {
235 "tag_result": tag_result, 237 "tag_result": tag_result,
......
This file is too large to display.
This diff could not be displayed because it is too large.
...@@ -6,7 +6,7 @@ import pandas as pd ...@@ -6,7 +6,7 @@ import pandas as pd
6 import math 6 import math
7 import activation as ac 7 import activation as ac
8 8
9 -def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): 9 +def recommend_videos(tags, segments, tag_model_path, video_model_path, video_id_model, top_k):
10 # 이 함수에서 모든걸 다 함 10 # 이 함수에서 모든걸 다 함
11 # tags는 label val 로 묶인 문자열 리스트임 11 # tags는 label val 로 묶인 문자열 리스트임
12 # tags의 길이는 segment의 길이 12 # tags의 길이는 segment의 길이
...@@ -17,11 +17,12 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): ...@@ -17,11 +17,12 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k):
17 #segments는 클래스 확률 클래스 확률... 일케 저장되어 있음 17 #segments는 클래스 확률 클래스 확률... 일케 저장되어 있음
18 tag_vectors = Word2Vec.load(tag_model_path).wv 18 tag_vectors = Word2Vec.load(tag_model_path).wv
19 video_vectors = Word2Vec().wv.load(video_model_path) 19 video_vectors = Word2Vec().wv.load(video_model_path)
20 + video_ids = Word2Vec().wv.load(video_id_model)
20 error_tags = [] 21 error_tags = []
21 maxSimilarSegment = 0 22 maxSimilarSegment = 0
22 maxSimilarity = -1 23 maxSimilarity = -1
23 - 24 + print('prev len',len(segments))
24 - kernel = [np.zeros(100) for i in range(0,5)] 25 + kernel = [np.zeros(100) for i in range(0,9)]
25 tagKernel = [] 26 tagKernel = []
26 #우선은 비교를 뜰 입력 영상의 단일 비디오벡터를 구함 27 #우선은 비교를 뜰 입력 영상의 단일 비디오벡터를 구함
27 video_vector = np.zeros(100) 28 video_vector = np.zeros(100)
...@@ -41,27 +42,39 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): ...@@ -41,27 +42,39 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k):
41 error_tags.append(tag) 42 error_tags.append(tag)
42 43
43 #각 세그먼트마다 비교를 떠서 인덱스를 저장 44 #각 세그먼트마다 비교를 떠서 인덱스를 저장
44 - currentIndex = 0 45 + midpos = math.floor(len(kernel)/2)
45 - for segment in segments: 46 + for i in range(0,midpos):
46 - segment_vector = np.zeros(100) 47 + segments.insert(0,segments[0])
47 - segTags = [segment[i] for i in range(0,len(segment),2)] 48 + segments.append(segments[len(segments)-1])
48 - segProbs = [float(segment[i]) for i in range(1,len(segment),2)]#ac.softmax([float(segment[i]) for i in range(1,len(segment),2)])
49 49
50 - for tag, weight in zip(segTags,segProbs): 50 + currentIndex = midpos
51 - if tag in tag_vectors.vocab: 51 + for si in range(midpos,len(segments) - midpos - 1):
52 - segment_vector = segment_vector + (tag_vectors[tag] * float(weight)) 52 + similarity = 0
53 - else: 53 + for segi in range(-1,2):
54 - # Pass if tag is unknown 54 + segment = segments[si + segi]
55 - if tag not in error_tags: 55 + segment_vector = np.zeros(100)
56 - error_tags.append(tag) 56 + segTags = [segment[i] for i in range(0,len(segment),2)]
57 - 57 + segProbs = [float(segment[i]) for i in range(1,len(segment),2)]#ac.softmax([float(segment[i]) for i in range(1,len(segment),2)])
58 - #비디오 벡터와 세그먼트 벡터 비교 58 +
59 - similarity = cos_sim(video_vector, segment_vector) #cos_sim(video_vector, segment_vector)# 59 + for tag, weight in zip(segTags,segProbs):
60 + if tag in tag_vectors.vocab:
61 + segment_vector = segment_vector + (tag_vectors[tag] * float(weight))
62 + else:
63 + # Pass if tag is unknown
64 + if tag not in error_tags:
65 + error_tags.append(tag)
66 +
67 + #비디오 벡터와 세그먼트 벡터 비교
68 + #similarity = similarity + cos_sim(video_vector, segment_vector) #cos_sim(video_vector, segment_vector)#
60 69
61 - for currentSegmentTag, videoVectorTag in zip(segTags,videoTagList): 70 + for currentSegmentTag, videoVectorTag,videoVectorTagPred in zip(segTags,videoTagList,tag_preds):
62 - if(currentSegmentTag in tag_vectors.vocab) and (videoVectorTag in tag_vectors.vocab): 71 + if(currentSegmentTag in tag_vectors.vocab) and (videoVectorTag in tag_vectors.vocab):
63 - similarity = similarity + tag_vectors.similarity(currentSegmentTag,videoVectorTag) 72 + prob = float(videoVectorTagPred)
64 - 73 + if videoVectorTag not in segTags:
74 + prob = 0
75 + similarity = similarity + (tag_vectors.similarity(currentSegmentTag,videoVectorTag) * prob)
76 +
77 +
65 if similarity >= maxSimilarity: 78 if similarity >= maxSimilarity:
66 maxSimilarSegment = currentIndex 79 maxSimilarSegment = currentIndex
67 maxSimilarity = similarity 80 maxSimilarity = similarity
...@@ -71,7 +84,7 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): ...@@ -71,7 +84,7 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k):
71 maxSimilarSegment = len(segments) - int(len(kernel)/2) - 1 84 maxSimilarSegment = len(segments) - int(len(kernel)/2) - 1
72 #세그먼트 인덱스 증가 85 #세그먼트 인덱스 증가
73 currentIndex = currentIndex + 1 86 currentIndex = currentIndex + 1
74 - 87 + print('maxSimilarSegment',maxSimilarSegment,'len',len(segments))
75 #커널 생성 88 #커널 생성
76 for k in range (0,len(kernel)): 89 for k in range (0,len(kernel)):
77 segment = segments[maxSimilarSegment - math.floor(len(kernel)/2) + k] 90 segment = segments[maxSimilarSegment - math.floor(len(kernel)/2) + k]
...@@ -139,11 +152,11 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): ...@@ -139,11 +152,11 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k):
139 tagList.append([row[i].split(":")[0],row[i].split(":")[1]]) 152 tagList.append([row[i].split(":")[0],row[i].split(":")[1]])
140 segmentTagList.append(tagList) 153 segmentTagList.append(tagList)
141 154
142 - similar_ids = [] 155 + #similar_ids = []
143 - for i in range(0,top_k): 156 + #for i in range(0,top_k):
144 - similar_ids.append(minimunVideoIds[i][0]) 157 + # similar_ids.append(minimunVideoIds[i][0])
145 158
146 - #similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)] 159 + similar_ids = [x[0] for x in video_ids.similar_by_vector(video_vector, top_k)]
147 print('results =' ,similar_ids) 160 print('results =' ,similar_ids)
148 return similar_ids 161 return similar_ids
149 162
...@@ -199,12 +212,14 @@ def differenceMax(arrs, _kernel, w2v, videoTaglist): ...@@ -199,12 +212,14 @@ def differenceMax(arrs, _kernel, w2v, videoTaglist):
199 convResult = 0 212 convResult = 0
200 processed_vocabNum = 1 213 processed_vocabNum = 1
201 for i in range(0, s): 214 for i in range(0, s):
202 - #if i == midpos: 215 + if(_kernel[i][0] not in arrs[j - midpos + i][0]):# and ((videoTaglist[0] not in arrs[j - midpos + i][0:2])) and ((videoTaglist[1] not in arrs[j - midpos + i][0:5])):
203 - if(_kernel[i][0] not in arrs[j - midpos + i][0:2][0]):# and ((videoTaglist[0] not in arrs[j - midpos + i][0:2])) and ((videoTaglist[1] not in arrs[j - midpos + i][0:5])):
204 continue 216 continue
205 for ind in range(0,5): 217 for ind in range(0,5):
206 if(arrs[j - midpos + i][ind][0] in w2v.vocab) and (_kernel[i][ind] in w2v.vocab): 218 if(arrs[j - midpos + i][ind][0] in w2v.vocab) and (_kernel[i][ind] in w2v.vocab):
207 - convResult = convResult + (w2v.similarity(arrs[j - midpos + i][ind][0],_kernel[i][ind]) * float(arrs[j - midpos + i][ind][1])) 219 + prob = float(arrs[j - midpos + i][ind][1])
220 + if arrs[j - midpos + i][ind][0] not in videoTaglist:
221 + prob = 0
222 + convResult = convResult + (w2v.similarity(arrs[j - midpos + i][ind][0],_kernel[i][ind]) * prob)
208 processed_vocabNum = processed_vocabNum + 1 223 processed_vocabNum = processed_vocabNum + 1
209 224
210 if prevMax < convResult: 225 if prevMax < convResult:
......
...@@ -16,7 +16,7 @@ def getURL(vid_id): ...@@ -16,7 +16,7 @@ def getURL(vid_id):
16 16
17 print(getURL('nzwW')) 17 print(getURL('nzwW'))
18 18
19 -def getVideoInfo(vid_id, video_tags_path, top_k): 19 +def getVideoInfo(vid_id, video_tags_path, top_k, isPerVideo):
20 print("vid_id = ",vid_id) 20 print("vid_id = ",vid_id)
21 video_url = getURL(vid_id[0:4]) 21 video_url = getURL(vid_id[0:4])
22 22
...@@ -31,8 +31,9 @@ def getVideoInfo(vid_id, video_tags_path, top_k): ...@@ -31,8 +31,9 @@ def getVideoInfo(vid_id, video_tags_path, top_k):
31 if video_url == "": 31 if video_url == "":
32 for x in video_tags: 32 for x in video_tags:
33 video_url = video_url + ' ' + x 33 video_url = video_url + ' ' + x
34 - 34 +
35 - video_url = video_url + '\nThe similar point is : ' + str(float(vid_id[5:]) * 5) 35 + if isPerVideo == False:
36 + video_url = video_url + '\nThe similar point is : ' + str(float(vid_id[5:]) * 5)
36 37
37 return { 38 return {
38 "video_url": video_url, 39 "video_url": video_url,
......
...@@ -24,8 +24,8 @@ if __name__ == '__main__': ...@@ -24,8 +24,8 @@ if __name__ == '__main__':
24 video_vectors = Word2Vec().wv # Empty model 24 video_vectors = Word2Vec().wv # Empty model
25 25
26 # Load video recommendation tags. 26 # Load video recommendation tags.
27 - video_tags = pd.read_csv('./segment_tags.csv', encoding='utf-8',error_bad_lines=False) 27 + #video_tags = pd.read_csv('./segment_tags.csv', encoding='utf-8',error_bad_lines=False)
28 - 28 + video_tags = pd.read_csv('./videoIds.csv', encoding='utf-8',error_bad_lines=False)
29 # Define batch variables. 29 # Define batch variables.
30 batch_video_ids = [] 30 batch_video_ids = []
31 batch_video_vectors = [] 31 batch_video_vectors = []
...@@ -64,7 +64,7 @@ if __name__ == '__main__': ...@@ -64,7 +64,7 @@ if __name__ == '__main__':
64 print(error_tags) 64 print(error_tags)
65 print(len(error_tags)) 65 print(len(error_tags))
66 66
67 - video_vectors.save("video_vectors.model") 67 + video_vectors.save("videoId_vectors.model")
68 68
69 # Usage 69 # Usage
70 # video_vectors = Word2Vec().wv.load("video_vectors.model") 70 # video_vectors = Word2Vec().wv.load("video_vectors.model")
......
...@@ -395,6 +395,197 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -395,6 +395,197 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
395 coord.join(threads) 395 coord.join(threads)
396 sess.close() 396 sess.close()
397 397
398 +
399 +def inference2(reader, train_dir, data_pattern, out_file_location, batch_size,
400 + top_k):
401 + """Inference function."""
402 + with tf.Session(config=tf.ConfigProto(
403 + allow_soft_placement=True)) as sess, gfile.Open(out_file_location,
404 + "w+") as out_file:
405 + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
406 + reader, data_pattern, batch_size)
407 + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model"
408 + checkpoint_file = os.path.join(train_dir, "inference_model",
409 + inference_model_name)
410 + if not gfile.Exists(checkpoint_file + ".meta"):
411 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
412 + meta_graph_location = checkpoint_file + ".meta"
413 + logging.info("loading meta-graph: " + meta_graph_location)
414 + if FLAGS.output_model_tgz:
415 + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
416 + for model_file in glob.glob(checkpoint_file + ".*"):
417 + tar.add(model_file, arcname=os.path.basename(model_file))
418 + tar.add(os.path.join(train_dir, "model_flags.json"),
419 + arcname="model_flags.json")
420 + print("Tarred model onto " + FLAGS.output_model_tgz)
421 + with tf.device("/cpu:0"):
422 + saver = tf.train.import_meta_graph(meta_graph_location,
423 + clear_devices=True)
424 + logging.info("restoring variables from " + checkpoint_file)
425 + saver.restore(sess, checkpoint_file)
426 + input_tensor = tf.get_collection("input_batch_raw")[0]
427 + num_frames_tensor = tf.get_collection("num_frames")[0]
428 + predictions_tensor = tf.get_collection("predictions")[0]
429 + # Workaround for num_epochs issue.
430 + def set_up_init_ops(variables):
431 + init_op_list = []
432 + for variable in list(variables):
433 + if "train_input" in variable.name:
434 + init_op_list.append(tf.assign(variable, 1))
435 + variables.remove(variable)
436 + init_op_list.append(tf.variables_initializer(variables))
437 + return init_op_list
438 + sess.run(
439 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
440 + coord = tf.train.Coordinator()
441 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
442 + num_examples_processed = 0
443 + start_time = time.time()
444 + whitelisted_cls_mask = None
445 + if FLAGS.segment_labels:
446 + final_out_file = out_file
447 + out_file = tempfile.NamedTemporaryFile()
448 + logging.info(
449 + "Segment temp prediction output will be written to temp file: %s",
450 + out_file.name)
451 + if FLAGS.segment_label_ids_file:
452 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
453 + dtype=np.float32)
454 + segment_label_ids_file = FLAGS.segment_label_ids_file
455 + if segment_label_ids_file.startswith("http"):
456 + logging.info("Retrieving segment ID whitelist files from %s...",
457 + segment_label_ids_file)
458 + segment_label_ids_file, _ = urllib.request.urlretrieve(
459 + segment_label_ids_file)
460 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
461 + for line in fobj:
462 + try:
463 + cls_id = int(line)
464 + whitelisted_cls_mask[cls_id] = 1.
465 + except ValueError:
466 + # Simply skip the non-integer line.
467 + continue
468 + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
469 +
470 + #=========================================
471 + #open vocab csv file and store to dictionary
472 + #=========================================
473 + voca_dict = {}
474 + vocabs = codecs.open('./vocabulary.csv', 'r','utf-8')
475 + while True:
476 + line = vocabs.readline()
477 + if not line: break
478 + vocab_dict_item = line.split(",")
479 + if vocab_dict_item[0] != "Index":
480 + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3]
481 + vocabs.close()
482 + try:
483 + while not coord.should_stop():
484 + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
485 + [video_id_batch, video_batch, num_frames_batch])
486 + if FLAGS.segment_labels:
487 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
488 + video_segment_ids = results["video_segment_ids"]
489 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
490 + video_id_batch_val = np.array([
491 + "%s:%d" % (x.decode("utf8"), y)
492 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
493 + ])
494 + video_batch_val = results["video_batch"]
495 + num_frames_batch_val = results["num_frames_batch"]
496 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
497 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
498 + "with correct segment_labels settings.")
499 + predictions_val, = sess.run([predictions_tensor],
500 + feed_dict={
501 + input_tensor: video_batch_val,
502 + num_frames_tensor: num_frames_batch_val
503 + })
504 + now = time.time()
505 + num_examples_processed += len(video_batch_val)
506 + elapsed_time = now - start_time
507 + logging.info("num examples processed: " + str(num_examples_processed) +
508 + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) +
509 + " examples/sec: %.2f" %
510 + (num_examples_processed / elapsed_time))
511 + for line in format_lines(video_id_batch_val, predictions_val, top_k,
512 + whitelisted_cls_mask):
513 + out_file.write(line)
514 + out_file.flush()
515 + except tf.errors.OutOfRangeError:
516 + logging.info("Done with inference. The output file was written to " +
517 + out_file.name)
518 + finally:
519 + coord.request_stop()
520 + if FLAGS.segment_labels:
521 + # Re-read the file and do heap sort.
522 + # Create multiple heaps.
523 + logging.info("Post-processing segment predictions...")
524 + segment_id_list = []
525 + segment_classes = []
526 + cls_result_arr = []
527 + cls_score_dict = {}
528 + out_file.seek(0, 0)
529 + old_seg_name = '0000'
530 + counter = 0
531 + for line in out_file:
532 + counter += 1
533 + if counter / 5000 == 0:
534 + print(counter, " processed")
535 + segment_id, preds = line.decode("utf8").split(",")
536 + if segment_id == "VideoId":
537 + # Skip the headline.
538 + continue
539 + preds = preds.split(" ")
540 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
541 + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
542 + #=======================================
543 + segment_id = str(segment_id.split(":")[0])
544 + if segment_id not in segment_id_list:
545 + segment_id_list.append(str(segment_id))
546 + segment_classes.append("")
547 + index = segment_id_list.index(segment_id)
548 +
549 + if old_seg_name != segment_id:
550 + cls_score_dict[segment_id] = {}
551 + old_seg_name = segment_id
552 +
553 + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids:
554 + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment
555 + if pred_cls_ids[classes] in cls_score_dict[segment_id]:
556 + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes]
557 + else:
558 + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes]
559 + for segs,item in zip(segment_id_list,segment_classes):
560 + print('====== R E C O R D ======')
561 + cls_arr = item.split(" ")[:-1]
562 +
563 + cls_arr = list(map(int,cls_arr))
564 + cls_arr = sorted(cls_arr) #클래스별로 정렬
565 +
566 + result_string = ""
567 +
568 + temp = cls_score_dict[segs]
569 + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬
570 + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
571 + #for item in temp:
572 + for itemIndex in range(0, top_k):
573 + # Normalize tag name
574 + segment_tag = str(voca_dict[str(temp[itemIndex][0])])
575 + normalized_tag = normalize_tag(segment_tag)
576 + result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + ","
577 +
578 + cls_result_arr.append(result_string[:-1])
579 + logging.info(segs + " : " + result_string[:-1])
580 + #=======================================
581 + final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n")
582 + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
583 + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies)))
584 + final_out_file.close()
585 + out_file.close()
586 + coord.join(threads)
587 + sess.close()
588 +
398 def main(unused_argv): 589 def main(unused_argv):
399 logging.set_verbosity(tf.logging.INFO) 590 logging.set_verbosity(tf.logging.INFO)
400 if FLAGS.input_model_tgz: 591 if FLAGS.input_model_tgz:
...@@ -431,7 +622,7 @@ def main(unused_argv): ...@@ -431,7 +622,7 @@ def main(unused_argv):
431 raise ValueError("'input_data_pattern' was not specified. " 622 raise ValueError("'input_data_pattern' was not specified. "
432 "Unable to continue with inference.") 623 "Unable to continue with inference.")
433 624
434 - inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, 625 + inference2(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
435 FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) 626 FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
436 627
437 def normalize(arrs): 628 def normalize(arrs):
...@@ -443,3 +634,4 @@ def normalize(arrs): ...@@ -443,3 +634,4 @@ def normalize(arrs):
443 634
444 if __name__ == "__main__": 635 if __name__ == "__main__":
445 app.run() 636 app.run()
637 +
......