Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design2
/
2016104124
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
박해연
2020-06-15 23:32:06 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
1fc1ee986c4b784dead941be65ea8709f70d400b
1fc1ee98
1 parent
03a2f9e5
add code
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
178 additions
and
0 deletions
code/pcn_modify/pcn/train.py
code/pcn_modify/pcn/train.py
0 → 100644
View file @
1fc1ee9
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018
import
argparse
import
datetime
import
importlib
import
models
import
os
import
tensorflow
as
tf
import
time
from
data_util
import
lmdb_dataflow
,
get_queued_data
,
resample_pcd
from
termcolor
import
colored
from
tf_util
import
add_train_summary
from
visu_util
import
plot_pcd_three_views
import
numpy
as
np
def
train
(
args
):
is_training_pl
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
(),
name
=
'is_training'
)
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
'global_step'
)
alpha
=
tf
.
train
.
piecewise_constant
(
global_step
,
[
3000
,
6000
,
15000
],
[
0.01
,
0.1
,
0.5
,
1.0
],
'alpha_op'
)
#beta = tf.train.piecewise_constant(global_step, [6000, 15000, 30000],
# [0.01, 0.1, 0.5, 1.0], 'beta_op')
beta
=
tf
.
constant
(
1.0
)
inputs_pl
=
tf
.
placeholder
(
tf
.
float32
,
(
1
,
None
,
3
),
'inputs'
)
my_inputs_pl
=
tf
.
placeholder
(
tf
.
float32
,(
args
.
batch_size
,
None
,
3
),
'my_inputs'
)
####
npts_pl
=
tf
.
placeholder
(
tf
.
int32
,
(
args
.
batch_size
,),
'num_points'
)
gt_pl
=
tf
.
placeholder
(
tf
.
float32
,
(
args
.
batch_size
,
args
.
num_gt_points
,
3
),
'ground_truths'
)
model_module
=
importlib
.
import_module
(
'.
%
s'
%
args
.
model_type
,
'models'
)
model
=
model_module
.
Model
(
inputs_pl
,
my_inputs_pl
,
npts_pl
,
gt_pl
,
alpha
,
beta
)
add_train_summary
(
'alpha'
,
alpha
)
add_train_summary
(
'beta'
,
beta
)
if
args
.
lr_decay
:
learning_rate
=
tf
.
train
.
exponential_decay
(
args
.
base_lr
,
global_step
,
args
.
lr_decay_steps
,
args
.
lr_decay_rate
,
staircase
=
True
,
name
=
'lr'
)
learning_rate
=
tf
.
maximum
(
learning_rate
,
args
.
lr_clip
)
add_train_summary
(
'learning_rate'
,
learning_rate
)
else
:
learning_rate
=
tf
.
constant
(
args
.
base_lr
,
name
=
'lr'
)
train_summary
=
tf
.
summary
.
merge_all
(
'train_summary'
)
valid_summary
=
tf
.
summary
.
merge_all
(
'valid_summary'
)
trainer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
)
train_op
=
trainer
.
minimize
(
model
.
loss
,
global_step
)
df_train
,
num_train
=
lmdb_dataflow
(
args
.
lmdb_train
,
args
.
batch_size
,
args
.
num_input_points
,
args
.
num_gt_points
,
is_training
=
True
)
train_gen
=
df_train
.
get_data
()
df_valid
,
num_valid
=
lmdb_dataflow
(
args
.
lmdb_valid
,
args
.
batch_size
,
args
.
num_input_points
,
args
.
num_gt_points
,
is_training
=
False
)
valid_gen
=
df_valid
.
get_data
()
config
=
tf
.
ConfigProto
()
config
.
gpu_options
.
allow_growth
=
True
config
.
allow_soft_placement
=
True
sess
=
tf
.
Session
(
config
=
config
)
saver
=
tf
.
train
.
Saver
()
print
(
'#########################################'
)
print
(
args
.
restore
)
if
args
.
restore
:
print
(
'*************************restore******************************'
)
saver
.
restore
(
sess
,
tf
.
train
.
latest_checkpoint
(
args
.
log_dir
))
writer
=
tf
.
summary
.
FileWriter
(
args
.
log_dir
)
else
:
print
(
'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
)
sess
.
run
(
tf
.
global_variables_initializer
())
if
os
.
path
.
exists
(
args
.
log_dir
):
delete_key
=
input
(
colored
(
'
%
s exists. Delete? [y (or enter)/N]'
%
args
.
log_dir
,
'white'
,
'on_red'
))
if
delete_key
==
'y'
or
delete_key
==
""
:
os
.
system
(
'rm -rf
%
s/*'
%
args
.
log_dir
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
log_dir
,
'plots'
))
else
:
os
.
makedirs
(
os
.
path
.
join
(
args
.
log_dir
,
'plots'
))
with
open
(
os
.
path
.
join
(
args
.
log_dir
,
'args.txt'
),
'w'
)
as
log
:
for
arg
in
sorted
(
vars
(
args
)):
log
.
write
(
arg
+
': '
+
str
(
getattr
(
args
,
arg
))
+
'
\n
'
)
# log of arguments
os
.
system
(
'cp models/
%
s.py
%
s'
%
(
args
.
model_type
,
args
.
log_dir
))
# bkp of model def
os
.
system
(
'cp train.py
%
s'
%
args
.
log_dir
)
# bkp of train procedure
writer
=
tf
.
summary
.
FileWriter
(
args
.
log_dir
,
sess
.
graph
)
total_time
=
0
train_start
=
time
.
time
()
init_step
=
sess
.
run
(
global_step
)
for
step
in
range
(
init_step
+
1
,
args
.
max_step
+
1
):
epoch
=
step
*
args
.
batch_size
//
num_train
+
1
ids
,
inputs
,
npts
,
gt
=
next
(
train_gen
)
#split idx arr
split_idx
=
[]
idx
=
0
for
num
in
npts
[:
-
1
]:
idx
+=
num
split_idx
.
append
(
idx
)
#print('split idx')
#print(split_idx)
max_pcd_size
=
np
.
max
(
npts
)
#print(npts)
#print(max_pcd_size)
ea_pcd
=
np
.
split
(
inputs
[
0
],
tuple
(
split_idx
))
inputs_sep
=
np
.
array
([
x
for
x
in
ea_pcd
])
my_inputs
=
np
.
array
([
resample_pcd
(
x
,
max_pcd_size
)
for
x
in
inputs_sep
])
#print(my_inputs.shape)
start
=
time
.
time
()
feed_dict
=
{
inputs_pl
:
inputs
,
my_inputs_pl
:
my_inputs
,
npts_pl
:
npts
,
gt_pl
:
gt
,
is_training_pl
:
True
}
###
_
,
loss
,
summary
=
sess
.
run
([
train_op
,
model
.
loss
,
train_summary
],
feed_dict
=
feed_dict
)
total_time
+=
time
.
time
()
-
start
writer
.
add_summary
(
summary
,
step
)
if
step
%
args
.
steps_per_print
==
0
:
print
(
'epoch
%
d step
%
d loss
%.8
f - time per batch
%.4
f'
%
(
epoch
,
step
,
loss
,
total_time
/
args
.
steps_per_print
))
total_time
=
0
if
step
%
args
.
steps_per_eval
==
0
:
print
(
colored
(
'Testing...'
,
'grey'
,
'on_green'
))
num_eval_steps
=
num_valid
//
args
.
batch_size
total_loss
=
0
total_time
=
0
sess
.
run
(
tf
.
local_variables_initializer
())
for
i
in
range
(
num_eval_steps
):
start
=
time
.
time
()
ids
,
inputs
,
npts
,
gt
=
next
(
valid_gen
)
feed_dict
=
{
inputs_pl
:
inputs
,
my_inputs_pl
:
my_inputs
,
npts_pl
:
npts
,
gt_pl
:
gt
,
is_training_pl
:
False
}
loss
,
_
=
sess
.
run
([
model
.
loss
,
model
.
update
],
feed_dict
=
feed_dict
)
total_loss
+=
loss
total_time
+=
time
.
time
()
-
start
summary
=
sess
.
run
(
valid_summary
,
feed_dict
=
{
is_training_pl
:
False
})
writer
.
add_summary
(
summary
,
step
)
print
(
colored
(
'epoch
%
d step
%
d loss
%.8
f - time per batch
%.4
f'
%
(
epoch
,
step
,
total_loss
/
num_eval_steps
,
total_time
/
num_eval_steps
),
'grey'
,
'on_green'
))
total_time
=
0
if
step
%
args
.
steps_per_visu
==
0
:
all_pcds
=
sess
.
run
(
model
.
visualize_ops
,
feed_dict
=
feed_dict
)
for
i
in
range
(
0
,
args
.
batch_size
,
args
.
visu_freq
):
plot_path
=
os
.
path
.
join
(
args
.
log_dir
,
'plots'
,
'epoch_
%
d_step_
%
d_
%
s.png'
%
(
epoch
,
step
,
ids
[
i
]))
pcds
=
[
x
[
i
]
for
x
in
all_pcds
]
plot_pcd_three_views
(
plot_path
,
pcds
,
model
.
visualize_titles
)
if
step
%
args
.
steps_per_save
==
0
:
saver
.
save
(
sess
,
os
.
path
.
join
(
args
.
log_dir
,
'model'
),
step
)
print
(
colored
(
'Model saved at
%
s'
%
args
.
log_dir
,
'white'
,
'on_blue'
))
print
(
'Total time'
,
datetime
.
timedelta
(
seconds
=
time
.
time
()
-
train_start
))
sess
.
close
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--lmdb_train'
,
default
=
'data/shapenet/train.lmdb'
)
parser
.
add_argument
(
'--lmdb_valid'
,
default
=
'data/shapenet/valid.lmdb'
)
parser
.
add_argument
(
'--log_dir'
,
default
=
'log/pcn_emd'
)
parser
.
add_argument
(
'--model_type'
,
default
=
'pcn_emd'
)
parser
.
add_argument
(
'--restore'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--num_input_points'
,
type
=
int
,
default
=
3000
)
parser
.
add_argument
(
'--num_gt_points'
,
type
=
int
,
default
=
16384
)
parser
.
add_argument
(
'--base_lr'
,
type
=
float
,
default
=
0.0001
)
parser
.
add_argument
(
'--lr_decay'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--lr_decay_steps'
,
type
=
int
,
default
=
50000
)
parser
.
add_argument
(
'--lr_decay_rate'
,
type
=
float
,
default
=
0.7
)
parser
.
add_argument
(
'--lr_clip'
,
type
=
float
,
default
=
1e-6
)
parser
.
add_argument
(
'--max_step'
,
type
=
int
,
default
=
300000
)
parser
.
add_argument
(
'--steps_per_print'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'--steps_per_eval'
,
type
=
int
,
default
=
1000
)
parser
.
add_argument
(
'--steps_per_visu'
,
type
=
int
,
default
=
3000
)
parser
.
add_argument
(
'--steps_per_save'
,
type
=
int
,
default
=
100000
)
parser
.
add_argument
(
'--visu_freq'
,
type
=
int
,
default
=
5
)
args
=
parser
.
parse_args
()
train
(
args
)
Please
register
or
login
to post a comment