yunjey

train and eval the model

Showing 1 changed file with 150 additions and 112 deletions
1 import tensorflow as tf 1 import tensorflow as tf
2 +import tensorflow.contrib.slim as slim
2 import numpy as np 3 import numpy as np
4 +import pickle
3 import os 5 import os
4 import scipy.io 6 import scipy.io
5 -import hickle
6 import scipy.misc 7 import scipy.misc
7 -from config import SummaryWriter
8 8
9 9
10 class Solver(object): 10 class Solver(object):
11 - """Load dataset and train and test the model"""
12 11
13 - def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', 12 + def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
14 - log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100): 13 + svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
14 + model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
15 self.model = model 15 self.model = model
16 - self.num_epoch = num_epoch 16 + self.batch_size = batch_size
17 - self.mnist_path = mnist_path 17 + self.pretrain_iter = pretrain_iter
18 - self.svhn_path = svhn_path 18 + self.train_iter = train_iter
19 - self.model_save_path = model_save_path
20 - self.log_path = log_path
21 - self.sample_path = sample_path
22 - self.test_model_path = test_model_path
23 self.sample_iter = sample_iter 19 self.sample_iter = sample_iter
20 + self.svhn_dir = svhn_dir
21 + self.mnist_dir = mnist_dir
22 + self.log_dir = log_dir
23 + self.sample_save_path = sample_save_path
24 + self.model_save_path = model_save_path
25 + self.pretrained_model = pretrained_model
26 + self.test_model = test_model
27 + self.config = tf.ConfigProto()
28 + self.config.gpu_options.allow_growth=True
24 29
25 - # create directory if not exists 30 + def load_svhn(self, image_dir, split='train'):
26 - if not os.path.exists(log_path):
27 - os.makedirs(log_path)
28 - if not os.path.exists(model_save_path):
29 - os.makedirs(model_save_path)
30 - if not os.path.exists(sample_path):
31 - os.makedirs(sample_path)
32 -
33 - # construct the dcgan model
34 - model.build_model()
35 -
36 -
37 - def load_svhn(self, image_path, split='train'):
38 print ('loading svhn image dataset..') 31 print ('loading svhn image dataset..')
39 - if split == 'train': 32 + image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat'
40 - svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat')) 33 + image_dir = os.path.join(image_dir, image_file)
41 - else: 34 + svhn = scipy.io.loadmat(image_dir)
42 - svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat')) 35 + images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1
43 - 36 + labels = svhn['y'].reshape(-1)
44 - images = np.transpose(svhn['X'], [3, 0, 1, 2]) 37 + labels[np.where(labels==10)] = 0
45 - images = images / 127.5 - 1
46 print ('finished loading svhn image dataset..!') 38 print ('finished loading svhn image dataset..!')
47 - return images 39 + return images, labels
48 40
49 - 41 + def load_mnist(self, image_dir, split='train'):
50 - def load_mnist(self, image_path, split='train'):
51 print ('loading mnist image dataset..') 42 print ('loading mnist image dataset..')
52 - if split == 'train': 43 + image_file = 'train.pkl' if split=='train' else 'test.pkl'
53 - image_file = os.path.join(image_path, 'train.images.hkl') 44 + image_dir = os.path.join(image_dir, image_file)
54 - else: 45 + with open(image_dir, 'rb') as f:
55 - image_file = os.path.join(image_path, 'test.images.hkl') 46 + mnist = pickle.load(f)
56 - 47 + images = mnist['X'] / 127.5 - 1
57 - try: 48 + labels = mnist['y']
58 - images = hickle.load(image_file)
59 - except:
60 - hickle.load(images, image_file)
61 -
62 - images = images / 127.5 - 1
63 print ('finished loading mnist image dataset..!') 49 print ('finished loading mnist image dataset..!')
64 - return images 50 + return images, labels
65 -
66 51
67 def merge_images(self, sources, targets, k=10): 52 def merge_images(self, sources, targets, k=10):
68 _, h, w, _ = sources.shape 53 _, h, w, _ = sources.shape
69 - row = int(np.sqrt(self.model.batch_size)) 54 + row = int(np.sqrt(self.batch_size))
70 merged = np.zeros([row*h, row*w*2, 3]) 55 merged = np.zeros([row*h, row*w*2, 3])
71 56
72 for idx, (s, t) in enumerate(zip(sources, targets)): 57 for idx, (s, t) in enumerate(zip(sources, targets)):
...@@ -74,87 +59,140 @@ class Solver(object): ...@@ -74,87 +59,140 @@ class Solver(object):
74 j = idx % row 59 j = idx % row
75 merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s 60 merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s
76 merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t 61 merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t
77 -
78 return merged 62 return merged
79 63
64 + def pretrain(self):
65 + # load svhn dataset
66 + train_images, train_labels = self.load_svhn(self.svhn_dir, split='train')
67 + test_images, test_labels = self.load_svhn(self.svhn_dir, split='test')
80 68
81 - def train(self): 69 + # build a graph
82 - model=self.model 70 + model = self.model
83 - 71 + model.build_model()
84 - # load image dataset
85 - svhn = self.load_svhn(self.svhn_path)
86 - mnist = self.load_mnist(self.mnist_path)
87 72
73 + with tf.Session(config=self.config) as sess:
74 + tf.global_variables_initializer().run()
75 + saver = tf.train.Saver()
76 + summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
77 +
78 + for step in range(self.pretrain_iter+1):
79 + i = step % int(train_images.shape[0] / self.batch_size)
80 + batch_images = train_images[i*self.batch_size:(i+1)*self.batch_size]
81 + batch_labels = train_labels[i*self.batch_size:(i+1)*self.batch_size]
82 + feed_dict = {model.images: batch_images, model.labels: batch_labels}
83 + sess.run(model.train_op, feed_dict)
84 +
85 + if (step+1) % 10 == 0:
86 + summary, l, acc = sess.run([model.summary_op, model.loss, model.accuracy], feed_dict)
87 + rand_idxs = np.random.permutation(test_images.shape[0])[:self.batch_size]
88 + test_acc, _ = sess.run(fetches=[model.accuracy, model.loss],
89 + feed_dict={model.images: test_images[rand_idxs],
90 + model.labels: test_labels[rand_idxs]})
91 + summary_writer.add_summary(summary, step)
92 + print ('Step: [%d/%d] loss: [%.6f] train acc: [%.2f] test acc [%.2f]' \
93 + %(step+1, self.pretrain_iter, l, acc, test_acc))
94 +
95 + if (step+1) % 1000 == 0:
96 + saver.save(sess, os.path.join(self.model_save_path, 'svhn_model'), global_step=step+1)
97 + print ('svhn_model-%d saved..!' %(step+1))
88 98
89 - num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) 99 + def train(self):
100 + # load svhn dataset
101 + svhn_images, _ = self.load_svhn(self.svhn_dir, split='train')
102 + mnist_images, _ = self.load_mnist(self.mnist_dir, split='train')
90 103
91 - config = tf.ConfigProto(allow_soft_placement = True) 104 + # build a graph
92 - config.gpu_options.allow_growth = True 105 + model = self.model
93 - with tf.Session(config=config) as sess: 106 + model.build_model()
94 - # initialize parameters
95 - try:
96 - tf.global_variables_initializer().run()
97 - except:
98 - tf.initialize_all_variables().run()
99 107
100 - summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) 108 + # make log directory if not exists
109 + if tf.gfile.Exists(self.log_dir):
110 + tf.gfile.DeleteRecursively(self.log_dir)
111 + tf.gfile.MakeDirs(self.log_dir)
101 112
102 - for e in range(self.num_epoch): 113 + with tf.Session(config=self.config) as sess:
103 - for i in range(num_iter_per_epoch): 114 + # initialize G and D
115 + tf.global_variables_initializer().run()
116 + # restore variables of F
117 + print ('loading pretrained model F..')
118 + variables_to_restore = slim.get_model_variables(scope='content_extractor')
119 + restorer = tf.train.Saver(variables_to_restore)
120 + restorer.restore(sess, self.pretrained_model)
121 + summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
122 + saver = tf.train.Saver()
104 123
105 - # train model for source domain S 124 + print ('start training..!')
106 - image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] 125 + for step in range(self.train_iter+1):
107 - feed_dict = {model.images: image_batch} 126 +
108 - sess.run(model.d_optimizer_fake, feed_dict) 127 + i = step % int(svhn_images.shape[0] / self.batch_size)
109 - sess.run(model.g_optimizer, feed_dict) 128 + # train the model for source domain S
110 - sess.run(model.g_optimizer, feed_dict) 129 + src_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size]
111 - if i % 3 == 0: 130 + feed_dict = {model.src_images: src_images}
112 - sess.run(model.f_optimizer_const, feed_dict) 131 +
113 - 132 + sess.run(model.d_train_op_src, feed_dict)
114 - if i % 10 == 0: 133 + sess.run([model.g_train_op_src], feed_dict)
115 - feed_dict = {model.images: image_batch} 134 + sess.run([model.g_train_op_src], feed_dict)
116 - summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict) 135 + sess.run([model.g_train_op_src], feed_dict)
117 - summary_writer.add_summary(summary, e*num_iter_per_epoch + i) 136 + sess.run([model.g_train_op_src], feed_dict)
118 - print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) 137 + sess.run([model.g_train_op_src], feed_dict)
119 - 138 + sess.run([model.g_train_op_src], feed_dict)
120 - # train model for target domain T 139 + if i % 15 == 0:
121 - image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] 140 + sess.run(model.f_train_op_src, feed_dict)
122 - feed_dict = {model.images: image_batch} 141 +
123 - sess.run(model.d_optimizer_real, feed_dict) 142 +
124 - sess.run(model.d_optimizer_fake, feed_dict) 143 + if (step+1) % 10 == 0:
125 - sess.run(model.g_optimizer, feed_dict) 144 + summary, dl, gl, fl = sess.run([model.summary_op_src, \
126 - sess.run(model.g_optimizer, feed_dict) 145 + model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict)
127 - sess.run(model.g_optimizer, feed_dict) 146 + summary_writer.add_summary(summary, step)
128 - sess.run(model.g_optimizer_const, feed_dict) 147 + print ('[Source] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f] f_loss: [%.6f]' \
129 - sess.run(model.g_optimizer_const, feed_dict) 148 + %(step+1, self.train_iter, dl, gl, fl))
130 - 149 +
131 - if i % 500 == 0: 150 + # train the model for target domain T
132 - model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1) 151 + j = step % int(mnist_images.shape[0] / self.batch_size)
133 - print ('model/dtn-%d-%d saved' %(e+1, i+1)) 152 + trg_images = mnist_images[j*self.batch_size:(j+1)*self.batch_size]
134 - 153 + feed_dict = {model.src_images: src_images, model.trg_images: trg_images}
135 - 154 + sess.run(model.d_train_op_trg, feed_dict)
136 - def test(self): 155 + sess.run(model.d_train_op_trg, feed_dict)
156 + sess.run(model.g_train_op_trg, feed_dict)
157 + sess.run(model.g_train_op_trg, feed_dict)
158 + sess.run(model.g_train_op_trg, feed_dict)
159 + sess.run(model.g_train_op_trg, feed_dict)
160 +
161 + if (step+1) % 10 == 0:
162 + summary, dl, gl = sess.run([model.summary_op_trg, \
163 + model.d_loss_trg, model.g_loss_trg], feed_dict)
164 + summary_writer.add_summary(summary, step)
165 + print ('[Target] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' \
166 + %(step+1, self.train_iter, dl, gl))
167 +
168 + if (step+1) % 200 == 0:
169 + saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1)
170 + print ('model/dtn-%d saved' %(step+1))
171 +
172 +
173 + def eval(self):
174 + # build model
137 model = self.model 175 model = self.model
176 + model.build_model()
138 177
139 - # load dataset 178 + # load svhn dataset
140 - svhn = self.load_svhn(self.svhn_path) 179 + svhn_images, _ = self.load_svhn(self.svhn_dir)
141 - num_iter = int(svhn.shape[0] / model.batch_size)
142 180
143 - config = tf.ConfigProto(allow_soft_placement = True) 181 + with tf.Session(config=self.config) as sess:
144 - config.gpu_options.allow_growth = True
145 - with tf.Session(config=config) as sess:
146 # load trained parameters 182 # load trained parameters
183 + print ('loading test model..')
147 saver = tf.train.Saver() 184 saver = tf.train.Saver()
148 - saver.restore(sess, self.test_model_path) 185 + saver.restore(sess, self.test_model)
149 186
187 + print ('start sampling..!')
150 for i in range(self.sample_iter): 188 for i in range(self.sample_iter):
151 # train model for source domain S 189 # train model for source domain S
152 - image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] 190 + batch_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size]
153 - feed_dict = {model.images: image_batch} 191 + feed_dict = {model.images: batch_images}
154 - sampled_image_batch = sess.run(model.sampled_images, feed_dict) 192 + sampled_batch_images = sess.run(model.sampled_images, feed_dict)
155 193
156 # merge and save source images and sampled target images 194 # merge and save source images and sampled target images
157 - merged = self.merge_images(image_batch, sampled_image_batch) 195 + merged = self.merge_images(batch_images, sampled_batch_images)
158 - path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size)) 196 + path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size))
159 scipy.misc.imsave(path, merged) 197 scipy.misc.imsave(path, merged)
160 print ('saved %s' %path) 198 print ('saved %s' %path)
......