윤영빈

adjusted code

...@@ -50,12 +50,12 @@ if __name__ == "__main__": ...@@ -50,12 +50,12 @@ if __name__ == "__main__":
50 "read 3D batches VS 4D batches.") 50 "read 3D batches VS 4D batches.")
51 51
52 # Other flags. 52 # Other flags.
53 - flags.DEFINE_integer("batch_size", 1024, 53 + flags.DEFINE_integer("batch_size", 1,
54 "How many examples to process per batch.") 54 "How many examples to process per batch.")
55 flags.DEFINE_integer("num_readers", 8, 55 flags.DEFINE_integer("num_readers", 8,
56 "How many threads to use for reading input files.") 56 "How many threads to use for reading input files.")
57 flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") 57 flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.")
58 - flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") 58 + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
59 59
60 60
61 def find_class_by_name(name, modules): 61 def find_class_by_name(name, modules):
...@@ -69,18 +69,15 @@ def get_input_evaluation_tensors(reader, ...@@ -69,18 +69,15 @@ def get_input_evaluation_tensors(reader,
69 batch_size=1024, 69 batch_size=1024,
70 num_readers=1): 70 num_readers=1):
71 """Creates the section of the graph which reads the evaluation data. 71 """Creates the section of the graph which reads the evaluation data.
72 -
73 Args: 72 Args:
74 reader: A class which parses the training data. 73 reader: A class which parses the training data.
75 data_pattern: A 'glob' style path to the data files. 74 data_pattern: A 'glob' style path to the data files.
76 batch_size: How many examples to process at a time. 75 batch_size: How many examples to process at a time.
77 num_readers: How many I/O threads to use. 76 num_readers: How many I/O threads to use.
78 -
79 Returns: 77 Returns:
80 A tuple containing the features tensor, labels tensor, and optionally a 78 A tuple containing the features tensor, labels tensor, and optionally a
81 tensor containing the number of frames per video. The exact dimensions 79 tensor containing the number of frames per video. The exact dimensions
82 depend on the reader being used. 80 depend on the reader being used.
83 -
84 Raises: 81 Raises:
85 IOError: If no files matching the given pattern were found. 82 IOError: If no files matching the given pattern were found.
86 """ 83 """
...@@ -110,7 +107,6 @@ def build_graph(reader, ...@@ -110,7 +107,6 @@ def build_graph(reader,
110 batch_size=1024, 107 batch_size=1024,
111 num_readers=1): 108 num_readers=1):
112 """Creates the Tensorflow graph for evaluation. 109 """Creates the Tensorflow graph for evaluation.
113 -
114 Args: 110 Args:
115 reader: The data file reader. It should inherit from BaseReader. 111 reader: The data file reader. It should inherit from BaseReader.
116 model: The core model (e.g. logistic or neural net). It should inherit from 112 model: The core model (e.g. logistic or neural net). It should inherit from
...@@ -169,14 +165,12 @@ def build_graph(reader, ...@@ -169,14 +165,12 @@ def build_graph(reader,
169 def evaluation_loop(fetches, saver, summary_writer, evl_metrics, 165 def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
170 last_global_step_val): 166 last_global_step_val):
171 """Run the evaluation loop once. 167 """Run the evaluation loop once.
172 -
173 Args: 168 Args:
174 fetches: a dict of tensors to be run within Session. 169 fetches: a dict of tensors to be run within Session.
175 saver: a tensorflow saver to restore the model. 170 saver: a tensorflow saver to restore the model.
176 summary_writer: a tensorflow summary_writer 171 summary_writer: a tensorflow summary_writer
177 evl_metrics: an EvaluationMetrics object. 172 evl_metrics: an EvaluationMetrics object.
178 last_global_step_val: the global step used in the previous evaluation. 173 last_global_step_val: the global step used in the previous evaluation.
179 -
180 Returns: 174 Returns:
181 The global_step used in the latest model. 175 The global_step used in the latest model.
182 """ 176 """
...@@ -192,17 +186,20 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, ...@@ -192,17 +186,20 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
192 # Assuming model_checkpoint_path looks something like: 186 # Assuming model_checkpoint_path looks something like:
193 # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. 187 # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
194 global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] 188 global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]
195 - 189 + print("COMES HERERERERERER 55555555555555555 : ")
196 # Save model 190 # Save model
197 if FLAGS.segment_labels: 191 if FLAGS.segment_labels:
192 + print("COMES HERERERERERER 666666666666666666 : 1111")
198 inference_model_name = "segment_inference_model" 193 inference_model_name = "segment_inference_model"
199 else: 194 else:
195 + print("COMES HERERERERERER 666666666666666666 : 22222")
200 inference_model_name = "inference_model" 196 inference_model_name = "inference_model"
201 saver.save( 197 saver.save(
202 sess, 198 sess,
203 os.path.join(FLAGS.train_dir, "inference_model", 199 os.path.join(FLAGS.train_dir, "inference_model",
204 inference_model_name)) 200 inference_model_name))
205 else: 201 else:
202 + print("COMES HERERERERERER 666666666666666666 : 3333")
206 logging.info("No checkpoint file found.") 203 logging.info("No checkpoint file found.")
207 return global_step_val 204 return global_step_val
208 205
...@@ -213,6 +210,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, ...@@ -213,6 +210,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
213 return global_step_val 210 return global_step_val
214 211
215 sess.run([tf.local_variables_initializer()]) 212 sess.run([tf.local_variables_initializer()])
213 + print("COMES HERERERERERER 777777777777 : ")
216 214
217 # Start the queue runners. 215 # Start the queue runners.
218 coord = tf.train.Coordinator() 216 coord = tf.train.Coordinator()
...@@ -227,15 +225,25 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, ...@@ -227,15 +225,25 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
227 evl_metrics.clear() 225 evl_metrics.clear()
228 226
229 examples_processed = 0 227 examples_processed = 0
228 + index = 0
230 while not coord.should_stop(): 229 while not coord.should_stop():
230 + index = index + 1
231 + print("proceeeeeeeeeeDDDDD!!! : " + str(index))
231 batch_start_time = time.time() 232 batch_start_time = time.time()
233 + print("step 1")
232 output_data_dict = sess.run(fetches) 234 output_data_dict = sess.run(fetches)
235 + print("step 2")
233 seconds_per_batch = time.time() - batch_start_time 236 seconds_per_batch = time.time() - batch_start_time
237 + print("step 3")
234 labels_val = output_data_dict["labels"] 238 labels_val = output_data_dict["labels"]
239 + print("step 4")
235 summary_val = output_data_dict["summary"] 240 summary_val = output_data_dict["summary"]
241 + print("step 5")
236 example_per_second = labels_val.shape[0] / seconds_per_batch 242 example_per_second = labels_val.shape[0] / seconds_per_batch
243 + print("step 6")
237 examples_processed += labels_val.shape[0] 244 examples_processed += labels_val.shape[0]
238 - 245 + print("step 7")
246 +
239 predictions = output_data_dict["predictions"] 247 predictions = output_data_dict["predictions"]
240 if FLAGS.segment_labels: 248 if FLAGS.segment_labels:
241 # This is a workaround to ignore the unrated labels. 249 # This is a workaround to ignore the unrated labels.
...@@ -354,4 +362,4 @@ def main(unused_argv): ...@@ -354,4 +362,4 @@ def main(unused_argv):
354 362
355 363
356 if __name__ == "__main__": 364 if __name__ == "__main__":
357 - tf.compat.v1.app.run() 365 + tf.compat.v1.app.run()
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -37,7 +37,7 @@ FLAGS = flags.FLAGS ...@@ -37,7 +37,7 @@ FLAGS = flags.FLAGS
37 37
38 if __name__ == "__main__": 38 if __name__ == "__main__":
39 # Dataset flags. 39 # Dataset flags.
40 - flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", 40 + flags.DEFINE_string("train_dir", "F:/yt8mDataset/savedModel",
41 "The directory to save the model files in.") 41 "The directory to save the model files in.")
42 flags.DEFINE_string( 42 flags.DEFINE_string(
43 "train_data_pattern", "", 43 "train_data_pattern", "",
...@@ -75,7 +75,7 @@ if __name__ == "__main__": ...@@ -75,7 +75,7 @@ if __name__ == "__main__":
75 flags.DEFINE_integer( 75 flags.DEFINE_integer(
76 "num_gpu", 1, "The maximum number of GPU devices to use for training. " 76 "num_gpu", 1, "The maximum number of GPU devices to use for training. "
77 "Flag only applies if GPUs are installed") 77 "Flag only applies if GPUs are installed")
78 - flags.DEFINE_integer("batch_size", 1024, 78 + flags.DEFINE_integer("batch_size", 256,
79 "How many examples to process per batch for training.") 79 "How many examples to process per batch for training.")
80 flags.DEFINE_string("label_loss", "CrossEntropyLoss", 80 flags.DEFINE_string("label_loss", "CrossEntropyLoss",
81 "Which loss function to use for training the model.") 81 "Which loss function to use for training the model.")
...@@ -94,13 +94,13 @@ if __name__ == "__main__": ...@@ -94,13 +94,13 @@ if __name__ == "__main__":
94 "Multiply current learning rate by learning_rate_decay " 94 "Multiply current learning rate by learning_rate_decay "
95 "every learning_rate_decay_examples.") 95 "every learning_rate_decay_examples.")
96 flags.DEFINE_integer( 96 flags.DEFINE_integer(
97 - "num_epochs", 1000, "How many passes to make over the dataset before " 97 + "num_epochs", 100, "How many passes to make over the dataset before "
98 "halting training.") 98 "halting training.")
99 flags.DEFINE_integer( 99 flags.DEFINE_integer(
100 "max_steps", None, 100 "max_steps", None,
101 "The maximum number of iterations of the training loop.") 101 "The maximum number of iterations of the training loop.")
102 flags.DEFINE_integer( 102 flags.DEFINE_integer(
103 - "export_model_steps", 1000, 103 + "export_model_steps", 1,
104 "The period, in number of steps, with which the model " 104 "The period, in number of steps, with which the model "
105 "is exported for batch prediction.") 105 "is exported for batch prediction.")
106 106
...@@ -404,6 +404,10 @@ class Trainer(object): ...@@ -404,6 +404,10 @@ class Trainer(object):
404 Returns: 404 Returns:
405 A tuple of the training Hit@1 and the training PERR. 405 A tuple of the training Hit@1 and the training PERR.
406 """ 406 """
407 +
408 + print("=========================")
409 + print("start now!!!!!")
410 +
407 if self.is_master and start_new_model: 411 if self.is_master and start_new_model:
408 self.remove_training_directory(self.train_dir) 412 self.remove_training_directory(self.train_dir)
409 413
...@@ -461,6 +465,7 @@ class Trainer(object): ...@@ -461,6 +465,7 @@ class Trainer(object):
461 save_summaries_secs=120, 465 save_summaries_secs=120,
462 saver=saver) 466 saver=saver)
463 467
468 +
464 logging.info("%s: Starting managed session.", task_as_string(self.task)) 469 logging.info("%s: Starting managed session.", task_as_string(self.task))
465 with sv.managed_session(target, config=self.config) as sess: 470 with sv.managed_session(target, config=self.config) as sess:
466 try: 471 try:
...@@ -470,8 +475,8 @@ class Trainer(object): ...@@ -470,8 +475,8 @@ class Trainer(object):
470 _, global_step_val, loss_val, predictions_val, labels_val = sess.run( 475 _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
471 [train_op, global_step, loss, predictions, labels]) 476 [train_op, global_step, loss, predictions, labels])
472 seconds_per_batch = time.time() - batch_start_time 477 seconds_per_batch = time.time() - batch_start_time
473 - examples_per_second = labels_val.shape[0] / seconds_per_batch 478 + examples_per_second = labels_val.shape[0] / seconds_per_batch
474 - 479 + print("CURRENT STEP IS " + str(global_step_val))
475 if self.max_steps and self.max_steps <= global_step_val: 480 if self.max_steps and self.max_steps <= global_step_val:
476 self.max_steps_reached = True 481 self.max_steps_reached = True
477 482
...@@ -640,7 +645,7 @@ class ParameterServer(object): ...@@ -640,7 +645,7 @@ class ParameterServer(object):
640 645
641 def run(self): 646 def run(self):
642 """Starts the parameter server.""" 647 """Starts the parameter server."""
643 - 648 + print("start now=================")
644 logging.info("%s: Starting parameter server within cluster %s.", 649 logging.info("%s: Starting parameter server within cluster %s.",
645 task_as_string(self.task), self.cluster.as_dict()) 650 task_as_string(self.task), self.cluster.as_dict())
646 server = start_server(self.cluster, self.task) 651 server = start_server(self.cluster, self.task)
......