Showing
1 changed file
with
150 additions
and
112 deletions
1 | import tensorflow as tf | 1 | import tensorflow as tf |
2 | +import tensorflow.contrib.slim as slim | ||
2 | import numpy as np | 3 | import numpy as np |
4 | +import pickle | ||
3 | import os | 5 | import os |
4 | import scipy.io | 6 | import scipy.io |
5 | -import hickle | ||
6 | import scipy.misc | 7 | import scipy.misc |
7 | -from config import SummaryWriter | ||
8 | 8 | ||
9 | 9 | ||
10 | class Solver(object): | 10 | class Solver(object): |
11 | - """Load dataset and train and test the model""" | ||
12 | 11 | ||
13 | - def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', | 12 | + def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, |
14 | - log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100): | 13 | + svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', |
14 | + model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): | ||
15 | self.model = model | 15 | self.model = model |
16 | - self.num_epoch = num_epoch | 16 | + self.batch_size = batch_size |
17 | - self.mnist_path = mnist_path | 17 | + self.pretrain_iter = pretrain_iter |
18 | - self.svhn_path = svhn_path | 18 | + self.train_iter = train_iter |
19 | - self.model_save_path = model_save_path | ||
20 | - self.log_path = log_path | ||
21 | - self.sample_path = sample_path | ||
22 | - self.test_model_path = test_model_path | ||
23 | self.sample_iter = sample_iter | 19 | self.sample_iter = sample_iter |
20 | + self.svhn_dir = svhn_dir | ||
21 | + self.mnist_dir = mnist_dir | ||
22 | + self.log_dir = log_dir | ||
23 | + self.sample_save_path = sample_save_path | ||
24 | + self.model_save_path = model_save_path | ||
25 | + self.pretrained_model = pretrained_model | ||
26 | + self.test_model = test_model | ||
27 | + self.config = tf.ConfigProto() | ||
28 | + self.config.gpu_options.allow_growth=True | ||
24 | 29 | ||
25 | - # create directory if not exists | 30 | + def load_svhn(self, image_dir, split='train'): |
26 | - if not os.path.exists(log_path): | ||
27 | - os.makedirs(log_path) | ||
28 | - if not os.path.exists(model_save_path): | ||
29 | - os.makedirs(model_save_path) | ||
30 | - if not os.path.exists(sample_path): | ||
31 | - os.makedirs(sample_path) | ||
32 | - | ||
33 | - # construct the dcgan model | ||
34 | - model.build_model() | ||
35 | - | ||
36 | - | ||
37 | - def load_svhn(self, image_path, split='train'): | ||
38 | print ('loading svhn image dataset..') | 31 | print ('loading svhn image dataset..') |
39 | - if split == 'train': | 32 | + image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' |
40 | - svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat')) | 33 | + image_dir = os.path.join(image_dir, image_file) |
41 | - else: | 34 | + svhn = scipy.io.loadmat(image_dir) |
42 | - svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat')) | 35 | + images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 |
43 | - | 36 | + labels = svhn['y'].reshape(-1) |
44 | - images = np.transpose(svhn['X'], [3, 0, 1, 2]) | 37 | + labels[np.where(labels==10)] = 0 |
45 | - images = images / 127.5 - 1 | ||
46 | print ('finished loading svhn image dataset..!') | 38 | print ('finished loading svhn image dataset..!') |
47 | - return images | 39 | + return images, labels |
48 | 40 | ||
49 | - | 41 | + def load_mnist(self, image_dir, split='train'): |
50 | - def load_mnist(self, image_path, split='train'): | ||
51 | print ('loading mnist image dataset..') | 42 | print ('loading mnist image dataset..') |
52 | - if split == 'train': | 43 | + image_file = 'train.pkl' if split=='train' else 'test.pkl' |
53 | - image_file = os.path.join(image_path, 'train.images.hkl') | 44 | + image_dir = os.path.join(image_dir, image_file) |
54 | - else: | 45 | + with open(image_dir, 'rb') as f: |
55 | - image_file = os.path.join(image_path, 'test.images.hkl') | 46 | + mnist = pickle.load(f) |
56 | - | 47 | + images = mnist['X'] / 127.5 - 1 |
57 | - try: | 48 | + labels = mnist['y'] |
58 | - images = hickle.load(image_file) | ||
59 | - except: | ||
60 | - hickle.load(images, image_file) | ||
61 | - | ||
62 | - images = images / 127.5 - 1 | ||
63 | print ('finished loading mnist image dataset..!') | 49 | print ('finished loading mnist image dataset..!') |
64 | - return images | 50 | + return images, labels |
65 | - | ||
66 | 51 | ||
67 | def merge_images(self, sources, targets, k=10): | 52 | def merge_images(self, sources, targets, k=10): |
68 | _, h, w, _ = sources.shape | 53 | _, h, w, _ = sources.shape |
69 | - row = int(np.sqrt(self.model.batch_size)) | 54 | + row = int(np.sqrt(self.batch_size)) |
70 | merged = np.zeros([row*h, row*w*2, 3]) | 55 | merged = np.zeros([row*h, row*w*2, 3]) |
71 | 56 | ||
72 | for idx, (s, t) in enumerate(zip(sources, targets)): | 57 | for idx, (s, t) in enumerate(zip(sources, targets)): |
... | @@ -74,87 +59,140 @@ class Solver(object): | ... | @@ -74,87 +59,140 @@ class Solver(object): |
74 | j = idx % row | 59 | j = idx % row |
75 | merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s | 60 | merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s |
76 | merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t | 61 | merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t |
77 | - | ||
78 | return merged | 62 | return merged |
79 | 63 | ||
64 | + def pretrain(self): | ||
65 | + # load svhn dataset | ||
66 | + train_images, train_labels = self.load_svhn(self.svhn_dir, split='train') | ||
67 | + test_images, test_labels = self.load_svhn(self.svhn_dir, split='test') | ||
80 | 68 | ||
81 | - def train(self): | 69 | + # build a graph |
82 | - model=self.model | 70 | + model = self.model |
83 | - | 71 | + model.build_model() |
84 | - # load image dataset | ||
85 | - svhn = self.load_svhn(self.svhn_path) | ||
86 | - mnist = self.load_mnist(self.mnist_path) | ||
87 | 72 | ||
73 | + with tf.Session(config=self.config) as sess: | ||
74 | + tf.global_variables_initializer().run() | ||
75 | + saver = tf.train.Saver() | ||
76 | + summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) | ||
77 | + | ||
78 | + for step in range(self.pretrain_iter+1): | ||
79 | + i = step % int(train_images.shape[0] / self.batch_size) | ||
80 | + batch_images = train_images[i*self.batch_size:(i+1)*self.batch_size] | ||
81 | + batch_labels = train_labels[i*self.batch_size:(i+1)*self.batch_size] | ||
82 | + feed_dict = {model.images: batch_images, model.labels: batch_labels} | ||
83 | + sess.run(model.train_op, feed_dict) | ||
84 | + | ||
85 | + if (step+1) % 10 == 0: | ||
86 | + summary, l, acc = sess.run([model.summary_op, model.loss, model.accuracy], feed_dict) | ||
87 | + rand_idxs = np.random.permutation(test_images.shape[0])[:self.batch_size] | ||
88 | + test_acc, _ = sess.run(fetches=[model.accuracy, model.loss], | ||
89 | + feed_dict={model.images: test_images[rand_idxs], | ||
90 | + model.labels: test_labels[rand_idxs]}) | ||
91 | + summary_writer.add_summary(summary, step) | ||
92 | + print ('Step: [%d/%d] loss: [%.6f] train acc: [%.2f] test acc [%.2f]' \ | ||
93 | + %(step+1, self.pretrain_iter, l, acc, test_acc)) | ||
94 | + | ||
95 | + if (step+1) % 1000 == 0: | ||
96 | + saver.save(sess, os.path.join(self.model_save_path, 'svhn_model'), global_step=step+1) | ||
97 | + print ('svhn_model-%d saved..!' %(step+1)) | ||
88 | 98 | ||
89 | - num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) | 99 | + def train(self): |
100 | + # load svhn dataset | ||
101 | + svhn_images, _ = self.load_svhn(self.svhn_dir, split='train') | ||
102 | + mnist_images, _ = self.load_mnist(self.mnist_dir, split='train') | ||
90 | 103 | ||
91 | - config = tf.ConfigProto(allow_soft_placement = True) | 104 | + # build a graph |
92 | - config.gpu_options.allow_growth = True | 105 | + model = self.model |
93 | - with tf.Session(config=config) as sess: | 106 | + model.build_model() |
94 | - # initialize parameters | ||
95 | - try: | ||
96 | - tf.global_variables_initializer().run() | ||
97 | - except: | ||
98 | - tf.initialize_all_variables().run() | ||
99 | 107 | ||
100 | - summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) | 108 | + # make log directory if not exists |
109 | + if tf.gfile.Exists(self.log_dir): | ||
110 | + tf.gfile.DeleteRecursively(self.log_dir) | ||
111 | + tf.gfile.MakeDirs(self.log_dir) | ||
101 | 112 | ||
102 | - for e in range(self.num_epoch): | 113 | + with tf.Session(config=self.config) as sess: |
103 | - for i in range(num_iter_per_epoch): | 114 | + # initialize G and D |
115 | + tf.global_variables_initializer().run() | ||
116 | + # restore variables of F | ||
117 | + print ('loading pretrained model F..') | ||
118 | + variables_to_restore = slim.get_model_variables(scope='content_extractor') | ||
119 | + restorer = tf.train.Saver(variables_to_restore) | ||
120 | + restorer.restore(sess, self.pretrained_model) | ||
121 | + summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) | ||
122 | + saver = tf.train.Saver() | ||
104 | 123 | ||
105 | - # train model for source domain S | 124 | + print ('start training..!') |
106 | - image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | 125 | + for step in range(self.train_iter+1): |
107 | - feed_dict = {model.images: image_batch} | 126 | + |
108 | - sess.run(model.d_optimizer_fake, feed_dict) | 127 | + i = step % int(svhn_images.shape[0] / self.batch_size) |
109 | - sess.run(model.g_optimizer, feed_dict) | 128 | + # train the model for source domain S |
110 | - sess.run(model.g_optimizer, feed_dict) | 129 | + src_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size] |
111 | - if i % 3 == 0: | 130 | + feed_dict = {model.src_images: src_images} |
112 | - sess.run(model.f_optimizer_const, feed_dict) | 131 | + |
113 | - | 132 | + sess.run(model.d_train_op_src, feed_dict) |
114 | - if i % 10 == 0: | 133 | + sess.run([model.g_train_op_src], feed_dict) |
115 | - feed_dict = {model.images: image_batch} | 134 | + sess.run([model.g_train_op_src], feed_dict) |
116 | - summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict) | 135 | + sess.run([model.g_train_op_src], feed_dict) |
117 | - summary_writer.add_summary(summary, e*num_iter_per_epoch + i) | 136 | + sess.run([model.g_train_op_src], feed_dict) |
118 | - print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) | 137 | + sess.run([model.g_train_op_src], feed_dict) |
119 | - | 138 | + sess.run([model.g_train_op_src], feed_dict) |
120 | - # train model for target domain T | 139 | + if i % 15 == 0: |
121 | - image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] | 140 | + sess.run(model.f_train_op_src, feed_dict) |
122 | - feed_dict = {model.images: image_batch} | 141 | + |
123 | - sess.run(model.d_optimizer_real, feed_dict) | 142 | + |
124 | - sess.run(model.d_optimizer_fake, feed_dict) | 143 | + if (step+1) % 10 == 0: |
125 | - sess.run(model.g_optimizer, feed_dict) | 144 | + summary, dl, gl, fl = sess.run([model.summary_op_src, \ |
126 | - sess.run(model.g_optimizer, feed_dict) | 145 | + model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) |
127 | - sess.run(model.g_optimizer, feed_dict) | 146 | + summary_writer.add_summary(summary, step) |
128 | - sess.run(model.g_optimizer_const, feed_dict) | 147 | + print ('[Source] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f] f_loss: [%.6f]' \ |
129 | - sess.run(model.g_optimizer_const, feed_dict) | 148 | + %(step+1, self.train_iter, dl, gl, fl)) |
130 | - | 149 | + |
131 | - if i % 500 == 0: | 150 | + # train the model for target domain T |
132 | - model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1) | 151 | + j = step % int(mnist_images.shape[0] / self.batch_size) |
133 | - print ('model/dtn-%d-%d saved' %(e+1, i+1)) | 152 | + trg_images = mnist_images[j*self.batch_size:(j+1)*self.batch_size] |
134 | - | 153 | + feed_dict = {model.src_images: src_images, model.trg_images: trg_images} |
135 | - | 154 | + sess.run(model.d_train_op_trg, feed_dict) |
136 | - def test(self): | 155 | + sess.run(model.d_train_op_trg, feed_dict) |
156 | + sess.run(model.g_train_op_trg, feed_dict) | ||
157 | + sess.run(model.g_train_op_trg, feed_dict) | ||
158 | + sess.run(model.g_train_op_trg, feed_dict) | ||
159 | + sess.run(model.g_train_op_trg, feed_dict) | ||
160 | + | ||
161 | + if (step+1) % 10 == 0: | ||
162 | + summary, dl, gl = sess.run([model.summary_op_trg, \ | ||
163 | + model.d_loss_trg, model.g_loss_trg], feed_dict) | ||
164 | + summary_writer.add_summary(summary, step) | ||
165 | + print ('[Target] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' \ | ||
166 | + %(step+1, self.train_iter, dl, gl)) | ||
167 | + | ||
168 | + if (step+1) % 200 == 0: | ||
169 | + saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) | ||
170 | + print ('model/dtn-%d saved' %(step+1)) | ||
171 | + | ||
172 | + | ||
173 | + def eval(self): | ||
174 | + # build model | ||
137 | model = self.model | 175 | model = self.model |
176 | + model.build_model() | ||
138 | 177 | ||
139 | - # load dataset | 178 | + # load svhn dataset |
140 | - svhn = self.load_svhn(self.svhn_path) | 179 | + svhn_images, _ = self.load_svhn(self.svhn_dir) |
141 | - num_iter = int(svhn.shape[0] / model.batch_size) | ||
142 | 180 | ||
143 | - config = tf.ConfigProto(allow_soft_placement = True) | 181 | + with tf.Session(config=self.config) as sess: |
144 | - config.gpu_options.allow_growth = True | ||
145 | - with tf.Session(config=config) as sess: | ||
146 | # load trained parameters | 182 | # load trained parameters |
183 | + print ('loading test model..') | ||
147 | saver = tf.train.Saver() | 184 | saver = tf.train.Saver() |
148 | - saver.restore(sess, self.test_model_path) | 185 | + saver.restore(sess, self.test_model) |
149 | 186 | ||
187 | + print ('start sampling..!') | ||
150 | for i in range(self.sample_iter): | 188 | for i in range(self.sample_iter): |
151 | # train model for source domain S | 189 | # train model for source domain S |
152 | - image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | 190 | + batch_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size] |
153 | - feed_dict = {model.images: image_batch} | 191 | + feed_dict = {model.images: batch_images} |
154 | - sampled_image_batch = sess.run(model.sampled_images, feed_dict) | 192 | + sampled_batch_images = sess.run(model.sampled_images, feed_dict) |
155 | 193 | ||
156 | # merge and save source images and sampled target images | 194 | # merge and save source images and sampled target images |
157 | - merged = self.merge_images(image_batch, sampled_image_batch) | 195 | + merged = self.merge_images(batch_images, sampled_batch_images) |
158 | - path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size)) | 196 | + path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size)) |
159 | scipy.misc.imsave(path, merged) | 197 | scipy.misc.imsave(path, merged) |
160 | print ('saved %s' %path) | 198 | print ('saved %s' %path) | ... | ... |
-
Please register or login to post a comment