Showing
1 changed file
with
13 additions
and
7 deletions
... | @@ -9,9 +9,9 @@ import scipy.misc | ... | @@ -9,9 +9,9 @@ import scipy.misc |
9 | 9 | ||
10 | class Solver(object): | 10 | class Solver(object): |
11 | 11 | ||
12 | - def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100, | 12 | + def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, |
13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', | 13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', |
14 | - model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): | 14 | + model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'): |
15 | 15 | ||
16 | self.model = model | 16 | self.model = model |
17 | self.batch_size = batch_size | 17 | self.batch_size = batch_size |
... | @@ -111,7 +111,7 @@ class Solver(object): | ... | @@ -111,7 +111,7 @@ class Solver(object): |
111 | model = self.model | 111 | model = self.model |
112 | model.build_model() | 112 | model.build_model() |
113 | 113 | ||
114 | - # make log directory if not exists | 114 | + # make directory if not exists |
115 | if tf.gfile.Exists(self.log_dir): | 115 | if tf.gfile.Exists(self.log_dir): |
116 | tf.gfile.DeleteRecursively(self.log_dir) | 116 | tf.gfile.DeleteRecursively(self.log_dir) |
117 | tf.gfile.MakeDirs(self.log_dir) | 117 | tf.gfile.MakeDirs(self.log_dir) |
... | @@ -121,13 +121,16 @@ class Solver(object): | ... | @@ -121,13 +121,16 @@ class Solver(object): |
121 | tf.global_variables_initializer().run() | 121 | tf.global_variables_initializer().run() |
122 | # restore variables of F | 122 | # restore variables of F |
123 | print ('loading pretrained model F..') | 123 | print ('loading pretrained model F..') |
124 | - variables_to_restore = slim.get_model_variables(scope='content_extractor') | 124 | + #variables_to_restore = slim.get_model_variables(scope='content_extractor') |
125 | - restorer = tf.train.Saver(variables_to_restore) | 125 | + #restorer = tf.train.Saver(variables_to_restore) |
126 | - restorer.restore(sess, self.pretrained_model) | 126 | + #restorer.restore(sess, self.pretrained_model) |
127 | + restorer = tf.train.Saver() | ||
128 | + restorer.restore(sess, 'model/dtn-1600') | ||
127 | summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) | 129 | summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) |
128 | saver = tf.train.Saver() | 130 | saver = tf.train.Saver() |
129 | 131 | ||
130 | print ('start training..!') | 132 | print ('start training..!') |
133 | + f_interval = 15 | ||
131 | for step in range(self.train_iter+1): | 134 | for step in range(self.train_iter+1): |
132 | 135 | ||
133 | i = step % int(svhn_images.shape[0] / self.batch_size) | 136 | i = step % int(svhn_images.shape[0] / self.batch_size) |
... | @@ -143,7 +146,10 @@ class Solver(object): | ... | @@ -143,7 +146,10 @@ class Solver(object): |
143 | sess.run([model.g_train_op_src], feed_dict) | 146 | sess.run([model.g_train_op_src], feed_dict) |
144 | sess.run([model.g_train_op_src], feed_dict) | 147 | sess.run([model.g_train_op_src], feed_dict) |
145 | 148 | ||
146 | - if i % 15 == 0: | 149 | + if step > 1600: |
150 | + f_interval = 30 | ||
151 | + | ||
152 | + if i % f_interval == 0: | ||
147 | sess.run(model.f_train_op_src, feed_dict) | 153 | sess.run(model.f_train_op_src, feed_dict) |
148 | 154 | ||
149 | if (step+1) % 10 == 0: | 155 | if (step+1) % 10 == 0: | ... | ... |
-
Please register or login to post a comment