debug.py
13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
#!/usr/bin/env python
"""
Copyright 2017-2018 Fizyr (https://fizyr.com)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import os
import sys
import cv2
# Set keycodes for changing images
# 81, 83 are left and right arrows on linux in Ascii code (probably not needed)
# 65361, 65363 are left and right arrows in linux
# 2424832, 2555904 are left and right arrows on Windows
# 110, 109 are 'n' and 'm' on mac, windows, linux
# (unfortunately arrow keys not picked up on mac)
leftkeys = (81, 110, 65361, 2424832)
rightkeys = (83, 109, 65363, 2555904)
# Allow relative imports when being executed as script.
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
import keras_retinanet.bin # noqa: F401
__package__ = "keras_retinanet.bin"
# Change these to absolute imports if you copy this script outside the keras_retinanet package.
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.kitti import KittiGenerator
from ..preprocessing.open_images import OpenImagesGenerator
from ..utils.anchors import anchors_for_shape, compute_gt_annotations
from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels
from ..utils.image import random_visual_effect_generator
from ..utils.tf_version import check_tf_version
from ..utils.transform import random_transform_generator
from ..utils.visualization import draw_annotations, draw_boxes, draw_caption
def create_generator(args):
""" Create the data generators.
Args:
args: parseargs arguments object.
"""
common_args = {
'config' : args.config,
'image_min_side' : args.image_min_side,
'image_max_side' : args.image_max_side,
'group_method' : args.group_method
}
# create random transform generator for augmenting training data
transform_generator = random_transform_generator(
min_rotation=-0.1,
max_rotation=0.1,
min_translation=(-0.1, -0.1),
max_translation=(0.1, 0.1),
min_shear=-0.1,
max_shear=0.1,
min_scaling=(0.9, 0.9),
max_scaling=(1.1, 1.1),
flip_x_chance=0.5,
flip_y_chance=0.5,
)
visual_effect_generator = random_visual_effect_generator(
contrast_range=(0.9, 1.1),
brightness_range=(-.1, .1),
hue_range=(-0.05, 0.05),
saturation_range=(0.95, 1.05)
)
if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator
generator = CocoGenerator(
args.coco_path,
args.coco_set,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)
elif args.dataset_type == 'pascal':
generator = PascalVocGenerator(
args.pascal_path,
args.pascal_set,
image_extension=args.image_extension,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)
elif args.dataset_type == 'csv':
generator = CSVGenerator(
args.annotations,
args.classes,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)
elif args.dataset_type == 'oid':
generator = OpenImagesGenerator(
args.main_dir,
subset=args.subset,
version=args.version,
labels_filter=args.labels_filter,
parent_label=args.parent_label,
annotation_cache_dir=args.annotation_cache_dir,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)
elif args.dataset_type == 'kitti':
generator = KittiGenerator(
args.kitti_path,
subset=args.subset,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
return generator
def parse_args(args):
""" Parse the arguments.
"""
parser = argparse.ArgumentParser(description='Debug script for a RetinaNet network.')
subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
subparsers.required = True
coco_parser = subparsers.add_parser('coco')
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
coco_parser.add_argument('--coco-set', help='Name of the set to show (defaults to val2017).', default='val2017')
pascal_parser = subparsers.add_parser('pascal')
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
pascal_parser.add_argument('--pascal-set', help='Name of the set to show (defaults to test).', default='test')
pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg')
kitti_parser = subparsers.add_parser('kitti')
kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).')
kitti_parser.add_argument('subset', help='Argument for loading a subset from train/val.')
def csv_list(string):
return string.split(',')
oid_parser = subparsers.add_parser('oid')
oid_parser.add_argument('main_dir', help='Path to dataset directory.')
oid_parser.add_argument('subset', help='Argument for loading a subset from train/validation/test.')
oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4')
oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None)
oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.')
oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None)
csv_parser = subparsers.add_parser('csv')
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.')
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
parser.add_argument('--no-resize', help='Disable image resizing.', dest='resize', action='store_false')
parser.add_argument('--anchors', help='Show positive anchors on the image.', action='store_true')
parser.add_argument('--display-name', help='Display image name on the bottom left corner.', action='store_true')
parser.add_argument('--show-annotations', help='Show annotations on the image. Green annotations have anchors, red annotations don\'t and therefore don\'t contribute to training.', action='store_true')
parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true')
parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800)
parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333)
parser.add_argument('--config', help='Path to a configuration parameters .ini file.')
parser.add_argument('--no-gui', help='Do not open a GUI window. Save images to an output directory instead.', action='store_true')
parser.add_argument('--output-dir', help='The output directory to save images to if --no-gui is specified.', default='.')
parser.add_argument('--flatten-output', help='Flatten the folder structure of saved output images into a single folder.', action='store_true')
parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio'])
return parser.parse_args(args)
def run(generator, args, anchor_params, pyramid_levels):
""" Main loop.
Args
generator: The generator to debug.
args: parseargs args object.
"""
# display images, one at a time
i = 0
while True:
# load the data
image = generator.load_image(i)
annotations = generator.load_annotations(i)
if len(annotations['labels']) > 0 :
# apply random transformations
if args.random_transform:
image, annotations = generator.random_transform_group_entry(image, annotations)
image, annotations = generator.random_visual_effect_group_entry(image, annotations)
# resize the image and annotations
if args.resize:
image, image_scale = generator.resize_image(image)
annotations['bboxes'] *= image_scale
anchors = anchors_for_shape(image.shape, anchor_params=anchor_params, pyramid_levels=pyramid_levels)
positive_indices, _, max_indices = compute_gt_annotations(anchors, annotations['bboxes'])
# draw anchors on the image
if args.anchors:
draw_boxes(image, anchors[positive_indices], (255, 255, 0), thickness=1)
# draw annotations on the image
if args.show_annotations:
# draw annotations in red
draw_annotations(image, annotations, color=(0, 0, 255), label_to_name=generator.label_to_name)
# draw regressed anchors in green to override most red annotations
# result is that annotations without anchors are red, with anchors are green
draw_boxes(image, annotations['bboxes'][max_indices[positive_indices], :], (0, 255, 0))
# display name on the image
if args.display_name:
draw_caption(image, [0, image.shape[0]], os.path.basename(generator.image_path(i)))
# write to file and advance if no-gui selected
if args.no_gui:
output_path = make_output_path(args.output_dir, generator.image_path(i), flatten=args.flatten_output)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cv2.imwrite(output_path, image)
i += 1
if i == generator.size(): # have written all images
break
else:
continue
# if we are using the GUI, then show an image
cv2.imshow('Image', image)
key = cv2.waitKeyEx()
# press right for next image and left for previous (linux or windows, doesn't work for macOS)
# if you run macOS, press "n" or "m" (will also work on linux and windows)
if key in rightkeys:
i = (i + 1) % generator.size()
if key in leftkeys:
i -= 1
if i < 0:
i = generator.size() - 1
# press q or Esc to quit
if (key == ord('q')) or (key == 27):
return False
return True
def make_output_path(output_dir, image_path, flatten = False):
""" Compute the output path for a debug image. """
# If the output hierarchy is flattened to a single folder, throw away all leading folders.
if flatten:
path = os.path.basename(image_path)
# Otherwise, make sure absolute paths are taken relative to the filesystem root.
else:
# Make sure to drop drive letters on Windows, otherwise relpath wil fail.
_, path = os.path.splitdrive(image_path)
if os.path.isabs(path):
path = os.path.relpath(path, '/')
# In all cases, append "_debug" to the filename, before the extension.
base, extension = os.path.splitext(path)
path = base + "_debug" + extension
# Finally, join the whole thing to the output directory.
return os.path.join(output_dir, path)
def main(args=None):
# parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)
# make sure tensorflow is the minimum required version
check_tf_version()
# create the generator
generator = create_generator(args)
# optionally load config parameters
if args.config:
args.config = read_config_file(args.config)
# optionally load anchor parameters
anchor_params = None
if args.config and 'anchor_parameters' in args.config:
anchor_params = parse_anchor_parameters(args.config)
pyramid_levels = None
if args.config and 'pyramid_levels' in args.config:
pyramid_levels = parse_pyramid_levels(args.config)
# create the display window if necessary
if not args.no_gui:
cv2.namedWindow('Image', cv2.WINDOW_NORMAL)
run(generator, args, anchor_params=anchor_params, pyramid_levels=pyramid_levels)
if __name__ == '__main__':
main()