yunjey

main file

Showing 1 changed file with 8 additions and 10 deletions
...@@ -5,25 +5,23 @@ from solver import Solver ...@@ -5,25 +5,23 @@ from solver import Solver
5 5
6 6
7 flags = tf.app.flags 7 flags = tf.app.flags
8 -flags.DEFINE_boolean('is_train', False, 'True if train mode, False if test mode') 8 +flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
9 9
10 FLAGS = flags.FLAGS 10 FLAGS = flags.FLAGS
11 11
12 def main(_): 12 def main(_):
13 13
14 - model = DTN(batch_size=100, learning_rate=0.001, image_size=32, output_size=32,
15 - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64)
16 14
17 - solver = Solver(model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', 15 + model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
18 - log_path='log/', sample_path='sample/', test_model_path='model/dtn-2-1', sample_iter=100) 16 + solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
17 + svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model')
19 18
20 - 19 + if FLAGS.mode == 'pretrain':
21 - if FLAGS.is_train: 20 + solver.pretrain()
21 + elif FLAGS.mode == 'train':
22 solver.train() 22 solver.train()
23 else: 23 else:
24 - solver.test() 24 + solver.eval()
25 -
26 -
27 25
28 if __name__ == '__main__': 26 if __name__ == '__main__':
29 tf.app.run() 27 tf.app.run()
......