Toggle navigation
Toggle navigation
This project
Loading...
Sign in
최강혁
/
dddd
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Graphs
Network
Create a new issue
Commits
Issue Boards
Authored by
yunjey
2017-01-22 18:04:34 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
bbb63e103bbd2aa99d1e75fa50a2319d3376193d
bbb63e10
1 parent
249b708c
train and eval the model
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
7 deletions
solver.py
solver.py
View file @
bbb63e1
...
...
@@ -9,9 +9,9 @@ import scipy.misc
class
Solver
(
object
):
def
__init__
(
self
,
model
,
batch_size
=
100
,
pretrain_iter
=
1
0000
,
train_iter
=
2000
,
sample_iter
=
100
,
def
__init__
(
self
,
model
,
batch_size
=
100
,
pretrain_iter
=
2
0000
,
train_iter
=
2000
,
sample_iter
=
100
,
svhn_dir
=
'svhn'
,
mnist_dir
=
'mnist'
,
log_dir
=
'logs'
,
sample_save_path
=
'sample'
,
model_save_path
=
'model'
,
pretrained_model
=
'model/svhn_model-
10000'
,
test_model
=
'model/dtn-20
00'
):
model_save_path
=
'model'
,
pretrained_model
=
'model/svhn_model-
20000'
,
test_model
=
'model/dtn-6
00'
):
self
.
model
=
model
self
.
batch_size
=
batch_size
...
...
@@ -111,7 +111,7 @@ class Solver(object):
model
=
self
.
model
model
.
build_model
()
# make
log
directory if not exists
# make directory if not exists
if
tf
.
gfile
.
Exists
(
self
.
log_dir
):
tf
.
gfile
.
DeleteRecursively
(
self
.
log_dir
)
tf
.
gfile
.
MakeDirs
(
self
.
log_dir
)
...
...
@@ -121,13 +121,16 @@ class Solver(object):
tf
.
global_variables_initializer
()
.
run
()
# restore variables of F
print
(
'loading pretrained model F..'
)
variables_to_restore
=
slim
.
get_model_variables
(
scope
=
'content_extractor'
)
restorer
=
tf
.
train
.
Saver
(
variables_to_restore
)
restorer
.
restore
(
sess
,
self
.
pretrained_model
)
#variables_to_restore = slim.get_model_variables(scope='content_extractor')
#restorer = tf.train.Saver(variables_to_restore)
#restorer.restore(sess, self.pretrained_model)
restorer
=
tf
.
train
.
Saver
()
restorer
.
restore
(
sess
,
'model/dtn-1600'
)
summary_writer
=
tf
.
summary
.
FileWriter
(
logdir
=
self
.
log_dir
,
graph
=
tf
.
get_default_graph
())
saver
=
tf
.
train
.
Saver
()
print
(
'start training..!'
)
f_interval
=
15
for
step
in
range
(
self
.
train_iter
+
1
):
i
=
step
%
int
(
svhn_images
.
shape
[
0
]
/
self
.
batch_size
)
...
...
@@ -143,7 +146,10 @@ class Solver(object):
sess
.
run
([
model
.
g_train_op_src
],
feed_dict
)
sess
.
run
([
model
.
g_train_op_src
],
feed_dict
)
if
i
%
15
==
0
:
if
step
>
1600
:
f_interval
=
30
if
i
%
f_interval
==
0
:
sess
.
run
(
model
.
f_train_op_src
,
feed_dict
)
if
(
step
+
1
)
%
10
==
0
:
...
...
Please
register
or
login
to post a comment