윤영빈

adjusted code

......@@ -50,12 +50,12 @@ if __name__ == "__main__":
"read 3D batches VS 4D batches.")
# Other flags.
flags.DEFINE_integer("batch_size", 1024,
flags.DEFINE_integer("batch_size", 1,
"How many examples to process per batch.")
flags.DEFINE_integer("num_readers", 8,
"How many threads to use for reading input files.")
flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.")
flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.")
flags.DEFINE_integer("top_k", 5, "How many predictions to output per video.")
def find_class_by_name(name, modules):
......@@ -69,18 +69,15 @@ def get_input_evaluation_tensors(reader,
batch_size=1024,
num_readers=1):
"""Creates the section of the graph which reads the evaluation data.
Args:
reader: A class which parses the training data.
data_pattern: A 'glob' style path to the data files.
batch_size: How many examples to process at a time.
num_readers: How many I/O threads to use.
Returns:
A tuple containing the features tensor, labels tensor, and optionally a
tensor containing the number of frames per video. The exact dimensions
depend on the reader being used.
Raises:
IOError: If no files matching the given pattern were found.
"""
......@@ -110,7 +107,6 @@ def build_graph(reader,
batch_size=1024,
num_readers=1):
"""Creates the Tensorflow graph for evaluation.
Args:
reader: The data file reader. It should inherit from BaseReader.
model: The core model (e.g. logistic or neural net). It should inherit from
......@@ -169,14 +165,12 @@ def build_graph(reader,
def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
last_global_step_val):
"""Run the evaluation loop once.
Args:
fetches: a dict of tensors to be run within Session.
saver: a tensorflow saver to restore the model.
summary_writer: a tensorflow summary_writer
evl_metrics: an EvaluationMetrics object.
last_global_step_val: the global step used in the previous evaluation.
Returns:
The global_step used in the latest model.
"""
......@@ -192,17 +186,20 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]
print("COMES HERERERERERER 55555555555555555 : ")
# Save model
if FLAGS.segment_labels:
print("COMES HERERERERERER 666666666666666666 : 1111")
inference_model_name = "segment_inference_model"
else:
print("COMES HERERERERERER 666666666666666666 : 22222")
inference_model_name = "inference_model"
saver.save(
sess,
os.path.join(FLAGS.train_dir, "inference_model",
inference_model_name))
else:
print("COMES HERERERERERER 666666666666666666 : 3333")
logging.info("No checkpoint file found.")
return global_step_val
......@@ -213,6 +210,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
return global_step_val
sess.run([tf.local_variables_initializer()])
print("COMES HERERERERERER 777777777777 : ")
# Start the queue runners.
coord = tf.train.Coordinator()
......@@ -227,15 +225,25 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
evl_metrics.clear()
examples_processed = 0
index = 0
while not coord.should_stop():
index = index + 1
print("proceeeeeeeeeeDDDDD!!! : " + str(index))
batch_start_time = time.time()
print("step 1")
output_data_dict = sess.run(fetches)
print("step 2")
seconds_per_batch = time.time() - batch_start_time
print("step 3")
labels_val = output_data_dict["labels"]
print("step 4")
summary_val = output_data_dict["summary"]
print("step 5")
example_per_second = labels_val.shape[0] / seconds_per_batch
print("step 6")
examples_processed += labels_val.shape[0]
print("step 7")
predictions = output_data_dict["predictions"]
if FLAGS.segment_labels:
# This is a workaround to ignore the unrated labels.
......@@ -354,4 +362,4 @@ def main(unused_argv):
if __name__ == "__main__":
tf.compat.v1.app.run()
tf.compat.v1.app.run()
\ No newline at end of file
......
......@@ -37,7 +37,7 @@ FLAGS = flags.FLAGS
if __name__ == "__main__":
# Dataset flags.
flags.DEFINE_string("train_dir", "/tmp/yt8m_model/",
flags.DEFINE_string("train_dir", "F:/yt8mDataset/savedModel",
"The directory to save the model files in.")
flags.DEFINE_string(
"train_data_pattern", "",
......@@ -75,7 +75,7 @@ if __name__ == "__main__":
flags.DEFINE_integer(
"num_gpu", 1, "The maximum number of GPU devices to use for training. "
"Flag only applies if GPUs are installed")
flags.DEFINE_integer("batch_size", 1024,
flags.DEFINE_integer("batch_size", 256,
"How many examples to process per batch for training.")
flags.DEFINE_string("label_loss", "CrossEntropyLoss",
"Which loss function to use for training the model.")
......@@ -94,13 +94,13 @@ if __name__ == "__main__":
"Multiply current learning rate by learning_rate_decay "
"every learning_rate_decay_examples.")
flags.DEFINE_integer(
"num_epochs", 1000, "How many passes to make over the dataset before "
"num_epochs", 100, "How many passes to make over the dataset before "
"halting training.")
flags.DEFINE_integer(
"max_steps", None,
"The maximum number of iterations of the training loop.")
flags.DEFINE_integer(
"export_model_steps", 1000,
"export_model_steps", 1,
"The period, in number of steps, with which the model "
"is exported for batch prediction.")
......@@ -404,6 +404,10 @@ class Trainer(object):
Returns:
A tuple of the training Hit@1 and the training PERR.
"""
print("=========================")
print("start now!!!!!")
if self.is_master and start_new_model:
self.remove_training_directory(self.train_dir)
......@@ -461,6 +465,7 @@ class Trainer(object):
save_summaries_secs=120,
saver=saver)
logging.info("%s: Starting managed session.", task_as_string(self.task))
with sv.managed_session(target, config=self.config) as sess:
try:
......@@ -470,8 +475,8 @@ class Trainer(object):
_, global_step_val, loss_val, predictions_val, labels_val = sess.run(
[train_op, global_step, loss, predictions, labels])
seconds_per_batch = time.time() - batch_start_time
examples_per_second = labels_val.shape[0] / seconds_per_batch
examples_per_second = labels_val.shape[0] / seconds_per_batch
print("CURRENT STEP IS " + str(global_step_val))
if self.max_steps and self.max_steps <= global_step_val:
self.max_steps_reached = True
......@@ -640,7 +645,7 @@ class ParameterServer(object):
def run(self):
"""Starts the parameter server."""
print("start now=================")
logging.info("%s: Starting parameter server within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
......