yunjey

main file

Showing 1 changed file with 8 additions and 10 deletions
......@@ -5,25 +5,23 @@ from solver import Solver
flags = tf.app.flags
flags.DEFINE_boolean('is_train', False, 'True if train mode, False if test mode')
flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
FLAGS = flags.FLAGS
def main(_):
model = DTN(batch_size=100, learning_rate=0.001, image_size=32, output_size=32,
dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64)
solver = Solver(model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
log_path='log/', sample_path='sample/', test_model_path='model/dtn-2-1', sample_iter=100)
model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model')
if FLAGS.is_train:
if FLAGS.mode == 'pretrain':
solver.pretrain()
elif FLAGS.mode == 'train':
solver.train()
else:
solver.test()
solver.eval()
if __name__ == '__main__':
tf.app.run()
......