Showing
1 changed file
with
13 additions
and
7 deletions
| ... | @@ -9,9 +9,9 @@ import scipy.misc | ... | @@ -9,9 +9,9 @@ import scipy.misc |
| 9 | 9 | ||
| 10 | class Solver(object): | 10 | class Solver(object): |
| 11 | 11 | ||
| 12 | - def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100, | 12 | + def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, |
| 13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', | 13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', |
| 14 | - model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): | 14 | + model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'): |
| 15 | 15 | ||
| 16 | self.model = model | 16 | self.model = model |
| 17 | self.batch_size = batch_size | 17 | self.batch_size = batch_size |
| ... | @@ -111,7 +111,7 @@ class Solver(object): | ... | @@ -111,7 +111,7 @@ class Solver(object): |
| 111 | model = self.model | 111 | model = self.model |
| 112 | model.build_model() | 112 | model.build_model() |
| 113 | 113 | ||
| 114 | - # make log directory if not exists | 114 | + # make directory if not exists |
| 115 | if tf.gfile.Exists(self.log_dir): | 115 | if tf.gfile.Exists(self.log_dir): |
| 116 | tf.gfile.DeleteRecursively(self.log_dir) | 116 | tf.gfile.DeleteRecursively(self.log_dir) |
| 117 | tf.gfile.MakeDirs(self.log_dir) | 117 | tf.gfile.MakeDirs(self.log_dir) |
| ... | @@ -121,13 +121,16 @@ class Solver(object): | ... | @@ -121,13 +121,16 @@ class Solver(object): |
| 121 | tf.global_variables_initializer().run() | 121 | tf.global_variables_initializer().run() |
| 122 | # restore variables of F | 122 | # restore variables of F |
| 123 | print ('loading pretrained model F..') | 123 | print ('loading pretrained model F..') |
| 124 | - variables_to_restore = slim.get_model_variables(scope='content_extractor') | 124 | + #variables_to_restore = slim.get_model_variables(scope='content_extractor') |
| 125 | - restorer = tf.train.Saver(variables_to_restore) | 125 | + #restorer = tf.train.Saver(variables_to_restore) |
| 126 | - restorer.restore(sess, self.pretrained_model) | 126 | + #restorer.restore(sess, self.pretrained_model) |
| 127 | + restorer = tf.train.Saver() | ||
| 128 | + restorer.restore(sess, 'model/dtn-1600') | ||
| 127 | summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) | 129 | summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) |
| 128 | saver = tf.train.Saver() | 130 | saver = tf.train.Saver() |
| 129 | 131 | ||
| 130 | print ('start training..!') | 132 | print ('start training..!') |
| 133 | + f_interval = 15 | ||
| 131 | for step in range(self.train_iter+1): | 134 | for step in range(self.train_iter+1): |
| 132 | 135 | ||
| 133 | i = step % int(svhn_images.shape[0] / self.batch_size) | 136 | i = step % int(svhn_images.shape[0] / self.batch_size) |
| ... | @@ -143,7 +146,10 @@ class Solver(object): | ... | @@ -143,7 +146,10 @@ class Solver(object): |
| 143 | sess.run([model.g_train_op_src], feed_dict) | 146 | sess.run([model.g_train_op_src], feed_dict) |
| 144 | sess.run([model.g_train_op_src], feed_dict) | 147 | sess.run([model.g_train_op_src], feed_dict) |
| 145 | 148 | ||
| 146 | - if i % 15 == 0: | 149 | + if step > 1600: |
| 150 | + f_interval = 30 | ||
| 151 | + | ||
| 152 | + if i % f_interval == 0: | ||
| 147 | sess.run(model.f_train_op_src, feed_dict) | 153 | sess.run(model.f_train_op_src, feed_dict) |
| 148 | 154 | ||
| 149 | if (step+1) % 10 == 0: | 155 | if (step+1) % 10 == 0: | ... | ... |
-
Please register or login to post a comment