Showing
2 changed files
with
13 additions
and
12 deletions
... | @@ -10,6 +10,7 @@ from data_utils import get_batch_data | ... | @@ -10,6 +10,7 @@ from data_utils import get_batch_data |
10 | from misc_utils import parse_anchors, read_class_names, AverageMeter | 10 | from misc_utils import parse_anchors, read_class_names, AverageMeter |
11 | from eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec | 11 | from eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec |
12 | from nms_utils import gpu_nms | 12 | from nms_utils import gpu_nms |
13 | +from tfrecord_utils import TFRecordIterator | ||
13 | 14 | ||
14 | from model import yolov3 | 15 | from model import yolov3 |
15 | 16 | ... | ... |
... | @@ -57,7 +57,7 @@ for y in y_true: | ... | @@ -57,7 +57,7 @@ for y in y_true: |
57 | 57 | ||
58 | 58 | ||
59 | ### Model definition | 59 | ### Model definition |
60 | -yolo_model = yolov3(class_num, anchors, use_label_smooth, use_focal_loss, batch_norm_decay, weight_decay, use_static_shape=False) | 60 | +yolo_model = yolov3(args.class_num, args.anchors, args.use_label_smooth, args.use_focal_loss, args.batch_norm_decay, args.weight_decay, use_static_shape=False) |
61 | 61 | ||
62 | with tf.variable_scope('yolov3'): | 62 | with tf.variable_scope('yolov3'): |
63 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) | 63 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) |
... | @@ -67,14 +67,14 @@ y_pred = yolo_model.predict(pred_feature_maps) | ... | @@ -67,14 +67,14 @@ y_pred = yolo_model.predict(pred_feature_maps) |
67 | 67 | ||
68 | l2_loss = tf.losses.get_regularization_loss() | 68 | l2_loss = tf.losses.get_regularization_loss() |
69 | 69 | ||
70 | -saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=restore_include, exclude=restore_exclude)) | 70 | +saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=args.restore_include, exclude=args.restore_exclude)) |
71 | -update_vars = tf.contrib.framework.get_variables_to_restore(include=update_part) | 71 | +update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part) |
72 | 72 | ||
73 | 73 | ||
74 | global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) | 74 | global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) |
75 | -if use_warm_up: | 75 | +if args.use_warm_up: |
76 | - learning_rate = tf.cond(tf.less(global_step, train_batch_num * warm_up_epoch), | 76 | + learning_rate = tf.cond(tf.less(global_step, args.train_batch_num * args.warm_up_epoch), |
77 | - lambda: learning_rate_init * global_step / (train_batch_num * warm_up_epoch), | 77 | + lambda: args.learning_rate_init * args.global_step / (args.train_batch_num * args.warm_up_epoch), |
78 | lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch)) | 78 | lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch)) |
79 | else: | 79 | else: |
80 | learning_rate = config_learning_rate(args, global_step) | 80 | learning_rate = config_learning_rate(args, global_step) |
... | @@ -196,7 +196,7 @@ with tf.Session() as sess: | ... | @@ -196,7 +196,7 @@ with tf.Session() as sess: |
196 | 196 | ||
197 | val_preds = [] | 197 | val_preds = [] |
198 | 198 | ||
199 | - for j in trange(val_img_cnt): | 199 | + for j in trange(args.val_img_cnt): |
200 | __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], | 200 | __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], |
201 | feed_dict={is_training: False}) | 201 | feed_dict={is_training: False}) |
202 | pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) | 202 | pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) |
... | @@ -209,12 +209,12 @@ with tf.Session() as sess: | ... | @@ -209,12 +209,12 @@ with tf.Session() as sess: |
209 | 209 | ||
210 | # calc mAP | 210 | # calc mAP |
211 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() | 211 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() |
212 | - gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize) | 212 | + gt_dict = parse_gt_rec(args.val_file, 'GZIP', args.img_size, args.letterbox_resize) |
213 | 213 | ||
214 | info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) | 214 | info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) |
215 | 215 | ||
216 | - for ii in range(class_num): | 216 | + for ii in range(args.class_num): |
217 | - npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric) | 217 | + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=args.use_voc_07_metric) |
218 | info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) | 218 | info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) |
219 | rec_total.update(rec, npos) | 219 | rec_total.update(rec, npos) |
220 | prec_total.update(prec, nd) | 220 | prec_total.update(prec, nd) |
... | @@ -226,7 +226,7 @@ with tf.Session() as sess: | ... | @@ -226,7 +226,7 @@ with tf.Session() as sess: |
226 | val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) | 226 | val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) |
227 | print(info) | 227 | print(info) |
228 | 228 | ||
229 | - if save_optimizer and mAP > best_mAP: | 229 | + if args.save_optimizer and mAP > best_mAP: |
230 | best_mAP = mAP | 230 | best_mAP = mAP |
231 | - saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | 231 | + saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( |
232 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | 232 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment