Showing
1 changed file
with
32 additions
and
0 deletions
main.py
0 → 100644
| 1 | +import tensorflow as tf | ||
| 2 | +from model import DTN | ||
| 3 | +from solver import Solver | ||
| 4 | + | ||
| 5 | + | ||
| 6 | + | ||
| 7 | +flags = tf.app.flags | ||
| 8 | +flags.DEFINE_boolean('is_train', False, 'True if train mode, False if test mode') | ||
| 9 | + | ||
| 10 | +FLAGS = flags.FLAGS | ||
| 11 | + | ||
| 12 | +def main(_): | ||
| 13 | + | ||
| 14 | + model = DTN(batch_size=100, learning_rate=0.001, image_size=32, output_size=32, | ||
| 15 | + dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64) | ||
| 16 | + | ||
| 17 | + solver = Solver(model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', | ||
| 18 | + log_path='log/', sample_path='sample/', test_model_path='model/dtn-2-1', sample_iter=100) | ||
| 19 | + | ||
| 20 | + | ||
| 21 | + if FLAGS.is_train: | ||
| 22 | + solver.train() | ||
| 23 | + else: | ||
| 24 | + solver.test() | ||
| 25 | + | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +if __name__ == '__main__': | ||
| 29 | + tf.app.run() | ||
| 30 | + | ||
| 31 | + | ||
| 32 | + | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment