Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design1
/
PKH_Project1
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
2
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
윤영빈
2020-05-31 13:11:56 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
1b4cf2a2d2844f9b27cd4bcf3801efcbad630941
1b4cf2a2
1 parent
e4adf017
top-k label ouput
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
14 deletions
web/backend/yt8m/inference_per_segment.py
web/backend/yt8m/inference_per_segment.py
View file @
1b4cf2a
...
...
@@ -34,6 +34,7 @@ from tensorflow import logging
from
tensorflow.python.lib.io
import
file_io
import
utils
from
collections
import
Counter
import
operator
FLAGS
=
flags
.
FLAGS
...
...
@@ -81,7 +82,7 @@ if __name__ == "__main__":
"the model graph and checkpoint will be bundled in this "
"gzip tar. This file can be uploaded to Kaggle for the "
"top 10 participants."
)
flags
.
DEFINE_integer
(
"top_k"
,
1
,
"How many predictions to output per video."
)
flags
.
DEFINE_integer
(
"top_k"
,
5
,
"How many predictions to output per video."
)
# Other flags.
flags
.
DEFINE_integer
(
"batch_size"
,
512
,
...
...
@@ -260,6 +261,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
out_file
.
write
(
u"VideoId,LabelConfidencePairs
\n
"
.
encode
(
"utf8"
))
#=========================================
#open vocab csv file and store to dictionary
#=========================================
voca_dict
=
{}
vocabs
=
open
(
"./vocabulary.csv"
,
'r'
)
while
True
:
line
=
vocabs
.
readline
()
if
not
line
:
break
vocab_dict_item
=
line
.
split
(
","
)
if
vocab_dict_item
[
0
]
!=
"Index"
:
voca_dict
[
vocab_dict_item
[
0
]]
=
vocab_dict_item
[
3
]
vocabs
.
close
()
try
:
while
not
coord
.
should_stop
():
video_id_batch_val
,
video_batch_val
,
num_frames_batch_val
=
sess
.
run
(
...
...
@@ -308,7 +321,9 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
segment_id_list
=
[]
segment_classes
=
[]
cls_result_arr
=
[]
cls_score_dict
=
{}
out_file
.
seek
(
0
,
0
)
old_seg_name
=
'0000'
for
line
in
out_file
:
segment_id
,
preds
=
line
.
decode
(
"utf8"
)
.
split
(
","
)
if
segment_id
==
"VideoId"
:
...
...
@@ -317,36 +332,48 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
preds
=
preds
.
split
(
" "
)
pred_cls_ids
=
[
int
(
preds
[
idx
])
for
idx
in
range
(
0
,
len
(
preds
),
2
)]
# =======================================
pred_cls_scores
=
[
float
(
preds
[
idx
])
for
idx
in
range
(
1
,
len
(
preds
),
2
)]
#=======================================
segment_id
=
str
(
segment_id
.
split
(
":"
)[
0
])
if
segment_id
not
in
segment_id_list
:
segment_id_list
.
append
(
str
(
segment_id
))
segment_classes
.
append
(
""
)
index
=
segment_id_list
.
index
(
segment_id
)
for
classes
in
pred_cls_ids
:
segment_classes
[
index
]
=
str
(
segment_classes
[
index
])
+
str
(
classes
)
+
" "
# append classes from new segment
for
segs
,
item
in
zip
(
segment_id_list
,
segment_classes
):
if
old_seg_name
!=
segment_id
:
cls_score_dict
[
segment_id
]
=
{}
old_seg_name
=
segment_id
for
classes
in
range
(
0
,
len
(
pred_cls_ids
)):
#pred_cls_ids:
segment_classes
[
index
]
=
str
(
segment_classes
[
index
])
+
str
(
pred_cls_ids
[
classes
])
+
" "
#append classes from new segment
if
pred_cls_ids
[
classes
]
in
cls_score_dict
[
segment_id
]:
cls_score_dict
[
segment_id
][
pred_cls_ids
[
classes
]]
=
cls_score_dict
[
segment_id
][
pred_cls_ids
[
classes
]]
+
pred_cls_scores
[
classes
]
else
:
cls_score_dict
[
segment_id
][
pred_cls_ids
[
classes
]]
=
pred_cls_scores
[
classes
]
for
segs
,
item
in
zip
(
segment_id_list
,
segment_classes
):
print
(
'====== R E C O R D ======'
)
cls_arr
=
item
.
split
(
" "
)[:
-
1
]
cls_arr
=
list
(
map
(
int
,
cls_arr
))
cls_arr
=
sorted
(
cls_arr
)
cls_arr
=
list
(
map
(
int
,
cls_arr
))
cls_arr
=
sorted
(
cls_arr
)
#클래스별로 정렬
result_string
=
""
temp
=
Counter
(
cls_arr
)
for
item
in
temp
:
result_string
=
result_string
+
str
(
item
)
+
":"
+
str
(
temp
[
item
])
+
","
temp
=
cls_score_dict
[
segs
]
temp
=
sorted
(
temp
.
items
(),
key
=
operator
.
itemgetter
(
1
),
reverse
=
True
)
#밸류값 기준으로 정렬
demoninator
=
float
(
temp
[
0
][
1
]
+
temp
[
1
][
1
]
+
temp
[
2
][
1
]
+
temp
[
3
][
1
]
+
temp
[
4
][
1
])
#for item in temp:
for
itemIndex
in
range
(
0
,
top_k
):
result_string
=
result_string
+
str
(
voca_dict
[
str
(
temp
[
itemIndex
][
0
])])
+
":"
+
format
(
temp
[
itemIndex
][
1
]
/
demoninator
,
".3f"
)
+
","
cls_result_arr
.
append
(
result_string
[:
-
1
])
logging
.
info
(
segs
+
" : "
+
result_string
[:
-
1
])
#
=======================================
#
=======================================
final_out_file
.
write
(
"vid_id,seg_classes
\n
"
)
for
seg_id
,
class_indcies
in
zip
(
segment_id_list
,
cls_result_arr
):
final_out_file
.
write
(
"
%
s,
%
s
\n
"
%
(
seg_id
,
str
(
class_indcies
)))
final_out_file
.
write
(
"
%
s,
%
s
\n
"
%
(
seg_id
,
str
(
class_indcies
)))
final_out_file
.
close
()
out_file
.
close
()
...
...
@@ -354,7 +381,6 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
coord
.
join
(
threads
)
sess
.
close
()
def
main
(
unused_argv
):
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
if
FLAGS
.
input_model_tgz
:
...
...
Please
register
or
login
to post a comment