Showing
1 changed file
with
32 additions
and
0 deletions
main.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +from model import DTN | ||
3 | +from solver import Solver | ||
4 | + | ||
5 | + | ||
6 | + | ||
7 | +flags = tf.app.flags | ||
8 | +flags.DEFINE_boolean('is_train', False, 'True if train mode, False if test mode') | ||
9 | + | ||
10 | +FLAGS = flags.FLAGS | ||
11 | + | ||
12 | +def main(_): | ||
13 | + | ||
14 | + model = DTN(batch_size=100, learning_rate=0.001, image_size=32, output_size=32, | ||
15 | + dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64) | ||
16 | + | ||
17 | + solver = Solver(model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', | ||
18 | + log_path='log/', sample_path='sample/', test_model_path='model/dtn-2-1', sample_iter=100) | ||
19 | + | ||
20 | + | ||
21 | + if FLAGS.is_train: | ||
22 | + solver.train() | ||
23 | + else: | ||
24 | + solver.test() | ||
25 | + | ||
26 | + | ||
27 | + | ||
28 | +if __name__ == '__main__': | ||
29 | + tf.app.run() | ||
30 | + | ||
31 | + | ||
32 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment