yunjey

.

1 +mkdir -p mnist
1 mkdir -p svhn 2 mkdir -p svhn
2 wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat 3 wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
3 wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat 4 wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
5 +wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat
4 6
5 7
......
...@@ -3,15 +3,12 @@ from model import DTN ...@@ -3,15 +3,12 @@ from model import DTN
3 from solver import Solver 3 from solver import Solver
4 4
5 5
6 -
7 flags = tf.app.flags 6 flags = tf.app.flags
8 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") 7 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
9 -
10 FLAGS = flags.FLAGS 8 FLAGS = flags.FLAGS
11 9
12 def main(_): 10 def main(_):
13 11
14 -
15 model = DTN(mode=FLAGS.mode, learning_rate=0.0003) 12 model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
16 solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, 13 solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
17 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') 14 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model')
...@@ -25,6 +22,3 @@ def main(_): ...@@ -25,6 +22,3 @@ def main(_):
25 22
26 if __name__ == '__main__': 23 if __name__ == '__main__':
27 tf.app.run() 24 tf.app.run()
...\ No newline at end of file ...\ No newline at end of file
28 -
29 -
30 -
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -33,7 +33,6 @@ class DTN(object): ...@@ -33,7 +33,6 @@ class DTN(object):
33 if self.mode == 'pretrain': 33 if self.mode == 'pretrain':
34 net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') 34 net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
35 net = slim.flatten(net) 35 net = slim.flatten(net)
36 -
37 return net 36 return net
38 37
39 def generator(self, inputs, reuse=False): 38 def generator(self, inputs, reuse=False):
...@@ -106,7 +105,6 @@ class DTN(object): ...@@ -106,7 +105,6 @@ class DTN(object):
106 105
107 106
108 # source domain (svhn to mnist) 107 # source domain (svhn to mnist)
109 - with tf.name_scope('model_for_source_domain'):
110 self.fx = self.content_extractor(self.src_images) 108 self.fx = self.content_extractor(self.src_images)
111 self.fake_images = self.generator(self.fx) 109 self.fake_images = self.generator(self.fx)
112 self.logits = self.discriminator(self.fake_images) 110 self.logits = self.discriminator(self.fake_images)
...@@ -128,7 +126,6 @@ class DTN(object): ...@@ -128,7 +126,6 @@ class DTN(object):
128 f_vars = [var for var in t_vars if 'content_extractor' in var.name] 126 f_vars = [var for var in t_vars if 'content_extractor' in var.name]
129 127
130 # train op 128 # train op
131 - with tf.name_scope('source_train_op'):
132 self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) 129 self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
133 self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) 130 self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
134 self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) 131 self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
...@@ -144,7 +141,6 @@ class DTN(object): ...@@ -144,7 +141,6 @@ class DTN(object):
144 sampled_images_summary]) 141 sampled_images_summary])
145 142
146 # target domain (mnist) 143 # target domain (mnist)
147 - with tf.name_scope('model_for_target_domain'):
148 self.fx = self.content_extractor(self.trg_images, reuse=True) 144 self.fx = self.content_extractor(self.trg_images, reuse=True)
149 self.reconst_images = self.generator(self.fx, reuse=True) 145 self.reconst_images = self.generator(self.fx, reuse=True)
150 self.logits_fake = self.discriminator(self.reconst_images, reuse=True) 146 self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
...@@ -162,13 +158,7 @@ class DTN(object): ...@@ -162,13 +158,7 @@ class DTN(object):
162 self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) 158 self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
163 self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) 159 self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
164 160
165 - t_vars = tf.trainable_variables()
166 - d_vars = [var for var in t_vars if 'discriminator' in var.name]
167 - g_vars = [var for var in t_vars if 'generator' in var.name]
168 - f_vars = [var for var in t_vars if 'content_extractor' in var.name]
169 -
170 # train op 161 # train op
171 - with tf.name_scope('target_train_op'):
172 self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) 162 self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
173 self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) 163 self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
174 164
......
...@@ -9,7 +9,7 @@ import scipy.misc ...@@ -9,7 +9,7 @@ import scipy.misc
9 9
10 class Solver(object): 10 class Solver(object):
11 11
12 - def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, 12 + def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100,
13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', 13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
14 model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): 14 model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
15 self.model = model 15 self.model = model
...@@ -29,7 +29,12 @@ class Solver(object): ...@@ -29,7 +29,12 @@ class Solver(object):
29 29
30 def load_svhn(self, image_dir, split='train'): 30 def load_svhn(self, image_dir, split='train'):
31 print ('loading svhn image dataset..') 31 print ('loading svhn image dataset..')
32 +
33 + if self.model.mode == 'pretrain':
34 + image_file = 'extra_32x32.mat' if split=='train' else 'test_32x32.mat'
35 + else:
32 image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' 36 image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat'
37 +
33 image_dir = os.path.join(image_dir, image_file) 38 image_dir = os.path.join(image_dir, image_file)
34 svhn = scipy.io.loadmat(image_dir) 39 svhn = scipy.io.loadmat(image_dir)
35 images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 40 images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1
...@@ -136,10 +141,10 @@ class Solver(object): ...@@ -136,10 +141,10 @@ class Solver(object):
136 sess.run([model.g_train_op_src], feed_dict) 141 sess.run([model.g_train_op_src], feed_dict)
137 sess.run([model.g_train_op_src], feed_dict) 142 sess.run([model.g_train_op_src], feed_dict)
138 sess.run([model.g_train_op_src], feed_dict) 143 sess.run([model.g_train_op_src], feed_dict)
144 +
139 if i % 15 == 0: 145 if i % 15 == 0:
140 sess.run(model.f_train_op_src, feed_dict) 146 sess.run(model.f_train_op_src, feed_dict)
141 147
142 -
143 if (step+1) % 10 == 0: 148 if (step+1) % 10 == 0:
144 summary, dl, gl, fl = sess.run([model.summary_op_src, \ 149 summary, dl, gl, fl = sess.run([model.summary_op_src, \
145 model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) 150 model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict)
...@@ -169,7 +174,6 @@ class Solver(object): ...@@ -169,7 +174,6 @@ class Solver(object):
169 saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) 174 saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1)
170 print ('model/dtn-%d saved' %(step+1)) 175 print ('model/dtn-%d saved' %(step+1))
171 176
172 -
173 def eval(self): 177 def eval(self):
174 # build model 178 # build model
175 model = self.model 179 model = self.model
......