yunjey

domain transfer network

Showing 1 changed file with 169 additions and 182 deletions
1 import tensorflow as tf 1 import tensorflow as tf
2 -from ops import * 2 +import tensorflow.contrib.slim as slim
3 -from config import *
4 3
5 4
6 class DTN(object): 5 class DTN(object):
7 - """Domain Transfer Network for unsupervised cross-domain image generation 6 + """Domain Transfer Network
8 -
9 - Construct discriminator and generator to prepare for training.
10 """ 7 """
11 - 8 + def __init__(self, mode='train', learning_rate=0.0003):
12 - def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32, 9 + self.mode = mode
13 - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
14 - """
15 - Args:
16 - learning_rate: (optional) learning rate for discriminator and generator
17 - image_size: (optional) spatial size of input image for discriminator
18 - output_size: (optional) spatial size of image generated by generator
19 - dim_color: (optional) dimension of image color; default is 3 for rgb
20 - dim_fout: (optional) dimension of z (random input vector for generator)
21 - dim_df: (optional) dimension of discriminator's filter in first convolution layer
22 - dim_gf: (optional) dimension of generator's filter in last convolution layer
23 - dim_ff: (optional) dimension of function f's filter in first convolution layer
24 - """
25 - # hyper parameters
26 - self.batch_size = batch_size
27 self.learning_rate = learning_rate 10 self.learning_rate = learning_rate
28 - self.image_size = image_size
29 - self.output_size = output_size
30 - self.dim_color = dim_color
31 - self.dim_fout = dim_fout
32 - self.dim_df = dim_df
33 - self.dim_gf = dim_gf
34 - self.dim_ff = dim_ff
35 -
36 - # placeholder
37 - self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images')
38 - #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator')
39 -
40 - # batch normalization layer for discriminator, generator and funtion f
41 - self.d_bn1 = batch_norm(name='d_bn1')
42 - self.d_bn2 = batch_norm(name='d_bn2')
43 - self.d_bn3 = batch_norm(name='d_bn3')
44 - self.d_bn4 = batch_norm(name='d_bn4')
45 -
46 - self.g_bn1 = batch_norm(name='g_bn1')
47 - self.g_bn2 = batch_norm(name='g_bn2')
48 - self.g_bn3 = batch_norm(name='g_bn3')
49 - self.g_bn4 = batch_norm(name='g_bn4')
50 -
51 - self.f_bn1 = batch_norm(name='f_bn1')
52 - self.f_bn2 = batch_norm(name='f_bn2')
53 - self.f_bn3 = batch_norm(name='f_bn3')
54 - self.f_bn4 = batch_norm(name='f_bn4')
55 -
56 11
12 + def content_extractor(self, images, reuse=False):
13 + # images: (batch, 32, 32, 3) or (batch, 32, 32, 1)
14 +
15 + if images.get_shape()[3] == 1:
16 + # For mnist dataset, replicate the gray scale image 3 times.
17 + images = tf.image.grayscale_to_rgb(images)
18 +
19 + with tf.variable_scope('content_extractor', reuse=reuse):
20 + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
21 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
22 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
23 + activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')):
24 +
25 + net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64)
26 + net = slim.batch_norm(net, scope='bn1')
27 + net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128)
28 + net = slim.batch_norm(net, scope='bn2')
29 + net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256)
30 + net = slim.batch_norm(net, scope='bn3')
31 + net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128)
32 + net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4')
33 + if self.mode == 'pretrain':
34 + net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
35 + net = slim.flatten(net)
36 +
37 + return net
38 +
39 + def generator(self, inputs, reuse=False):
40 + # inputs: (batch, 1, 1, 128)
41 + with tf.variable_scope('generator', reuse=reuse):
42 + with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
43 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
44 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
45 + activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
46 +
47 + net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512)
48 + net = slim.batch_norm(net, scope='bn1')
49 + net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256)
50 + net = slim.batch_norm(net, scope='bn2')
51 + net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128)
52 + net = slim.batch_norm(net, scope='bn3')
53 + net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1)
54 + return net
55 +
56 + def discriminator(self, images, reuse=False):
57 + # images: (batch, 32, 32, 1)
58 + with tf.variable_scope('discriminator', reuse=reuse):
59 + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
60 + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
61 + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
62 + activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
63 +
64 + net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128)
65 + net = slim.batch_norm(net, scope='bn1')
66 + net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256)
67 + net = slim.batch_norm(net, scope='bn2')
68 + net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512)
69 + net = slim.batch_norm(net, scope='bn3')
70 + net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1)
71 + net = slim.flatten(net)
72 + return net
73 +
74 + def build_model(self):
57 75
58 - def function_f(self, images, reuse=False, train=True): 76 + if self.mode == 'pretrain':
59 - """f consistancy 77 + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
60 - 78 + self.labels = tf.placeholder(tf.int64, [None], 'svhn_labels')
61 - Args:
62 - images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color)
63 79
64 - Returns: 80 + # logits and accuracy
65 - out: output vectors, of shape (batch_size, dim_f_out) 81 + self.logits = self.content_extractor(self.images)
66 - """ 82 + self.pred = tf.argmax(self.logits, 1)
67 - with tf.variable_scope('function_f', reuse=reuse): 83 + self.correct_pred = tf.equal(self.pred, self.labels)
68 - h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64) 84 + self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
69 - h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
70 - h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
71 - h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
72 85
73 - h4 = tf.reshape(h4, [self.batch_size,-1]) 86 + # loss and train op
74 - out = linear(h4, self.dim_fout, name='f_out') 87 + self.loss = slim.losses.sparse_softmax_cross_entropy(self.logits, self.labels)
75 - 88 + self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
76 - return tf.nn.tanh(out) 89 + self.train_op = slim.learning.create_train_op(self.loss, self.optimizer)
77 -
78 -
79 - def generator(self, z, reuse=False):
80 - """Generator: Deconvolutional neural network with relu activations.
81 -
82 - Last deconv layer does not use batch normalization.
83 -
84 - Args:
85 - z: random input vectors, of shape (batch_size, dim_z)
86 90
87 - Returns: 91 + # summary op
88 - out: generated images, of shape (batch_size, image_size, image_size, dim_color) 92 + loss_summary = tf.summary.scalar('classification_loss', self.loss)
89 - """ 93 + accuracy_summary = tf.summary.scalar('accuracy', self.accuracy)
90 - if reuse: 94 + self.summary_op = tf.summary.merge([loss_summary, accuracy_summary])
91 - train = False 95 +
92 - else: 96 + elif self.mode == 'eval':
93 - train = True 97 + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
94 - 98 +
95 - with tf.variable_scope('generator', reuse=reuse): 99 + # source domain (svhn to mnist)
100 + self.fx = self.content_extractor(self.images)
101 + self.sampled_images = self.generator(self.fx)
102 +
103 + elif self.mode == 'train':
104 + self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
105 + self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images')
96 106
97 - # spatial size for convolution
98 - s = self.output_size
99 - s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4
100 107
101 - # project and reshape z 108 + # source domain (svhn to mnist)
102 - h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) 109 + with tf.name_scope('model_for_source_domain'):
103 - h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) 110 + self.fx = self.content_extractor(self.src_images)
104 - h1 = relu(self.g_bn1(h1, train=train)) 111 + self.fake_images = self.generator(self.fx)
112 + self.logits = self.discriminator(self.fake_images)
113 + self.fgfx = self.content_extractor(self.fake_images, reuse=True)
105 114
106 - h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) 115 + # loss
107 - h2 = relu(self.g_bn2(h2, train=train)) 116 + self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
117 + self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
118 + self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0
108 119
109 - h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128) 120 + # optimizer
110 - h3 = relu(self.g_bn3(h3, train=train)) 121 + self.d_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
122 + self.g_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
123 + self.f_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
111 124
112 - h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64) 125 + t_vars = tf.trainable_variables()
113 - h4 = relu(self.g_bn4(h4, train=train)) 126 + d_vars = [var for var in t_vars if 'discriminator' in var.name]
127 + g_vars = [var for var in t_vars if 'generator' in var.name]
128 + f_vars = [var for var in t_vars if 'content_extractor' in var.name]
114 129
115 - out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color) 130 + # train op
131 + with tf.name_scope('source_train_op'):
132 + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
133 + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
134 + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
116 135
117 - return tf.nn.tanh(out) 136 + # summary op
118 - 137 + d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src)
119 - 138 + g_loss_src_summary = tf.summary.scalar('src_g_loss', self.g_loss_src)
120 - def discriminator(self, images, reuse=False): 139 + f_loss_src_summary = tf.summary.scalar('src_f_loss', self.f_loss_src)
121 - """Discrimator: Convolutional neural network with leaky relu activations. 140 + origin_images_summary = tf.summary.image('src_origin_images', self.src_images)
122 - 141 + sampled_images_summary = tf.summary.image('src_sampled_images', self.fake_images)
123 - First conv layer does not use batch normalization. 142 + self.summary_op_src = tf.summary.merge([d_loss_src_summary, g_loss_src_summary,
124 - 143 + f_loss_src_summary, origin_images_summary,
125 - Args: 144 + sampled_images_summary])
126 - images: real or fake images of shape (batch_size, image_size, image_size, dim_color)
127 -
128 - Returns:
129 - out: scores for whether it is a real image or a fake image, of shape (batch_size,)
130 - """
131 - with tf.variable_scope('discriminator', reuse=reuse):
132 -
133 - # convolution layer
134 - h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
135 - h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
136 - h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
137 - h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
138 -
139 - # fully connected layer
140 - h4 = tf.reshape(h4, [self.batch_size, -1])
141 - out = linear(h4, 1, name='d_out') # (batch_size,)
142 -
143 - return out
144 -
145 -
146 - def build_model(self):
147 -
148 - # construct generator and discriminator for training phase
149 - self.f_x = self.function_f(self.images)
150 - self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3)
151 - self.logits_real = self.discriminator(self.images) # (batch_size,)
152 - self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
153 - self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
154 -
155 - # construct generator for test phase (use moving average and variance for batch norm)
156 - self.f_x = self.function_f(self.images, reuse=True, train=False)
157 - self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
158 -
159 -
160 - # compute loss
161 - self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real)))
162 - self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))
163 - self.d_loss = self.d_loss_real + self.d_loss_fake
164 - self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
165 - self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
166 - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
167 -
168 - # divide variables for discriminator and generator
169 - t_vars = tf.trainable_variables()
170 - self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
171 - self.g_vars = [var for var in t_vars if 'generator' in var.name]
172 - self.f_vars = [var for var in t_vars if 'function_f' in var.name]
173 -
174 - # optimizer for discriminator and generator
175 - with tf.name_scope('optimizer'):
176 - self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
177 - self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
178 - self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)
179 - self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
180 - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
181 145
146 + # target domain (mnist)
147 + with tf.name_scope('model_for_target_domain'):
148 + self.fx = self.content_extractor(self.trg_images, reuse=True)
149 + self.reconst_images = self.generator(self.fx, reuse=True)
150 + self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
151 + self.logits_real = self.discriminator(self.trg_images, reuse=True)
182 152
183 - # summary ops for tensorboard visualization 153 + # loss
184 - scalar_summary('d_loss_real', self.d_loss_real) 154 + self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
185 - scalar_summary('d_loss_fake', self.d_loss_fake) 155 + self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real))
186 - scalar_summary('d_loss', self.d_loss) 156 + self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg
187 - scalar_summary('g_loss', self.g_loss) 157 + self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake))
188 - scalar_summary('g_const_loss', self.g_const_loss) 158 + self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0
189 - scalar_summary('f_const_loss', self.f_const_loss) 159 + self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg
190 -
191 - try:
192 - image_summary('original_images', self.images, max_outputs=4)
193 - image_summary('sampled_images', self.sampled_images, max_outputs=4)
194 - except:
195 - image_summary('original_images', self.images, max_images=4)
196 - image_summary('sampled_images', self.sampled_images, max_images=4)
197 -
198 - for var in tf.trainable_variables():
199 - histogram_summary(var.op.name, var)
200 160
201 - self.summary_op = merge_summary()
202 -
203 - self.saver = tf.train.Saver()
...\ No newline at end of file ...\ No newline at end of file
161 + # optimizer
162 + self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
163 + self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
164 +
165 + t_vars = tf.trainable_variables()
166 + d_vars = [var for var in t_vars if 'discriminator' in var.name]
167 + g_vars = [var for var in t_vars if 'generator' in var.name]
168 + f_vars = [var for var in t_vars if 'content_extractor' in var.name]
169 +
170 + # train op
171 + with tf.name_scope('target_train_op'):
172 + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
173 + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
174 +
175 + # summary op
176 + d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg)
177 + d_loss_real_trg_summary = tf.summary.scalar('trg_d_loss_real', self.d_loss_real_trg)
178 + d_loss_trg_summary = tf.summary.scalar('trg_d_loss', self.d_loss_trg)
179 + g_loss_fake_trg_summary = tf.summary.scalar('trg_g_loss_fake', self.g_loss_fake_trg)
180 + g_loss_const_trg_summary = tf.summary.scalar('trg_g_loss_const', self.g_loss_const_trg)
181 + g_loss_trg_summary = tf.summary.scalar('trg_g_loss', self.g_loss_trg)
182 + origin_images_summary = tf.summary.image('trg_origin_images', self.trg_images)
183 + sampled_images_summary = tf.summary.image('trg_reconstructed_images', self.reconst_images)
184 + self.summary_op_trg = tf.summary.merge([d_loss_trg_summary, g_loss_trg_summary,
185 + d_loss_fake_trg_summary, d_loss_real_trg_summary,
186 + g_loss_fake_trg_summary, g_loss_const_trg_summary,
187 + origin_images_summary, sampled_images_summary])
188 + for var in tf.trainable_variables():
189 + tf.summary.histogram(var.op.name, var)
190 +
...\ No newline at end of file ...\ No newline at end of file
......