Showing
1 changed file
with
8 additions
and
10 deletions
... | @@ -5,25 +5,23 @@ from solver import Solver | ... | @@ -5,25 +5,23 @@ from solver import Solver |
5 | 5 | ||
6 | 6 | ||
7 | flags = tf.app.flags | 7 | flags = tf.app.flags |
8 | -flags.DEFINE_boolean('is_train', False, 'True if train mode, False if test mode') | 8 | +flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") |
9 | 9 | ||
10 | FLAGS = flags.FLAGS | 10 | FLAGS = flags.FLAGS |
11 | 11 | ||
12 | def main(_): | 12 | def main(_): |
13 | 13 | ||
14 | - model = DTN(batch_size=100, learning_rate=0.001, image_size=32, output_size=32, | ||
15 | - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64) | ||
16 | 14 | ||
17 | - solver = Solver(model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', | 15 | + model = DTN(mode=FLAGS.mode, learning_rate=0.0003) |
18 | - log_path='log/', sample_path='sample/', test_model_path='model/dtn-2-1', sample_iter=100) | 16 | + solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, |
17 | + svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') | ||
19 | 18 | ||
20 | - | 19 | + if FLAGS.mode == 'pretrain': |
21 | - if FLAGS.is_train: | 20 | + solver.pretrain() |
21 | + elif FLAGS.mode == 'train': | ||
22 | solver.train() | 22 | solver.train() |
23 | else: | 23 | else: |
24 | - solver.test() | 24 | + solver.eval() |
25 | - | ||
26 | - | ||
27 | 25 | ||
28 | if __name__ == '__main__': | 26 | if __name__ == '__main__': |
29 | tf.app.run() | 27 | tf.app.run() | ... | ... |
-
Please register or login to post a comment