yunjey

domain transfer network

Showing 1 changed file with 163 additions and 176 deletions
1 import tensorflow as tf 1 import tensorflow as tf
2 -from ops import * 2 +import tensorflow.contrib.slim as slim
3 -from config import *
4 3
5 4
6 class DTN(object): 5 class DTN(object):
7 - """Domain Transfer Network for unsupervised cross-domain image generation 6 + """Domain Transfer Network
8 -
9 - Construct discriminator and generator to prepare for training.
10 - """
11 -
12 - def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
13 - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
14 """ 7 """
15 - Args: 8 + def __init__(self, mode='train', learning_rate=0.0003):
16 - learning_rate: (optional) learning rate for discriminator and generator 9 + self.mode = mode
17 - image_size: (optional) spatial size of input image for discriminator
18 - output_size: (optional) spatial size of image generated by generator
19 - dim_color: (optional) dimension of image color; default is 3 for rgb
20 - dim_fout: (optional) dimension of z (random input vector for generator)
21 - dim_df: (optional) dimension of discriminator's filter in first convolution layer
22 - dim_gf: (optional) dimension of generator's filter in last convolution layer
23 - dim_ff: (optional) dimension of function f's filter in first convolution layer
24 - """
25 - # hyper parameters
26 - self.batch_size = batch_size
27 self.learning_rate = learning_rate 10 self.learning_rate = learning_rate
28 - self.image_size = image_size
29 - self.output_size = output_size
30 - self.dim_color = dim_color
31 - self.dim_fout = dim_fout
32 - self.dim_df = dim_df
33 - self.dim_gf = dim_gf
34 - self.dim_ff = dim_ff
35 -
36 - # placeholder
37 - self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images')
38 - #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator')
39 -
40 - # batch normalization layer for discriminator, generator and funtion f
41 - self.d_bn1 = batch_norm(name='d_bn1')
42 - self.d_bn2 = batch_norm(name='d_bn2')
43 - self.d_bn3 = batch_norm(name='d_bn3')
44 - self.d_bn4 = batch_norm(name='d_bn4')
45 -
46 - self.g_bn1 = batch_norm(name='g_bn1')
47 - self.g_bn2 = batch_norm(name='g_bn2')
48 - self.g_bn3 = batch_norm(name='g_bn3')
49 - self.g_bn4 = batch_norm(name='g_bn4')
50 -
51 - self.f_bn1 = batch_norm(name='f_bn1')
52 - self.f_bn2 = batch_norm(name='f_bn2')
53 - self.f_bn3 = batch_norm(name='f_bn3')
54 - self.f_bn4 = batch_norm(name='f_bn4')
55 -
56 -
57 -
58 - def function_f(self, images, reuse=False, train=True):
59 - """f consistancy
60 -
61 - Args:
62 - images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color)
63 -
64 - Returns:
65 - out: output vectors, of shape (batch_size, dim_f_out)
66 - """
67 - with tf.variable_scope('function_f', reuse=reuse):
68 - h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
69 - h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
70 - h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
71 - h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
72 -
73 - h4 = tf.reshape(h4, [self.batch_size,-1])
74 - out = linear(h4, self.dim_fout, name='f_out')
75 -
76 - return tf.nn.tanh(out)
77 -
78 -
79 - def generator(self, z, reuse=False):
80 - """Generator: Deconvolutional neural network with relu activations.
81 -
82 - Last deconv layer does not use batch normalization.
83 -
84 - Args:
85 - z: random input vectors, of shape (batch_size, dim_z)
86 -
87 - Returns:
88 - out: generated images, of shape (batch_size, image_size, image_size, dim_color)
89 - """
90 - if reuse:
91 - train = False
92 - else:
93 - train = True
94 11
12 + def content_extractor(self, images, reuse=False):
13 + # images: (batch, 32, 32, 3) or (batch, 32, 32, 1)
14 +
15 + if images.get_shape()[3] == 1:
16 + # For mnist dataset, replicate the gray scale image 3 times.
17 + images = tf.image.grayscale_to_rgb(images)
18 +
19 + with tf.variable_scope('content_extractor', reuse=reuse):
20 + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
21 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
22 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
23 + activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')):
24 +
25 + net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64)
26 + net = slim.batch_norm(net, scope='bn1')
27 + net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128)
28 + net = slim.batch_norm(net, scope='bn2')
29 + net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256)
30 + net = slim.batch_norm(net, scope='bn3')
31 + net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128)
32 + net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4')
33 + if self.mode == 'pretrain':
34 + net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
35 + net = slim.flatten(net)
36 +
37 + return net
38 +
39 + def generator(self, inputs, reuse=False):
40 + # inputs: (batch, 1, 1, 128)
95 with tf.variable_scope('generator', reuse=reuse): 41 with tf.variable_scope('generator', reuse=reuse):
96 - 42 + with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
97 - # spatial size for convolution 43 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
98 - s = self.output_size 44 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
99 - s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4 45 + activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
100 - 46 +
101 - # project and reshape z 47 + net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512)
102 - h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) 48 + net = slim.batch_norm(net, scope='bn1')
103 - h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) 49 + net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256)
104 - h1 = relu(self.g_bn1(h1, train=train)) 50 + net = slim.batch_norm(net, scope='bn2')
105 - 51 + net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128)
106 - h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) 52 + net = slim.batch_norm(net, scope='bn3')
107 - h2 = relu(self.g_bn2(h2, train=train)) 53 + net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1)
108 - 54 + return net
109 - h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128)
110 - h3 = relu(self.g_bn3(h3, train=train))
111 -
112 - h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64)
113 - h4 = relu(self.g_bn4(h4, train=train))
114 -
115 - out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color)
116 -
117 - return tf.nn.tanh(out)
118 -
119 55
120 def discriminator(self, images, reuse=False): 56 def discriminator(self, images, reuse=False):
121 - """Discrimator: Convolutional neural network with leaky relu activations. 57 + # images: (batch, 32, 32, 1)
58 + with tf.variable_scope('discriminator', reuse=reuse):
59 + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
60 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
61 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
62 + activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
63 +
64 + net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128)
65 + net = slim.batch_norm(net, scope='bn1')
66 + net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256)
67 + net = slim.batch_norm(net, scope='bn2')
68 + net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512)
69 + net = slim.batch_norm(net, scope='bn3')
70 + net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1)
71 + net = slim.flatten(net)
72 + return net
122 73
123 - First conv layer does not use batch normalization. 74 + def build_model(self):
124 75
125 - Args: 76 + if self.mode == 'pretrain':
126 - images: real or fake images of shape (batch_size, image_size, image_size, dim_color) 77 + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
78 + self.labels = tf.placeholder(tf.int64, [None], 'svhn_labels')
127 79
128 - Returns: 80 + # logits and accuracy
129 - out: scores for whether it is a real image or a fake image, of shape (batch_size,) 81 + self.logits = self.content_extractor(self.images)
130 - """ 82 + self.pred = tf.argmax(self.logits, 1)
131 - with tf.variable_scope('discriminator', reuse=reuse): 83 + self.correct_pred = tf.equal(self.pred, self.labels)
84 + self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
132 85
133 - # convolution layer 86 + # loss and train op
134 - h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64) 87 + self.loss = slim.losses.sparse_softmax_cross_entropy(self.logits, self.labels)
135 - h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) 88 + self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
136 - h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) 89 + self.train_op = slim.learning.create_train_op(self.loss, self.optimizer)
137 - h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
138 90
139 - # fully connected layer 91 + # summary op
140 - h4 = tf.reshape(h4, [self.batch_size, -1]) 92 + loss_summary = tf.summary.scalar('classification_loss', self.loss)
141 - out = linear(h4, 1, name='d_out') # (batch_size,) 93 + accuracy_summary = tf.summary.scalar('accuracy', self.accuracy)
94 + self.summary_op = tf.summary.merge([loss_summary, accuracy_summary])
142 95
143 - return out 96 + elif self.mode == 'eval':
97 + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
144 98
99 + # source domain (svhn to mnist)
100 + self.fx = self.content_extractor(self.images)
101 + self.sampled_images = self.generator(self.fx)
145 102
146 - def build_model(self): 103 + elif self.mode == 'train':
104 + self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
105 + self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images')
147 106
148 - # construct generator and discriminator for training phase
149 - self.f_x = self.function_f(self.images)
150 - self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3)
151 - self.logits_real = self.discriminator(self.images) # (batch_size,)
152 - self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
153 - self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
154 107
155 - # construct generator for test phase (use moving average and variance for batch norm) 108 + # source domain (svhn to mnist)
156 - self.f_x = self.function_f(self.images, reuse=True, train=False) 109 + with tf.name_scope('model_for_source_domain'):
157 - self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) 110 + self.fx = self.content_extractor(self.src_images)
111 + self.fake_images = self.generator(self.fx)
112 + self.logits = self.discriminator(self.fake_images)
113 + self.fgfx = self.content_extractor(self.fake_images, reuse=True)
158 114
115 + # loss
116 + self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
117 + self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
118 + self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0
159 119
160 - # compute loss 120 + # optimizer
161 - self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real))) 121 + self.d_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
162 - self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake))) 122 + self.g_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
163 - self.d_loss = self.d_loss_real + self.d_loss_fake 123 + self.f_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
164 - self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
165 - self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
166 - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
167 124
168 - # divide variables for discriminator and generator
169 t_vars = tf.trainable_variables() 125 t_vars = tf.trainable_variables()
170 - self.d_vars = [var for var in t_vars if 'discriminator' in var.name] 126 + d_vars = [var for var in t_vars if 'discriminator' in var.name]
171 - self.g_vars = [var for var in t_vars if 'generator' in var.name] 127 + g_vars = [var for var in t_vars if 'generator' in var.name]
172 - self.f_vars = [var for var in t_vars if 'function_f' in var.name] 128 + f_vars = [var for var in t_vars if 'content_extractor' in var.name]
173 - 129 +
174 - # optimizer for discriminator and generator 130 + # train op
175 - with tf.name_scope('optimizer'): 131 + with tf.name_scope('source_train_op'):
176 - self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars) 132 + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
177 - self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) 133 + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
178 - self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) 134 + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
179 - self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) 135 +
180 - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars) 136 + # summary op
181 - 137 + d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src)
182 - 138 + g_loss_src_summary = tf.summary.scalar('src_g_loss', self.g_loss_src)
183 - # summary ops for tensorboard visualization 139 + f_loss_src_summary = tf.summary.scalar('src_f_loss', self.f_loss_src)
184 - scalar_summary('d_loss_real', self.d_loss_real) 140 + origin_images_summary = tf.summary.image('src_origin_images', self.src_images)
185 - scalar_summary('d_loss_fake', self.d_loss_fake) 141 + sampled_images_summary = tf.summary.image('src_sampled_images', self.fake_images)
186 - scalar_summary('d_loss', self.d_loss) 142 + self.summary_op_src = tf.summary.merge([d_loss_src_summary, g_loss_src_summary,
187 - scalar_summary('g_loss', self.g_loss) 143 + f_loss_src_summary, origin_images_summary,
188 - scalar_summary('g_const_loss', self.g_const_loss) 144 + sampled_images_summary])
189 - scalar_summary('f_const_loss', self.f_const_loss) 145 +
190 - 146 + # target domain (mnist)
191 - try: 147 + with tf.name_scope('model_for_target_domain'):
192 - image_summary('original_images', self.images, max_outputs=4) 148 + self.fx = self.content_extractor(self.trg_images, reuse=True)
193 - image_summary('sampled_images', self.sampled_images, max_outputs=4) 149 + self.reconst_images = self.generator(self.fx, reuse=True)
194 - except: 150 + self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
195 - image_summary('original_images', self.images, max_images=4) 151 + self.logits_real = self.discriminator(self.trg_images, reuse=True)
196 - image_summary('sampled_images', self.sampled_images, max_images=4) 152 +
153 + # loss
154 + self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
155 + self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real))
156 + self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg
157 + self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake))
158 + self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0
159 + self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg
160 +
161 + # optimizer
162 + self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
163 + self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
197 164
165 + t_vars = tf.trainable_variables()
166 + d_vars = [var for var in t_vars if 'discriminator' in var.name]
167 + g_vars = [var for var in t_vars if 'generator' in var.name]
168 + f_vars = [var for var in t_vars if 'content_extractor' in var.name]
169 +
170 + # train op
171 + with tf.name_scope('target_train_op'):
172 + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
173 + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
174 +
175 + # summary op
176 + d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg)
177 + d_loss_real_trg_summary = tf.summary.scalar('trg_d_loss_real', self.d_loss_real_trg)
178 + d_loss_trg_summary = tf.summary.scalar('trg_d_loss', self.d_loss_trg)
179 + g_loss_fake_trg_summary = tf.summary.scalar('trg_g_loss_fake', self.g_loss_fake_trg)
180 + g_loss_const_trg_summary = tf.summary.scalar('trg_g_loss_const', self.g_loss_const_trg)
181 + g_loss_trg_summary = tf.summary.scalar('trg_g_loss', self.g_loss_trg)
182 + origin_images_summary = tf.summary.image('trg_origin_images', self.trg_images)
183 + sampled_images_summary = tf.summary.image('trg_reconstructed_images', self.reconst_images)
184 + self.summary_op_trg = tf.summary.merge([d_loss_trg_summary, g_loss_trg_summary,
185 + d_loss_fake_trg_summary, d_loss_real_trg_summary,
186 + g_loss_fake_trg_summary, g_loss_const_trg_summary,
187 + origin_images_summary, sampled_images_summary])
198 for var in tf.trainable_variables(): 188 for var in tf.trainable_variables():
199 - histogram_summary(var.op.name, var) 189 + tf.summary.histogram(var.op.name, var)
200 -
201 - self.summary_op = merge_summary()
202 190
...\ No newline at end of file ...\ No newline at end of file
203 - self.saver = tf.train.Saver()
...\ No newline at end of file ...\ No newline at end of file
......