yunjey

train and eval the model

Showing 1 changed file with 150 additions and 112 deletions
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import pickle
import os
import scipy.io
import hickle
import scipy.misc
from config import SummaryWriter
class Solver(object):
"""Load dataset and train and test the model"""
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100):
def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
self.model = model
self.num_epoch = num_epoch
self.mnist_path = mnist_path
self.svhn_path = svhn_path
self.model_save_path = model_save_path
self.log_path = log_path
self.sample_path = sample_path
self.test_model_path = test_model_path
self.batch_size = batch_size
self.pretrain_iter = pretrain_iter
self.train_iter = train_iter
self.sample_iter = sample_iter
self.svhn_dir = svhn_dir
self.mnist_dir = mnist_dir
self.log_dir = log_dir
self.sample_save_path = sample_save_path
self.model_save_path = model_save_path
self.pretrained_model = pretrained_model
self.test_model = test_model
self.config = tf.ConfigProto()
self.config.gpu_options.allow_growth=True
# create directory if not exists
if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
if not os.path.exists(sample_path):
os.makedirs(sample_path)
# construct the dcgan model
model.build_model()
def load_svhn(self, image_path, split='train'):
def load_svhn(self, image_dir, split='train'):
print ('loading svhn image dataset..')
if split == 'train':
svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat'))
else:
svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat'))
images = np.transpose(svhn['X'], [3, 0, 1, 2])
images = images / 127.5 - 1
image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat'
image_dir = os.path.join(image_dir, image_file)
svhn = scipy.io.loadmat(image_dir)
images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1
labels = svhn['y'].reshape(-1)
labels[np.where(labels==10)] = 0
print ('finished loading svhn image dataset..!')
return images
return images, labels
def load_mnist(self, image_path, split='train'):
def load_mnist(self, image_dir, split='train'):
print ('loading mnist image dataset..')
if split == 'train':
image_file = os.path.join(image_path, 'train.images.hkl')
else:
image_file = os.path.join(image_path, 'test.images.hkl')
try:
images = hickle.load(image_file)
except:
hickle.load(images, image_file)
images = images / 127.5 - 1
image_file = 'train.pkl' if split=='train' else 'test.pkl'
image_dir = os.path.join(image_dir, image_file)
with open(image_dir, 'rb') as f:
mnist = pickle.load(f)
images = mnist['X'] / 127.5 - 1
labels = mnist['y']
print ('finished loading mnist image dataset..!')
return images
return images, labels
def merge_images(self, sources, targets, k=10):
_, h, w, _ = sources.shape
row = int(np.sqrt(self.model.batch_size))
row = int(np.sqrt(self.batch_size))
merged = np.zeros([row*h, row*w*2, 3])
for idx, (s, t) in enumerate(zip(sources, targets)):
......@@ -74,87 +59,140 @@ class Solver(object):
j = idx % row
merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s
merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t
return merged
def pretrain(self):
# load svhn dataset
train_images, train_labels = self.load_svhn(self.svhn_dir, split='train')
test_images, test_labels = self.load_svhn(self.svhn_dir, split='test')
def train(self):
model=self.model
# load image dataset
svhn = self.load_svhn(self.svhn_path)
mnist = self.load_mnist(self.mnist_path)
# build a graph
model = self.model
model.build_model()
with tf.Session(config=self.config) as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
for step in range(self.pretrain_iter+1):
i = step % int(train_images.shape[0] / self.batch_size)
batch_images = train_images[i*self.batch_size:(i+1)*self.batch_size]
batch_labels = train_labels[i*self.batch_size:(i+1)*self.batch_size]
feed_dict = {model.images: batch_images, model.labels: batch_labels}
sess.run(model.train_op, feed_dict)
if (step+1) % 10 == 0:
summary, l, acc = sess.run([model.summary_op, model.loss, model.accuracy], feed_dict)
rand_idxs = np.random.permutation(test_images.shape[0])[:self.batch_size]
test_acc, _ = sess.run(fetches=[model.accuracy, model.loss],
feed_dict={model.images: test_images[rand_idxs],
model.labels: test_labels[rand_idxs]})
summary_writer.add_summary(summary, step)
print ('Step: [%d/%d] loss: [%.6f] train acc: [%.2f] test acc [%.2f]' \
%(step+1, self.pretrain_iter, l, acc, test_acc))
if (step+1) % 1000 == 0:
saver.save(sess, os.path.join(self.model_save_path, 'svhn_model'), global_step=step+1)
print ('svhn_model-%d saved..!' %(step+1))
num_iter_per_epoch = int(mnist.shape[0] / model.batch_size)
def train(self):
# load svhn dataset
svhn_images, _ = self.load_svhn(self.svhn_dir, split='train')
mnist_images, _ = self.load_mnist(self.mnist_dir, split='train')
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# initialize parameters
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
# build a graph
model = self.model
model.build_model()
summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
# make log directory if not exists
if tf.gfile.Exists(self.log_dir):
tf.gfile.DeleteRecursively(self.log_dir)
tf.gfile.MakeDirs(self.log_dir)
for e in range(self.num_epoch):
for i in range(num_iter_per_epoch):
with tf.Session(config=self.config) as sess:
# initialize G and D
tf.global_variables_initializer().run()
# restore variables of F
print ('loading pretrained model F..')
variables_to_restore = slim.get_model_variables(scope='content_extractor')
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, self.pretrained_model)
summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
saver = tf.train.Saver()
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
if i % 3 == 0:
sess.run(model.f_optimizer_const, feed_dict)
if i % 10 == 0:
feed_dict = {model.images: image_batch}
summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict)
summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
# train model for target domain T
image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_real, feed_dict)
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer_const, feed_dict)
sess.run(model.g_optimizer_const, feed_dict)
if i % 500 == 0:
model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1)
print ('model/dtn-%d-%d saved' %(e+1, i+1))
def test(self):
print ('start training..!')
for step in range(self.train_iter+1):
i = step % int(svhn_images.shape[0] / self.batch_size)
# train the model for source domain S
src_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size]
feed_dict = {model.src_images: src_images}
sess.run(model.d_train_op_src, feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
if i % 15 == 0:
sess.run(model.f_train_op_src, feed_dict)
if (step+1) % 10 == 0:
summary, dl, gl, fl = sess.run([model.summary_op_src, \
model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict)
summary_writer.add_summary(summary, step)
print ('[Source] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f] f_loss: [%.6f]' \
%(step+1, self.train_iter, dl, gl, fl))
# train the model for target domain T
j = step % int(mnist_images.shape[0] / self.batch_size)
trg_images = mnist_images[j*self.batch_size:(j+1)*self.batch_size]
feed_dict = {model.src_images: src_images, model.trg_images: trg_images}
sess.run(model.d_train_op_trg, feed_dict)
sess.run(model.d_train_op_trg, feed_dict)
sess.run(model.g_train_op_trg, feed_dict)
sess.run(model.g_train_op_trg, feed_dict)
sess.run(model.g_train_op_trg, feed_dict)
sess.run(model.g_train_op_trg, feed_dict)
if (step+1) % 10 == 0:
summary, dl, gl = sess.run([model.summary_op_trg, \
model.d_loss_trg, model.g_loss_trg], feed_dict)
summary_writer.add_summary(summary, step)
print ('[Target] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' \
%(step+1, self.train_iter, dl, gl))
if (step+1) % 200 == 0:
saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1)
print ('model/dtn-%d saved' %(step+1))
def eval(self):
# build model
model = self.model
model.build_model()
# load dataset
svhn = self.load_svhn(self.svhn_path)
num_iter = int(svhn.shape[0] / model.batch_size)
# load svhn dataset
svhn_images, _ = self.load_svhn(self.svhn_dir)
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
with tf.Session(config=self.config) as sess:
# load trained parameters
print ('loading test model..')
saver = tf.train.Saver()
saver.restore(sess, self.test_model_path)
saver.restore(sess, self.test_model)
print ('start sampling..!')
for i in range(self.sample_iter):
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sampled_image_batch = sess.run(model.sampled_images, feed_dict)
batch_images = svhn_images[i*self.batch_size:(i+1)*self.batch_size]
feed_dict = {model.images: batch_images}
sampled_batch_images = sess.run(model.sampled_images, feed_dict)
# merge and save source images and sampled target images
merged = self.merge_images(image_batch, sampled_image_batch)
path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size))
merged = self.merge_images(batch_images, sampled_batch_images)
path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size))
scipy.misc.imsave(path, merged)
print ('saved %s' %path)
......