Showing
4 changed files
with
9 additions
and
19 deletions
1 | +mkdir -p mnist | ||
1 | mkdir -p svhn | 2 | mkdir -p svhn |
2 | wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat | 3 | wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat |
3 | wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat | 4 | wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat |
5 | +wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat | ||
4 | 6 | ||
5 | 7 | ... | ... |
... | @@ -3,15 +3,12 @@ from model import DTN | ... | @@ -3,15 +3,12 @@ from model import DTN |
3 | from solver import Solver | 3 | from solver import Solver |
4 | 4 | ||
5 | 5 | ||
6 | - | ||
7 | flags = tf.app.flags | 6 | flags = tf.app.flags |
8 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") | 7 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") |
9 | - | ||
10 | FLAGS = flags.FLAGS | 8 | FLAGS = flags.FLAGS |
11 | 9 | ||
12 | def main(_): | 10 | def main(_): |
13 | 11 | ||
14 | - | ||
15 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) | 12 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) |
16 | solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 13 | solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, |
17 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') | 14 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') |
... | @@ -25,6 +22,3 @@ def main(_): | ... | @@ -25,6 +22,3 @@ def main(_): |
25 | 22 | ||
26 | if __name__ == '__main__': | 23 | if __name__ == '__main__': |
27 | tf.app.run() | 24 | tf.app.run() |
... | \ No newline at end of file | ... | \ No newline at end of file |
28 | - | ||
29 | - | ||
30 | - | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -33,7 +33,6 @@ class DTN(object): | ... | @@ -33,7 +33,6 @@ class DTN(object): |
33 | if self.mode == 'pretrain': | 33 | if self.mode == 'pretrain': |
34 | net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') | 34 | net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') |
35 | net = slim.flatten(net) | 35 | net = slim.flatten(net) |
36 | - | ||
37 | return net | 36 | return net |
38 | 37 | ||
39 | def generator(self, inputs, reuse=False): | 38 | def generator(self, inputs, reuse=False): |
... | @@ -106,7 +105,6 @@ class DTN(object): | ... | @@ -106,7 +105,6 @@ class DTN(object): |
106 | 105 | ||
107 | 106 | ||
108 | # source domain (svhn to mnist) | 107 | # source domain (svhn to mnist) |
109 | - with tf.name_scope('model_for_source_domain'): | ||
110 | self.fx = self.content_extractor(self.src_images) | 108 | self.fx = self.content_extractor(self.src_images) |
111 | self.fake_images = self.generator(self.fx) | 109 | self.fake_images = self.generator(self.fx) |
112 | self.logits = self.discriminator(self.fake_images) | 110 | self.logits = self.discriminator(self.fake_images) |
... | @@ -128,7 +126,6 @@ class DTN(object): | ... | @@ -128,7 +126,6 @@ class DTN(object): |
128 | f_vars = [var for var in t_vars if 'content_extractor' in var.name] | 126 | f_vars = [var for var in t_vars if 'content_extractor' in var.name] |
129 | 127 | ||
130 | # train op | 128 | # train op |
131 | - with tf.name_scope('source_train_op'): | ||
132 | self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) | 129 | self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) |
133 | self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) | 130 | self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) |
134 | self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) | 131 | self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) |
... | @@ -144,7 +141,6 @@ class DTN(object): | ... | @@ -144,7 +141,6 @@ class DTN(object): |
144 | sampled_images_summary]) | 141 | sampled_images_summary]) |
145 | 142 | ||
146 | # target domain (mnist) | 143 | # target domain (mnist) |
147 | - with tf.name_scope('model_for_target_domain'): | ||
148 | self.fx = self.content_extractor(self.trg_images, reuse=True) | 144 | self.fx = self.content_extractor(self.trg_images, reuse=True) |
149 | self.reconst_images = self.generator(self.fx, reuse=True) | 145 | self.reconst_images = self.generator(self.fx, reuse=True) |
150 | self.logits_fake = self.discriminator(self.reconst_images, reuse=True) | 146 | self.logits_fake = self.discriminator(self.reconst_images, reuse=True) |
... | @@ -162,13 +158,7 @@ class DTN(object): | ... | @@ -162,13 +158,7 @@ class DTN(object): |
162 | self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | 158 | self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) |
163 | self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | 159 | self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) |
164 | 160 | ||
165 | - t_vars = tf.trainable_variables() | ||
166 | - d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
167 | - g_vars = [var for var in t_vars if 'generator' in var.name] | ||
168 | - f_vars = [var for var in t_vars if 'content_extractor' in var.name] | ||
169 | - | ||
170 | # train op | 161 | # train op |
171 | - with tf.name_scope('target_train_op'): | ||
172 | self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) | 162 | self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) |
173 | self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) | 163 | self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) |
174 | 164 | ... | ... |
... | @@ -9,7 +9,7 @@ import scipy.misc | ... | @@ -9,7 +9,7 @@ import scipy.misc |
9 | 9 | ||
10 | class Solver(object): | 10 | class Solver(object): |
11 | 11 | ||
12 | - def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 12 | + def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100, |
13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', | 13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', |
14 | model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): | 14 | model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): |
15 | self.model = model | 15 | self.model = model |
... | @@ -29,7 +29,12 @@ class Solver(object): | ... | @@ -29,7 +29,12 @@ class Solver(object): |
29 | 29 | ||
30 | def load_svhn(self, image_dir, split='train'): | 30 | def load_svhn(self, image_dir, split='train'): |
31 | print ('loading svhn image dataset..') | 31 | print ('loading svhn image dataset..') |
32 | + | ||
33 | + if self.model.mode == 'pretrain': | ||
34 | + image_file = 'extra_32x32.mat' if split=='train' else 'test_32x32.mat' | ||
35 | + else: | ||
32 | image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' | 36 | image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' |
37 | + | ||
33 | image_dir = os.path.join(image_dir, image_file) | 38 | image_dir = os.path.join(image_dir, image_file) |
34 | svhn = scipy.io.loadmat(image_dir) | 39 | svhn = scipy.io.loadmat(image_dir) |
35 | images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 | 40 | images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 |
... | @@ -136,10 +141,10 @@ class Solver(object): | ... | @@ -136,10 +141,10 @@ class Solver(object): |
136 | sess.run([model.g_train_op_src], feed_dict) | 141 | sess.run([model.g_train_op_src], feed_dict) |
137 | sess.run([model.g_train_op_src], feed_dict) | 142 | sess.run([model.g_train_op_src], feed_dict) |
138 | sess.run([model.g_train_op_src], feed_dict) | 143 | sess.run([model.g_train_op_src], feed_dict) |
144 | + | ||
139 | if i % 15 == 0: | 145 | if i % 15 == 0: |
140 | sess.run(model.f_train_op_src, feed_dict) | 146 | sess.run(model.f_train_op_src, feed_dict) |
141 | 147 | ||
142 | - | ||
143 | if (step+1) % 10 == 0: | 148 | if (step+1) % 10 == 0: |
144 | summary, dl, gl, fl = sess.run([model.summary_op_src, \ | 149 | summary, dl, gl, fl = sess.run([model.summary_op_src, \ |
145 | model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) | 150 | model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) |
... | @@ -169,7 +174,6 @@ class Solver(object): | ... | @@ -169,7 +174,6 @@ class Solver(object): |
169 | saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) | 174 | saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) |
170 | print ('model/dtn-%d saved' %(step+1)) | 175 | print ('model/dtn-%d saved' %(step+1)) |
171 | 176 | ||
172 | - | ||
173 | def eval(self): | 177 | def eval(self): |
174 | # build model | 178 | # build model |
175 | model = self.model | 179 | model = self.model | ... | ... |
-
Please register or login to post a comment