김성주

fix/update for pretrained transfer learning

1 +from __future__ import division, print_function
2 +
3 +import os
4 +import sys
5 +import tensorflow as tf
6 +import numpy as np
7 +
8 +from model import yolov3
9 +from misc_utils import parse_anchors, load_weights
10 +
11 +img_size = 416
12 +weight_path = '../../data/darknet_weights/yolov3.weights'
13 +save_path = '../../data/darknet_weights/yolov3.ckpt'
14 +anchors = parse_anchors('../../data/yolo_anchors.txt')
15 +
16 +model = yolov3(80, anchors)
17 +with tf.Session() as sess:
18 + inputs = tf.placeholder(tf.float32, [1, img_size, img_size, 3])
19 +
20 + with tf.variable_scope('yolov3'):
21 + feature_map = model.forward(inputs)
22 +
23 + saver = tf.train.Saver(var_list=tf.global_variables(scope='yolov3'))
24 +
25 + load_ops = load_weights(tf.global_variables(scope='yolov3'), weight_path)
26 + sess.run(load_ops)
27 + saver.save(sess, save_path=save_path)
28 + print('TensorFlow model checkpoint has been saved to {}'.format(save_path))
...\ No newline at end of file ...\ No newline at end of file
...@@ -97,9 +97,9 @@ saver_to_restore = tf.train.Saver() ...@@ -97,9 +97,9 @@ saver_to_restore = tf.train.Saver()
97 97
98 with tf.Session() as sess: 98 with tf.Session() as sess:
99 sess.run([tf.global_variables_initializer()]) 99 sess.run([tf.global_variables_initializer()])
100 - if os.path.exists(args.restore_path): 100 + try:
101 saver_to_restore.restore(sess, args.restore_path) 101 saver_to_restore.restore(sess, args.restore_path)
102 - else: 102 + except:
103 raise ValueError('there is no model to evaluate. You should move/create the checkpoint file to restore path') 103 raise ValueError('there is no model to evaluate. You should move/create the checkpoint file to restore path')
104 104
105 print('\nStart evaluation...\n') 105 print('\nStart evaluation...\n')
......
...@@ -102,8 +102,12 @@ else: ...@@ -102,8 +102,12 @@ else:
102 with tf.Session() as sess: 102 with tf.Session() as sess:
103 sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 103 sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
104 104
105 - if os.path.exists(args.restore_path): 105 + try:
106 - saver_to_restore.restore(sess, args.restore_path) 106 + saver_to_restore.restore(sess, restore_path)
107 + print("Restoring parameters...")
108 + except:
109 + print("*** Failed to restore parameters!!! You would need pretrained weights ***")
110 +
107 111
108 print('\nStart training...: Total epoches =', args.total_epoches, '\n') 112 print('\nStart training...: Total epoches =', args.total_epoches, '\n')
109 113
...@@ -184,7 +188,6 @@ with tf.Session() as sess: ...@@ -184,7 +188,6 @@ with tf.Session() as sess:
184 best_mAP = mAP 188 best_mAP = mAP
185 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( 189 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
186 epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) 190 epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
187 - saver_to_restore.save(sess, restore_path)
188 191
189 ## all epoches end 192 ## all epoches end
190 sess.run(val_init_op) 193 sess.run(val_init_op)
...@@ -227,4 +230,3 @@ with tf.Session() as sess: ...@@ -227,4 +230,3 @@ with tf.Session() as sess:
227 best_mAP = mAP 230 best_mAP = mAP
228 saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( 231 saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
229 epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) 232 epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
...\ No newline at end of file ...\ No newline at end of file
230 - saver_to_restore.save(sess, restore_path)
...\ No newline at end of file ...\ No newline at end of file
......
This diff could not be displayed because it is too large.