Showing
2 changed files
with
31 additions
and
18 deletions
... | @@ -50,12 +50,12 @@ if __name__ == "__main__": | ... | @@ -50,12 +50,12 @@ if __name__ == "__main__": |
50 | "read 3D batches VS 4D batches.") | 50 | "read 3D batches VS 4D batches.") |
51 | 51 | ||
52 | # Other flags. | 52 | # Other flags. |
53 | - flags.DEFINE_integer("batch_size", 1024, | 53 | + flags.DEFINE_integer("batch_size", 1, |
54 | "How many examples to process per batch.") | 54 | "How many examples to process per batch.") |
55 | flags.DEFINE_integer("num_readers", 8, | 55 | flags.DEFINE_integer("num_readers", 8, |
56 | "How many threads to use for reading input files.") | 56 | "How many threads to use for reading input files.") |
57 | flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") | 57 | flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") |
58 | - flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | 58 | + flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.") |
59 | 59 | ||
60 | 60 | ||
61 | def find_class_by_name(name, modules): | 61 | def find_class_by_name(name, modules): |
... | @@ -69,18 +69,15 @@ def get_input_evaluation_tensors(reader, | ... | @@ -69,18 +69,15 @@ def get_input_evaluation_tensors(reader, |
69 | batch_size=1024, | 69 | batch_size=1024, |
70 | num_readers=1): | 70 | num_readers=1): |
71 | """Creates the section of the graph which reads the evaluation data. | 71 | """Creates the section of the graph which reads the evaluation data. |
72 | - | ||
73 | Args: | 72 | Args: |
74 | reader: A class which parses the training data. | 73 | reader: A class which parses the training data. |
75 | data_pattern: A 'glob' style path to the data files. | 74 | data_pattern: A 'glob' style path to the data files. |
76 | batch_size: How many examples to process at a time. | 75 | batch_size: How many examples to process at a time. |
77 | num_readers: How many I/O threads to use. | 76 | num_readers: How many I/O threads to use. |
78 | - | ||
79 | Returns: | 77 | Returns: |
80 | A tuple containing the features tensor, labels tensor, and optionally a | 78 | A tuple containing the features tensor, labels tensor, and optionally a |
81 | tensor containing the number of frames per video. The exact dimensions | 79 | tensor containing the number of frames per video. The exact dimensions |
82 | depend on the reader being used. | 80 | depend on the reader being used. |
83 | - | ||
84 | Raises: | 81 | Raises: |
85 | IOError: If no files matching the given pattern were found. | 82 | IOError: If no files matching the given pattern were found. |
86 | """ | 83 | """ |
... | @@ -110,7 +107,6 @@ def build_graph(reader, | ... | @@ -110,7 +107,6 @@ def build_graph(reader, |
110 | batch_size=1024, | 107 | batch_size=1024, |
111 | num_readers=1): | 108 | num_readers=1): |
112 | """Creates the Tensorflow graph for evaluation. | 109 | """Creates the Tensorflow graph for evaluation. |
113 | - | ||
114 | Args: | 110 | Args: |
115 | reader: The data file reader. It should inherit from BaseReader. | 111 | reader: The data file reader. It should inherit from BaseReader. |
116 | model: The core model (e.g. logistic or neural net). It should inherit from | 112 | model: The core model (e.g. logistic or neural net). It should inherit from |
... | @@ -169,14 +165,12 @@ def build_graph(reader, | ... | @@ -169,14 +165,12 @@ def build_graph(reader, |
169 | def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | 165 | def evaluation_loop(fetches, saver, summary_writer, evl_metrics, |
170 | last_global_step_val): | 166 | last_global_step_val): |
171 | """Run the evaluation loop once. | 167 | """Run the evaluation loop once. |
172 | - | ||
173 | Args: | 168 | Args: |
174 | fetches: a dict of tensors to be run within Session. | 169 | fetches: a dict of tensors to be run within Session. |
175 | saver: a tensorflow saver to restore the model. | 170 | saver: a tensorflow saver to restore the model. |
176 | summary_writer: a tensorflow summary_writer | 171 | summary_writer: a tensorflow summary_writer |
177 | evl_metrics: an EvaluationMetrics object. | 172 | evl_metrics: an EvaluationMetrics object. |
178 | last_global_step_val: the global step used in the previous evaluation. | 173 | last_global_step_val: the global step used in the previous evaluation. |
179 | - | ||
180 | Returns: | 174 | Returns: |
181 | The global_step used in the latest model. | 175 | The global_step used in the latest model. |
182 | """ | 176 | """ |
... | @@ -192,17 +186,20 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | ... | @@ -192,17 +186,20 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, |
192 | # Assuming model_checkpoint_path looks something like: | 186 | # Assuming model_checkpoint_path looks something like: |
193 | # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. | 187 | # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. |
194 | global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] | 188 | global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] |
195 | - | 189 | + print("COMES HERERERERERER 55555555555555555 : ") |
196 | # Save model | 190 | # Save model |
197 | if FLAGS.segment_labels: | 191 | if FLAGS.segment_labels: |
192 | + print("COMES HERERERERERER 666666666666666666 : 1111") | ||
198 | inference_model_name = "segment_inference_model" | 193 | inference_model_name = "segment_inference_model" |
199 | else: | 194 | else: |
195 | + print("COMES HERERERERERER 666666666666666666 : 22222") | ||
200 | inference_model_name = "inference_model" | 196 | inference_model_name = "inference_model" |
201 | saver.save( | 197 | saver.save( |
202 | sess, | 198 | sess, |
203 | os.path.join(FLAGS.train_dir, "inference_model", | 199 | os.path.join(FLAGS.train_dir, "inference_model", |
204 | inference_model_name)) | 200 | inference_model_name)) |
205 | else: | 201 | else: |
202 | + print("COMES HERERERERERER 666666666666666666 : 3333") | ||
206 | logging.info("No checkpoint file found.") | 203 | logging.info("No checkpoint file found.") |
207 | return global_step_val | 204 | return global_step_val |
208 | 205 | ||
... | @@ -213,6 +210,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | ... | @@ -213,6 +210,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, |
213 | return global_step_val | 210 | return global_step_val |
214 | 211 | ||
215 | sess.run([tf.local_variables_initializer()]) | 212 | sess.run([tf.local_variables_initializer()]) |
213 | + print("COMES HERERERERERER 777777777777 : ") | ||
216 | 214 | ||
217 | # Start the queue runners. | 215 | # Start the queue runners. |
218 | coord = tf.train.Coordinator() | 216 | coord = tf.train.Coordinator() |
... | @@ -227,15 +225,25 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | ... | @@ -227,15 +225,25 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics, |
227 | evl_metrics.clear() | 225 | evl_metrics.clear() |
228 | 226 | ||
229 | examples_processed = 0 | 227 | examples_processed = 0 |
228 | + index = 0 | ||
230 | while not coord.should_stop(): | 229 | while not coord.should_stop(): |
230 | + index = index + 1 | ||
231 | + print("proceeeeeeeeeeDDDDD!!! : " + str(index)) | ||
231 | batch_start_time = time.time() | 232 | batch_start_time = time.time() |
233 | + print("step 1") | ||
232 | output_data_dict = sess.run(fetches) | 234 | output_data_dict = sess.run(fetches) |
235 | + print("step 2") | ||
233 | seconds_per_batch = time.time() - batch_start_time | 236 | seconds_per_batch = time.time() - batch_start_time |
237 | + print("step 3") | ||
234 | labels_val = output_data_dict["labels"] | 238 | labels_val = output_data_dict["labels"] |
239 | + print("step 4") | ||
235 | summary_val = output_data_dict["summary"] | 240 | summary_val = output_data_dict["summary"] |
241 | + print("step 5") | ||
236 | example_per_second = labels_val.shape[0] / seconds_per_batch | 242 | example_per_second = labels_val.shape[0] / seconds_per_batch |
243 | + print("step 6") | ||
237 | examples_processed += labels_val.shape[0] | 244 | examples_processed += labels_val.shape[0] |
238 | - | 245 | + print("step 7") |
246 | + | ||
239 | predictions = output_data_dict["predictions"] | 247 | predictions = output_data_dict["predictions"] |
240 | if FLAGS.segment_labels: | 248 | if FLAGS.segment_labels: |
241 | # This is a workaround to ignore the unrated labels. | 249 | # This is a workaround to ignore the unrated labels. |
... | @@ -354,4 +362,4 @@ def main(unused_argv): | ... | @@ -354,4 +362,4 @@ def main(unused_argv): |
354 | 362 | ||
355 | 363 | ||
356 | if __name__ == "__main__": | 364 | if __name__ == "__main__": |
357 | - tf.compat.v1.app.run() | 365 | + tf.compat.v1.app.run() |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -37,7 +37,7 @@ FLAGS = flags.FLAGS | ... | @@ -37,7 +37,7 @@ FLAGS = flags.FLAGS |
37 | 37 | ||
38 | if __name__ == "__main__": | 38 | if __name__ == "__main__": |
39 | # Dataset flags. | 39 | # Dataset flags. |
40 | - flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", | 40 | + flags.DEFINE_string("train_dir", "F:/yt8mDataset/savedModel", |
41 | "The directory to save the model files in.") | 41 | "The directory to save the model files in.") |
42 | flags.DEFINE_string( | 42 | flags.DEFINE_string( |
43 | "train_data_pattern", "", | 43 | "train_data_pattern", "", |
... | @@ -75,7 +75,7 @@ if __name__ == "__main__": | ... | @@ -75,7 +75,7 @@ if __name__ == "__main__": |
75 | flags.DEFINE_integer( | 75 | flags.DEFINE_integer( |
76 | "num_gpu", 1, "The maximum number of GPU devices to use for training. " | 76 | "num_gpu", 1, "The maximum number of GPU devices to use for training. " |
77 | "Flag only applies if GPUs are installed") | 77 | "Flag only applies if GPUs are installed") |
78 | - flags.DEFINE_integer("batch_size", 1024, | 78 | + flags.DEFINE_integer("batch_size", 256, |
79 | "How many examples to process per batch for training.") | 79 | "How many examples to process per batch for training.") |
80 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", | 80 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", |
81 | "Which loss function to use for training the model.") | 81 | "Which loss function to use for training the model.") |
... | @@ -94,13 +94,13 @@ if __name__ == "__main__": | ... | @@ -94,13 +94,13 @@ if __name__ == "__main__": |
94 | "Multiply current learning rate by learning_rate_decay " | 94 | "Multiply current learning rate by learning_rate_decay " |
95 | "every learning_rate_decay_examples.") | 95 | "every learning_rate_decay_examples.") |
96 | flags.DEFINE_integer( | 96 | flags.DEFINE_integer( |
97 | - "num_epochs", 1000, "How many passes to make over the dataset before " | 97 | + "num_epochs", 100, "How many passes to make over the dataset before " |
98 | "halting training.") | 98 | "halting training.") |
99 | flags.DEFINE_integer( | 99 | flags.DEFINE_integer( |
100 | "max_steps", None, | 100 | "max_steps", None, |
101 | "The maximum number of iterations of the training loop.") | 101 | "The maximum number of iterations of the training loop.") |
102 | flags.DEFINE_integer( | 102 | flags.DEFINE_integer( |
103 | - "export_model_steps", 1000, | 103 | + "export_model_steps", 1, |
104 | "The period, in number of steps, with which the model " | 104 | "The period, in number of steps, with which the model " |
105 | "is exported for batch prediction.") | 105 | "is exported for batch prediction.") |
106 | 106 | ||
... | @@ -404,6 +404,10 @@ class Trainer(object): | ... | @@ -404,6 +404,10 @@ class Trainer(object): |
404 | Returns: | 404 | Returns: |
405 | A tuple of the training Hit@1 and the training PERR. | 405 | A tuple of the training Hit@1 and the training PERR. |
406 | """ | 406 | """ |
407 | + | ||
408 | + print("=========================") | ||
409 | + print("start now!!!!!") | ||
410 | + | ||
407 | if self.is_master and start_new_model: | 411 | if self.is_master and start_new_model: |
408 | self.remove_training_directory(self.train_dir) | 412 | self.remove_training_directory(self.train_dir) |
409 | 413 | ||
... | @@ -461,6 +465,7 @@ class Trainer(object): | ... | @@ -461,6 +465,7 @@ class Trainer(object): |
461 | save_summaries_secs=120, | 465 | save_summaries_secs=120, |
462 | saver=saver) | 466 | saver=saver) |
463 | 467 | ||
468 | + | ||
464 | logging.info("%s: Starting managed session.", task_as_string(self.task)) | 469 | logging.info("%s: Starting managed session.", task_as_string(self.task)) |
465 | with sv.managed_session(target, config=self.config) as sess: | 470 | with sv.managed_session(target, config=self.config) as sess: |
466 | try: | 471 | try: |
... | @@ -470,8 +475,8 @@ class Trainer(object): | ... | @@ -470,8 +475,8 @@ class Trainer(object): |
470 | _, global_step_val, loss_val, predictions_val, labels_val = sess.run( | 475 | _, global_step_val, loss_val, predictions_val, labels_val = sess.run( |
471 | [train_op, global_step, loss, predictions, labels]) | 476 | [train_op, global_step, loss, predictions, labels]) |
472 | seconds_per_batch = time.time() - batch_start_time | 477 | seconds_per_batch = time.time() - batch_start_time |
473 | - examples_per_second = labels_val.shape[0] / seconds_per_batch | 478 | + examples_per_second = labels_val.shape[0] / seconds_per_batch |
474 | - | 479 | + print("CURRENT STEP IS " + str(global_step_val)) |
475 | if self.max_steps and self.max_steps <= global_step_val: | 480 | if self.max_steps and self.max_steps <= global_step_val: |
476 | self.max_steps_reached = True | 481 | self.max_steps_reached = True |
477 | 482 | ||
... | @@ -640,7 +645,7 @@ class ParameterServer(object): | ... | @@ -640,7 +645,7 @@ class ParameterServer(object): |
640 | 645 | ||
641 | def run(self): | 646 | def run(self): |
642 | """Starts the parameter server.""" | 647 | """Starts the parameter server.""" |
643 | - | 648 | + print("start now=================") |
644 | logging.info("%s: Starting parameter server within cluster %s.", | 649 | logging.info("%s: Starting parameter server within cluster %s.", |
645 | task_as_string(self.task), self.cluster.as_dict()) | 650 | task_as_string(self.task), self.cluster.as_dict()) |
646 | server = start_server(self.cluster, self.task) | 651 | server = start_server(self.cluster, self.task) | ... | ... |
-
Please register or login to post a comment