Showing
8 changed files
with
63 additions
and
31 deletions
... | @@ -93,7 +93,6 @@ class VideoFileUploadView(APIView): | ... | @@ -93,7 +93,6 @@ class VideoFileUploadView(APIView): |
93 | else: | 93 | else: |
94 | return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) | 94 | return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) |
95 | 95 | ||
96 | - | ||
97 | class VideoFileList(APIView): | 96 | class VideoFileList(APIView): |
98 | 97 | ||
99 | def get_object(self, pk): | 98 | def get_object(self, pk): | ... | ... |
1 | +{"model": "NetVLADModelLF", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -33,17 +33,17 @@ FLAGS = flags.FLAGS | ... | @@ -33,17 +33,17 @@ FLAGS = flags.FLAGS |
33 | if __name__ == "__main__": | 33 | if __name__ == "__main__": |
34 | # Dataset flags. | 34 | # Dataset flags. |
35 | flags.DEFINE_string( | 35 | flags.DEFINE_string( |
36 | - "train_dir", "F:/yt8mDataset/savedModel", | 36 | + "train_dir", "F:/yt8mDataset/savedModelaa2", |
37 | "The directory to load the model files from. " | 37 | "The directory to load the model files from. " |
38 | "The tensorboard metrics files are also saved to this " | 38 | "The tensorboard metrics files are also saved to this " |
39 | "directory.") | 39 | "directory.") |
40 | flags.DEFINE_string( | 40 | flags.DEFINE_string( |
41 | - "eval_data_pattern", "", | 41 | + "eval_data_pattern", "F:/yt8mDataset/train2/train*.tfrecord", |
42 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " | 42 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " |
43 | "format. The SequenceExamples are expected to have an 'rgb' byte array " | 43 | "format. The SequenceExamples are expected to have an 'rgb' byte array " |
44 | "sequence feature as well as a 'labels' int64 context feature.") | 44 | "sequence feature as well as a 'labels' int64 context feature.") |
45 | flags.DEFINE_bool( | 45 | flags.DEFINE_bool( |
46 | - "segment_labels", False, | 46 | + "segment_labels", True, |
47 | "If set, then --eval_data_pattern must be frame-level features (but with" | 47 | "If set, then --eval_data_pattern must be frame-level features (but with" |
48 | " segment_labels). Otherwise, --eval_data_pattern must be aggregated " | 48 | " segment_labels). Otherwise, --eval_data_pattern must be aggregated " |
49 | "video-level features. The model must also be set appropriately (i.e. to " | 49 | "video-level features. The model must also be set appropriately (i.e. to " |
... | @@ -54,7 +54,7 @@ if __name__ == "__main__": | ... | @@ -54,7 +54,7 @@ if __name__ == "__main__": |
54 | "How many examples to process per batch.") | 54 | "How many examples to process per batch.") |
55 | flags.DEFINE_integer("num_readers", 8, | 55 | flags.DEFINE_integer("num_readers", 8, |
56 | "How many threads to use for reading input files.") | 56 | "How many threads to use for reading input files.") |
57 | - flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") | 57 | + flags.DEFINE_boolean("run_once", True, "Whether to run eval only once.") |
58 | flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") | 58 | flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
59 | 59 | ||
60 | 60 | ... | ... |
... | @@ -46,7 +46,7 @@ flags.DEFINE_string( | ... | @@ -46,7 +46,7 @@ flags.DEFINE_string( |
46 | "Some Frame-Level models can be decomposed into a " | 46 | "Some Frame-Level models can be decomposed into a " |
47 | "generalized pooling operation followed by a " | 47 | "generalized pooling operation followed by a " |
48 | "classifier layer") | 48 | "classifier layer") |
49 | -flags.DEFINE_integer("lstm_cells", 512, "Number of LSTM cells.") | 49 | +flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.") |
50 | flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") | 50 | flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") |
51 | 51 | ||
52 | 52 | ... | ... |
... | @@ -23,7 +23,7 @@ import tarfile | ... | @@ -23,7 +23,7 @@ import tarfile |
23 | import tempfile | 23 | import tempfile |
24 | import time | 24 | import time |
25 | import numpy as np | 25 | import numpy as np |
26 | - | 26 | +import ssl |
27 | import readers | 27 | import readers |
28 | from six.moves import urllib | 28 | from six.moves import urllib |
29 | import tensorflow as tf | 29 | import tensorflow as tf |
... | @@ -39,11 +39,11 @@ FLAGS = flags.FLAGS | ... | @@ -39,11 +39,11 @@ FLAGS = flags.FLAGS |
39 | if __name__ == "__main__": | 39 | if __name__ == "__main__": |
40 | # Input | 40 | # Input |
41 | flags.DEFINE_string( | 41 | flags.DEFINE_string( |
42 | - "train_dir", "", "The directory to load the model files from. We assume " | 42 | + "train_dir", "/mnt/f/yt8mDataset/savedModelaa", "The directory to load the model files from. We assume " |
43 | "that you have already run eval.py onto this, such that " | 43 | "that you have already run eval.py onto this, such that " |
44 | "inference_model.* files already exist.") | 44 | "inference_model.* files already exist.") |
45 | flags.DEFINE_string( | 45 | flags.DEFINE_string( |
46 | - "input_data_pattern", "", | 46 | + "input_data_pattern", "/mnt/f/yt8mDataset/train2/train*.tfrecord", |
47 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " | 47 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " |
48 | "format. The SequenceExamples are expected to have an 'rgb' byte array " | 48 | "format. The SequenceExamples are expected to have an 'rgb' byte array " |
49 | "sequence feature as well as a 'labels' int64 context feature.") | 49 | "sequence feature as well as a 'labels' int64 context feature.") |
... | @@ -60,7 +60,7 @@ if __name__ == "__main__": | ... | @@ -60,7 +60,7 @@ if __name__ == "__main__": |
60 | "be created and the contents of the .tgz file will be " | 60 | "be created and the contents of the .tgz file will be " |
61 | "untarred here.") | 61 | "untarred here.") |
62 | flags.DEFINE_bool( | 62 | flags.DEFINE_bool( |
63 | - "segment_labels", False, | 63 | + "segment_labels", True, |
64 | "If set, then --input_data_pattern must be frame-level features (but with" | 64 | "If set, then --input_data_pattern must be frame-level features (but with" |
65 | " segment_labels). Otherwise, --input_data_pattern must be aggregated " | 65 | " segment_labels). Otherwise, --input_data_pattern must be aggregated " |
66 | "video-level features. The model must also be set appropriately (i.e. to " | 66 | "video-level features. The model must also be set appropriately (i.e. to " |
... | @@ -69,18 +69,18 @@ if __name__ == "__main__": | ... | @@ -69,18 +69,18 @@ if __name__ == "__main__": |
69 | "Limit total number of segment outputs per entity.") | 69 | "Limit total number of segment outputs per entity.") |
70 | flags.DEFINE_string( | 70 | flags.DEFINE_string( |
71 | "segment_label_ids_file", | 71 | "segment_label_ids_file", |
72 | - "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | 72 | + "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv", |
73 | "The file that contains the segment label ids.") | 73 | "The file that contains the segment label ids.") |
74 | 74 | ||
75 | # Output | 75 | # Output |
76 | - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | 76 | + flags.DEFINE_string("output_file", "/mnt/f/yt8mDataset/kaggle_solution_validation_temp.csv", "The file to save the predictions to.") |
77 | flags.DEFINE_string( | 77 | flags.DEFINE_string( |
78 | "output_model_tgz", "", | 78 | "output_model_tgz", "", |
79 | "If given, should be a filename with a .tgz extension, " | 79 | "If given, should be a filename with a .tgz extension, " |
80 | "the model graph and checkpoint will be bundled in this " | 80 | "the model graph and checkpoint will be bundled in this " |
81 | "gzip tar. This file can be uploaded to Kaggle for the " | 81 | "gzip tar. This file can be uploaded to Kaggle for the " |
82 | "top 10 participants.") | 82 | "top 10 participants.") |
83 | - flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | 83 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
84 | 84 | ||
85 | # Other flags. | 85 | # Other flags. |
86 | flags.DEFINE_integer("batch_size", 512, | 86 | flags.DEFINE_integer("batch_size", 512, |
... | @@ -108,18 +108,15 @@ def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ... | @@ -108,18 +108,15 @@ def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): |
108 | 108 | ||
109 | def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | 109 | def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): |
110 | """Creates the section of the graph which reads the input data. | 110 | """Creates the section of the graph which reads the input data. |
111 | - | ||
112 | Args: | 111 | Args: |
113 | reader: A class which parses the input data. | 112 | reader: A class which parses the input data. |
114 | data_pattern: A 'glob' style path to the data files. | 113 | data_pattern: A 'glob' style path to the data files. |
115 | batch_size: How many examples to process at a time. | 114 | batch_size: How many examples to process at a time. |
116 | num_readers: How many I/O threads to use. | 115 | num_readers: How many I/O threads to use. |
117 | - | ||
118 | Returns: | 116 | Returns: |
119 | A tuple containing the features tensor, labels tensor, and optionally a | 117 | A tuple containing the features tensor, labels tensor, and optionally a |
120 | tensor containing the number of frames per video. The exact dimensions | 118 | tensor containing the number of frames per video. The exact dimensions |
121 | depend on the reader being used. | 119 | depend on the reader being used. |
122 | - | ||
123 | Raises: | 120 | Raises: |
124 | IOError: If no files matching the given pattern were found. | 121 | IOError: If no files matching the given pattern were found. |
125 | """ | 122 | """ |
... | @@ -243,6 +240,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -243,6 +240,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
243 | whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | 240 | whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), |
244 | dtype=np.float32) | 241 | dtype=np.float32) |
245 | segment_label_ids_file = FLAGS.segment_label_ids_file | 242 | segment_label_ids_file = FLAGS.segment_label_ids_file |
243 | + """ | ||
244 | + if segment_label_ids_file.startswith("http"): | ||
245 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
246 | + segment_label_ids_file) | ||
247 | + segment_label_ids_file, _ = urllib.request.urlretrieve(segment_label_ids_file) | ||
248 | + """ | ||
246 | if segment_label_ids_file.startswith("http"): | 249 | if segment_label_ids_file.startswith("http"): |
247 | logging.info("Retrieving segment ID whitelist files from %s...", | 250 | logging.info("Retrieving segment ID whitelist files from %s...", |
248 | segment_label_ids_file) | 251 | segment_label_ids_file) |
... | @@ -307,6 +310,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -307,6 +310,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
307 | heaps = {} | 310 | heaps = {} |
308 | out_file.seek(0, 0) | 311 | out_file.seek(0, 0) |
309 | for line in out_file: | 312 | for line in out_file: |
313 | + print(line) | ||
310 | segment_id, preds = line.decode("utf8").split(",") | 314 | segment_id, preds = line.decode("utf8").split(",") |
311 | if segment_id == "VideoId": | 315 | if segment_id == "VideoId": |
312 | # Skip the headline. | 316 | # Skip the headline. | ... | ... |
... | @@ -41,11 +41,11 @@ FLAGS = flags.FLAGS | ... | @@ -41,11 +41,11 @@ FLAGS = flags.FLAGS |
41 | if __name__ == "__main__": | 41 | if __name__ == "__main__": |
42 | # Input | 42 | # Input |
43 | flags.DEFINE_string( | 43 | flags.DEFINE_string( |
44 | - "train_dir", "", "The directory to load the model files from. We assume " | 44 | + "train_dir", "E:/savedModel", "The directory to load the model files from. We assume " |
45 | "that you have already run eval.py onto this, such that " | 45 | "that you have already run eval.py onto this, such that " |
46 | "inference_model.* files already exist.") | 46 | "inference_model.* files already exist.") |
47 | flags.DEFINE_string( | 47 | flags.DEFINE_string( |
48 | - "input_data_pattern", "", | 48 | + "input_data_pattern", "F:/yt8mDataset/test/train*.tfrecord", |
49 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " | 49 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " |
50 | "format. The SequenceExamples are expected to have an 'rgb' byte array " | 50 | "format. The SequenceExamples are expected to have an 'rgb' byte array " |
51 | "sequence feature as well as a 'labels' int64 context feature.") | 51 | "sequence feature as well as a 'labels' int64 context feature.") |
... | @@ -62,7 +62,7 @@ if __name__ == "__main__": | ... | @@ -62,7 +62,7 @@ if __name__ == "__main__": |
62 | "be created and the contents of the .tgz file will be " | 62 | "be created and the contents of the .tgz file will be " |
63 | "untarred here.") | 63 | "untarred here.") |
64 | flags.DEFINE_bool( | 64 | flags.DEFINE_bool( |
65 | - "segment_labels", False, | 65 | + "segment_labels", True, |
66 | "If set, then --input_data_pattern must be frame-level features (but with" | 66 | "If set, then --input_data_pattern must be frame-level features (but with" |
67 | " segment_labels). Otherwise, --input_data_pattern must be aggregated " | 67 | " segment_labels). Otherwise, --input_data_pattern must be aggregated " |
68 | "video-level features. The model must also be set appropriately (i.e. to " | 68 | "video-level features. The model must also be set appropriately (i.e. to " |
... | @@ -75,7 +75,7 @@ if __name__ == "__main__": | ... | @@ -75,7 +75,7 @@ if __name__ == "__main__": |
75 | "The file that contains the segment label ids.") | 75 | "The file that contains the segment label ids.") |
76 | 76 | ||
77 | # Output | 77 | # Output |
78 | - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | 78 | + flags.DEFINE_string("output_file", "F:/yt8mDataset/kaggle_solution_validation_temp.csv", "The file to save the predictions to.") |
79 | flags.DEFINE_string( | 79 | flags.DEFINE_string( |
80 | "output_model_tgz", "", | 80 | "output_model_tgz", "", |
81 | "If given, should be a filename with a .tgz extension, " | 81 | "If given, should be a filename with a .tgz extension, " |
... | @@ -85,7 +85,7 @@ if __name__ == "__main__": | ... | @@ -85,7 +85,7 @@ if __name__ == "__main__": |
85 | flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") | 85 | flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
86 | 86 | ||
87 | # Other flags. | 87 | # Other flags. |
88 | - flags.DEFINE_integer("batch_size", 32, | 88 | + flags.DEFINE_integer("batch_size", 16, |
89 | "How many examples to process per batch.") | 89 | "How many examples to process per batch.") |
90 | flags.DEFINE_integer("num_readers", 1, | 90 | flags.DEFINE_integer("num_readers", 1, |
91 | "How many threads to use for reading input files.") | 91 | "How many threads to use for reading input files.") |
... | @@ -269,6 +269,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -269,6 +269,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
269 | except ValueError: | 269 | except ValueError: |
270 | # Simply skip the non-integer line. | 270 | # Simply skip the non-integer line. |
271 | continue | 271 | continue |
272 | + fobj.close() | ||
272 | 273 | ||
273 | out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | 274 | out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) |
274 | 275 | ||
... | @@ -286,8 +287,10 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -286,8 +287,10 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
286 | vocabs.close() | 287 | vocabs.close() |
287 | try: | 288 | try: |
288 | while not coord.should_stop(): | 289 | while not coord.should_stop(): |
290 | + print("CAME IN") | ||
289 | video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | 291 | video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( |
290 | [video_id_batch, video_batch, num_frames_batch]) | 292 | [video_id_batch, video_batch, num_frames_batch]) |
293 | + print("SESS OUT") | ||
291 | if FLAGS.segment_labels: | 294 | if FLAGS.segment_labels: |
292 | results = get_segments(video_batch_val, num_frames_batch_val, 5) | 295 | results = get_segments(video_batch_val, num_frames_batch_val, 5) |
293 | video_segment_ids = results["video_segment_ids"] | 296 | video_segment_ids = results["video_segment_ids"] |
... | @@ -333,6 +336,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -333,6 +336,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
333 | segment_classes = [] | 336 | segment_classes = [] |
334 | cls_result_arr = [] | 337 | cls_result_arr = [] |
335 | cls_score_dict = {} | 338 | cls_score_dict = {} |
339 | + resultList = [] | ||
336 | out_file.seek(0, 0) | 340 | out_file.seek(0, 0) |
337 | old_seg_name = '0000' | 341 | old_seg_name = '0000' |
338 | counter = 0 | 342 | counter = 0 |
... | @@ -344,12 +348,30 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -344,12 +348,30 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
344 | if segment_id == "VideoId": | 348 | if segment_id == "VideoId": |
345 | # Skip the headline. | 349 | # Skip the headline. |
346 | continue | 350 | continue |
347 | - | 351 | + print(line) |
348 | preds = preds.split(" ") | 352 | preds = preds.split(" ") |
353 | + """ | ||
349 | pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | 354 | pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] |
350 | pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] | 355 | pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] |
356 | + """ | ||
357 | + """ | ||
358 | + denom = 1.0 | ||
359 | + for i in range(0,top_k): | ||
360 | + denom = denom + float(preds[(i*2) + 1]) | ||
361 | + print("DENOM = ",denom) | ||
362 | + for i in range(0,top_k): | ||
363 | + preds[(i*2) + 1] = float(preds[(i*2) + 1])/denom | ||
364 | + """ | ||
365 | + | ||
366 | + segment_id = "{0}_{1}".format(str(segment_id.split(":")[0]),str(int(segment_id.split(":")[1])/5)) | ||
367 | + resultList.append("{0},{1},{2},{3},{4},{5}".format(segment_id, | ||
368 | + preds[0]+","+str(preds[1]), | ||
369 | + preds[2]+","+str(preds[3]), | ||
370 | + preds[4]+","+str(preds[5]), | ||
371 | + preds[6]+","+str(preds[7]), | ||
372 | + preds[8]+","+str(preds[9]))) | ||
351 | #======================================= | 373 | #======================================= |
352 | - segment_id = str(segment_id.split(":")[0]) | 374 | + """ |
353 | if segment_id not in segment_id_list: | 375 | if segment_id not in segment_id_list: |
354 | segment_id_list.append(str(segment_id)) | 376 | segment_id_list.append(str(segment_id)) |
355 | segment_classes.append("") | 377 | segment_classes.append("") |
... | @@ -372,12 +394,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -372,12 +394,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
372 | cls_arr = item.split(" ")[:-1] | 394 | cls_arr = item.split(" ")[:-1] |
373 | 395 | ||
374 | cls_arr = list(map(int,cls_arr)) | 396 | cls_arr = list(map(int,cls_arr)) |
375 | - cls_arr = sorted(cls_arr) #클래스별로 정렬 | 397 | + cls_arr = sorted(cls_arr) #sort by class |
376 | 398 | ||
377 | result_string = "" | 399 | result_string = "" |
378 | 400 | ||
379 | temp = cls_score_dict[segs] | 401 | temp = cls_score_dict[segs] |
380 | - temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 | 402 | + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #sort by value |
381 | demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) | 403 | demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) |
382 | #for item in temp: | 404 | #for item in temp: |
383 | for itemIndex in range(0, top_k): | 405 | for itemIndex in range(0, top_k): |
... | @@ -387,11 +409,16 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ... | @@ -387,11 +409,16 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, |
387 | result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," | 409 | result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," |
388 | 410 | ||
389 | cls_result_arr.append(result_string[:-1]) | 411 | cls_result_arr.append(result_string[:-1]) |
390 | - logging.info(segs + " : " + result_string[:-1]) | 412 | + #logging.info(segs + " : " + result_string[:-1]) |
413 | + | ||
391 | #======================================= | 414 | #======================================= |
392 | final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") | 415 | final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") |
393 | for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | 416 | for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): |
394 | final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) | 417 | final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) |
418 | + """ | ||
419 | + final_out_file.write("vid_id,segment1,pred1,segment2,pred2,segment3,pred3,segment4,pred4,segment5,pred5\n") | ||
420 | + for resultInfo in resultList: | ||
421 | + final_out_file.write(resultInfo) | ||
395 | final_out_file.close() | 422 | final_out_file.close() |
396 | 423 | ||
397 | out_file.close() | 424 | out_file.close() |
... | @@ -410,7 +437,7 @@ def main(unused_argv): | ... | @@ -410,7 +437,7 @@ def main(unused_argv): |
410 | os.makedirs(FLAGS.untar_model_dir) | 437 | os.makedirs(FLAGS.untar_model_dir) |
411 | tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | 438 | tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) |
412 | FLAGS.train_dir = FLAGS.untar_model_dir | 439 | FLAGS.train_dir = FLAGS.untar_model_dir |
413 | - | 440 | + print("TRAIN DIR ", FLAGS.input_data_pattern) |
414 | flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | 441 | flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") |
415 | if not file_io.file_exists(flags_dict_file): | 442 | if not file_io.file_exists(flags_dict_file): |
416 | raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | 443 | raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ... | ... |
... | @@ -75,7 +75,7 @@ if __name__ == "__main__": | ... | @@ -75,7 +75,7 @@ if __name__ == "__main__": |
75 | flags.DEFINE_integer( | 75 | flags.DEFINE_integer( |
76 | "num_gpu", 1, "The maximum number of GPU devices to use for training. " | 76 | "num_gpu", 1, "The maximum number of GPU devices to use for training. " |
77 | "Flag only applies if GPUs are installed") | 77 | "Flag only applies if GPUs are installed") |
78 | - flags.DEFINE_integer("batch_size", 64, | 78 | + flags.DEFINE_integer("batch_size", 16, |
79 | "How many examples to process per batch for training.") | 79 | "How many examples to process per batch for training.") |
80 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", | 80 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", |
81 | "Which loss function to use for training the model.") | 81 | "Which loss function to use for training the model.") |
... | @@ -94,13 +94,13 @@ if __name__ == "__main__": | ... | @@ -94,13 +94,13 @@ if __name__ == "__main__": |
94 | "Multiply current learning rate by learning_rate_decay " | 94 | "Multiply current learning rate by learning_rate_decay " |
95 | "every learning_rate_decay_examples.") | 95 | "every learning_rate_decay_examples.") |
96 | flags.DEFINE_integer( | 96 | flags.DEFINE_integer( |
97 | - "num_epochs", 5, "How many passes to make over the dataset before " | 97 | + "num_epochs", 100, "How many passes to make over the dataset before " |
98 | "halting training.") | 98 | "halting training.") |
99 | flags.DEFINE_integer( | 99 | flags.DEFINE_integer( |
100 | - "max_steps", None, | 100 | + "max_steps", 100, |
101 | "The maximum number of iterations of the training loop.") | 101 | "The maximum number of iterations of the training loop.") |
102 | flags.DEFINE_integer( | 102 | flags.DEFINE_integer( |
103 | - "export_model_steps", 5, | 103 | + "export_model_steps", 10, |
104 | "The period, in number of steps, with which the model " | 104 | "The period, in number of steps, with which the model " |
105 | "is exported for batch prediction.") | 105 | "is exported for batch prediction.") |
106 | 106 | ... | ... |
... | @@ -154,6 +154,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): | ... | @@ -154,6 +154,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): |
154 | List of the feature names and list of the dimensionality of each feature. | 154 | List of the feature names and list of the dimensionality of each feature. |
155 | Elements in the first/second list are strings/integers. | 155 | Elements in the first/second list are strings/integers. |
156 | """ | 156 | """ |
157 | + feature_sizes = str(feature_sizes) | ||
157 | list_of_feature_names = [ | 158 | list_of_feature_names = [ |
158 | feature_names.strip() for feature_names in feature_names.split(",") | 159 | feature_names.strip() for feature_names in feature_names.split(",") |
159 | ] | 160 | ] | ... | ... |
-
Please register or login to post a comment