train.py 8.34 KB
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

import argparse
import datetime
import importlib
import models
import os
import tensorflow as tf
import time
from data_util import lmdb_dataflow, get_queued_data, resample_pcd
from termcolor import colored
from tf_util import add_train_summary
from visu_util import plot_pcd_three_views
import numpy as np

def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(global_step, [3000, 6000, 15000],
                                        [0.01, 0.1, 0.5, 1.0], 'alpha_op')
    #beta = tf.train.piecewise_constant(global_step, [6000, 15000, 30000],
    #                                    [0.01, 0.1, 0.5, 1.0], 'beta_op')
    beta = tf.constant(1.0)
    inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
    my_inputs_pl = tf.placeholder(tf.float32,(args.batch_size,None,3),'my_inputs')####
    npts_pl = tf.placeholder(tf.int32, (args.batch_size,), 'num_points')
    gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'ground_truths')

    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs_pl,my_inputs_pl, npts_pl, gt_pl, alpha, beta)
    add_train_summary('alpha', alpha)
    add_train_summary('beta',beta)

    if args.lr_decay:
        learning_rate = tf.train.exponential_decay(args.base_lr, global_step,
                                                   args.lr_decay_steps, args.lr_decay_rate,
                                                   staircase=True, name='lr')
        learning_rate = tf.maximum(learning_rate, args.lr_clip)
        add_train_summary('learning_rate', learning_rate)
    else:
        learning_rate = tf.constant(args.base_lr, name='lr')
    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')

    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(model.loss, global_step)

    df_train, num_train = lmdb_dataflow(
        args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True)
    train_gen = df_train.get_data()
    df_valid, num_valid = lmdb_dataflow(
        args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False)
    valid_gen = df_valid.get_data()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    saver = tf.train.Saver()

    print('#########################################')
    print(args.restore)
    if args.restore:
        print('*************************restore******************************')
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
        writer = tf.summary.FileWriter(args.log_dir)
    else:
        print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
        sess.run(tf.global_variables_initializer())
        if os.path.exists(args.log_dir):
            delete_key = input(colored('%s exists. Delete? [y (or enter)/N]'
                                       % args.log_dir, 'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s/*' % args.log_dir)
                os.makedirs(os.path.join(args.log_dir, 'plots'))
        else:
            os.makedirs(os.path.join(args.log_dir, 'plots'))
        with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(args)):
                log.write(arg + ': ' + str(getattr(args, arg)) + '\n')     # log of arguments
        os.system('cp models/%s.py %s' % (args.model_type, args.log_dir))  # bkp of model def
        os.system('cp train.py %s' % args.log_dir)                         # bkp of train procedure
        writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    total_time = 0
    train_start = time.time()
    init_step = sess.run(global_step)
    for step in range(init_step+1, args.max_step+1):
        epoch = step * args.batch_size // num_train + 1
        ids, inputs, npts, gt = next(train_gen)
        
        #split idx arr
        split_idx=[]
        idx=0
        for num in npts[:-1]:
            idx+=num
            split_idx.append(idx)
        #print('split idx')
        #print(split_idx)

        max_pcd_size = np.max(npts)
        #print(npts)
        #print(max_pcd_size)

        ea_pcd = np.split(inputs[0],tuple(split_idx))
        inputs_sep = np.array([x for x in ea_pcd])
        my_inputs = np.array([resample_pcd(x,max_pcd_size) for x in inputs_sep])
        
        #print(my_inputs.shape)


        start = time.time()
        feed_dict = {inputs_pl: inputs, my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}###
        _, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict)
        total_time += time.time() - start
        writer.add_summary(summary, step)
        if step % args.steps_per_print == 0:
            print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
                  (epoch, step, loss, total_time / args.steps_per_print))
            total_time = 0
        if step % args.steps_per_eval == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            num_eval_steps = num_valid // args.batch_size
            total_loss = 0
            total_time = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                ids, inputs, npts, gt = next(valid_gen)
                feed_dict = {inputs_pl: inputs,my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False}
                loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict)
                total_loss += loss
                total_time += time.time() - start
            summary = sess.run(valid_summary, feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            print(colored('epoch %d  step %d  loss %.8f - time per batch %.4f' %
                          (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps),
                          'grey', 'on_green'))
            total_time = 0
            if step % args.steps_per_visu == 0:
                all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
                for i in range(0, args.batch_size, args.visu_freq):
                    plot_path = os.path.join(args.log_dir, 'plots',
                                            'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
                    pcds = [x[i] for x in all_pcds]
                    plot_pcd_three_views(plot_path, pcds, model.visualize_titles)
        if step % args.steps_per_save == 0:
            saver.save(sess, os.path.join(args.log_dir, 'model'), step)
            print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue'))

    print('Total time', datetime.timedelta(seconds=time.time() - train_start))
    sess.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmdb_train', default='data/shapenet/train.lmdb')
    parser.add_argument('--lmdb_valid', default='data/shapenet/valid.lmdb')
    parser.add_argument('--log_dir', default='log/pcn_emd')
    parser.add_argument('--model_type', default='pcn_emd')
    parser.add_argument('--restore', action='store_true')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_input_points', type=int, default=3000)
    parser.add_argument('--num_gt_points', type=int, default=16384)
    parser.add_argument('--base_lr', type=float, default=0.0001)
    parser.add_argument('--lr_decay', action='store_true')
    parser.add_argument('--lr_decay_steps', type=int, default=50000)
    parser.add_argument('--lr_decay_rate', type=float, default=0.7)
    parser.add_argument('--lr_clip', type=float, default=1e-6)
    parser.add_argument('--max_step', type=int, default=300000)
    parser.add_argument('--steps_per_print', type=int, default=100)
    parser.add_argument('--steps_per_eval', type=int, default=1000)
    parser.add_argument('--steps_per_visu', type=int, default=3000)
    parser.add_argument('--steps_per_save', type=int, default=100000)
    parser.add_argument('--visu_freq', type=int, default=5)
    args = parser.parse_args()

    train(args)