윤영빈

changed saving method

...@@ -92,7 +92,6 @@ class VideoFileUploadView(APIView): ...@@ -92,7 +92,6 @@ class VideoFileUploadView(APIView):
92 return Response(result, status=status.HTTP_201_CREATED) 92 return Response(result, status=status.HTTP_201_CREATED)
93 else: 93 else:
94 return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST) 94 return Response(file_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
95 -
96 95
97 class VideoFileList(APIView): 96 class VideoFileList(APIView):
98 97
......
1 +{"model": "NetVLADModelLF", "feature_sizes": "1024,128", "feature_names": "rgb,audio", "frame_features": true, "label_loss": "CrossEntropyLoss"}
...\ No newline at end of file ...\ No newline at end of file
...@@ -33,17 +33,17 @@ FLAGS = flags.FLAGS ...@@ -33,17 +33,17 @@ FLAGS = flags.FLAGS
33 if __name__ == "__main__": 33 if __name__ == "__main__":
34 # Dataset flags. 34 # Dataset flags.
35 flags.DEFINE_string( 35 flags.DEFINE_string(
36 - "train_dir", "F:/yt8mDataset/savedModel", 36 + "train_dir", "F:/yt8mDataset/savedModelaa2",
37 "The directory to load the model files from. " 37 "The directory to load the model files from. "
38 "The tensorboard metrics files are also saved to this " 38 "The tensorboard metrics files are also saved to this "
39 "directory.") 39 "directory.")
40 flags.DEFINE_string( 40 flags.DEFINE_string(
41 - "eval_data_pattern", "", 41 + "eval_data_pattern", "F:/yt8mDataset/train2/train*.tfrecord",
42 "File glob defining the evaluation dataset in tensorflow.SequenceExample " 42 "File glob defining the evaluation dataset in tensorflow.SequenceExample "
43 "format. The SequenceExamples are expected to have an 'rgb' byte array " 43 "format. The SequenceExamples are expected to have an 'rgb' byte array "
44 "sequence feature as well as a 'labels' int64 context feature.") 44 "sequence feature as well as a 'labels' int64 context feature.")
45 flags.DEFINE_bool( 45 flags.DEFINE_bool(
46 - "segment_labels", False, 46 + "segment_labels", True,
47 "If set, then --eval_data_pattern must be frame-level features (but with" 47 "If set, then --eval_data_pattern must be frame-level features (but with"
48 " segment_labels). Otherwise, --eval_data_pattern must be aggregated " 48 " segment_labels). Otherwise, --eval_data_pattern must be aggregated "
49 "video-level features. The model must also be set appropriately (i.e. to " 49 "video-level features. The model must also be set appropriately (i.e. to "
...@@ -54,7 +54,7 @@ if __name__ == "__main__": ...@@ -54,7 +54,7 @@ if __name__ == "__main__":
54 "How many examples to process per batch.") 54 "How many examples to process per batch.")
55 flags.DEFINE_integer("num_readers", 8, 55 flags.DEFINE_integer("num_readers", 8,
56 "How many threads to use for reading input files.") 56 "How many threads to use for reading input files.")
57 - flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") 57 + flags.DEFINE_boolean("run_once", True, "Whether to run eval only once.")
58 flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") 58 flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
59 59
60 60
......
...@@ -46,7 +46,7 @@ flags.DEFINE_string( ...@@ -46,7 +46,7 @@ flags.DEFINE_string(
46 "Some Frame-Level models can be decomposed into a " 46 "Some Frame-Level models can be decomposed into a "
47 "generalized pooling operation followed by a " 47 "generalized pooling operation followed by a "
48 "classifier layer") 48 "classifier layer")
49 -flags.DEFINE_integer("lstm_cells", 512, "Number of LSTM cells.") 49 +flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.")
50 flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") 50 flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.")
51 51
52 52
......
...@@ -23,7 +23,7 @@ import tarfile ...@@ -23,7 +23,7 @@ import tarfile
23 import tempfile 23 import tempfile
24 import time 24 import time
25 import numpy as np 25 import numpy as np
26 - 26 +import ssl
27 import readers 27 import readers
28 from six.moves import urllib 28 from six.moves import urllib
29 import tensorflow as tf 29 import tensorflow as tf
...@@ -39,11 +39,11 @@ FLAGS = flags.FLAGS ...@@ -39,11 +39,11 @@ FLAGS = flags.FLAGS
39 if __name__ == "__main__": 39 if __name__ == "__main__":
40 # Input 40 # Input
41 flags.DEFINE_string( 41 flags.DEFINE_string(
42 - "train_dir", "", "The directory to load the model files from. We assume " 42 + "train_dir", "/mnt/f/yt8mDataset/savedModelaa", "The directory to load the model files from. We assume "
43 "that you have already run eval.py onto this, such that " 43 "that you have already run eval.py onto this, such that "
44 "inference_model.* files already exist.") 44 "inference_model.* files already exist.")
45 flags.DEFINE_string( 45 flags.DEFINE_string(
46 - "input_data_pattern", "", 46 + "input_data_pattern", "/mnt/f/yt8mDataset/train2/train*.tfrecord",
47 "File glob defining the evaluation dataset in tensorflow.SequenceExample " 47 "File glob defining the evaluation dataset in tensorflow.SequenceExample "
48 "format. The SequenceExamples are expected to have an 'rgb' byte array " 48 "format. The SequenceExamples are expected to have an 'rgb' byte array "
49 "sequence feature as well as a 'labels' int64 context feature.") 49 "sequence feature as well as a 'labels' int64 context feature.")
...@@ -60,7 +60,7 @@ if __name__ == "__main__": ...@@ -60,7 +60,7 @@ if __name__ == "__main__":
60 "be created and the contents of the .tgz file will be " 60 "be created and the contents of the .tgz file will be "
61 "untarred here.") 61 "untarred here.")
62 flags.DEFINE_bool( 62 flags.DEFINE_bool(
63 - "segment_labels", False, 63 + "segment_labels", True,
64 "If set, then --input_data_pattern must be frame-level features (but with" 64 "If set, then --input_data_pattern must be frame-level features (but with"
65 " segment_labels). Otherwise, --input_data_pattern must be aggregated " 65 " segment_labels). Otherwise, --input_data_pattern must be aggregated "
66 "video-level features. The model must also be set appropriately (i.e. to " 66 "video-level features. The model must also be set appropriately (i.e. to "
...@@ -69,18 +69,18 @@ if __name__ == "__main__": ...@@ -69,18 +69,18 @@ if __name__ == "__main__":
69 "Limit total number of segment outputs per entity.") 69 "Limit total number of segment outputs per entity.")
70 flags.DEFINE_string( 70 flags.DEFINE_string(
71 "segment_label_ids_file", 71 "segment_label_ids_file",
72 - "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", 72 + "/mnt/e/khuhub/2015104192/web/backend/yt8m/vocabulary.csv",
73 "The file that contains the segment label ids.") 73 "The file that contains the segment label ids.")
74 74
75 # Output 75 # Output
76 - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") 76 + flags.DEFINE_string("output_file", "/mnt/f/yt8mDataset/kaggle_solution_validation_temp.csv", "The file to save the predictions to.")
77 flags.DEFINE_string( 77 flags.DEFINE_string(
78 "output_model_tgz", "", 78 "output_model_tgz", "",
79 "If given, should be a filename with a .tgz extension, " 79 "If given, should be a filename with a .tgz extension, "
80 "the model graph and checkpoint will be bundled in this " 80 "the model graph and checkpoint will be bundled in this "
81 "gzip tar. This file can be uploaded to Kaggle for the " 81 "gzip tar. This file can be uploaded to Kaggle for the "
82 "top 10 participants.") 82 "top 10 participants.")
83 - flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") 83 + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
84 84
85 # Other flags. 85 # Other flags.
86 flags.DEFINE_integer("batch_size", 512, 86 flags.DEFINE_integer("batch_size", 512,
...@@ -108,18 +108,15 @@ def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): ...@@ -108,18 +108,15 @@ def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None):
108 108
109 def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): 109 def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
110 """Creates the section of the graph which reads the input data. 110 """Creates the section of the graph which reads the input data.
111 -
112 Args: 111 Args:
113 reader: A class which parses the input data. 112 reader: A class which parses the input data.
114 data_pattern: A 'glob' style path to the data files. 113 data_pattern: A 'glob' style path to the data files.
115 batch_size: How many examples to process at a time. 114 batch_size: How many examples to process at a time.
116 num_readers: How many I/O threads to use. 115 num_readers: How many I/O threads to use.
117 -
118 Returns: 116 Returns:
119 A tuple containing the features tensor, labels tensor, and optionally a 117 A tuple containing the features tensor, labels tensor, and optionally a
120 tensor containing the number of frames per video. The exact dimensions 118 tensor containing the number of frames per video. The exact dimensions
121 depend on the reader being used. 119 depend on the reader being used.
122 -
123 Raises: 120 Raises:
124 IOError: If no files matching the given pattern were found. 121 IOError: If no files matching the given pattern were found.
125 """ 122 """
...@@ -243,12 +240,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -243,12 +240,18 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
243 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), 240 whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
244 dtype=np.float32) 241 dtype=np.float32)
245 segment_label_ids_file = FLAGS.segment_label_ids_file 242 segment_label_ids_file = FLAGS.segment_label_ids_file
243 + """
244 + if segment_label_ids_file.startswith("http"):
245 + logging.info("Retrieving segment ID whitelist files from %s...",
246 + segment_label_ids_file)
247 + segment_label_ids_file, _ = urllib.request.urlretrieve(segment_label_ids_file)
248 + """
246 if segment_label_ids_file.startswith("http"): 249 if segment_label_ids_file.startswith("http"):
247 logging.info("Retrieving segment ID whitelist files from %s...", 250 logging.info("Retrieving segment ID whitelist files from %s...",
248 segment_label_ids_file) 251 segment_label_ids_file)
249 segment_label_ids_file, _ = urllib.request.urlretrieve( 252 segment_label_ids_file, _ = urllib.request.urlretrieve(
250 segment_label_ids_file) 253 segment_label_ids_file)
251 - with tf.io.gfile.GFile(segment_label_ids_file) as fobj: 254 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
252 for line in fobj: 255 for line in fobj:
253 try: 256 try:
254 cls_id = int(line) 257 cls_id = int(line)
...@@ -307,6 +310,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -307,6 +310,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
307 heaps = {} 310 heaps = {}
308 out_file.seek(0, 0) 311 out_file.seek(0, 0)
309 for line in out_file: 312 for line in out_file:
313 + print(line)
310 segment_id, preds = line.decode("utf8").split(",") 314 segment_id, preds = line.decode("utf8").split(",")
311 if segment_id == "VideoId": 315 if segment_id == "VideoId":
312 # Skip the headline. 316 # Skip the headline.
...@@ -382,4 +386,4 @@ def main(unused_argv): ...@@ -382,4 +386,4 @@ def main(unused_argv):
382 386
383 387
384 if __name__ == "__main__": 388 if __name__ == "__main__":
385 - app.run() 389 + app.run()
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -41,11 +41,11 @@ FLAGS = flags.FLAGS ...@@ -41,11 +41,11 @@ FLAGS = flags.FLAGS
41 if __name__ == "__main__": 41 if __name__ == "__main__":
42 # Input 42 # Input
43 flags.DEFINE_string( 43 flags.DEFINE_string(
44 - "train_dir", "", "The directory to load the model files from. We assume " 44 + "train_dir", "E:/savedModel", "The directory to load the model files from. We assume "
45 "that you have already run eval.py onto this, such that " 45 "that you have already run eval.py onto this, such that "
46 "inference_model.* files already exist.") 46 "inference_model.* files already exist.")
47 flags.DEFINE_string( 47 flags.DEFINE_string(
48 - "input_data_pattern", "", 48 + "input_data_pattern", "F:/yt8mDataset/test/train*.tfrecord",
49 "File glob defining the evaluation dataset in tensorflow.SequenceExample " 49 "File glob defining the evaluation dataset in tensorflow.SequenceExample "
50 "format. The SequenceExamples are expected to have an 'rgb' byte array " 50 "format. The SequenceExamples are expected to have an 'rgb' byte array "
51 "sequence feature as well as a 'labels' int64 context feature.") 51 "sequence feature as well as a 'labels' int64 context feature.")
...@@ -62,7 +62,7 @@ if __name__ == "__main__": ...@@ -62,7 +62,7 @@ if __name__ == "__main__":
62 "be created and the contents of the .tgz file will be " 62 "be created and the contents of the .tgz file will be "
63 "untarred here.") 63 "untarred here.")
64 flags.DEFINE_bool( 64 flags.DEFINE_bool(
65 - "segment_labels", False, 65 + "segment_labels", True,
66 "If set, then --input_data_pattern must be frame-level features (but with" 66 "If set, then --input_data_pattern must be frame-level features (but with"
67 " segment_labels). Otherwise, --input_data_pattern must be aggregated " 67 " segment_labels). Otherwise, --input_data_pattern must be aggregated "
68 "video-level features. The model must also be set appropriately (i.e. to " 68 "video-level features. The model must also be set appropriately (i.e. to "
...@@ -75,7 +75,7 @@ if __name__ == "__main__": ...@@ -75,7 +75,7 @@ if __name__ == "__main__":
75 "The file that contains the segment label ids.") 75 "The file that contains the segment label ids.")
76 76
77 # Output 77 # Output
78 - flags.DEFINE_string("output_file", "", "The file to save the predictions to.") 78 + flags.DEFINE_string("output_file", "F:/yt8mDataset/kaggle_solution_validation_temp.csv", "The file to save the predictions to.")
79 flags.DEFINE_string( 79 flags.DEFINE_string(
80 "output_model_tgz", "", 80 "output_model_tgz", "",
81 "If given, should be a filename with a .tgz extension, " 81 "If given, should be a filename with a .tgz extension, "
...@@ -85,7 +85,7 @@ if __name__ == "__main__": ...@@ -85,7 +85,7 @@ if __name__ == "__main__":
85 flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") 85 flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
86 86
87 # Other flags. 87 # Other flags.
88 - flags.DEFINE_integer("batch_size", 32, 88 + flags.DEFINE_integer("batch_size", 16,
89 "How many examples to process per batch.") 89 "How many examples to process per batch.")
90 flags.DEFINE_integer("num_readers", 1, 90 flags.DEFINE_integer("num_readers", 1,
91 "How many threads to use for reading input files.") 91 "How many threads to use for reading input files.")
...@@ -269,6 +269,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -269,6 +269,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
269 except ValueError: 269 except ValueError:
270 # Simply skip the non-integer line. 270 # Simply skip the non-integer line.
271 continue 271 continue
272 + fobj.close()
272 273
273 out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) 274 out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
274 275
...@@ -286,8 +287,10 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -286,8 +287,10 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
286 vocabs.close() 287 vocabs.close()
287 try: 288 try:
288 while not coord.should_stop(): 289 while not coord.should_stop():
290 + print("CAME IN")
289 video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( 291 video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
290 [video_id_batch, video_batch, num_frames_batch]) 292 [video_id_batch, video_batch, num_frames_batch])
293 + print("SESS OUT")
291 if FLAGS.segment_labels: 294 if FLAGS.segment_labels:
292 results = get_segments(video_batch_val, num_frames_batch_val, 5) 295 results = get_segments(video_batch_val, num_frames_batch_val, 5)
293 video_segment_ids = results["video_segment_ids"] 296 video_segment_ids = results["video_segment_ids"]
...@@ -333,6 +336,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -333,6 +336,7 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
333 segment_classes = [] 336 segment_classes = []
334 cls_result_arr = [] 337 cls_result_arr = []
335 cls_score_dict = {} 338 cls_score_dict = {}
339 + resultList = []
336 out_file.seek(0, 0) 340 out_file.seek(0, 0)
337 old_seg_name = '0000' 341 old_seg_name = '0000'
338 counter = 0 342 counter = 0
...@@ -344,12 +348,30 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -344,12 +348,30 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
344 if segment_id == "VideoId": 348 if segment_id == "VideoId":
345 # Skip the headline. 349 # Skip the headline.
346 continue 350 continue
347 - 351 + print(line)
348 preds = preds.split(" ") 352 preds = preds.split(" ")
353 + """
349 pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] 354 pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
350 pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)] 355 pred_cls_scores = [float(preds[idx]) for idx in range(1, len(preds), 2)]
351 - #======================================= 356 + """
352 - segment_id = str(segment_id.split(":")[0]) 357 + """
358 + denom = 1.0
359 + for i in range(0,top_k):
360 + denom = denom + float(preds[(i*2) + 1])
361 + print("DENOM = ",denom)
362 + for i in range(0,top_k):
363 + preds[(i*2) + 1] = float(preds[(i*2) + 1])/denom
364 + """
365 +
366 + segment_id = "{0}_{1}".format(str(segment_id.split(":")[0]),str(int(segment_id.split(":")[1])/5))
367 + resultList.append("{0},{1},{2},{3},{4},{5}".format(segment_id,
368 + preds[0]+","+str(preds[1]),
369 + preds[2]+","+str(preds[3]),
370 + preds[4]+","+str(preds[5]),
371 + preds[6]+","+str(preds[7]),
372 + preds[8]+","+str(preds[9])))
373 + #=======================================
374 + """
353 if segment_id not in segment_id_list: 375 if segment_id not in segment_id_list:
354 segment_id_list.append(str(segment_id)) 376 segment_id_list.append(str(segment_id))
355 segment_classes.append("") 377 segment_classes.append("")
...@@ -372,12 +394,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -372,12 +394,12 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
372 cls_arr = item.split(" ")[:-1] 394 cls_arr = item.split(" ")[:-1]
373 395
374 cls_arr = list(map(int,cls_arr)) 396 cls_arr = list(map(int,cls_arr))
375 - cls_arr = sorted(cls_arr) #클래스별로 정렬 397 + cls_arr = sorted(cls_arr) #sort by class
376 398
377 result_string = "" 399 result_string = ""
378 400
379 temp = cls_score_dict[segs] 401 temp = cls_score_dict[segs]
380 - temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #밸류값 기준으로 정렬 402 + temp= sorted(temp.items(), key=operator.itemgetter(1), reverse=True) #sort by value
381 demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1]) 403 demoninator = float(temp[0][1] + temp[1][1] + temp[2][1] + temp[3][1] + temp[4][1])
382 #for item in temp: 404 #for item in temp:
383 for itemIndex in range(0, top_k): 405 for itemIndex in range(0, top_k):
...@@ -387,11 +409,16 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size, ...@@ -387,11 +409,16 @@ def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
387 result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + "," 409 result_string = result_string + normalized_tag + ":" + format(temp[itemIndex][1]/demoninator,".3f") + ","
388 410
389 cls_result_arr.append(result_string[:-1]) 411 cls_result_arr.append(result_string[:-1])
390 - logging.info(segs + " : " + result_string[:-1]) 412 + #logging.info(segs + " : " + result_string[:-1])
413 +
391 #======================================= 414 #=======================================
392 final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n") 415 final_out_file.write("vid_id,segment1,segment2,segment3,segment4,segment5\n")
393 for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): 416 for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
394 final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies))) 417 final_out_file.write("%s,%s\n" %(seg_id, str(class_indcies)))
418 + """
419 + final_out_file.write("vid_id,segment1,pred1,segment2,pred2,segment3,pred3,segment4,pred4,segment5,pred5\n")
420 + for resultInfo in resultList:
421 + final_out_file.write(resultInfo)
395 final_out_file.close() 422 final_out_file.close()
396 423
397 out_file.close() 424 out_file.close()
...@@ -410,7 +437,7 @@ def main(unused_argv): ...@@ -410,7 +437,7 @@ def main(unused_argv):
410 os.makedirs(FLAGS.untar_model_dir) 437 os.makedirs(FLAGS.untar_model_dir)
411 tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) 438 tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
412 FLAGS.train_dir = FLAGS.untar_model_dir 439 FLAGS.train_dir = FLAGS.untar_model_dir
413 - 440 + print("TRAIN DIR ", FLAGS.input_data_pattern)
414 flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") 441 flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
415 if not file_io.file_exists(flags_dict_file): 442 if not file_io.file_exists(flags_dict_file):
416 raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) 443 raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
......
...@@ -75,7 +75,7 @@ if __name__ == "__main__": ...@@ -75,7 +75,7 @@ if __name__ == "__main__":
75 flags.DEFINE_integer( 75 flags.DEFINE_integer(
76 "num_gpu", 1, "The maximum number of GPU devices to use for training. " 76 "num_gpu", 1, "The maximum number of GPU devices to use for training. "
77 "Flag only applies if GPUs are installed") 77 "Flag only applies if GPUs are installed")
78 - flags.DEFINE_integer("batch_size", 64, 78 + flags.DEFINE_integer("batch_size", 16,
79 "How many examples to process per batch for training.") 79 "How many examples to process per batch for training.")
80 flags.DEFINE_string("label_loss", "CrossEntropyLoss", 80 flags.DEFINE_string("label_loss", "CrossEntropyLoss",
81 "Which loss function to use for training the model.") 81 "Which loss function to use for training the model.")
...@@ -94,13 +94,13 @@ if __name__ == "__main__": ...@@ -94,13 +94,13 @@ if __name__ == "__main__":
94 "Multiply current learning rate by learning_rate_decay " 94 "Multiply current learning rate by learning_rate_decay "
95 "every learning_rate_decay_examples.") 95 "every learning_rate_decay_examples.")
96 flags.DEFINE_integer( 96 flags.DEFINE_integer(
97 - "num_epochs", 5, "How many passes to make over the dataset before " 97 + "num_epochs", 100, "How many passes to make over the dataset before "
98 "halting training.") 98 "halting training.")
99 flags.DEFINE_integer( 99 flags.DEFINE_integer(
100 - "max_steps", None, 100 + "max_steps", 100,
101 "The maximum number of iterations of the training loop.") 101 "The maximum number of iterations of the training loop.")
102 flags.DEFINE_integer( 102 flags.DEFINE_integer(
103 - "export_model_steps", 5, 103 + "export_model_steps", 10,
104 "The period, in number of steps, with which the model " 104 "The period, in number of steps, with which the model "
105 "is exported for batch prediction.") 105 "is exported for batch prediction.")
106 106
......
...@@ -154,6 +154,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): ...@@ -154,6 +154,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
154 List of the feature names and list of the dimensionality of each feature. 154 List of the feature names and list of the dimensionality of each feature.
155 Elements in the first/second list are strings/integers. 155 Elements in the first/second list are strings/integers.
156 """ 156 """
157 + feature_sizes = str(feature_sizes)
157 list_of_feature_names = [ 158 list_of_feature_names = [
158 feature_names.strip() for feature_names in feature_names.split(",") 159 feature_names.strip() for feature_names in feature_names.split(",")
159 ] 160 ]
......