yunjey

domain transfer network

Showing 1 changed file with 163 additions and 176 deletions
import tensorflow as tf
from ops import *
from config import *
import tensorflow.contrib.slim as slim
class DTN(object):
"""Domain Transfer Network for unsupervised cross-domain image generation
Construct discriminator and generator to prepare for training.
"""
def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
"""Domain Transfer Network
"""
Args:
learning_rate: (optional) learning rate for discriminator and generator
image_size: (optional) spatial size of input image for discriminator
output_size: (optional) spatial size of image generated by generator
dim_color: (optional) dimension of image color; default is 3 for rgb
dim_fout: (optional) dimension of z (random input vector for generator)
dim_df: (optional) dimension of discriminator's filter in first convolution layer
dim_gf: (optional) dimension of generator's filter in last convolution layer
dim_ff: (optional) dimension of function f's filter in first convolution layer
"""
# hyper parameters
self.batch_size = batch_size
def __init__(self, mode='train', learning_rate=0.0003):
self.mode = mode
self.learning_rate = learning_rate
self.image_size = image_size
self.output_size = output_size
self.dim_color = dim_color
self.dim_fout = dim_fout
self.dim_df = dim_df
self.dim_gf = dim_gf
self.dim_ff = dim_ff
# placeholder
self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images')
#self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator')
# batch normalization layer for discriminator, generator and funtion f
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
self.d_bn3 = batch_norm(name='d_bn3')
self.d_bn4 = batch_norm(name='d_bn4')
self.g_bn1 = batch_norm(name='g_bn1')
self.g_bn2 = batch_norm(name='g_bn2')
self.g_bn3 = batch_norm(name='g_bn3')
self.g_bn4 = batch_norm(name='g_bn4')
self.f_bn1 = batch_norm(name='f_bn1')
self.f_bn2 = batch_norm(name='f_bn2')
self.f_bn3 = batch_norm(name='f_bn3')
self.f_bn4 = batch_norm(name='f_bn4')
def function_f(self, images, reuse=False, train=True):
"""f consistancy
Args:
images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color)
Returns:
out: output vectors, of shape (batch_size, dim_f_out)
"""
with tf.variable_scope('function_f', reuse=reuse):
h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
h4 = tf.reshape(h4, [self.batch_size,-1])
out = linear(h4, self.dim_fout, name='f_out')
return tf.nn.tanh(out)
def generator(self, z, reuse=False):
"""Generator: Deconvolutional neural network with relu activations.
Last deconv layer does not use batch normalization.
Args:
z: random input vectors, of shape (batch_size, dim_z)
Returns:
out: generated images, of shape (batch_size, image_size, image_size, dim_color)
"""
if reuse:
train = False
else:
train = True
def content_extractor(self, images, reuse=False):
# images: (batch, 32, 32, 3) or (batch, 32, 32, 1)
if images.get_shape()[3] == 1:
# For mnist dataset, replicate the gray scale image 3 times.
images = tf.image.grayscale_to_rgb(images)
with tf.variable_scope('content_extractor', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')):
net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128)
net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4')
if self.mode == 'pretrain':
net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
net = slim.flatten(net)
return net
def generator(self, inputs, reuse=False):
# inputs: (batch, 1, 1, 128)
with tf.variable_scope('generator', reuse=reuse):
# spatial size for convolution
s = self.output_size
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4
# project and reshape z
h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
h1 = relu(self.g_bn1(h1, train=train))
h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256)
h2 = relu(self.g_bn2(h2, train=train))
h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128)
h3 = relu(self.g_bn3(h3, train=train))
h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64)
h4 = relu(self.g_bn4(h4, train=train))
out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color)
return tf.nn.tanh(out)
with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1)
return net
def discriminator(self, images, reuse=False):
"""Discrimator: Convolutional neural network with leaky relu activations.
# images: (batch, 32, 32, 1)
with tf.variable_scope('discriminator', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1)
net = slim.flatten(net)
return net
First conv layer does not use batch normalization.
def build_model(self):
Args:
images: real or fake images of shape (batch_size, image_size, image_size, dim_color)
if self.mode == 'pretrain':
self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
self.labels = tf.placeholder(tf.int64, [None], 'svhn_labels')
Returns:
out: scores for whether it is a real image or a fake image, of shape (batch_size,)
"""
with tf.variable_scope('discriminator', reuse=reuse):
# logits and accuracy
self.logits = self.content_extractor(self.images)
self.pred = tf.argmax(self.logits, 1)
self.correct_pred = tf.equal(self.pred, self.labels)
self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
# convolution layer
h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
# loss and train op
self.loss = slim.losses.sparse_softmax_cross_entropy(self.logits, self.labels)
self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = slim.learning.create_train_op(self.loss, self.optimizer)
# fully connected layer
h4 = tf.reshape(h4, [self.batch_size, -1])
out = linear(h4, 1, name='d_out') # (batch_size,)
# summary op
loss_summary = tf.summary.scalar('classification_loss', self.loss)
accuracy_summary = tf.summary.scalar('accuracy', self.accuracy)
self.summary_op = tf.summary.merge([loss_summary, accuracy_summary])
return out
elif self.mode == 'eval':
self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
# source domain (svhn to mnist)
self.fx = self.content_extractor(self.images)
self.sampled_images = self.generator(self.fx)
def build_model(self):
elif self.mode == 'train':
self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images')
# construct generator and discriminator for training phase
self.f_x = self.function_f(self.images)
self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3)
self.logits_real = self.discriminator(self.images) # (batch_size,)
self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
# construct generator for test phase (use moving average and variance for batch norm)
self.f_x = self.function_f(self.images, reuse=True, train=False)
self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
# source domain (svhn to mnist)
with tf.name_scope('model_for_source_domain'):
self.fx = self.content_extractor(self.src_images)
self.fake_images = self.generator(self.fx)
self.logits = self.discriminator(self.fake_images)
self.fgfx = self.content_extractor(self.fake_images, reuse=True)
# loss
self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0
# compute loss
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
# optimizer
self.d_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
self.g_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
self.f_optimizer_src = tf.train.AdamOptimizer(self.learning_rate)
# divide variables for discriminator and generator
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
self.g_vars = [var for var in t_vars if 'generator' in var.name]
self.f_vars = [var for var in t_vars if 'function_f' in var.name]
# optimizer for discriminator and generator
with tf.name_scope('optimizer'):
self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)
self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
# summary ops for tensorboard visualization
scalar_summary('d_loss_real', self.d_loss_real)
scalar_summary('d_loss_fake', self.d_loss_fake)
scalar_summary('d_loss', self.d_loss)
scalar_summary('g_loss', self.g_loss)
scalar_summary('g_const_loss', self.g_const_loss)
scalar_summary('f_const_loss', self.f_const_loss)
try:
image_summary('original_images', self.images, max_outputs=4)
image_summary('sampled_images', self.sampled_images, max_outputs=4)
except:
image_summary('original_images', self.images, max_images=4)
image_summary('sampled_images', self.sampled_images, max_images=4)
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
f_vars = [var for var in t_vars if 'content_extractor' in var.name]
# train op
with tf.name_scope('source_train_op'):
self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
# summary op
d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src)
g_loss_src_summary = tf.summary.scalar('src_g_loss', self.g_loss_src)
f_loss_src_summary = tf.summary.scalar('src_f_loss', self.f_loss_src)
origin_images_summary = tf.summary.image('src_origin_images', self.src_images)
sampled_images_summary = tf.summary.image('src_sampled_images', self.fake_images)
self.summary_op_src = tf.summary.merge([d_loss_src_summary, g_loss_src_summary,
f_loss_src_summary, origin_images_summary,
sampled_images_summary])
# target domain (mnist)
with tf.name_scope('model_for_target_domain'):
self.fx = self.content_extractor(self.trg_images, reuse=True)
self.reconst_images = self.generator(self.fx, reuse=True)
self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
self.logits_real = self.discriminator(self.trg_images, reuse=True)
# loss
self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real))
self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg
self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake))
self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0
self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg
# optimizer
self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
f_vars = [var for var in t_vars if 'content_extractor' in var.name]
# train op
with tf.name_scope('target_train_op'):
self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
# summary op
d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg)
d_loss_real_trg_summary = tf.summary.scalar('trg_d_loss_real', self.d_loss_real_trg)
d_loss_trg_summary = tf.summary.scalar('trg_d_loss', self.d_loss_trg)
g_loss_fake_trg_summary = tf.summary.scalar('trg_g_loss_fake', self.g_loss_fake_trg)
g_loss_const_trg_summary = tf.summary.scalar('trg_g_loss_const', self.g_loss_const_trg)
g_loss_trg_summary = tf.summary.scalar('trg_g_loss', self.g_loss_trg)
origin_images_summary = tf.summary.image('trg_origin_images', self.trg_images)
sampled_images_summary = tf.summary.image('trg_reconstructed_images', self.reconst_images)
self.summary_op_trg = tf.summary.merge([d_loss_trg_summary, g_loss_trg_summary,
d_loss_fake_trg_summary, d_loss_real_trg_summary,
g_loss_fake_trg_summary, g_loss_const_trg_summary,
origin_images_summary, sampled_images_summary])
for var in tf.trainable_variables():
histogram_summary(var.op.name, var)
self.summary_op = merge_summary()
tf.summary.histogram(var.op.name, var)
\ No newline at end of file
self.saver = tf.train.Saver()
\ No newline at end of file
......