Showing
8 changed files
with
249 additions
and
39 deletions
... | @@ -12,8 +12,10 @@ import video_util as videoutil | ... | @@ -12,8 +12,10 @@ import video_util as videoutil |
12 | MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model" | 12 | MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/model/inference_model/segment_inference_model" |
13 | VOCAB_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv" | 13 | VOCAB_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv" |
14 | VIDEO_TAGS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/segment_tags.csv" | 14 | VIDEO_TAGS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/segment_tags.csv" |
15 | +VIDEO_IDS_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/videoIds.csv" | ||
15 | TAG_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/tag_vectors.model" | 16 | TAG_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/tag_vectors.model" |
16 | VIDEO_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/video_vectors.model" | 17 | VIDEO_VECTOR_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/video_vectors.model" |
18 | +VIDEO_ID_MODEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/esot3ria/videoId_vectors.model" | ||
17 | SEGMENT_LABEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/segment_label_ids.csv" | 19 | SEGMENT_LABEL_PATH = "/mnt/e/khuhub/2015104192/web/backend/yt8m/segment_label_ids.csv" |
18 | 20 | ||
19 | # Define parameters. | 21 | # Define parameters. |
... | @@ -228,8 +230,8 @@ def inference_pb(file_path, threshold): | ... | @@ -228,8 +230,8 @@ def inference_pb(file_path, threshold): |
228 | 230 | ||
229 | # 5. Create recommend videos info, Combine results. | 231 | # 5. Create recommend videos info, Combine results. |
230 | recommend_video_ids = recommender.recommend_videos(tag_result, inputVideoTagResults, TAG_VECTOR_MODEL_PATH, | 232 | recommend_video_ids = recommender.recommend_videos(tag_result, inputVideoTagResults, TAG_VECTOR_MODEL_PATH, |
231 | - VIDEO_VECTOR_MODEL_PATH, VIDEO_TOP_K) | 233 | + VIDEO_VECTOR_MODEL_PATH, VIDEO_ID_MODEL_PATH, VIDEO_TOP_K) |
232 | - video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K) for ids in recommend_video_ids] | 234 | + video_result = [videoutil.getVideoInfo(ids, VIDEO_TAGS_PATH, TAG_TOP_K,False) for ids in recommend_video_ids] |
233 | 235 | ||
234 | inference_result = { | 236 | inference_result = { |
235 | "tag_result": tag_result, | 237 | "tag_result": tag_result, | ... | ... |
This file is too large to display.
No preview for this file type
web/backend/yt8m/esot3ria/videoIds.csv
0 → 100644
This diff could not be displayed because it is too large.
... | @@ -6,7 +6,7 @@ import pandas as pd | ... | @@ -6,7 +6,7 @@ import pandas as pd |
6 | import math | 6 | import math |
7 | import activation as ac | 7 | import activation as ac |
8 | 8 | ||
9 | -def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): | 9 | +def recommend_videos(tags, segments, tag_model_path, video_model_path, video_id_model, top_k): |
10 | # 이 함수에서 모든걸 다 함 | 10 | # 이 함수에서 모든걸 다 함 |
11 | # tags는 label val 로 묶인 문자열 리스트임 | 11 | # tags는 label val 로 묶인 문자열 리스트임 |
12 | # tags의 길이는 segment의 길이 | 12 | # tags의 길이는 segment의 길이 |
... | @@ -17,11 +17,12 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): | ... | @@ -17,11 +17,12 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): |
17 | #segments는 클래스 확률 클래스 확률... 일케 저장되어 있음 | 17 | #segments는 클래스 확률 클래스 확률... 일케 저장되어 있음 |
18 | tag_vectors = Word2Vec.load(tag_model_path).wv | 18 | tag_vectors = Word2Vec.load(tag_model_path).wv |
19 | video_vectors = Word2Vec().wv.load(video_model_path) | 19 | video_vectors = Word2Vec().wv.load(video_model_path) |
20 | + video_ids = Word2Vec().wv.load(video_id_model) | ||
20 | error_tags = [] | 21 | error_tags = [] |
21 | maxSimilarSegment = 0 | 22 | maxSimilarSegment = 0 |
22 | maxSimilarity = -1 | 23 | maxSimilarity = -1 |
23 | - | 24 | + print('prev len',len(segments)) |
24 | - kernel = [np.zeros(100) for i in range(0,5)] | 25 | + kernel = [np.zeros(100) for i in range(0,9)] |
25 | tagKernel = [] | 26 | tagKernel = [] |
26 | #우선은 비교를 뜰 입력 영상의 단일 비디오벡터를 구함 | 27 | #우선은 비교를 뜰 입력 영상의 단일 비디오벡터를 구함 |
27 | video_vector = np.zeros(100) | 28 | video_vector = np.zeros(100) |
... | @@ -41,27 +42,39 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): | ... | @@ -41,27 +42,39 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): |
41 | error_tags.append(tag) | 42 | error_tags.append(tag) |
42 | 43 | ||
43 | #각 세그먼트마다 비교를 떠서 인덱스를 저장 | 44 | #각 세그먼트마다 비교를 떠서 인덱스를 저장 |
44 | - currentIndex = 0 | 45 | + midpos = math.floor(len(kernel)/2) |
45 | - for segment in segments: | 46 | + for i in range(0,midpos): |
46 | - segment_vector = np.zeros(100) | 47 | + segments.insert(0,segments[0]) |
47 | - segTags = [segment[i] for i in range(0,len(segment),2)] | 48 | + segments.append(segments[len(segments)-1]) |
48 | - segProbs = [float(segment[i]) for i in range(1,len(segment),2)]#ac.softmax([float(segment[i]) for i in range(1,len(segment),2)]) | ||
49 | 49 | ||
50 | - for tag, weight in zip(segTags,segProbs): | 50 | + currentIndex = midpos |
51 | - if tag in tag_vectors.vocab: | 51 | + for si in range(midpos,len(segments) - midpos - 1): |
52 | - segment_vector = segment_vector + (tag_vectors[tag] * float(weight)) | 52 | + similarity = 0 |
53 | - else: | 53 | + for segi in range(-1,2): |
54 | - # Pass if tag is unknown | 54 | + segment = segments[si + segi] |
55 | - if tag not in error_tags: | 55 | + segment_vector = np.zeros(100) |
56 | - error_tags.append(tag) | 56 | + segTags = [segment[i] for i in range(0,len(segment),2)] |
57 | - | 57 | + segProbs = [float(segment[i]) for i in range(1,len(segment),2)]#ac.softmax([float(segment[i]) for i in range(1,len(segment),2)]) |
58 | - #비디오 벡터와 세그먼트 벡터 비교 | 58 | + |
59 | - similarity = cos_sim(video_vector, segment_vector) #cos_sim(video_vector, segment_vector)# | 59 | + for tag, weight in zip(segTags,segProbs): |
60 | + if tag in tag_vectors.vocab: | ||
61 | + segment_vector = segment_vector + (tag_vectors[tag] * float(weight)) | ||
62 | + else: | ||
63 | + # Pass if tag is unknown | ||
64 | + if tag not in error_tags: | ||
65 | + error_tags.append(tag) | ||
66 | + | ||
67 | + #비디오 벡터와 세그먼트 벡터 비교 | ||
68 | + #similarity = similarity + cos_sim(video_vector, segment_vector) #cos_sim(video_vector, segment_vector)# | ||
60 | 69 | ||
61 | - for currentSegmentTag, videoVectorTag in zip(segTags,videoTagList): | 70 | + for currentSegmentTag, videoVectorTag,videoVectorTagPred in zip(segTags,videoTagList,tag_preds): |
62 | - if(currentSegmentTag in tag_vectors.vocab) and (videoVectorTag in tag_vectors.vocab): | 71 | + if(currentSegmentTag in tag_vectors.vocab) and (videoVectorTag in tag_vectors.vocab): |
63 | - similarity = similarity + tag_vectors.similarity(currentSegmentTag,videoVectorTag) | 72 | + prob = float(videoVectorTagPred) |
64 | - | 73 | + if videoVectorTag not in segTags: |
74 | + prob = 0 | ||
75 | + similarity = similarity + (tag_vectors.similarity(currentSegmentTag,videoVectorTag) * prob) | ||
76 | + | ||
77 | + | ||
65 | if similarity >= maxSimilarity: | 78 | if similarity >= maxSimilarity: |
66 | maxSimilarSegment = currentIndex | 79 | maxSimilarSegment = currentIndex |
67 | maxSimilarity = similarity | 80 | maxSimilarity = similarity |
... | @@ -71,7 +84,7 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): | ... | @@ -71,7 +84,7 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): |
71 | maxSimilarSegment = len(segments) - int(len(kernel)/2) - 1 | 84 | maxSimilarSegment = len(segments) - int(len(kernel)/2) - 1 |
72 | #세그먼트 인덱스 증가 | 85 | #세그먼트 인덱스 증가 |
73 | currentIndex = currentIndex + 1 | 86 | currentIndex = currentIndex + 1 |
74 | - | 87 | + print('maxSimilarSegment',maxSimilarSegment,'len',len(segments)) |
75 | #커널 생성 | 88 | #커널 생성 |
76 | for k in range (0,len(kernel)): | 89 | for k in range (0,len(kernel)): |
77 | segment = segments[maxSimilarSegment - math.floor(len(kernel)/2) + k] | 90 | segment = segments[maxSimilarSegment - math.floor(len(kernel)/2) + k] |
... | @@ -139,11 +152,11 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): | ... | @@ -139,11 +152,11 @@ def recommend_videos(tags, segments, tag_model_path, video_model_path, top_k): |
139 | tagList.append([row[i].split(":")[0],row[i].split(":")[1]]) | 152 | tagList.append([row[i].split(":")[0],row[i].split(":")[1]]) |
140 | segmentTagList.append(tagList) | 153 | segmentTagList.append(tagList) |
141 | 154 | ||
142 | - similar_ids = [] | 155 | + #similar_ids = [] |
143 | - for i in range(0,top_k): | 156 | + #for i in range(0,top_k): |
144 | - similar_ids.append(minimunVideoIds[i][0]) | 157 | + # similar_ids.append(minimunVideoIds[i][0]) |
145 | 158 | ||
146 | - #similar_ids = [x[0] for x in video_vectors.similar_by_vector(video_vector, top_k)] | 159 | + similar_ids = [x[0] for x in video_ids.similar_by_vector(video_vector, top_k)] |
147 | print('results =' ,similar_ids) | 160 | print('results =' ,similar_ids) |
148 | return similar_ids | 161 | return similar_ids |
149 | 162 | ||
... | @@ -199,12 +212,14 @@ def differenceMax(arrs, _kernel, w2v, videoTaglist): | ... | @@ -199,12 +212,14 @@ def differenceMax(arrs, _kernel, w2v, videoTaglist): |
199 | convResult = 0 | 212 | convResult = 0 |
200 | processed_vocabNum = 1 | 213 | processed_vocabNum = 1 |
201 | for i in range(0, s): | 214 | for i in range(0, s): |
202 | - #if i == midpos: | 215 | + if(_kernel[i][0] not in arrs[j - midpos + i][0]):# and ((videoTaglist[0] not in arrs[j - midpos + i][0:2])) and ((videoTaglist[1] not in arrs[j - midpos + i][0:5])): |
203 | - if(_kernel[i][0] not in arrs[j - midpos + i][0:2][0]):# and ((videoTaglist[0] not in arrs[j - midpos + i][0:2])) and ((videoTaglist[1] not in arrs[j - midpos + i][0:5])): | ||
204 | continue | 216 | continue |
205 | for ind in range(0,5): | 217 | for ind in range(0,5): |
206 | if(arrs[j - midpos + i][ind][0] in w2v.vocab) and (_kernel[i][ind] in w2v.vocab): | 218 | if(arrs[j - midpos + i][ind][0] in w2v.vocab) and (_kernel[i][ind] in w2v.vocab): |
207 | - convResult = convResult + (w2v.similarity(arrs[j - midpos + i][ind][0],_kernel[i][ind]) * float(arrs[j - midpos + i][ind][1])) | 219 | + prob = float(arrs[j - midpos + i][ind][1]) |
220 | + if arrs[j - midpos + i][ind][0] not in videoTaglist: | ||
221 | + prob = 0 | ||
222 | + convResult = convResult + (w2v.similarity(arrs[j - midpos + i][ind][0],_kernel[i][ind]) * prob) | ||
208 | processed_vocabNum = processed_vocabNum + 1 | 223 | processed_vocabNum = processed_vocabNum + 1 |
209 | 224 | ||
210 | if prevMax < convResult: | 225 | if prevMax < convResult: | ... | ... |
... | @@ -16,7 +16,7 @@ def getURL(vid_id): | ... | @@ -16,7 +16,7 @@ def getURL(vid_id): |
16 | 16 | ||
17 | print(getURL('nzwW')) | 17 | print(getURL('nzwW')) |
18 | 18 | ||
19 | -def getVideoInfo(vid_id, video_tags_path, top_k): | 19 | +def getVideoInfo(vid_id, video_tags_path, top_k, isPerVideo): |
20 | print("vid_id = ",vid_id) | 20 | print("vid_id = ",vid_id) |
21 | video_url = getURL(vid_id[0:4]) | 21 | video_url = getURL(vid_id[0:4]) |
22 | 22 | ||
... | @@ -31,8 +31,9 @@ def getVideoInfo(vid_id, video_tags_path, top_k): | ... | @@ -31,8 +31,9 @@ def getVideoInfo(vid_id, video_tags_path, top_k): |
31 | if video_url == "": | 31 | if video_url == "": |
32 | for x in video_tags: | 32 | for x in video_tags: |
33 | video_url = video_url + ' ' + x | 33 | video_url = video_url + ' ' + x |
34 | - | 34 | + |
35 | - video_url = video_url + '\nThe similar point is : ' + str(float(vid_id[5:]) * 5) | 35 | + if isPerVideo == False: |
36 | + video_url = video_url + '\nThe similar point is : ' + str(float(vid_id[5:]) * 5) | ||
36 | 37 | ||
37 | return { | 38 | return { |
38 | "video_url": video_url, | 39 | "video_url": video_url, | ... | ... |
... | @@ -24,8 +24,8 @@ if __name__ == '__main__': | ... | @@ -24,8 +24,8 @@ if __name__ == '__main__': |
24 | video_vectors = Word2Vec().wv # Empty model | 24 | video_vectors = Word2Vec().wv # Empty model |
25 | 25 | ||
26 | # Load video recommendation tags. | 26 | # Load video recommendation tags. |
27 | - video_tags = pd.read_csv('./segment_tags.csv', encoding='utf-8',error_bad_lines=False) | 27 | + #video_tags = pd.read_csv('./segment_tags.csv', encoding='utf-8',error_bad_lines=False) |
28 | - | 28 | + video_tags = pd.read_csv('./videoIds.csv', encoding='utf-8',error_bad_lines=False) |
29 | # Define batch variables. | 29 | # Define batch variables. |
30 | batch_video_ids = [] | 30 | batch_video_ids = [] |
31 | batch_video_vectors = [] | 31 | batch_video_vectors = [] |
... | @@ -64,7 +64,7 @@ if __name__ == '__main__': | ... | @@ -64,7 +64,7 @@ if __name__ == '__main__': |
64 | print(error_tags) | 64 | print(error_tags) |
65 | print(len(error_tags)) | 65 | print(len(error_tags)) |
66 | 66 | ||
67 | - video_vectors.save("video_vectors.model") | 67 | + video_vectors.save("videoId_vectors.model") |
68 | 68 | ||
69 | # Usage | 69 | # Usage |
70 | # video_vectors = Word2Vec().wv.load("video_vectors.model") | 70 | # video_vectors = Word2Vec().wv.load("video_vectors.model") | ... | ... |
... | @@ -395,6 +395,197 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -395,6 +395,197 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
395 | coord.join(threads) | 395 | coord.join(threads) |
396 | sess.close() | 396 | sess.close() |
397 | 397 | ||
398 | + | ||
399 | +def inference2(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
400 | + top_k): | ||
401 | + """Inference function.""" | ||
402 | + with tf.Session(config=tf.ConfigProto( | ||
403 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
404 | + "w+") as out_file: | ||
405 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
406 | + reader, data_pattern, batch_size) | ||
407 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
408 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
409 | + inference_model_name) | ||
410 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
411 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
412 | + meta_graph_location = checkpoint_file + ".meta" | ||
413 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
414 | + if FLAGS.output_model_tgz: | ||
415 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
416 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
417 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
418 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
419 | + arcname="model_flags.json") | ||
420 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
421 | + with tf.device("/cpu:0"): | ||
422 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
423 | + clear_devices=True) | ||
424 | + logging.info("restoring variables from " + checkpoint_file) | ||
425 | + saver.restore(sess, checkpoint_file) | ||
426 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
427 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
428 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
429 | + # Workaround for num_epochs issue. | ||
430 | + def set_up_init_ops(variables): | ||
431 | + init_op_list = [] | ||
432 | + for variable in list(variables): | ||
433 | + if "train_input" in variable.name: | ||
434 | + init_op_list.append(tf.assign(variable, 1)) | ||
435 | + variables.remove(variable) | ||
436 | + init_op_list.append(tf.variables_initializer(variables)) | ||
437 | + return init_op_list | ||
438 | + sess.run( | ||
439 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
440 | + coord = tf.train.Coordinator() | ||
441 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
442 | + num_examples_processed = 0 | ||
443 | + start_time = time.time() | ||
444 | + whitelisted_cls_mask = None | ||
445 | + if FLAGS.segment_labels: | ||
446 | + final_out_file = out_file | ||
447 | + out_file = tempfile.NamedTemporaryFile() | ||
448 | + logging.info( | ||
449 | + "Segment temp prediction output will be written to temp file: %s", | ||
450 | + out_file.name) | ||
451 | + if FLAGS.segment_label_ids_file: | ||
452 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
453 | + dtype=np.float32) | ||
454 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
455 | + if segment_label_ids_file.startswith("http"): | ||
456 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
457 | + segment_label_ids_file) | ||
458 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
459 | + segment_label_ids_file) | ||
460 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
461 | + for line in fobj: | ||
462 | + try: | ||
463 | + cls_id = int(line) | ||
464 | + whitelisted_cls_mask[cls_id] = 1. | ||
465 | + except ValueError: | ||
466 | + # Simply skip the non-integer line. | ||
467 | + continue | ||
468 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
469 | + | ||
470 | + #========================================= | ||
471 | + #open vocab csv file and store to dictionary | ||
472 | + #========================================= | ||
473 | + voca_dict = {} | ||
474 | + vocabs = codecs.open('./vocabulary.csv', 'r','utf-8') | ||
475 | + while True: | ||
476 | + line = vocabs.readline() | ||
477 | + if not line: break | ||
478 | + vocab_dict_item = line.split(",") | ||
479 | + if vocab_dict_item[0] != "Index": | ||
480 | + voca_dict[vocab_dict_item[0]] = vocab_dict_item[3] | ||
481 | + vocabs.close() | ||
482 | + try: | ||
483 | + while not coord.should_stop(): | ||
484 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
485 | + [video_id_batch, video_batch, num_frames_batch]) | ||
486 | + if FLAGS.segment_labels: | ||
487 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
488 | + video_segment_ids = results["video_segment_ids"] | ||
489 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
490 | + video_id_batch_val = np.array([ | ||
491 | + "%s:%d" % (x.decode("utf8"), y) | ||
492 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
493 | + ]) | ||
494 | + video_batch_val = results["video_batch"] | ||
495 | + num_frames_batch_val = results["num_frames_batch"] | ||
496 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
497 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
498 | + "with correct segment_labels settings.") | ||
499 | + predictions_val, = sess.run([predictions_tensor], | ||
500 | + feed_dict={ | ||
501 | + input_tensor: video_batch_val, | ||
502 | + num_frames_tensor: num_frames_batch_val | ||
503 | + }) | ||
504 | + now = time.time() | ||
505 | + num_examples_processed += len(video_batch_val) | ||
506 | + elapsed_time = now - start_time | ||
507 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
508 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
509 | + " examples/sec: %.2f" % | ||
510 | + (num_examples_processed / elapsed_time)) | ||
511 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
512 | + whitelisted_cls_mask): | ||
513 | + out_file.write(line) | ||
514 | + out_file.flush() | ||
515 | + except tf.errors.OutOfRangeError: | ||
516 | + logging.info("Done with inference. The output file was written to " + | ||
517 | + out_file.name) | ||
518 | + finally: | ||
519 | + coord.request_stop() | ||
520 | + if FLAGS.segment_labels: | ||
521 | + # Re-read the file and do heap sort. | ||
522 | + # Create multiple heaps. | ||
523 | + logging.info("Post-processing segment predictions...") | ||
524 | + segment_id_list = [] | ||
525 | + segment_classes = [] | ||
526 | + cls_result_arr = [] | ||
527 | + cls_score_dict = {} | ||
528 | + out_file.seek(0, 0) | ||
529 | + old_seg_name = '0000' | ||
530 | + counter = 0 | ||
531 | + for line in out_file: | ||
532 | + counter += 1 | ||
533 | + if counter / 5000 == 0: | ||
534 | + print(counter, " processed") | ||
535 | + segment_id, preds = line.decode("utf8").split(",") | ||
536 | + if segment_id == "VideoId": | ||
537 | + # Skip the headline. | ||
538 | + continue | ||
539 | + preds = preds.split(" ") | ||
540 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
541 | + pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | ||
542 | + #======================================= | ||
543 | + segment_id = str(segment_id.split(":")[0]) | ||
544 | + if segment_id not in segment_id_list: | ||
545 | + segment_id_list.append(str(segment_id)) | ||
546 | + segment_classes.append("") | ||
547 | + index = segment_id_list.index(segment_id) | ||
548 | + | ||
549 | + if old_seg_name != segment_id: | ||
550 | + cls_score_dict[segment_id] = {} | ||
551 | + old_seg_name = segment_id | ||
552 | + | ||
553 | + for classes in range(0,len(pred_cls_ids)):#pred_cls_ids: | ||
554 | + segment_classes[index] = str(segment_classes[index]) + str(pred_cls_ids[classes]) + " " #append classes from new segment | ||
555 | + if pred_cls_ids[classes] in cls_score_dict[segment_id]: | ||
556 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = cls_score_dict[segment_id][pred_cls_ids[classes]] + pred_cls_scores[classes] | ||
557 | + else: | ||
558 | + cls_score_dict[segment_id][pred_cls_ids[classes]] = pred_cls_scores[classes] | ||
559 | + for segs,item in zip(segment_id_list,segment_classes): | ||
560 | + print('====== R E C O R D ======') | ||
561 | + cls_arr = item.split(" ")[:-1] | ||
562 | + | ||
563 | + cls_arr = list(map(int,cls_arr)) | ||
564 | + cls_arr = sorted(cls_arr) #클래스별로 정렬 | ||
565 | + | ||
566 | + result_string = "" | ||
567 | + | ||
568 | + temp = cls_score_dict[segs] | ||
569 | + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 | ||
570 | + demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) | ||
571 | + #for item in temp: | ||
572 | + for itemIndex in range(0, top_k): | ||
573 | + # Normalize tag name | ||
574 | + segment_tag = str(voca_dict[str(temp[itemIndex][0])]) | ||
575 | + normalized_tag = normalize_tag(segment_tag) | ||
576 | + result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," | ||
577 | + | ||
578 | + cls_result_arr.append(result_string[:-1]) | ||
579 | + logging.info(segs + " : " + result_string[:-1]) | ||
580 | + #======================================= | ||
581 | + final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") | ||
582 | + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | ||
583 | + final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) | ||
584 | + final_out_file.close() | ||
585 | + out_file.close() | ||
586 | + coord.join(threads) | ||
587 | + sess.close() | ||
588 | + | ||
398 | def main(unused_argv): | 589 | def main(unused_argv): |
399 | logging.set_verbosity(tf.logging.INFO) | 590 | logging.set_verbosity(tf.logging.INFO) |
400 | if FLAGS.input_model_tgz: | 591 | if FLAGS.input_model_tgz: |
... | @@ -431,7 +622,7 @@ def main(unused_argv): | ... | @@ -431,7 +622,7 @@ def main(unused_argv): |
431 | raise ValueError("'input_data_pattern' was not specified. " | 622 | raise ValueError("'input_data_pattern' was not specified. " |
432 | "Unable to continue with inference.") | 623 | "Unable to continue with inference.") |
433 | 624 | ||
434 | - inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | 625 | + inference2(reader, FLAGS.train_dir, FLAGS.input_data_pattern, |
435 | FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | 626 | FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) |
436 | 627 | ||
437 | def normalize(arrs): | 628 | def normalize(arrs): |
... | @@ -443,3 +634,4 @@ def normalize(arrs): | ... | @@ -443,3 +634,4 @@ def normalize(arrs): |
443 | 634 | ||
444 | if __name__ == "__main__": | 635 | if __name__ == "__main__": |
445 | app.run() | 636 | app.run() |
637 | + | ... | ... |
-
Please register or login to post a comment