Showing
1 changed file
with
178 additions
and
0 deletions
code/pcn_modify/pcn/train.py
0 → 100644
1 | +# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018 | ||
2 | + | ||
3 | +import argparse | ||
4 | +import datetime | ||
5 | +import importlib | ||
6 | +import models | ||
7 | +import os | ||
8 | +import tensorflow as tf | ||
9 | +import time | ||
10 | +from data_util import lmdb_dataflow, get_queued_data, resample_pcd | ||
11 | +from termcolor import colored | ||
12 | +from tf_util import add_train_summary | ||
13 | +from visu_util import plot_pcd_three_views | ||
14 | +import numpy as np | ||
15 | + | ||
16 | +def train(args): | ||
17 | + is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') | ||
18 | + global_step = tf.Variable(0, trainable=False, name='global_step') | ||
19 | + alpha = tf.train.piecewise_constant(global_step, [3000, 6000, 15000], | ||
20 | + [0.01, 0.1, 0.5, 1.0], 'alpha_op') | ||
21 | + #beta = tf.train.piecewise_constant(global_step, [6000, 15000, 30000], | ||
22 | + # [0.01, 0.1, 0.5, 1.0], 'beta_op') | ||
23 | + beta = tf.constant(1.0) | ||
24 | + inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs') | ||
25 | + my_inputs_pl = tf.placeholder(tf.float32,(args.batch_size,None,3),'my_inputs')#### | ||
26 | + npts_pl = tf.placeholder(tf.int32, (args.batch_size,), 'num_points') | ||
27 | + gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'ground_truths') | ||
28 | + | ||
29 | + model_module = importlib.import_module('.%s' % args.model_type, 'models') | ||
30 | + model = model_module.Model(inputs_pl,my_inputs_pl, npts_pl, gt_pl, alpha, beta) | ||
31 | + add_train_summary('alpha', alpha) | ||
32 | + add_train_summary('beta',beta) | ||
33 | + | ||
34 | + if args.lr_decay: | ||
35 | + learning_rate = tf.train.exponential_decay(args.base_lr, global_step, | ||
36 | + args.lr_decay_steps, args.lr_decay_rate, | ||
37 | + staircase=True, name='lr') | ||
38 | + learning_rate = tf.maximum(learning_rate, args.lr_clip) | ||
39 | + add_train_summary('learning_rate', learning_rate) | ||
40 | + else: | ||
41 | + learning_rate = tf.constant(args.base_lr, name='lr') | ||
42 | + train_summary = tf.summary.merge_all('train_summary') | ||
43 | + valid_summary = tf.summary.merge_all('valid_summary') | ||
44 | + | ||
45 | + trainer = tf.train.AdamOptimizer(learning_rate) | ||
46 | + train_op = trainer.minimize(model.loss, global_step) | ||
47 | + | ||
48 | + df_train, num_train = lmdb_dataflow( | ||
49 | + args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True) | ||
50 | + train_gen = df_train.get_data() | ||
51 | + df_valid, num_valid = lmdb_dataflow( | ||
52 | + args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False) | ||
53 | + valid_gen = df_valid.get_data() | ||
54 | + | ||
55 | + config = tf.ConfigProto() | ||
56 | + config.gpu_options.allow_growth = True | ||
57 | + config.allow_soft_placement = True | ||
58 | + sess = tf.Session(config=config) | ||
59 | + saver = tf.train.Saver() | ||
60 | + | ||
61 | + print('#########################################') | ||
62 | + print(args.restore) | ||
63 | + if args.restore: | ||
64 | + print('*************************restore******************************') | ||
65 | + saver.restore(sess, tf.train.latest_checkpoint(args.log_dir)) | ||
66 | + writer = tf.summary.FileWriter(args.log_dir) | ||
67 | + else: | ||
68 | + print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') | ||
69 | + sess.run(tf.global_variables_initializer()) | ||
70 | + if os.path.exists(args.log_dir): | ||
71 | + delete_key = input(colored('%s exists. Delete? [y (or enter)/N]' | ||
72 | + % args.log_dir, 'white', 'on_red')) | ||
73 | + if delete_key == 'y' or delete_key == "": | ||
74 | + os.system('rm -rf %s/*' % args.log_dir) | ||
75 | + os.makedirs(os.path.join(args.log_dir, 'plots')) | ||
76 | + else: | ||
77 | + os.makedirs(os.path.join(args.log_dir, 'plots')) | ||
78 | + with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log: | ||
79 | + for arg in sorted(vars(args)): | ||
80 | + log.write(arg + ': ' + str(getattr(args, arg)) + '\n') # log of arguments | ||
81 | + os.system('cp models/%s.py %s' % (args.model_type, args.log_dir)) # bkp of model def | ||
82 | + os.system('cp train.py %s' % args.log_dir) # bkp of train procedure | ||
83 | + writer = tf.summary.FileWriter(args.log_dir, sess.graph) | ||
84 | + | ||
85 | + total_time = 0 | ||
86 | + train_start = time.time() | ||
87 | + init_step = sess.run(global_step) | ||
88 | + for step in range(init_step+1, args.max_step+1): | ||
89 | + epoch = step * args.batch_size // num_train + 1 | ||
90 | + ids, inputs, npts, gt = next(train_gen) | ||
91 | + | ||
92 | + #split idx arr | ||
93 | + split_idx=[] | ||
94 | + idx=0 | ||
95 | + for num in npts[:-1]: | ||
96 | + idx+=num | ||
97 | + split_idx.append(idx) | ||
98 | + #print('split idx') | ||
99 | + #print(split_idx) | ||
100 | + | ||
101 | + max_pcd_size = np.max(npts) | ||
102 | + #print(npts) | ||
103 | + #print(max_pcd_size) | ||
104 | + | ||
105 | + ea_pcd = np.split(inputs[0],tuple(split_idx)) | ||
106 | + inputs_sep = np.array([x for x in ea_pcd]) | ||
107 | + my_inputs = np.array([resample_pcd(x,max_pcd_size) for x in inputs_sep]) | ||
108 | + | ||
109 | + #print(my_inputs.shape) | ||
110 | + | ||
111 | + | ||
112 | + start = time.time() | ||
113 | + feed_dict = {inputs_pl: inputs, my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}### | ||
114 | + _, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict) | ||
115 | + total_time += time.time() - start | ||
116 | + writer.add_summary(summary, step) | ||
117 | + if step % args.steps_per_print == 0: | ||
118 | + print('epoch %d step %d loss %.8f - time per batch %.4f' % | ||
119 | + (epoch, step, loss, total_time / args.steps_per_print)) | ||
120 | + total_time = 0 | ||
121 | + if step % args.steps_per_eval == 0: | ||
122 | + print(colored('Testing...', 'grey', 'on_green')) | ||
123 | + num_eval_steps = num_valid // args.batch_size | ||
124 | + total_loss = 0 | ||
125 | + total_time = 0 | ||
126 | + sess.run(tf.local_variables_initializer()) | ||
127 | + for i in range(num_eval_steps): | ||
128 | + start = time.time() | ||
129 | + ids, inputs, npts, gt = next(valid_gen) | ||
130 | + feed_dict = {inputs_pl: inputs,my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False} | ||
131 | + loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict) | ||
132 | + total_loss += loss | ||
133 | + total_time += time.time() - start | ||
134 | + summary = sess.run(valid_summary, feed_dict={is_training_pl: False}) | ||
135 | + writer.add_summary(summary, step) | ||
136 | + print(colored('epoch %d step %d loss %.8f - time per batch %.4f' % | ||
137 | + (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps), | ||
138 | + 'grey', 'on_green')) | ||
139 | + total_time = 0 | ||
140 | + if step % args.steps_per_visu == 0: | ||
141 | + all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict) | ||
142 | + for i in range(0, args.batch_size, args.visu_freq): | ||
143 | + plot_path = os.path.join(args.log_dir, 'plots', | ||
144 | + 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i])) | ||
145 | + pcds = [x[i] for x in all_pcds] | ||
146 | + plot_pcd_three_views(plot_path, pcds, model.visualize_titles) | ||
147 | + if step % args.steps_per_save == 0: | ||
148 | + saver.save(sess, os.path.join(args.log_dir, 'model'), step) | ||
149 | + print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue')) | ||
150 | + | ||
151 | + print('Total time', datetime.timedelta(seconds=time.time() - train_start)) | ||
152 | + sess.close() | ||
153 | + | ||
154 | + | ||
155 | +if __name__ == '__main__': | ||
156 | + parser = argparse.ArgumentParser() | ||
157 | + parser.add_argument('--lmdb_train', default='data/shapenet/train.lmdb') | ||
158 | + parser.add_argument('--lmdb_valid', default='data/shapenet/valid.lmdb') | ||
159 | + parser.add_argument('--log_dir', default='log/pcn_emd') | ||
160 | + parser.add_argument('--model_type', default='pcn_emd') | ||
161 | + parser.add_argument('--restore', action='store_true') | ||
162 | + parser.add_argument('--batch_size', type=int, default=32) | ||
163 | + parser.add_argument('--num_input_points', type=int, default=3000) | ||
164 | + parser.add_argument('--num_gt_points', type=int, default=16384) | ||
165 | + parser.add_argument('--base_lr', type=float, default=0.0001) | ||
166 | + parser.add_argument('--lr_decay', action='store_true') | ||
167 | + parser.add_argument('--lr_decay_steps', type=int, default=50000) | ||
168 | + parser.add_argument('--lr_decay_rate', type=float, default=0.7) | ||
169 | + parser.add_argument('--lr_clip', type=float, default=1e-6) | ||
170 | + parser.add_argument('--max_step', type=int, default=300000) | ||
171 | + parser.add_argument('--steps_per_print', type=int, default=100) | ||
172 | + parser.add_argument('--steps_per_eval', type=int, default=1000) | ||
173 | + parser.add_argument('--steps_per_visu', type=int, default=3000) | ||
174 | + parser.add_argument('--steps_per_save', type=int, default=100000) | ||
175 | + parser.add_argument('--visu_freq', type=int, default=5) | ||
176 | + args = parser.parse_args() | ||
177 | + | ||
178 | + train(args) |
-
Please register or login to post a comment