Showing
4 changed files
with
412 additions
and
0 deletions
model.py
0 → 100644
| 1 | +import tensorflow as tf | ||
| 2 | +from ops import * | ||
| 3 | + | ||
| 4 | +class DTN(object): | ||
| 5 | + """Domain Transfer Network for unsupervised cross-domain image generation | ||
| 6 | + | ||
| 7 | + Construct discriminator and generator to prepare for training. | ||
| 8 | + """ | ||
| 9 | + | ||
| 10 | + def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32, | ||
| 11 | + dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): | ||
| 12 | + """ | ||
| 13 | + Args: | ||
| 14 | + learning_rate: (optional) learning rate for discriminator and generator | ||
| 15 | + image_size: (optional) spatial size of input image for discriminator | ||
| 16 | + output_size: (optional) spatial size of image generated by generator | ||
| 17 | + dim_color: (optional) dimension of image color; default is 3 for rgb | ||
| 18 | + dim_fout: (optional) dimension of z (random input vector for generator) | ||
| 19 | + dim_df: (optional) dimension of discriminator's filter in first convolution layer | ||
| 20 | + dim_gf: (optional) dimension of generator's filter in last convolution layer | ||
| 21 | + dim_ff: (optional) dimension of function f's filter in first convolution layer | ||
| 22 | + """ | ||
| 23 | + # hyper parameters | ||
| 24 | + self.batch_size = batch_size | ||
| 25 | + self.learning_rate = learning_rate | ||
| 26 | + self.image_size = image_size | ||
| 27 | + self.output_size = output_size | ||
| 28 | + self.dim_color = dim_color | ||
| 29 | + self.dim_fout = dim_fout | ||
| 30 | + self.dim_df = dim_df | ||
| 31 | + self.dim_gf = dim_gf | ||
| 32 | + self.dim_ff = dim_ff | ||
| 33 | + | ||
| 34 | + # placeholder | ||
| 35 | + self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images') | ||
| 36 | + #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator') | ||
| 37 | + | ||
| 38 | + # batch normalization layer for discriminator, generator and funtion f | ||
| 39 | + self.d_bn1 = batch_norm(name='d_bn1') | ||
| 40 | + self.d_bn2 = batch_norm(name='d_bn2') | ||
| 41 | + self.d_bn3 = batch_norm(name='d_bn3') | ||
| 42 | + | ||
| 43 | + self.g_bn1 = batch_norm(name='g_bn1') | ||
| 44 | + self.g_bn2 = batch_norm(name='g_bn2') | ||
| 45 | + self.g_bn3 = batch_norm(name='g_bn3') | ||
| 46 | + self.g_bn4 = batch_norm(name='g_bn4') | ||
| 47 | + | ||
| 48 | + self.f_bn1 = batch_norm(name='f_bn1') | ||
| 49 | + self.f_bn2 = batch_norm(name='f_bn2') | ||
| 50 | + self.f_bn3 = batch_norm(name='f_bn3') | ||
| 51 | + self.f_bn4 = batch_norm(name='f_bn4') | ||
| 52 | + | ||
| 53 | + | ||
| 54 | + | ||
| 55 | + def function_f(self, images, reuse=False): | ||
| 56 | + """f consistancy | ||
| 57 | + | ||
| 58 | + Args: | ||
| 59 | + images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color) | ||
| 60 | + | ||
| 61 | + Returns: | ||
| 62 | + out: output vectors, of shape (batch_size, dim_f_out) | ||
| 63 | + """ | ||
| 64 | + with tf.variable_scope('function_f', reuse=reuse): | ||
| 65 | + h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64) | ||
| 66 | + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128) | ||
| 67 | + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256) | ||
| 68 | + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512) | ||
| 69 | + | ||
| 70 | + h4 = tf.reshape(h4, [self.batch_size,-1]) | ||
| 71 | + out = linear(h4, self.dim_fout, name='f_out') | ||
| 72 | + | ||
| 73 | + return tf.nn.tanh(out) | ||
| 74 | + | ||
| 75 | + | ||
| 76 | + def generator(self, z, reuse=False): | ||
| 77 | + """Generator: Deconvolutional neural network with relu activations. | ||
| 78 | + | ||
| 79 | + Last deconv layer does not use batch normalization. | ||
| 80 | + | ||
| 81 | + Args: | ||
| 82 | + z: random input vectors, of shape (batch_size, dim_z) | ||
| 83 | + | ||
| 84 | + Returns: | ||
| 85 | + out: generated images, of shape (batch_size, image_size, image_size, dim_color) | ||
| 86 | + """ | ||
| 87 | + if reuse: | ||
| 88 | + train = False | ||
| 89 | + else: | ||
| 90 | + train = True | ||
| 91 | + | ||
| 92 | + with tf.variable_scope('generator', reuse=reuse): | ||
| 93 | + | ||
| 94 | + # spatial size for convolution | ||
| 95 | + s = self.output_size | ||
| 96 | + s2, s4, s8, s16 = s/2, s/4, s/8, s/16 # 32, 16, 8, 4 | ||
| 97 | + | ||
| 98 | + # project and reshape z | ||
| 99 | + h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) | ||
| 100 | + h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) | ||
| 101 | + h1 = relu(self.g_bn1(h1, train=train)) | ||
| 102 | + | ||
| 103 | + h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) | ||
| 104 | + h2 = relu(self.g_bn2(h2, train=train)) | ||
| 105 | + | ||
| 106 | + h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128) | ||
| 107 | + h3 = relu(self.g_bn3(h3, train=train)) | ||
| 108 | + | ||
| 109 | + h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64) | ||
| 110 | + h4 = relu(self.g_bn4(h4, train=train)) | ||
| 111 | + | ||
| 112 | + out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color) | ||
| 113 | + | ||
| 114 | + return tf.nn.tanh(out) | ||
| 115 | + | ||
| 116 | + | ||
| 117 | + def discriminator(self, images, reuse=False): | ||
| 118 | + """Discrimator: Convolutional neural network with leaky relu activations. | ||
| 119 | + | ||
| 120 | + First conv layer does not use batch normalization. | ||
| 121 | + | ||
| 122 | + Args: | ||
| 123 | + images: real or fake images of shape (batch_size, image_size, image_size, dim_color) | ||
| 124 | + | ||
| 125 | + Returns: | ||
| 126 | + out: scores for whether it is a real image or a fake image, of shape (batch_size,) | ||
| 127 | + """ | ||
| 128 | + with tf.variable_scope('discriminator', reuse=reuse): | ||
| 129 | + | ||
| 130 | + # convolution layer | ||
| 131 | + h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64) | ||
| 132 | + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) | ||
| 133 | + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) | ||
| 134 | + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) | ||
| 135 | + | ||
| 136 | + # fully connected layer | ||
| 137 | + h4 = tf.reshape(h4, [self.batch_size, -1]) | ||
| 138 | + out = linear(h4, 1, name='d_out') # (batch_size,) | ||
| 139 | + | ||
| 140 | + return out | ||
| 141 | + | ||
| 142 | + | ||
| 143 | + def build_model(self): | ||
| 144 | + | ||
| 145 | + # construct generator and discriminator for training phase | ||
| 146 | + self.f_x = self.function_f(self.images) | ||
| 147 | + self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3) | ||
| 148 | + self.logits_real = self.discriminator(self.images) # (batch_size,) | ||
| 149 | + self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) | ||
| 150 | + self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) | ||
| 151 | + | ||
| 152 | + # construct generator for test phase | ||
| 153 | + self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) | ||
| 154 | + | ||
| 155 | + | ||
| 156 | + # compute loss | ||
| 157 | + self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real))) | ||
| 158 | + self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake))) | ||
| 159 | + self.d_loss = self.d_loss_real + self.d_loss_fake | ||
| 160 | + self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) | ||
| 161 | + self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID | ||
| 162 | + self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST | ||
| 163 | + | ||
| 164 | + # divide variables for discriminator and generator | ||
| 165 | + t_vars = tf.trainable_variables() | ||
| 166 | + self.d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
| 167 | + self.g_vars = [var for var in t_vars if 'generator' in var.name] | ||
| 168 | + self.f_vars = [var for var in t_vars if 'function_f' in var.name] | ||
| 169 | + | ||
| 170 | + # optimizer for discriminator and generator | ||
| 171 | + with tf.name_scope('optimizer'): | ||
| 172 | + self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars) | ||
| 173 | + self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) | ||
| 174 | + self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) | ||
| 175 | + self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) | ||
| 176 | + self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars) | ||
| 177 | + | ||
| 178 | + | ||
| 179 | + # summary ops for tensorboard visualization | ||
| 180 | + tf.scalar_summary('d_loss_real', self.d_loss_real) | ||
| 181 | + tf.scalar_summary('d_loss_fake', self.d_loss_fake) | ||
| 182 | + tf.scalar_summary('d_loss', self.d_loss) | ||
| 183 | + tf.scalar_summary('g_loss', self.g_loss) | ||
| 184 | + tf.scalar_summary('g_const_loss', self.g_const_loss) | ||
| 185 | + tf.scalar_summary('f_const_loss', self.f_const_loss) | ||
| 186 | + tf.image_summary('original_images', self.images, max_images=6) | ||
| 187 | + tf.image_summary('sampled_images', self.sampled_images, max_images=6) | ||
| 188 | + | ||
| 189 | + for var in tf.trainable_variables(): | ||
| 190 | + tf.histogram_summary(var.op.name, var) | ||
| 191 | + | ||
| 192 | + self.summary_op = tf.merge_all_summaries() | ||
| 193 | + | ||
| 194 | + self.saver = tf.train.Saver() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops.py
0 → 100644
| 1 | +import tensorflow as tf | ||
| 2 | + | ||
| 3 | + | ||
| 4 | +class batch_norm(object): | ||
| 5 | + """Computes batch normalization operation | ||
| 6 | + | ||
| 7 | + Args: | ||
| 8 | + x: input tensor of shape (batch_size, width, height, channels_in) or (batch_size, dim_in) | ||
| 9 | + train: True or False; At train mode, it normalizes the input with mini-batch statistics | ||
| 10 | + At test mode, it normalizes the input with the moving averages and variances | ||
| 11 | + | ||
| 12 | + Returns: | ||
| 13 | + out: batch normalized output of the same shape with x | ||
| 14 | + """ | ||
| 15 | + def __init__(self, name): | ||
| 16 | + self.name = name | ||
| 17 | + | ||
| 18 | + def __call__(self, x, train=True): | ||
| 19 | + out = tf.contrib.layers.batch_norm(x, decay=0.99, center=True, scale=True, activation_fn=None, | ||
| 20 | + updates_collections=None, is_training=train, scope=self.name) | ||
| 21 | + return out | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +def conv2d(x, channel_out, k_w=5, k_h=5, s_w=2, s_h=2, name=None): | ||
| 25 | + """Computes convolution operation | ||
| 26 | + | ||
| 27 | + Args: | ||
| 28 | + x: input tensor of shape (batch_size, width_in, heigth_in, channel_in) | ||
| 29 | + channel_out: number of channel for output tensor | ||
| 30 | + k_w: kernel width size; default is 5 | ||
| 31 | + k_h: kernel height size; default is 5 | ||
| 32 | + s_w: stride size for width; default is 2 | ||
| 33 | + s_h: stride size for heigth; default is 2 | ||
| 34 | + | ||
| 35 | + Returns: | ||
| 36 | + out: output tensor of shape (batch_size, width_out, height_out, channel_out) | ||
| 37 | + """ | ||
| 38 | + channel_in = x.get_shape()[-1] | ||
| 39 | + | ||
| 40 | + with tf.variable_scope(name): | ||
| 41 | + w = tf.get_variable('w', shape=[k_w, k_h, channel_in, channel_out], | ||
| 42 | + initializer=tf.contrib.layers.xavier_initializer()) | ||
| 43 | + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0)) | ||
| 44 | + | ||
| 45 | + out = tf.nn.conv2d(x, w, strides=[1, s_w, s_h, 1], padding='SAME') + b | ||
| 46 | + | ||
| 47 | + return out | ||
| 48 | + | ||
| 49 | + | ||
| 50 | +def deconv2d(x, output_shape, k_w=5, k_h=5, s_w=2, s_h=2, name=None): | ||
| 51 | + """Computes deconvolution operation | ||
| 52 | + | ||
| 53 | + Args: | ||
| 54 | + x: input tensor of shape (batch_size, width_in, height_in, channel_in) | ||
| 55 | + output_shape: list corresponding to [batch_size, width_out, height_out, channel_out] | ||
| 56 | + k_w: kernel width size; default is 5 | ||
| 57 | + k_h: kernel height size; default is 5 | ||
| 58 | + s_w: stride size for width; default is 2 | ||
| 59 | + s_h: stride size for heigth; default is 2 | ||
| 60 | + | ||
| 61 | + Returns: | ||
| 62 | + out: output tensor of shape (batch_size, width_out, hegith_out, channel_out) | ||
| 63 | + """ | ||
| 64 | + channel_in = x.get_shape()[-1] | ||
| 65 | + channel_out = output_shape[-1] | ||
| 66 | + | ||
| 67 | + | ||
| 68 | + with tf.variable_scope(name): | ||
| 69 | + w = tf.get_variable('w', shape=[k_w, k_h, channel_out, channel_in], | ||
| 70 | + initializer=tf.contrib.layers.xavier_initializer()) | ||
| 71 | + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0)) | ||
| 72 | + | ||
| 73 | + out = tf.nn.conv2d_transpose(x, filter=w, output_shape=output_shape, strides=[1, s_w, s_h, 1]) + b | ||
| 74 | + | ||
| 75 | + return out | ||
| 76 | + | ||
| 77 | +def linear(x, dim_out, name=None): | ||
| 78 | + """Computes linear transform (fully-connected layer) | ||
| 79 | + | ||
| 80 | + Args: | ||
| 81 | + x: input tensor of shape (batch_size, dim_in) | ||
| 82 | + dim_out: dimension for output tensor | ||
| 83 | + | ||
| 84 | + Returns: | ||
| 85 | + out: output tensor of shape (batch_size, dim_out) | ||
| 86 | + """ | ||
| 87 | + dim_in = x.get_shape()[-1] | ||
| 88 | + | ||
| 89 | + with tf.variable_scope(name): | ||
| 90 | + w = tf.get_variable('w', shape=[dim_in, dim_out], initializer=tf.contrib.layers.xavier_initializer()) | ||
| 91 | + b = tf.get_variable('b', shape=[dim_out], initializer=tf.constant_initializer(0.0)) | ||
| 92 | + | ||
| 93 | + out = tf.matmul(x, w) + b | ||
| 94 | + | ||
| 95 | + return out | ||
| 96 | + | ||
| 97 | + | ||
| 98 | +def relu(x): | ||
| 99 | + return tf.nn.relu(x) | ||
| 100 | + | ||
| 101 | + | ||
| 102 | +def lrelu(x, leak=0.2): | ||
| 103 | + return tf.maximum(x, leak*x) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
solver.py
0 → 100644
| 1 | +import tensorflow as tf | ||
| 2 | +import numpy as np | ||
| 3 | +import os | ||
| 4 | +import scipy.io | ||
| 5 | +import hickle | ||
| 6 | +from scipy import ndimage | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +class Solver(object): | ||
| 10 | + """Load dataset and train DCGAN""" | ||
| 11 | + | ||
| 12 | + def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'): | ||
| 13 | + self.model = model | ||
| 14 | + self.num_epoch = num_epoch | ||
| 15 | + self.mnist_path = mnist_path | ||
| 16 | + self.svhn_path = svhn_path | ||
| 17 | + self.model_save_path = model_save_path | ||
| 18 | + self.log_path = log_path | ||
| 19 | + | ||
| 20 | + # create directory if not exists | ||
| 21 | + if not os.path.exists(log_path): | ||
| 22 | + os.makedirs(log_path) | ||
| 23 | + if not os.path.exists(model_save_path): | ||
| 24 | + os.makedirs(model_save_path) | ||
| 25 | + | ||
| 26 | + # construct the dcgan model | ||
| 27 | + model.build_model() | ||
| 28 | + | ||
| 29 | + # load dataset | ||
| 30 | + self.svhn = self.load_svhn(self.svhn_path) | ||
| 31 | + self.mnist = self.load_mnist(self.mnist_path) | ||
| 32 | + | ||
| 33 | + | ||
| 34 | + def load_svhn(self, image_path, split='train'): | ||
| 35 | + print ('loading svhn image dataset..') | ||
| 36 | + if split == 'train': | ||
| 37 | + svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat')) | ||
| 38 | + else: | ||
| 39 | + svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat')) | ||
| 40 | + | ||
| 41 | + images = np.transpose(svhn['X'], [3, 0, 1, 2]) | ||
| 42 | + images = images / 127.5 - 1 | ||
| 43 | + print ('finished loading svhn image dataset..!') | ||
| 44 | + return images | ||
| 45 | + | ||
| 46 | + | ||
| 47 | + def load_mnist(self, image_path, split='train'): | ||
| 48 | + print ('loading mnist image dataset..') | ||
| 49 | + if split == 'train': | ||
| 50 | + image_file = os.path.join(image_path, 'train.images.hkl') | ||
| 51 | + else: | ||
| 52 | + image_file = os.path.join(image_path, 'test.images.hkl') | ||
| 53 | + | ||
| 54 | + images = hickle.load(image_file) | ||
| 55 | + images = images / 127.5 - 1 | ||
| 56 | + print ('finished loading mnist image dataset..!') | ||
| 57 | + return images | ||
| 58 | + | ||
| 59 | + | ||
| 60 | + def train(self): | ||
| 61 | + model=self.model | ||
| 62 | + | ||
| 63 | + #load image dataset | ||
| 64 | + svhn = self.svhn | ||
| 65 | + mnist = self.mnist | ||
| 66 | + | ||
| 67 | + num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) | ||
| 68 | + | ||
| 69 | + config = tf.ConfigProto(allow_soft_placement = True) | ||
| 70 | + config.gpu_options.allow_growth = True | ||
| 71 | + with tf.Session(config=config) as sess: | ||
| 72 | + # initialize parameters | ||
| 73 | + tf.initialize_all_variables().run() | ||
| 74 | + summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) | ||
| 75 | + | ||
| 76 | + for e in range(self.num_epoch): | ||
| 77 | + for i in range(num_iter_per_epoch): | ||
| 78 | + | ||
| 79 | + # train model for domain S | ||
| 80 | + image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | ||
| 81 | + feed_dict = {model.images: image_batch} | ||
| 82 | + sess.run(model.d_optimizer_fake, feed_dict) | ||
| 83 | + sess.run(model.f_optimizer_const, feed_dict) | ||
| 84 | + sess.run(model.g_optimizer, feed_dict) | ||
| 85 | + | ||
| 86 | + if i % 10 == 0: | ||
| 87 | + feed_dict = {model.images: image_batch} | ||
| 88 | + summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict) | ||
| 89 | + summary_writer.add_summary(summary, e*num_iter_per_epoch + i) | ||
| 90 | + print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) | ||
| 91 | + | ||
| 92 | + # train model for domain T | ||
| 93 | + image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] | ||
| 94 | + feed_dict = {model.images: image_batch} | ||
| 95 | + sess.run(model.d_optimizer_real, feed_dict) | ||
| 96 | + sess.run(model.d_optimizer_fake, feed_dict) | ||
| 97 | + sess.run(model.g_optimizer, feed_dict) | ||
| 98 | + sess.run(model.g_optimizer_const, feed_dict) | ||
| 99 | + | ||
| 100 | + | ||
| 101 | + | ||
| 102 | + if i % 500 == 0: | ||
| 103 | + model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1) | ||
| 104 | + print ('model/dcgan-%d-%d saved' %(e+1, i+1)) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
train.py
0 → 100644
| 1 | +from model import DTN | ||
| 2 | +from solver import Solver | ||
| 3 | + | ||
| 4 | +def main(): | ||
| 5 | + model = DTN() | ||
| 6 | + solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/') | ||
| 7 | + solver.train() | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +if __name__ == "__main__": | ||
| 11 | + main() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment