Showing
1 changed file
with
169 additions
and
182 deletions
1 | import tensorflow as tf | 1 | import tensorflow as tf |
2 | -from ops import * | 2 | +import tensorflow.contrib.slim as slim |
3 | -from config import * | ||
4 | 3 | ||
5 | 4 | ||
6 | class DTN(object): | 5 | class DTN(object): |
7 | - """Domain Transfer Network for unsupervised cross-domain image generation | 6 | + """Domain Transfer Network |
8 | - | ||
9 | - Construct discriminator and generator to prepare for training. | ||
10 | """ | 7 | """ |
11 | - | 8 | + def __init__(self, mode='train', learning_rate=0.0003): |
12 | - def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32, | 9 | + self.mode = mode |
13 | - dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): | ||
14 | - """ | ||
15 | - Args: | ||
16 | - learning_rate: (optional) learning rate for discriminator and generator | ||
17 | - image_size: (optional) spatial size of input image for discriminator | ||
18 | - output_size: (optional) spatial size of image generated by generator | ||
19 | - dim_color: (optional) dimension of image color; default is 3 for rgb | ||
20 | - dim_fout: (optional) dimension of z (random input vector for generator) | ||
21 | - dim_df: (optional) dimension of discriminator's filter in first convolution layer | ||
22 | - dim_gf: (optional) dimension of generator's filter in last convolution layer | ||
23 | - dim_ff: (optional) dimension of function f's filter in first convolution layer | ||
24 | - """ | ||
25 | - # hyper parameters | ||
26 | - self.batch_size = batch_size | ||
27 | self.learning_rate = learning_rate | 10 | self.learning_rate = learning_rate |
28 | - self.image_size = image_size | ||
29 | - self.output_size = output_size | ||
30 | - self.dim_color = dim_color | ||
31 | - self.dim_fout = dim_fout | ||
32 | - self.dim_df = dim_df | ||
33 | - self.dim_gf = dim_gf | ||
34 | - self.dim_ff = dim_ff | ||
35 | - | ||
36 | - # placeholder | ||
37 | - self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images') | ||
38 | - #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator') | ||
39 | - | ||
40 | - # batch normalization layer for discriminator, generator and funtion f | ||
41 | - self.d_bn1 = batch_norm(name='d_bn1') | ||
42 | - self.d_bn2 = batch_norm(name='d_bn2') | ||
43 | - self.d_bn3 = batch_norm(name='d_bn3') | ||
44 | - self.d_bn4 = batch_norm(name='d_bn4') | ||
45 | - | ||
46 | - self.g_bn1 = batch_norm(name='g_bn1') | ||
47 | - self.g_bn2 = batch_norm(name='g_bn2') | ||
48 | - self.g_bn3 = batch_norm(name='g_bn3') | ||
49 | - self.g_bn4 = batch_norm(name='g_bn4') | ||
50 | - | ||
51 | - self.f_bn1 = batch_norm(name='f_bn1') | ||
52 | - self.f_bn2 = batch_norm(name='f_bn2') | ||
53 | - self.f_bn3 = batch_norm(name='f_bn3') | ||
54 | - self.f_bn4 = batch_norm(name='f_bn4') | ||
55 | - | ||
56 | 11 | ||
12 | + def content_extractor(self, images, reuse=False): | ||
13 | + # images: (batch, 32, 32, 3) or (batch, 32, 32, 1) | ||
14 | + | ||
15 | + if images.get_shape()[3] == 1: | ||
16 | + # For mnist dataset, replicate the gray scale image 3 times. | ||
17 | + images = tf.image.grayscale_to_rgb(images) | ||
18 | + | ||
19 | + with tf.variable_scope('content_extractor', reuse=reuse): | ||
20 | + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None, | ||
21 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): | ||
22 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, | ||
23 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')): | ||
24 | + | ||
25 | + net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64) | ||
26 | + net = slim.batch_norm(net, scope='bn1') | ||
27 | + net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128) | ||
28 | + net = slim.batch_norm(net, scope='bn2') | ||
29 | + net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256) | ||
30 | + net = slim.batch_norm(net, scope='bn3') | ||
31 | + net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128) | ||
32 | + net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4') | ||
33 | + if self.mode == 'pretrain': | ||
34 | + net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') | ||
35 | + net = slim.flatten(net) | ||
36 | + | ||
37 | + return net | ||
38 | + | ||
39 | + def generator(self, inputs, reuse=False): | ||
40 | + # inputs: (batch, 1, 1, 128) | ||
41 | + with tf.variable_scope('generator', reuse=reuse): | ||
42 | + with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None, | ||
43 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): | ||
44 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, | ||
45 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train')): | ||
46 | + | ||
47 | + net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512) | ||
48 | + net = slim.batch_norm(net, scope='bn1') | ||
49 | + net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256) | ||
50 | + net = slim.batch_norm(net, scope='bn2') | ||
51 | + net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128) | ||
52 | + net = slim.batch_norm(net, scope='bn3') | ||
53 | + net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1) | ||
54 | + return net | ||
55 | + | ||
56 | + def discriminator(self, images, reuse=False): | ||
57 | + # images: (batch, 32, 32, 1) | ||
58 | + with tf.variable_scope('discriminator', reuse=reuse): | ||
59 | + with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None, | ||
60 | + stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()): | ||
61 | + with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True, | ||
62 | + activation_fn=tf.nn.relu, is_training=(self.mode=='train')): | ||
63 | + | ||
64 | + net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128) | ||
65 | + net = slim.batch_norm(net, scope='bn1') | ||
66 | + net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256) | ||
67 | + net = slim.batch_norm(net, scope='bn2') | ||
68 | + net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512) | ||
69 | + net = slim.batch_norm(net, scope='bn3') | ||
70 | + net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1) | ||
71 | + net = slim.flatten(net) | ||
72 | + return net | ||
73 | + | ||
74 | + def build_model(self): | ||
57 | 75 | ||
58 | - def function_f(self, images, reuse=False, train=True): | 76 | + if self.mode == 'pretrain': |
59 | - """f consistancy | 77 | + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') |
60 | - | 78 | + self.labels = tf.placeholder(tf.int64, [None], 'svhn_labels') |
61 | - Args: | ||
62 | - images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color) | ||
63 | 79 | ||
64 | - Returns: | 80 | + # logits and accuracy |
65 | - out: output vectors, of shape (batch_size, dim_f_out) | 81 | + self.logits = self.content_extractor(self.images) |
66 | - """ | 82 | + self.pred = tf.argmax(self.logits, 1) |
67 | - with tf.variable_scope('function_f', reuse=reuse): | 83 | + self.correct_pred = tf.equal(self.pred, self.labels) |
68 | - h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64) | 84 | + self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) |
69 | - h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128) | ||
70 | - h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256) | ||
71 | - h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512) | ||
72 | 85 | ||
73 | - h4 = tf.reshape(h4, [self.batch_size,-1]) | 86 | + # loss and train op |
74 | - out = linear(h4, self.dim_fout, name='f_out') | 87 | + self.loss = slim.losses.sparse_softmax_cross_entropy(self.logits, self.labels) |
75 | - | 88 | + self.optimizer = tf.train.AdamOptimizer(self.learning_rate) |
76 | - return tf.nn.tanh(out) | 89 | + self.train_op = slim.learning.create_train_op(self.loss, self.optimizer) |
77 | - | ||
78 | - | ||
79 | - def generator(self, z, reuse=False): | ||
80 | - """Generator: Deconvolutional neural network with relu activations. | ||
81 | - | ||
82 | - Last deconv layer does not use batch normalization. | ||
83 | - | ||
84 | - Args: | ||
85 | - z: random input vectors, of shape (batch_size, dim_z) | ||
86 | 90 | ||
87 | - Returns: | 91 | + # summary op |
88 | - out: generated images, of shape (batch_size, image_size, image_size, dim_color) | 92 | + loss_summary = tf.summary.scalar('classification_loss', self.loss) |
89 | - """ | 93 | + accuracy_summary = tf.summary.scalar('accuracy', self.accuracy) |
90 | - if reuse: | 94 | + self.summary_op = tf.summary.merge([loss_summary, accuracy_summary]) |
91 | - train = False | 95 | + |
92 | - else: | 96 | + elif self.mode == 'eval': |
93 | - train = True | 97 | + self.images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') |
94 | - | 98 | + |
95 | - with tf.variable_scope('generator', reuse=reuse): | 99 | + # source domain (svhn to mnist) |
100 | + self.fx = self.content_extractor(self.images) | ||
101 | + self.sampled_images = self.generator(self.fx) | ||
102 | + | ||
103 | + elif self.mode == 'train': | ||
104 | + self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images') | ||
105 | + self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images') | ||
96 | 106 | ||
97 | - # spatial size for convolution | ||
98 | - s = self.output_size | ||
99 | - s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4 | ||
100 | 107 | ||
101 | - # project and reshape z | 108 | + # source domain (svhn to mnist) |
102 | - h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) | 109 | + with tf.name_scope('model_for_source_domain'): |
103 | - h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) | 110 | + self.fx = self.content_extractor(self.src_images) |
104 | - h1 = relu(self.g_bn1(h1, train=train)) | 111 | + self.fake_images = self.generator(self.fx) |
112 | + self.logits = self.discriminator(self.fake_images) | ||
113 | + self.fgfx = self.content_extractor(self.fake_images, reuse=True) | ||
105 | 114 | ||
106 | - h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) | 115 | + # loss |
107 | - h2 = relu(self.g_bn2(h2, train=train)) | 116 | + self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits)) |
117 | + self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits)) | ||
118 | + self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0 | ||
108 | 119 | ||
109 | - h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128) | 120 | + # optimizer |
110 | - h3 = relu(self.g_bn3(h3, train=train)) | 121 | + self.d_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) |
122 | + self.g_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) | ||
123 | + self.f_optimizer_src = tf.train.AdamOptimizer(self.learning_rate) | ||
111 | 124 | ||
112 | - h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64) | 125 | + t_vars = tf.trainable_variables() |
113 | - h4 = relu(self.g_bn4(h4, train=train)) | 126 | + d_vars = [var for var in t_vars if 'discriminator' in var.name] |
127 | + g_vars = [var for var in t_vars if 'generator' in var.name] | ||
128 | + f_vars = [var for var in t_vars if 'content_extractor' in var.name] | ||
114 | 129 | ||
115 | - out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color) | 130 | + # train op |
131 | + with tf.name_scope('source_train_op'): | ||
132 | + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) | ||
133 | + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) | ||
134 | + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) | ||
116 | 135 | ||
117 | - return tf.nn.tanh(out) | 136 | + # summary op |
118 | - | 137 | + d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src) |
119 | - | 138 | + g_loss_src_summary = tf.summary.scalar('src_g_loss', self.g_loss_src) |
120 | - def discriminator(self, images, reuse=False): | 139 | + f_loss_src_summary = tf.summary.scalar('src_f_loss', self.f_loss_src) |
121 | - """Discrimator: Convolutional neural network with leaky relu activations. | 140 | + origin_images_summary = tf.summary.image('src_origin_images', self.src_images) |
122 | - | 141 | + sampled_images_summary = tf.summary.image('src_sampled_images', self.fake_images) |
123 | - First conv layer does not use batch normalization. | 142 | + self.summary_op_src = tf.summary.merge([d_loss_src_summary, g_loss_src_summary, |
124 | - | 143 | + f_loss_src_summary, origin_images_summary, |
125 | - Args: | 144 | + sampled_images_summary]) |
126 | - images: real or fake images of shape (batch_size, image_size, image_size, dim_color) | ||
127 | - | ||
128 | - Returns: | ||
129 | - out: scores for whether it is a real image or a fake image, of shape (batch_size,) | ||
130 | - """ | ||
131 | - with tf.variable_scope('discriminator', reuse=reuse): | ||
132 | - | ||
133 | - # convolution layer | ||
134 | - h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64) | ||
135 | - h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) | ||
136 | - h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) | ||
137 | - h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) | ||
138 | - | ||
139 | - # fully connected layer | ||
140 | - h4 = tf.reshape(h4, [self.batch_size, -1]) | ||
141 | - out = linear(h4, 1, name='d_out') # (batch_size,) | ||
142 | - | ||
143 | - return out | ||
144 | - | ||
145 | - | ||
146 | - def build_model(self): | ||
147 | - | ||
148 | - # construct generator and discriminator for training phase | ||
149 | - self.f_x = self.function_f(self.images) | ||
150 | - self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3) | ||
151 | - self.logits_real = self.discriminator(self.images) # (batch_size,) | ||
152 | - self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) | ||
153 | - self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) | ||
154 | - | ||
155 | - # construct generator for test phase (use moving average and variance for batch norm) | ||
156 | - self.f_x = self.function_f(self.images, reuse=True, train=False) | ||
157 | - self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) | ||
158 | - | ||
159 | - | ||
160 | - # compute loss | ||
161 | - self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real))) | ||
162 | - self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake))) | ||
163 | - self.d_loss = self.d_loss_real + self.d_loss_fake | ||
164 | - self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) | ||
165 | - self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID | ||
166 | - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST | ||
167 | - | ||
168 | - # divide variables for discriminator and generator | ||
169 | - t_vars = tf.trainable_variables() | ||
170 | - self.d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
171 | - self.g_vars = [var for var in t_vars if 'generator' in var.name] | ||
172 | - self.f_vars = [var for var in t_vars if 'function_f' in var.name] | ||
173 | - | ||
174 | - # optimizer for discriminator and generator | ||
175 | - with tf.name_scope('optimizer'): | ||
176 | - self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars) | ||
177 | - self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) | ||
178 | - self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) | ||
179 | - self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) | ||
180 | - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars) | ||
181 | 145 | ||
146 | + # target domain (mnist) | ||
147 | + with tf.name_scope('model_for_target_domain'): | ||
148 | + self.fx = self.content_extractor(self.trg_images, reuse=True) | ||
149 | + self.reconst_images = self.generator(self.fx, reuse=True) | ||
150 | + self.logits_fake = self.discriminator(self.reconst_images, reuse=True) | ||
151 | + self.logits_real = self.discriminator(self.trg_images, reuse=True) | ||
182 | 152 | ||
183 | - # summary ops for tensorboard visualization | 153 | + # loss |
184 | - scalar_summary('d_loss_real', self.d_loss_real) | 154 | + self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake)) |
185 | - scalar_summary('d_loss_fake', self.d_loss_fake) | 155 | + self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real)) |
186 | - scalar_summary('d_loss', self.d_loss) | 156 | + self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg |
187 | - scalar_summary('g_loss', self.g_loss) | 157 | + self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake)) |
188 | - scalar_summary('g_const_loss', self.g_const_loss) | 158 | + self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0 |
189 | - scalar_summary('f_const_loss', self.f_const_loss) | 159 | + self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg |
190 | - | ||
191 | - try: | ||
192 | - image_summary('original_images', self.images, max_outputs=4) | ||
193 | - image_summary('sampled_images', self.sampled_images, max_outputs=4) | ||
194 | - except: | ||
195 | - image_summary('original_images', self.images, max_images=4) | ||
196 | - image_summary('sampled_images', self.sampled_images, max_images=4) | ||
197 | - | ||
198 | - for var in tf.trainable_variables(): | ||
199 | - histogram_summary(var.op.name, var) | ||
200 | 160 | ||
201 | - self.summary_op = merge_summary() | ||
202 | - | ||
203 | - self.saver = tf.train.Saver() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
161 | + # optimizer | ||
162 | + self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | ||
163 | + self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | ||
164 | + | ||
165 | + t_vars = tf.trainable_variables() | ||
166 | + d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
167 | + g_vars = [var for var in t_vars if 'generator' in var.name] | ||
168 | + f_vars = [var for var in t_vars if 'content_extractor' in var.name] | ||
169 | + | ||
170 | + # train op | ||
171 | + with tf.name_scope('target_train_op'): | ||
172 | + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) | ||
173 | + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) | ||
174 | + | ||
175 | + # summary op | ||
176 | + d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg) | ||
177 | + d_loss_real_trg_summary = tf.summary.scalar('trg_d_loss_real', self.d_loss_real_trg) | ||
178 | + d_loss_trg_summary = tf.summary.scalar('trg_d_loss', self.d_loss_trg) | ||
179 | + g_loss_fake_trg_summary = tf.summary.scalar('trg_g_loss_fake', self.g_loss_fake_trg) | ||
180 | + g_loss_const_trg_summary = tf.summary.scalar('trg_g_loss_const', self.g_loss_const_trg) | ||
181 | + g_loss_trg_summary = tf.summary.scalar('trg_g_loss', self.g_loss_trg) | ||
182 | + origin_images_summary = tf.summary.image('trg_origin_images', self.trg_images) | ||
183 | + sampled_images_summary = tf.summary.image('trg_reconstructed_images', self.reconst_images) | ||
184 | + self.summary_op_trg = tf.summary.merge([d_loss_trg_summary, g_loss_trg_summary, | ||
185 | + d_loss_fake_trg_summary, d_loss_real_trg_summary, | ||
186 | + g_loss_fake_trg_summary, g_loss_const_trg_summary, | ||
187 | + origin_images_summary, sampled_images_summary]) | ||
188 | + for var in tf.trainable_variables(): | ||
189 | + tf.summary.histogram(var.op.name, var) | ||
190 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment