Showing
1 changed file
with
10 additions
and
2 deletions
| ... | @@ -4,13 +4,21 @@ from solver import Solver | ... | @@ -4,13 +4,21 @@ from solver import Solver |
| 4 | 4 | ||
| 5 | flags = tf.app.flags | 5 | flags = tf.app.flags |
| 6 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") | 6 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") |
| 7 | +flags.DEFINE_string('model_save_path', 'model', "directory for saving the model") | ||
| 8 | +flags.DEFINE_string('sample_save_path', 'sample', "directory for saving the sampled images") | ||
| 7 | FLAGS = flags.FLAGS | 9 | FLAGS = flags.FLAGS |
| 8 | 10 | ||
| 9 | def main(_): | 11 | def main(_): |
| 10 | 12 | ||
| 11 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) | 13 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) |
| 12 | - solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 14 | + solver = Solver(model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, |
| 13 | - svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') | 15 | + svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path) |
| 16 | + | ||
| 17 | + # create directories if not exist | ||
| 18 | + if not tf.gfile.Exists(FLAGS.model_save_path): | ||
| 19 | + tf.gfile.MakeDirs(FLAGS.model_save_path) | ||
| 20 | + if not tf.gfile.Exists(FLAGS.sample_save_path): | ||
| 21 | + tf.gfile.MakeDirs(FLAGS.sample_save_path) | ||
| 14 | 22 | ||
| 15 | if FLAGS.mode == 'pretrain': | 23 | if FLAGS.mode == 'pretrain': |
| 16 | solver.pretrain() | 24 | solver.pretrain() | ... | ... |
-
Please register or login to post a comment