Showing
4 changed files
with
36 additions
and
6 deletions
code/yolov3/convert_weights.py
0 → 100644
1 | +from __future__ import division, print_function | ||
2 | + | ||
3 | +import os | ||
4 | +import sys | ||
5 | +import tensorflow as tf | ||
6 | +import numpy as np | ||
7 | + | ||
8 | +from model import yolov3 | ||
9 | +from misc_utils import parse_anchors, load_weights | ||
10 | + | ||
11 | +img_size = 416 | ||
12 | +weight_path = '../../data/darknet_weights/yolov3.weights' | ||
13 | +save_path = '../../data/darknet_weights/yolov3.ckpt' | ||
14 | +anchors = parse_anchors('../../data/yolo_anchors.txt') | ||
15 | + | ||
16 | +model = yolov3(80, anchors) | ||
17 | +with tf.Session() as sess: | ||
18 | + inputs = tf.placeholder(tf.float32, [1, img_size, img_size, 3]) | ||
19 | + | ||
20 | + with tf.variable_scope('yolov3'): | ||
21 | + feature_map = model.forward(inputs) | ||
22 | + | ||
23 | + saver = tf.train.Saver(var_list=tf.global_variables(scope='yolov3')) | ||
24 | + | ||
25 | + load_ops = load_weights(tf.global_variables(scope='yolov3'), weight_path) | ||
26 | + sess.run(load_ops) | ||
27 | + saver.save(sess, save_path=save_path) | ||
28 | + print('TensorFlow model checkpoint has been saved to {}'.format(save_path)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -97,9 +97,9 @@ saver_to_restore = tf.train.Saver() | ... | @@ -97,9 +97,9 @@ saver_to_restore = tf.train.Saver() |
97 | 97 | ||
98 | with tf.Session() as sess: | 98 | with tf.Session() as sess: |
99 | sess.run([tf.global_variables_initializer()]) | 99 | sess.run([tf.global_variables_initializer()]) |
100 | - if os.path.exists(args.restore_path): | 100 | + try: |
101 | saver_to_restore.restore(sess, args.restore_path) | 101 | saver_to_restore.restore(sess, args.restore_path) |
102 | - else: | 102 | + except: |
103 | raise ValueError('there is no model to evaluate. You should move/create the checkpoint file to restore path') | 103 | raise ValueError('there is no model to evaluate. You should move/create the checkpoint file to restore path') |
104 | 104 | ||
105 | print('\nStart evaluation...\n') | 105 | print('\nStart evaluation...\n') | ... | ... |
... | @@ -102,8 +102,12 @@ else: | ... | @@ -102,8 +102,12 @@ else: |
102 | with tf.Session() as sess: | 102 | with tf.Session() as sess: |
103 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) | 103 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) |
104 | 104 | ||
105 | - if os.path.exists(args.restore_path): | 105 | + try: |
106 | - saver_to_restore.restore(sess, args.restore_path) | 106 | + saver_to_restore.restore(sess, restore_path) |
107 | + print("Restoring parameters...") | ||
108 | + except: | ||
109 | + print("*** Failed to restore parameters!!! You would need pretrained weights ***") | ||
110 | + | ||
107 | 111 | ||
108 | print('\nStart training...: Total epoches =', args.total_epoches, '\n') | 112 | print('\nStart training...: Total epoches =', args.total_epoches, '\n') |
109 | 113 | ||
... | @@ -184,7 +188,6 @@ with tf.Session() as sess: | ... | @@ -184,7 +188,6 @@ with tf.Session() as sess: |
184 | best_mAP = mAP | 188 | best_mAP = mAP |
185 | saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | 189 | saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( |
186 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | 190 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) |
187 | - saver_to_restore.save(sess, restore_path) | ||
188 | 191 | ||
189 | ## all epoches end | 192 | ## all epoches end |
190 | sess.run(val_init_op) | 193 | sess.run(val_init_op) |
... | @@ -227,4 +230,3 @@ with tf.Session() as sess: | ... | @@ -227,4 +230,3 @@ with tf.Session() as sess: |
227 | best_mAP = mAP | 230 | best_mAP = mAP |
228 | saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | 231 | saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( |
229 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | 232 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) |
... | \ No newline at end of file | ... | \ No newline at end of file |
230 | - saver_to_restore.save(sess, restore_path) | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
This diff could not be displayed because it is too large.
-
Please register or login to post a comment