Showing
1 changed file
with
163 additions
and
176 deletions
1 | import tensorflow as tf | 1 | import tensorflow as tf |
2 | -from ops import * | 2 | +import tensorflow.contrib.slim as slim |
3 | -from config import * | ||
4 | 3 | ||
5 | 4 | ||
6 | class DTN(object): | 5 | class DTN(object): |
7 | - """Domain Transfer Network for unsupervised cross-domain image generation | 6 | + """Domain Transfer Network |
8 | - | ||
9 | - Construct discriminator and generator to prepare for training. | ||
10 | - """ | ||
11 | - | ||
12 | - def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32, | ||
13 | - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): | ||
14 | """ | 7 | """ |
15 | - Args: | 8 | + def __init__(self, mode='train', learning_rate=0.0003): |
16 | - learning_rate: (optional) learning rate for discriminator and generator | 9 | + self.mode = mode |
17 | - image_size: (optional) spatial size of input image for discriminator | ||
18 | - output_size: (optional) spatial size of image generated by generator | ||
19 | - dim_color: (optional) dimension of image color; default is 3 for rgb | ||
20 | - dim_fout: (optional) dimension of z (random input vector for generator) | ||
21 | - dim_df: (optional) dimension of discriminator's filter in first convolution layer | ||
22 | - dim_gf: (optional) dimension of generator's filter in last convolution layer | ||
23 | - dim_ff: (optional) dimension of function f's filter in first convolution layer | ||
24 | - """ | ||
25 | - # hyper parameters | ||
26 | - self.batch_size = batch_size | ||
27 | self.learning_rate = learning_rate | 10 | self.learning_rate = learning_rate |
28 | - self.image_size = image_size | ||
29 | - self.output_size = output_size | ||
30 | - self.dim_color = dim_color | ||
31 | - self.dim_fout = dim_fout | ||
32 | - self.dim_df = dim_df | ||
33 | - self.dim_gf = dim_gf | ||
34 | - self.dim_ff = dim_ff | ||
35 | - | ||
36 | - # placeholder | ||
37 | - self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images') | ||
38 | - #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator') | ||
39 | - | ||
40 | - # batch normalization layer for discriminator, generator and funtion f | ||
41 | - self.d_bn1 = batch_norm(name='d_bn1') | ||
42 | - self.d_bn2 = batch_norm(name='d_bn2') | ||
43 | - self.d_bn3 = batch_norm(name='d_bn3') | ||
44 | - self.d_bn4 = batch_norm(name='d_bn4') | ||
45 | - | ||
46 | - self.g_bn1 = batch_norm(name='g_bn1') | ||
47 | - self.g_bn2 = batch_norm(name='g_bn2') | ||
48 | - self.g_bn3 = batch_norm(name='g_bn3') | ||
49 | - self.g_bn4 = batch_norm(name='g_bn4') | ||
50 | - | ||
51 | - self.f_bn1 = batch_norm(name='f_bn1') | ||
52 | - self.f_bn2 = batch_norm(name='f_bn2') | ||
53 | - self.f_bn3 = batch_norm(name='f_bn3') | ||
54 | - self.f_bn4 = batch_norm(name='f_bn4') | ||
55 | - | ||
56 | - | ||
57 | - | ||
58 | - def function_f(self, images, reuse=False, train=True): | ||
59 | - """f consistancy | ||
60 | - | ||
61 | - Args: | ||
62 | - images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color) | ||
63 | - | ||
64 | - Returns: | ||
65 | - out: output vectors, of shape (batch_size, dim_f_out) | ||
66 | - """ | ||
67 | - with tf.variable_scope('function_f', reuse=reuse): | ||
68 | - h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64) | ||
69 | - h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128) | ||
70 | - h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256) | ||
71 | - h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512) | ||
72 | - | ||
73 | - h4 = tf.reshape(h4, [self.batch_size,-1]) | ||
74 | - out = linear(h4, self.dim_fout, name='f_out') | ||
75 | - | ||
76 | - return tf.nn.tanh(out) | ||
77 | - | ||
78 | - | ||
79 | - def generator(self, z, reuse=False): | ||
80 | - """Generator: Deconvolutional neural network with relu activations. | ||
81 | - | ||
82 | - Last deconv layer does not use batch normalization. | ||
83 | - | ||
84 | - Args: | ||
85 | - z: random input vectors, of shape (batch_size, dim_z) | ||
86 | - | ||
87 | - Returns: | ||
88 | - out: generated images, of shape (batch_size, image_size, image_size, dim_color) | ||
89 | - """ | ||
90 | - if reuse: | ||
91 | - train = False | ||
92 | - else: | ||
93 | - train = True | ||
94 | 11 | ||
12 | + def content_extractor(self, images, reuse=False): | ||
13 | + # images: (batch, 32, 32, 3) or (batch, 32, 32, 1) | ||
14 | + | ||
15 | + if images.get_shape()[3] == 1: | ||
16 | + # For mnist dataset, replicate the gray scale image 3 times. | ||
17 | + images = tf.image.grayscale_to_rgb(images) | ||
18 | + | ||
19 | + with tf.variable_scope('content_extractor', reuse=reuse): | ||
20 | + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None, | ||
21 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): | ||
22 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, | ||
23 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')): | ||
24 | + | ||
25 | + net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64) | ||
26 | + net = slim.batch_norm(net, scope='bn1') | ||
27 | + net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128) | ||
28 | + net = slim.batch_norm(net, scope='bn2') | ||
29 | + net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256) | ||
30 | + net = slim.batch_norm(net, scope='bn3') | ||
31 | + net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128) | ||
32 | + net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4') | ||
33 | + if self.mode == 'pretrain': | ||
34 | + net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') | ||
35 | + net = slim.flatten(net) | ||
36 | + | ||
37 | + return net | ||
38 | + | ||
39 | + def generator(self, inputs, reuse=False): | ||
40 | + # inputs: (batch, 1, 1, 128) | ||
95 | with tf.variable_scope('generator', reuse=reuse): | 41 | with tf.variable_scope('generator', reuse=reuse): |
96 | - | 42 | + with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None, |
97 | - # spatial size for convolution | 43 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): |
98 | - s = self.output_size | 44 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, |
99 | - s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4 | 45 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train')): |
100 | - | 46 | + |
101 | - # project and reshape z | 47 | + net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512) |
102 | - h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) | 48 | + net = slim.batch_norm(net, scope='bn1') |
103 | - h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) | 49 | + net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256) |
104 | - h1 = relu(self.g_bn1(h1, train=train)) | 50 | + net = slim.batch_norm(net, scope='bn2') |
105 | - | 51 | + net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128) |
106 | - h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) | 52 | + net = slim.batch_norm(net, scope='bn3') |
107 | - h2 = relu(self.g_bn2(h2, train=train)) | 53 | + net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1) |
108 | - | 54 | + return net |
109 | - h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128) | ||
110 | - h3 = relu(self.g_bn3(h3, train=train)) | ||
111 | - | ||
112 | - h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64) | ||
113 | - h4 = relu(self.g_bn4(h4, train=train)) | ||
114 | - | ||
115 | - out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color) | ||
116 | - | ||
117 | - return tf.nn.tanh(out) | ||
118 | - | ||
119 | 55 | ||
120 | def discriminator(self, images, reuse=False): | 56 | def discriminator(self, images, reuse=False): |
121 | - """Discrimator: Convolutional neural network with leaky relu activations. | 57 | + # images: (batch, 32, 32, 1) |
58 | + with tf.variable_scope('discriminator', reuse=reuse): | ||
59 | + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None, | ||
60 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): | ||
61 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, | ||
62 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train')): | ||
63 | + | ||
64 | + net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128) | ||
65 | + net = slim.batch_norm(net, scope='bn1') | ||
66 | + net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256) | ||
67 | + net = slim.batch_norm(net, scope='bn2') | ||
68 | + net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512) | ||
69 | + net = slim.batch_norm(net, scope='bn3') | ||
70 | + net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1) | ||
71 | + net = slim.flatten(net) | ||
72 | + return net | ||
122 | 73 | ||
123 | - First conv layer does not use batch normalization. | 74 | + def build_model(self): |
124 | 75 | ||
125 | - Args: | 76 | + if self.mode == 'pretrain': |
126 | - images: real or fake images of shape (batch_size, image_size, image_size, dim_color) | 77 | + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') |
78 | + self.labels = tf.placeholder(tf.int64, [None], 'svhn_labels') | ||
127 | 79 | ||
128 | - Returns: | 80 | + # logits and accuracy |
129 | - out: scores for whether it is a real image or a fake image, of shape (batch_size,) | 81 | + self.logits = self.content_extractor(self.images) |
130 | - """ | 82 | + self.pred = tf.argmax(self.logits, 1) |
131 | - with tf.variable_scope('discriminator', reuse=reuse): | 83 | + self.correct_pred = tf.equal(self.pred, self.labels) |
84 | + self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) | ||
132 | 85 | ||
133 | - # convolution layer | 86 | + # loss and train op |
134 | - h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64) | 87 | + self.loss = slim.losses.sparse_softmax_cross_entropy(self.logits, self.labels) |
135 | - h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) | 88 | + self.optimizer = tf.train.AdamOptimizer(self.learning_rate) |
136 | - h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) | 89 | + self.train_op = slim.learning.create_train_op(self.loss, self.optimizer) |
137 | - h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) | ||
138 | 90 | ||
139 | - # fully connected layer | 91 | + # summary op |
140 | - h4 = tf.reshape(h4, [self.batch_size, -1]) | 92 | + loss_summary = tf.summary.scalar('classification_loss', self.loss) |
141 | - out = linear(h4, 1, name='d_out') # (batch_size,) | 93 | + accuracy_summary = tf.summary.scalar('accuracy', self.accuracy) |
94 | + self.summary_op = tf.summary.merge([loss_summary, accuracy_summary]) | ||
142 | 95 | ||
143 | - return out | 96 | + elif self.mode == 'eval': |
97 | + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') | ||
144 | 98 | ||
99 | + # source domain (svhn to mnist) | ||
100 | + self.fx = self.content_extractor(self.images) | ||
101 | + self.sampled_images = self.generator(self.fx) | ||
145 | 102 | ||
146 | - def build_model(self): | 103 | + elif self.mode == 'train': |
104 | + self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') | ||
105 | + self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images') | ||
147 | 106 | ||
148 | - # construct generator and discriminator for training phase | ||
149 | - self.f_x = self.function_f(self.images) | ||
150 | - self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3) | ||
151 | - self.logits_real = self.discriminator(self.images) # (batch_size,) | ||
152 | - self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) | ||
153 | - self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) | ||
154 | 107 | ||
155 | - # construct generator for test phase (use moving average and variance for batch norm) | 108 | + # source domain (svhn to mnist) |
156 | - self.f_x = self.function_f(self.images, reuse=True, train=False) | 109 | + with tf.name_scope('model_for_source_domain'): |
157 | - self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) | 110 | + self.fx = self.content_extractor(self.src_images) |
111 | + self.fake_images = self.generator(self.fx) | ||
112 | + self.logits = self.discriminator(self.fake_images) | ||
113 | + self.fgfx = self.content_extractor(self.fake_images, reuse=True) | ||
158 | 114 | ||
115 | + # loss | ||
116 | + self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits)) | ||
117 | + self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits)) | ||
118 | + self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0 | ||
159 | 119 | ||
160 | - # compute loss | 120 | + # optimizer |
161 | - self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real))) | 121 | + self.d_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) |
162 | - self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake))) | 122 | + self.g_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) |
163 | - self.d_loss = self.d_loss_real + self.d_loss_fake | 123 | + self.f_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) |
164 | - self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) | ||
165 | - self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID | ||
166 | - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST | ||
167 | 124 | ||
168 | - # divide variables for discriminator and generator | ||
169 | t_vars = tf.trainable_variables() | 125 | t_vars = tf.trainable_variables() |
170 | - self.d_vars = [var for var in t_vars if 'discriminator' in var.name] | 126 | + d_vars = [var for var in t_vars if 'discriminator' in var.name] |
171 | - self.g_vars = [var for var in t_vars if 'generator' in var.name] | 127 | + g_vars = [var for var in t_vars if 'generator' in var.name] |
172 | - self.f_vars = [var for var in t_vars if 'function_f' in var.name] | 128 | + f_vars = [var for var in t_vars if 'content_extractor' in var.name] |
173 | - | 129 | + |
174 | - # optimizer for discriminator and generator | 130 | + # train op |
175 | - with tf.name_scope('optimizer'): | 131 | + with tf.name_scope('source_train_op'): |
176 | - self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars) | 132 | + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) |
177 | - self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) | 133 | + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) |
178 | - self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) | 134 | + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) |
179 | - self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) | 135 | + |
180 | - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars) | 136 | + # summary op |
181 | - | 137 | + d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src) |
182 | - | 138 | + g_loss_src_summary = tf.summary.scalar('src_g_loss', self.g_loss_src) |
183 | - # summary ops for tensorboard visualization | 139 | + f_loss_src_summary = tf.summary.scalar('src_f_loss', self.f_loss_src) |
184 | - scalar_summary('d_loss_real', self.d_loss_real) | 140 | + origin_images_summary = tf.summary.image('src_origin_images', self.src_images) |
185 | - scalar_summary('d_loss_fake', self.d_loss_fake) | 141 | + sampled_images_summary = tf.summary.image('src_sampled_images', self.fake_images) |
186 | - scalar_summary('d_loss', self.d_loss) | 142 | + self.summary_op_src = tf.summary.merge([d_loss_src_summary, g_loss_src_summary, |
187 | - scalar_summary('g_loss', self.g_loss) | 143 | + f_loss_src_summary, origin_images_summary, |
188 | - scalar_summary('g_const_loss', self.g_const_loss) | 144 | + sampled_images_summary]) |
189 | - scalar_summary('f_const_loss', self.f_const_loss) | 145 | + |
190 | - | 146 | + # target domain (mnist) |
191 | - try: | 147 | + with tf.name_scope('model_for_target_domain'): |
192 | - image_summary('original_images', self.images, max_outputs=4) | 148 | + self.fx = self.content_extractor(self.trg_images, reuse=True) |
193 | - image_summary('sampled_images', self.sampled_images, max_outputs=4) | 149 | + self.reconst_images = self.generator(self.fx, reuse=True) |
194 | - except: | 150 | + self.logits_fake = self.discriminator(self.reconst_images, reuse=True) |
195 | - image_summary('original_images', self.images, max_images=4) | 151 | + self.logits_real = self.discriminator(self.trg_images, reuse=True) |
196 | - image_summary('sampled_images', self.sampled_images, max_images=4) | 152 | + |
153 | + # loss | ||
154 | + self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake)) | ||
155 | + self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real)) | ||
156 | + self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg | ||
157 | + self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake)) | ||
158 | + self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0 | ||
159 | + self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg | ||
160 | + | ||
161 | + # optimizer | ||
162 | + self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | ||
163 | + self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | ||
197 | 164 | ||
165 | + t_vars = tf.trainable_variables() | ||
166 | + d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
167 | + g_vars = [var for var in t_vars if 'generator' in var.name] | ||
168 | + f_vars = [var for var in t_vars if 'content_extractor' in var.name] | ||
169 | + | ||
170 | + # train op | ||
171 | + with tf.name_scope('target_train_op'): | ||
172 | + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) | ||
173 | + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) | ||
174 | + | ||
175 | + # summary op | ||
176 | + d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg) | ||
177 | + d_loss_real_trg_summary = tf.summary.scalar('trg_d_loss_real', self.d_loss_real_trg) | ||
178 | + d_loss_trg_summary = tf.summary.scalar('trg_d_loss', self.d_loss_trg) | ||
179 | + g_loss_fake_trg_summary = tf.summary.scalar('trg_g_loss_fake', self.g_loss_fake_trg) | ||
180 | + g_loss_const_trg_summary = tf.summary.scalar('trg_g_loss_const', self.g_loss_const_trg) | ||
181 | + g_loss_trg_summary = tf.summary.scalar('trg_g_loss', self.g_loss_trg) | ||
182 | + origin_images_summary = tf.summary.image('trg_origin_images', self.trg_images) | ||
183 | + sampled_images_summary = tf.summary.image('trg_reconstructed_images', self.reconst_images) | ||
184 | + self.summary_op_trg = tf.summary.merge([d_loss_trg_summary, g_loss_trg_summary, | ||
185 | + d_loss_fake_trg_summary, d_loss_real_trg_summary, | ||
186 | + g_loss_fake_trg_summary, g_loss_const_trg_summary, | ||
187 | + origin_images_summary, sampled_images_summary]) | ||
198 | for var in tf.trainable_variables(): | 188 | for var in tf.trainable_variables(): |
199 | - histogram_summary(var.op.name, var) | 189 | + tf.summary.histogram(var.op.name, var) |
200 | - | ||
201 | - self.summary_op = merge_summary() | ||
202 | 190 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
203 | - self.saver = tf.train.Saver() | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment