박해연

add code

1 +# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018
2 +
3 +import argparse
4 +import datetime
5 +import importlib
6 +import models
7 +import os
8 +import tensorflow as tf
9 +import time
10 +from data_util import lmdb_dataflow, get_queued_data, resample_pcd
11 +from termcolor import colored
12 +from tf_util import add_train_summary
13 +from visu_util import plot_pcd_three_views
14 +import numpy as np
15 +
16 +def train(args):
17 + is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
18 + global_step = tf.Variable(0, trainable=False, name='global_step')
19 + alpha = tf.train.piecewise_constant(global_step, [3000, 6000, 15000],
20 + [0.01, 0.1, 0.5, 1.0], 'alpha_op')
21 + #beta = tf.train.piecewise_constant(global_step, [6000, 15000, 30000],
22 + # [0.01, 0.1, 0.5, 1.0], 'beta_op')
23 + beta = tf.constant(1.0)
24 + inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
25 + my_inputs_pl = tf.placeholder(tf.float32,(args.batch_size,None,3),'my_inputs')####
26 + npts_pl = tf.placeholder(tf.int32, (args.batch_size,), 'num_points')
27 + gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'ground_truths')
28 +
29 + model_module = importlib.import_module('.%s' % args.model_type, 'models')
30 + model = model_module.Model(inputs_pl,my_inputs_pl, npts_pl, gt_pl, alpha, beta)
31 + add_train_summary('alpha', alpha)
32 + add_train_summary('beta',beta)
33 +
34 + if args.lr_decay:
35 + learning_rate = tf.train.exponential_decay(args.base_lr, global_step,
36 + args.lr_decay_steps, args.lr_decay_rate,
37 + staircase=True, name='lr')
38 + learning_rate = tf.maximum(learning_rate, args.lr_clip)
39 + add_train_summary('learning_rate', learning_rate)
40 + else:
41 + learning_rate = tf.constant(args.base_lr, name='lr')
42 + train_summary = tf.summary.merge_all('train_summary')
43 + valid_summary = tf.summary.merge_all('valid_summary')
44 +
45 + trainer = tf.train.AdamOptimizer(learning_rate)
46 + train_op = trainer.minimize(model.loss, global_step)
47 +
48 + df_train, num_train = lmdb_dataflow(
49 + args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True)
50 + train_gen = df_train.get_data()
51 + df_valid, num_valid = lmdb_dataflow(
52 + args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False)
53 + valid_gen = df_valid.get_data()
54 +
55 + config = tf.ConfigProto()
56 + config.gpu_options.allow_growth = True
57 + config.allow_soft_placement = True
58 + sess = tf.Session(config=config)
59 + saver = tf.train.Saver()
60 +
61 + print('#########################################')
62 + print(args.restore)
63 + if args.restore:
64 + print('*************************restore******************************')
65 + saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
66 + writer = tf.summary.FileWriter(args.log_dir)
67 + else:
68 + print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
69 + sess.run(tf.global_variables_initializer())
70 + if os.path.exists(args.log_dir):
71 + delete_key = input(colored('%s exists. Delete? [y (or enter)/N]'
72 + % args.log_dir, 'white', 'on_red'))
73 + if delete_key == 'y' or delete_key == "":
74 + os.system('rm -rf %s/*' % args.log_dir)
75 + os.makedirs(os.path.join(args.log_dir, 'plots'))
76 + else:
77 + os.makedirs(os.path.join(args.log_dir, 'plots'))
78 + with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
79 + for arg in sorted(vars(args)):
80 + log.write(arg + ': ' + str(getattr(args, arg)) + '\n') # log of arguments
81 + os.system('cp models/%s.py %s' % (args.model_type, args.log_dir)) # bkp of model def
82 + os.system('cp train.py %s' % args.log_dir) # bkp of train procedure
83 + writer = tf.summary.FileWriter(args.log_dir, sess.graph)
84 +
85 + total_time = 0
86 + train_start = time.time()
87 + init_step = sess.run(global_step)
88 + for step in range(init_step+1, args.max_step+1):
89 + epoch = step * args.batch_size // num_train + 1
90 + ids, inputs, npts, gt = next(train_gen)
91 +
92 + #split idx arr
93 + split_idx=[]
94 + idx=0
95 + for num in npts[:-1]:
96 + idx+=num
97 + split_idx.append(idx)
98 + #print('split idx')
99 + #print(split_idx)
100 +
101 + max_pcd_size = np.max(npts)
102 + #print(npts)
103 + #print(max_pcd_size)
104 +
105 + ea_pcd = np.split(inputs[0],tuple(split_idx))
106 + inputs_sep = np.array([x for x in ea_pcd])
107 + my_inputs = np.array([resample_pcd(x,max_pcd_size) for x in inputs_sep])
108 +
109 + #print(my_inputs.shape)
110 +
111 +
112 + start = time.time()
113 + feed_dict = {inputs_pl: inputs, my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}###
114 + _, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict)
115 + total_time += time.time() - start
116 + writer.add_summary(summary, step)
117 + if step % args.steps_per_print == 0:
118 + print('epoch %d step %d loss %.8f - time per batch %.4f' %
119 + (epoch, step, loss, total_time / args.steps_per_print))
120 + total_time = 0
121 + if step % args.steps_per_eval == 0:
122 + print(colored('Testing...', 'grey', 'on_green'))
123 + num_eval_steps = num_valid // args.batch_size
124 + total_loss = 0
125 + total_time = 0
126 + sess.run(tf.local_variables_initializer())
127 + for i in range(num_eval_steps):
128 + start = time.time()
129 + ids, inputs, npts, gt = next(valid_gen)
130 + feed_dict = {inputs_pl: inputs,my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False}
131 + loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict)
132 + total_loss += loss
133 + total_time += time.time() - start
134 + summary = sess.run(valid_summary, feed_dict={is_training_pl: False})
135 + writer.add_summary(summary, step)
136 + print(colored('epoch %d step %d loss %.8f - time per batch %.4f' %
137 + (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps),
138 + 'grey', 'on_green'))
139 + total_time = 0
140 + if step % args.steps_per_visu == 0:
141 + all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
142 + for i in range(0, args.batch_size, args.visu_freq):
143 + plot_path = os.path.join(args.log_dir, 'plots',
144 + 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
145 + pcds = [x[i] for x in all_pcds]
146 + plot_pcd_three_views(plot_path, pcds, model.visualize_titles)
147 + if step % args.steps_per_save == 0:
148 + saver.save(sess, os.path.join(args.log_dir, 'model'), step)
149 + print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue'))
150 +
151 + print('Total time', datetime.timedelta(seconds=time.time() - train_start))
152 + sess.close()
153 +
154 +
155 +if __name__ == '__main__':
156 + parser = argparse.ArgumentParser()
157 + parser.add_argument('--lmdb_train', default='data/shapenet/train.lmdb')
158 + parser.add_argument('--lmdb_valid', default='data/shapenet/valid.lmdb')
159 + parser.add_argument('--log_dir', default='log/pcn_emd')
160 + parser.add_argument('--model_type', default='pcn_emd')
161 + parser.add_argument('--restore', action='store_true')
162 + parser.add_argument('--batch_size', type=int, default=32)
163 + parser.add_argument('--num_input_points', type=int, default=3000)
164 + parser.add_argument('--num_gt_points', type=int, default=16384)
165 + parser.add_argument('--base_lr', type=float, default=0.0001)
166 + parser.add_argument('--lr_decay', action='store_true')
167 + parser.add_argument('--lr_decay_steps', type=int, default=50000)
168 + parser.add_argument('--lr_decay_rate', type=float, default=0.7)
169 + parser.add_argument('--lr_clip', type=float, default=1e-6)
170 + parser.add_argument('--max_step', type=int, default=300000)
171 + parser.add_argument('--steps_per_print', type=int, default=100)
172 + parser.add_argument('--steps_per_eval', type=int, default=1000)
173 + parser.add_argument('--steps_per_visu', type=int, default=3000)
174 + parser.add_argument('--steps_per_save', type=int, default=100000)
175 + parser.add_argument('--visu_freq', type=int, default=5)
176 + args = parser.parse_args()
177 +
178 + train(args)