Showing
2 changed files
with
47 additions
and
46 deletions
... | @@ -10,6 +10,7 @@ from data_utils import get_batch_data | ... | @@ -10,6 +10,7 @@ from data_utils import get_batch_data |
10 | from misc_utils import parse_anchors, read_class_names, AverageMeter | 10 | from misc_utils import parse_anchors, read_class_names, AverageMeter |
11 | from eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec | 11 | from eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec |
12 | from nms_utils import gpu_nms | 12 | from nms_utils import gpu_nms |
13 | +from tfrecord_utils import TFRecordIterator | ||
13 | 14 | ||
14 | from model import yolov3 | 15 | from model import yolov3 |
15 | 16 | ... | ... |
... | @@ -57,7 +57,7 @@ for y in y_true: | ... | @@ -57,7 +57,7 @@ for y in y_true: |
57 | 57 | ||
58 | 58 | ||
59 | ### Model definition | 59 | ### Model definition |
60 | -yolo_model = yolov3(class_num, anchors, use_label_smooth, use_focal_loss, batch_norm_decay, weight_decay, use_static_shape=False) | 60 | +yolo_model = yolov3(args.class_num, args.anchors, args.use_label_smooth, args.use_focal_loss, args.batch_norm_decay, args.weight_decay, use_static_shape=False) |
61 | 61 | ||
62 | with tf.variable_scope('yolov3'): | 62 | with tf.variable_scope('yolov3'): |
63 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) | 63 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) |
... | @@ -67,14 +67,14 @@ y_pred = yolo_model.predict(pred_feature_maps) | ... | @@ -67,14 +67,14 @@ y_pred = yolo_model.predict(pred_feature_maps) |
67 | 67 | ||
68 | l2_loss = tf.losses.get_regularization_loss() | 68 | l2_loss = tf.losses.get_regularization_loss() |
69 | 69 | ||
70 | -saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=restore_include, exclude=restore_exclude)) | 70 | +saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=args.restore_include, exclude=args.restore_exclude)) |
71 | -update_vars = tf.contrib.framework.get_variables_to_restore(include=update_part) | 71 | +update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part) |
72 | 72 | ||
73 | 73 | ||
74 | global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) | 74 | global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) |
75 | -if use_warm_up: | 75 | +if args.use_warm_up: |
76 | - learning_rate = tf.cond(tf.less(global_step, train_batch_num * warm_up_epoch), | 76 | + learning_rate = tf.cond(tf.less(global_step, args.train_batch_num * args.warm_up_epoch), |
77 | - lambda: learning_rate_init * global_step / (train_batch_num * warm_up_epoch), | 77 | + lambda: args.learning_rate_init * args.global_step / (args.train_batch_num * args.warm_up_epoch), |
78 | lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch)) | 78 | lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch)) |
79 | else: | 79 | else: |
80 | learning_rate = config_learning_rate(args, global_step) | 80 | learning_rate = config_learning_rate(args, global_step) |
... | @@ -190,43 +190,43 @@ with tf.Session() as sess: | ... | @@ -190,43 +190,43 @@ with tf.Session() as sess: |
190 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | 190 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) |
191 | 191 | ||
192 | ## all epoches end | 192 | ## all epoches end |
193 | - sess.run(val_init_op) | ||
194 | - | ||
195 | - val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||
196 | - | ||
197 | - val_preds = [] | ||
198 | - | ||
199 | - for j in trange(val_img_cnt): | ||
200 | - __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], | ||
201 | - feed_dict={is_training: False}) | ||
202 | - pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) | ||
203 | - val_preds.extend(pred_content) | ||
204 | - val_loss_total.update(__loss[0]) | ||
205 | - val_loss_xy.update(__loss[1]) | ||
206 | - val_loss_wh.update(__loss[2]) | ||
207 | - val_loss_conf.update(__loss[3]) | ||
208 | - val_loss_class.update(__loss[4]) | ||
209 | - | ||
210 | - # calc mAP | ||
211 | - rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() | ||
212 | - gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize) | ||
213 | - | ||
214 | - info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) | ||
215 | - | ||
216 | - for ii in range(class_num): | ||
217 | - npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric) | ||
218 | - info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) | ||
219 | - rec_total.update(rec, npos) | ||
220 | - prec_total.update(prec, nd) | ||
221 | - ap_total.update(ap, 1) | ||
222 | - | ||
223 | - mAP = ap_total.average | ||
224 | - info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP) | ||
225 | - info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format( | ||
226 | - val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) | ||
227 | - print(info) | ||
228 | - | ||
229 | - if save_optimizer and mAP > best_mAP: | ||
230 | - best_mAP = mAP | ||
231 | - saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | ||
232 | - epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
193 | + sess.run(val_init_op) | ||
194 | + | ||
195 | + val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||
196 | + | ||
197 | + val_preds = [] | ||
198 | + | ||
199 | + for j in trange(args.val_img_cnt): | ||
200 | + __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], | ||
201 | + feed_dict={is_training: False}) | ||
202 | + pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) | ||
203 | + val_preds.extend(pred_content) | ||
204 | + val_loss_total.update(__loss[0]) | ||
205 | + val_loss_xy.update(__loss[1]) | ||
206 | + val_loss_wh.update(__loss[2]) | ||
207 | + val_loss_conf.update(__loss[3]) | ||
208 | + val_loss_class.update(__loss[4]) | ||
209 | + | ||
210 | + # calc mAP | ||
211 | + rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() | ||
212 | + gt_dict = parse_gt_rec(args.val_file, 'GZIP', args.img_size, args.letterbox_resize) | ||
213 | + | ||
214 | + info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) | ||
215 | + | ||
216 | + for ii in range(args.class_num): | ||
217 | + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=args.use_voc_07_metric) | ||
218 | + info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) | ||
219 | + rec_total.update(rec, npos) | ||
220 | + prec_total.update(prec, nd) | ||
221 | + ap_total.update(ap, 1) | ||
222 | + | ||
223 | + mAP = ap_total.average | ||
224 | + info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP) | ||
225 | + info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format( | ||
226 | + val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) | ||
227 | + print(info) | ||
228 | + | ||
229 | + if args.save_optimizer and mAP > best_mAP: | ||
230 | + best_mAP = mAP | ||
231 | + saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | ||
232 | + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment