김연수

retinaNet

Showing 86 changed files with 10409 additions and 0 deletions
1 +#see https://github.com/codecov/support/wiki/Codecov-Yaml
2 +codecov:
3 + notify:
4 + require_ci_to_pass: yes
5 +
6 +coverage:
7 + precision: 0 # 2 = xx.xx%, 0 = xx%
8 + round: nearest # how coverage is rounded: down/up/nearest
9 + range: 40...100 # custom range of coverage colors from red -> yellow -> green
10 + status:
11 + # https://codecov.readme.io/v1.0/docs/commit-status
12 + project:
13 + default:
14 + against: auto
15 + target: 90% # specify the target coverage for each commit status
16 + threshold: 20% # allow this little decrease on project
17 + # https://github.com/codecov/support/wiki/Filtering-Branches
18 + # branches: master
19 + if_ci_failed: error
20 + # https://github.com/codecov/support/wiki/Patch-Status
21 + patch:
22 + default:
23 + against: auto
24 + target: 40% # specify the target "X%" coverage to hit
25 + # threshold: 50% # allow this much decrease on patch
26 + changes: false
27 +
28 +parsers:
29 + gcov:
30 + branch_detection:
31 + conditional: true
32 + loop: true
33 + macro: false
34 + method: false
35 + javascript:
36 + enable_partials: false
37 +
38 +comment:
39 + layout: header, diff
40 + require_changes: false
41 + behavior: default # update if exists else create new
42 + # branches: *
...\ No newline at end of file ...\ No newline at end of file
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# Distribution / packaging
7 +.Python
8 +/build/
9 +/dist/
10 +/eggs/
11 +/*-eggs/
12 +.eggs/
13 +/sdist/
14 +/wheels/
15 +/*.egg-info/
16 +.installed.cfg
17 +*.egg
18 +
19 +# Unit test / coverage reports
20 +.coverage
21 +.coverage.*
22 +coverage.xml
23 +*.cover
...\ No newline at end of file ...\ No newline at end of file
1 +[submodule "tests/test-data"]
2 + path = tests/test-data
3 + url = https://github.com/fizyr/keras-retinanet-test-data.git
1 +language: python
2 +
3 +sudo: required
4 +
5 +python:
6 + - '3.6'
7 + - '3.7'
8 +
9 +install:
10 + - pip install -r requirements.txt
11 + - pip install -r tests/requirements.txt
12 +
13 +cache: pip
14 +
15 +script:
16 + - python setup.py check -m -s
17 + - python setup.py build_ext --inplace
18 + - coverage run --source keras_retinanet -m py.test keras_retinanet tests --doctest-modules --forked --flake8
19 +
20 +after_success:
21 + - coverage xml
22 + - coverage report
23 + - codecov
1 +# Contributors
2 +
3 +This is a list of people who contributed patches to keras-retinanet.
4 +
5 +If you feel you should be listed here or if you have any other questions/comments on your listing here,
6 +please create an issue or pull request at https://github.com/fizyr/keras-retinanet/
7 +
8 +* Hans Gaiser <h.gaiser@fizyr.com>
9 +* Maarten de Vries <maarten@de-vri.es>
10 +* Valerio Carpani
11 +* Ashley Williamson
12 +* Yann Henon
13 +* Valeriu Lacatusu
14 +* András Vidosits
15 +* Cristian Gratie
16 +* jjiunlin
17 +* Sorin Panduru
18 +* Rodrigo Meira de Andrade
19 +* Enrico Liscio <e.liscio@fizyr.com>
20 +* Mihai Morariu
21 +* pedroconceicao
22 +* jjiun
23 +* Wudi Fang
24 +* Mike Clark
25 +* hannesedvartsen
26 +* Max Van Sande
27 +* Pierre Dérian
28 +* ori
29 +* mxvs
30 +* mwilder
31 +* Muhammed Kocabas
32 +* Koen Vijverberg
33 +* iver56
34 +* hnsywangxin
35 +* Guillaume Erhard
36 +* Eduardo Ramos
37 +* DiegoAgher
38 +* Alexander Pacha
39 +* Agastya Kalra
40 +* Jiri BOROVEC
41 +* ntsagko
42 +* charlie / tianqi
43 +* jsemric
44 +* Martin Zlocha
45 +* Raghav Bhardwaj
46 +* bw4sz
47 +* Morten Back Nielsen
48 +* dshahrokhian
49 +* Alex / adreo00
50 +* simone.merello
51 +* Matt Wilder
52 +* Jinwoo Baek
53 +* Etienne Meunier
54 +* Denis Dowling
55 +* cclauss
56 +* Andrew Grigorev
57 +* ZFTurbo
58 +* UgoLouche
59 +* Richard Higgins
60 +* Rajat / rajat.goel
61 +* philipp.marquardt
62 +* peacherwu
63 +* Paul / pauldesigaud
64 +* Martin Genet
65 +* Leo / leonardvandriel
66 +* Laurens Hagendoorn
67 +* Julius / juliussimonelli
68 +* HolyGuacamole
69 +* Fausto Morales
70 +* borakrc
71 +* Ben Weinstein
72 +* Anil Karaka
73 +* Andrea Panizza
74 +* Bruno Santos
...\ No newline at end of file ...\ No newline at end of file
1 + Apache License
2 + Version 2.0, January 2004
3 + http://www.apache.org/licenses/
4 +
5 + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 +
7 + 1. Definitions.
8 +
9 + "License" shall mean the terms and conditions for use, reproduction,
10 + and distribution as defined by Sections 1 through 9 of this document.
11 +
12 + "Licensor" shall mean the copyright owner or entity authorized by
13 + the copyright owner that is granting the License.
14 +
15 + "Legal Entity" shall mean the union of the acting entity and all
16 + other entities that control, are controlled by, or are under common
17 + control with that entity. For the purposes of this definition,
18 + "control" means (i) the power, direct or indirect, to cause the
19 + direction or management of such entity, whether by contract or
20 + otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 + outstanding shares, or (iii) beneficial ownership of such entity.
22 +
23 + "You" (or "Your") shall mean an individual or Legal Entity
24 + exercising permissions granted by this License.
25 +
26 + "Source" form shall mean the preferred form for making modifications,
27 + including but not limited to software source code, documentation
28 + source, and configuration files.
29 +
30 + "Object" form shall mean any form resulting from mechanical
31 + transformation or translation of a Source form, including but
32 + not limited to compiled object code, generated documentation,
33 + and conversions to other media types.
34 +
35 + "Work" shall mean the work of authorship, whether in Source or
36 + Object form, made available under the License, as indicated by a
37 + copyright notice that is included in or attached to the work
38 + (an example is provided in the Appendix below).
39 +
40 + "Derivative Works" shall mean any work, whether in Source or Object
41 + form, that is based on (or derived from) the Work and for which the
42 + editorial revisions, annotations, elaborations, or other modifications
43 + represent, as a whole, an original work of authorship. For the purposes
44 + of this License, Derivative Works shall not include works that remain
45 + separable from, or merely link (or bind by name) to the interfaces of,
46 + the Work and Derivative Works thereof.
47 +
48 + "Contribution" shall mean any work of authorship, including
49 + the original version of the Work and any modifications or additions
50 + to that Work or Derivative Works thereof, that is intentionally
51 + submitted to Licensor for inclusion in the Work by the copyright owner
52 + or by an individual or Legal Entity authorized to submit on behalf of
53 + the copyright owner. For the purposes of this definition, "submitted"
54 + means any form of electronic, verbal, or written communication sent
55 + to the Licensor or its representatives, including but not limited to
56 + communication on electronic mailing lists, source code control systems,
57 + and issue tracking systems that are managed by, or on behalf of, the
58 + Licensor for the purpose of discussing and improving the Work, but
59 + excluding communication that is conspicuously marked or otherwise
60 + designated in writing by the copyright owner as "Not a Contribution."
61 +
62 + "Contributor" shall mean Licensor and any individual or Legal Entity
63 + on behalf of whom a Contribution has been received by Licensor and
64 + subsequently incorporated within the Work.
65 +
66 + 2. Grant of Copyright License. Subject to the terms and conditions of
67 + this License, each Contributor hereby grants to You a perpetual,
68 + worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 + copyright license to reproduce, prepare Derivative Works of,
70 + publicly display, publicly perform, sublicense, and distribute the
71 + Work and such Derivative Works in Source or Object form.
72 +
73 + 3. Grant of Patent License. Subject to the terms and conditions of
74 + this License, each Contributor hereby grants to You a perpetual,
75 + worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 + (except as stated in this section) patent license to make, have made,
77 + use, offer to sell, sell, import, and otherwise transfer the Work,
78 + where such license applies only to those patent claims licensable
79 + by such Contributor that are necessarily infringed by their
80 + Contribution(s) alone or by combination of their Contribution(s)
81 + with the Work to which such Contribution(s) was submitted. If You
82 + institute patent litigation against any entity (including a
83 + cross-claim or counterclaim in a lawsuit) alleging that the Work
84 + or a Contribution incorporated within the Work constitutes direct
85 + or contributory patent infringement, then any patent licenses
86 + granted to You under this License for that Work shall terminate
87 + as of the date such litigation is filed.
88 +
89 + 4. Redistribution. You may reproduce and distribute copies of the
90 + Work or Derivative Works thereof in any medium, with or without
91 + modifications, and in Source or Object form, provided that You
92 + meet the following conditions:
93 +
94 + (a) You must give any other recipients of the Work or
95 + Derivative Works a copy of this License; and
96 +
97 + (b) You must cause any modified files to carry prominent notices
98 + stating that You changed the files; and
99 +
100 + (c) You must retain, in the Source form of any Derivative Works
101 + that You distribute, all copyright, patent, trademark, and
102 + attribution notices from the Source form of the Work,
103 + excluding those notices that do not pertain to any part of
104 + the Derivative Works; and
105 +
106 + (d) If the Work includes a "NOTICE" text file as part of its
107 + distribution, then any Derivative Works that You distribute must
108 + include a readable copy of the attribution notices contained
109 + within such NOTICE file, excluding those notices that do not
110 + pertain to any part of the Derivative Works, in at least one
111 + of the following places: within a NOTICE text file distributed
112 + as part of the Derivative Works; within the Source form or
113 + documentation, if provided along with the Derivative Works; or,
114 + within a display generated by the Derivative Works, if and
115 + wherever such third-party notices normally appear. The contents
116 + of the NOTICE file are for informational purposes only and
117 + do not modify the License. You may add Your own attribution
118 + notices within Derivative Works that You distribute, alongside
119 + or as an addendum to the NOTICE text from the Work, provided
120 + that such additional attribution notices cannot be construed
121 + as modifying the License.
122 +
123 + You may add Your own copyright statement to Your modifications and
124 + may provide additional or different license terms and conditions
125 + for use, reproduction, or distribution of Your modifications, or
126 + for any such Derivative Works as a whole, provided Your use,
127 + reproduction, and distribution of the Work otherwise complies with
128 + the conditions stated in this License.
129 +
130 + 5. Submission of Contributions. Unless You explicitly state otherwise,
131 + any Contribution intentionally submitted for inclusion in the Work
132 + by You to the Licensor shall be under the terms and conditions of
133 + this License, without any additional terms or conditions.
134 + Notwithstanding the above, nothing herein shall supersede or modify
135 + the terms of any separate license agreement you may have executed
136 + with Licensor regarding such Contributions.
137 +
138 + 6. Trademarks. This License does not grant permission to use the trade
139 + names, trademarks, service marks, or product names of the Licensor,
140 + except as required for reasonable and customary use in describing the
141 + origin of the Work and reproducing the content of the NOTICE file.
142 +
143 + 7. Disclaimer of Warranty. Unless required by applicable law or
144 + agreed to in writing, Licensor provides the Work (and each
145 + Contributor provides its Contributions) on an "AS IS" BASIS,
146 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 + implied, including, without limitation, any warranties or conditions
148 + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 + PARTICULAR PURPOSE. You are solely responsible for determining the
150 + appropriateness of using or redistributing the Work and assume any
151 + risks associated with Your exercise of permissions under this License.
152 +
153 + 8. Limitation of Liability. In no event and under no legal theory,
154 + whether in tort (including negligence), contract, or otherwise,
155 + unless required by applicable law (such as deliberate and grossly
156 + negligent acts) or agreed to in writing, shall any Contributor be
157 + liable to You for damages, including any direct, indirect, special,
158 + incidental, or consequential damages of any character arising as a
159 + result of this License or out of the use or inability to use the
160 + Work (including but not limited to damages for loss of goodwill,
161 + work stoppage, computer failure or malfunction, or any and all
162 + other commercial damages or losses), even if such Contributor
163 + has been advised of the possibility of such damages.
164 +
165 + 9. Accepting Warranty or Additional Liability. While redistributing
166 + the Work or Derivative Works thereof, You may choose to offer,
167 + and charge a fee for, acceptance of support, warranty, indemnity,
168 + or other liability obligations and/or rights consistent with this
169 + License. However, in accepting such obligations, You may act only
170 + on Your own behalf and on Your sole responsibility, not on behalf
171 + of any other Contributor, and only if You agree to indemnify,
172 + defend, and hold each Contributor harmless for any liability
173 + incurred by, or claims asserted against, such Contributor by reason
174 + of your accepting any such warranty or additional liability.
175 +
176 + END OF TERMS AND CONDITIONS
177 +
178 + APPENDIX: How to apply the Apache License to your work.
179 +
180 + To apply the Apache License to your work, attach the following
181 + boilerplate notice, with the fields enclosed by brackets "{}"
182 + replaced with your own identifying information. (Don't include
183 + the brackets!) The text should be enclosed in the appropriate
184 + comment syntax for the file format. We also recommend that a
185 + file or class name and description of purpose be included on the
186 + same "printed page" as the copyright notice for easier
187 + identification within third-party archives.
188 +
189 + Copyright {yyyy} {name of copyright owner}
190 +
191 + Licensed under the Apache License, Version 2.0 (the "License");
192 + you may not use this file except in compliance with the License.
193 + You may obtain a copy of the License at
194 +
195 + http://www.apache.org/licenses/LICENSE-2.0
196 +
197 + Unless required by applicable law or agreed to in writing, software
198 + distributed under the License is distributed on an "AS IS" BASIS,
199 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 + See the License for the specific language governing permissions and
201 + limitations under the License.
1 +# Keras RetinaNet [![Build Status](https://travis-ci.org/fizyr/keras-retinanet.svg?branch=master)](https://travis-ci.org/fizyr/keras-retinanet) [![DOI](https://zenodo.org/badge/100249425.svg)](https://zenodo.org/badge/latestdoi/100249425)
2 +
3 +Keras implementation of RetinaNet object detection as described in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
4 +by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollár.
5 +
6 +## :warning: Deprecated
7 +
8 +This repository is deprecated in favor of the [torchvision](https://github.com/pytorch/vision/) module.
9 +This project should work with keras 2.4 and tensorflow 2.3.0, newer versions might break support.
10 +For more information, check [here](https://github.com/fizyr/keras-retinanet/issues/1471#issuecomment-704187205).
11 +
12 +## Installation
13 +
14 +1) Clone this repository.
15 +2) In the repository, execute `pip install . --user`.
16 + Note that due to inconsistencies with how `tensorflow` should be installed,
17 + this package does not define a dependency on `tensorflow` as it will try to install that (which at least on Arch Linux results in an incorrect installation).
18 + Please make sure `tensorflow` is installed as per your systems requirements.
19 +3) Alternatively, you can run the code directly from the cloned repository, however you need to run `python setup.py build_ext --inplace` to compile Cython code first.
20 +4) Optionally, install `pycocotools` if you want to train / test on the MS COCO dataset by running `pip install --user git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI`.
21 +
22 +## Testing
23 +An example of testing the network can be seen in [this Notebook](https://github.com/delftrobotics/keras-retinanet/blob/master/examples/ResNet50RetinaNet.ipynb).
24 +In general, inference of the network works as follows:
25 +```python
26 +boxes, scores, labels = model.predict_on_batch(inputs)
27 +```
28 +
29 +Where `boxes` are shaped `(None, None, 4)` (for `(x1, y1, x2, y2)`), scores is shaped `(None, None)` (classification score) and labels is shaped `(None, None)` (label corresponding to the score). In all three outputs, the first dimension represents the shape and the second dimension indexes the list of detections.
30 +
31 +Loading models can be done in the following manner:
32 +```python
33 +from keras_retinanet.models import load_model
34 +model = load_model('/path/to/model.h5', backbone_name='resnet50')
35 +```
36 +
37 +Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x800x3`.
38 +
39 +### Converting a training model to inference model
40 +The training procedure of `keras-retinanet` works with *training models*. These are stripped down versions compared to the *inference model* and only contains the layers necessary for training (regression and classification values). If you wish to do inference on a model (perform object detection on an image), you need to convert the trained model to an inference model. This is done as follows:
41 +
42 +```shell
43 +# Running directly from the repository:
44 +keras_retinanet/bin/convert_model.py /path/to/training/model.h5 /path/to/save/inference/model.h5
45 +
46 +# Using the installed script:
47 +retinanet-convert-model /path/to/training/model.h5 /path/to/save/inference/model.h5
48 +```
49 +
50 +Most scripts (like `retinanet-evaluate`) also support converting on the fly, using the `--convert-model` argument.
51 +
52 +
53 +## Training
54 +`keras-retinanet` can be trained using [this](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/bin/train.py) script.
55 +Note that the train script uses relative imports since it is inside the `keras_retinanet` package.
56 +If you want to adjust the script for your own use outside of this repository,
57 +you will need to switch it to use absolute imports.
58 +
59 +If you installed `keras-retinanet` correctly, the train script will be installed as `retinanet-train`.
60 +However, if you make local modifications to the `keras-retinanet` repository, you should run the script directly from the repository.
61 +That will ensure that your local changes will be used by the train script.
62 +
63 +The default backbone is `resnet50`. You can change this using the `--backbone=xxx` argument in the running script.
64 +`xxx` can be one of the backbones in resnet models (`resnet50`, `resnet101`, `resnet152`), mobilenet models (`mobilenet128_1.0`, `mobilenet128_0.75`, `mobilenet160_1.0`, etc), densenet models or vgg models. The different options are defined by each model in their corresponding python scripts (`resnet.py`, `mobilenet.py`, etc).
65 +
66 +Trained models can't be used directly for inference. To convert a trained model to an inference model, check [here](https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model).
67 +
68 +### Usage
69 +For training on [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/), run:
70 +```shell
71 +# Running directly from the repository:
72 +keras_retinanet/bin/train.py pascal /path/to/VOCdevkit/VOC2007
73 +
74 +# Using the installed script:
75 +retinanet-train pascal /path/to/VOCdevkit/VOC2007
76 +```
77 +
78 +For training on [MS COCO](http://cocodataset.org/#home), run:
79 +```shell
80 +# Running directly from the repository:
81 +keras_retinanet/bin/train.py coco /path/to/MS/COCO
82 +
83 +# Using the installed script:
84 +retinanet-train coco /path/to/MS/COCO
85 +```
86 +
87 +For training on Open Images Dataset [OID](https://storage.googleapis.com/openimages/web/index.html)
88 +or taking place to the [OID challenges](https://storage.googleapis.com/openimages/web/challenge.html), run:
89 +```shell
90 +# Running directly from the repository:
91 +keras_retinanet/bin/train.py oid /path/to/OID
92 +
93 +# Using the installed script:
94 +retinanet-train oid /path/to/OID
95 +
96 +# You can also specify a list of labels if you want to train on a subset
97 +# by adding the argument 'labels_filter':
98 +keras_retinanet/bin/train.py oid /path/to/OID --labels-filter=Helmet,Tree
99 +
100 +# You can also specify a parent label if you want to train on a branch
101 +# from the semantic hierarchical tree (i.e a parent and all children)
102 +(https://storage.googleapis.com/openimages/challenge_2018/bbox_labels_500_hierarchy_visualizer/circle.html)
103 +# by adding the argument 'parent-label':
104 +keras_retinanet/bin/train.py oid /path/to/OID --parent-label=Boat
105 +```
106 +
107 +
108 +For training on [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php), run:
109 +```shell
110 +# Running directly from the repository:
111 +keras_retinanet/bin/train.py kitti /path/to/KITTI
112 +
113 +# Using the installed script:
114 +retinanet-train kitti /path/to/KITTI
115 +
116 +If you want to prepare the dataset you can use the following script:
117 +https://github.com/NVIDIA/DIGITS/blob/master/examples/object-detection/prepare_kitti_data.py
118 +```
119 +
120 +
121 +For training on a [custom dataset], a CSV file can be used as a way to pass the data.
122 +See below for more details on the format of these CSV files.
123 +To train using your CSV, run:
124 +```shell
125 +# Running directly from the repository:
126 +keras_retinanet/bin/train.py csv /path/to/csv/file/containing/annotations /path/to/csv/file/containing/classes
127 +
128 +# Using the installed script:
129 +retinanet-train csv /path/to/csv/file/containing/annotations /path/to/csv/file/containing/classes
130 +```
131 +
132 +In general, the steps to train on your own datasets are:
133 +1) Create a model by calling for instance `keras_retinanet.models.backbone('resnet50').retinanet(num_classes=80)` and compile it.
134 + Empirically, the following compile arguments have been found to work well:
135 +```python
136 +model.compile(
137 + loss={
138 + 'regression' : keras_retinanet.losses.smooth_l1(),
139 + 'classification': keras_retinanet.losses.focal()
140 + },
141 + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001)
142 +)
143 +```
144 +2) Create generators for training and testing data (an example is show in [`keras_retinanet.preprocessing.pascal_voc.PascalVocGenerator`](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/preprocessing/pascal_voc.py)).
145 +3) Use `model.fit_generator` to start training.
146 +
147 +## Pretrained models
148 +
149 +All models can be downloaded from the [releases page](https://github.com/fizyr/keras-retinanet/releases).
150 +
151 +### MS COCO
152 +
153 +Results using the `cocoapi` are shown below (note: according to the paper, this configuration should achieve a mAP of 0.357).
154 +
155 +```
156 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350
157 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.537
158 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.374
159 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.191
160 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.383
161 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.472
162 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.306
163 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.491
164 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533
165 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.345
166 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.577
167 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.681
168 +```
169 +
170 +### Open Images Dataset
171 +There are 3 RetinaNet models based on ResNet50, ResNet101 and ResNet152 trained on all [500 classes](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018/blob/master/a00_utils_and_constants.py#L130) of the Open Images Dataset (thanks to @ZFTurbo).
172 +
173 +| Backbone | Image Size (px) | Small validation mAP | LB (Public) |
174 +| --------- | --------------- | -------------------- | ----------- |
175 +| ResNet50 | 768 - 1024 | 0.4594 | 0.4223 |
176 +| ResNet101 | 768 - 1024 | 0.4986 | 0.4520 |
177 +| ResNet152 | 600 - 800 | 0.4991 | 0.4651 |
178 +
179 +For more information, check [@ZFTurbo's](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018) repository.
180 +
181 +## CSV datasets
182 +The `CSVGenerator` provides an easy way to define your own datasets.
183 +It uses two CSV files: one file containing annotations and one file containing a class name to ID mapping.
184 +
185 +### Annotations format
186 +The CSV file with annotations should contain one annotation per line.
187 +Images with multiple bounding boxes should use one row per bounding box.
188 +Note that indexing for pixel values starts at 0.
189 +The expected format of each line is:
190 +```
191 +path/to/image.jpg,x1,y1,x2,y2,class_name
192 +```
193 +By default the CSV generator will look for images relative to the directory of the annotations file.
194 +
195 +Some images may not contain any labeled objects.
196 +To add these images to the dataset as negative examples,
197 +add an annotation where `x1`, `y1`, `x2`, `y2` and `class_name` are all empty:
198 +```
199 +path/to/image.jpg,,,,,
200 +```
201 +
202 +A full example:
203 +```
204 +/data/imgs/img_001.jpg,837,346,981,456,cow
205 +/data/imgs/img_002.jpg,215,312,279,391,cat
206 +/data/imgs/img_002.jpg,22,5,89,84,bird
207 +/data/imgs/img_003.jpg,,,,,
208 +```
209 +
210 +This defines a dataset with 3 images.
211 +`img_001.jpg` contains a cow.
212 +`img_002.jpg` contains a cat and a bird.
213 +`img_003.jpg` contains no interesting objects/animals.
214 +
215 +
216 +### Class mapping format
217 +The class name to ID mapping file should contain one mapping per line.
218 +Each line should use the following format:
219 +```
220 +class_name,id
221 +```
222 +
223 +Indexing for classes starts at 0.
224 +Do not include a background class as it is implicit.
225 +
226 +For example:
227 +```
228 +cow,0
229 +cat,1
230 +bird,2
231 +```
232 +
233 +## Anchor optimization
234 +
235 +In some cases, the default anchor configuration is not suitable for detecting objects in your dataset, for example, if your objects are smaller than the 32x32px (size of the smallest anchors). In this case, it might be suitable to modify the anchor configuration, this can be done automatically by following the steps in the [anchor-optimization](https://github.com/martinzlocha/anchor-optimization/) repository. To use the generated configuration check [here](https://github.com/fizyr/keras-retinanet-test-data/blob/master/config/config.ini) for an example config file and then pass it to `train.py` using the `--config` parameter.
236 +
237 +## Debugging
238 +Creating your own dataset does not always work out of the box. There is a [`debug.py`](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/bin/debug.py) tool to help find the most common mistakes.
239 +
240 +Particularly helpful is the `--annotations` flag which displays your annotations on the images from your dataset. Annotations are colored in green when there are anchors available and colored in red when there are no anchors available. If an annotation doesn't have anchors available, it means it won't contribute to training. It is normal for a small amount of annotations to show up in red, but if most or all annotations are red there is cause for concern. The most common issues are that the annotations are too small or too oddly shaped (stretched out).
241 +
242 +## Results
243 +
244 +### MS COCO
245 +
246 +## Status
247 +Example output images using `keras-retinanet` are shown below.
248 +
249 +<p align="center">
250 + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco1.png" alt="Example result of RetinaNet on MS COCO"/>
251 + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco2.png" alt="Example result of RetinaNet on MS COCO"/>
252 + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco3.png" alt="Example result of RetinaNet on MS COCO"/>
253 +</p>
254 +
255 +### Projects using keras-retinanet
256 +* [Improving Apple Detection and Counting Using RetinaNet](https://github.com/nikostsagk/Apple-detection). This work aims to investigate the apple detection problem through the deployment of the Keras RetinaNet.
257 +* [Improving RetinaNet for CT Lesion Detection with Dense Masks from Weak RECIST Labels](https://arxiv.org/abs/1906.02283). Research project for detecting lesions in CT using keras-retinanet.
258 +* [NudeNet](https://github.com/bedapudi6788/NudeNet). Project that focuses on detecting and censoring of nudity.
259 +* [Individual tree-crown detection in RGB imagery using self-supervised deep learning neural networks](https://www.biorxiv.org/content/10.1101/532952v1). Research project focused on improving the performance of remotely sensed tree surveys.
260 +* [ESRI Object Detection Challenge 2019](https://github.com/kunwar31/ESRI_Object_Detection). Winning implementation of the ESRI Object Detection Challenge 2019.
261 +* [Lunar Rockfall Detector Project](https://ieeexplore.ieee.org/document/8587120). The aim of this project is to [map lunar rockfalls on a global scale](https://www.nature.com/articles/s41467-020-16653-3) using the available > 2 million satellite images.
262 +* [Mars Rockfall Detector Project](https://ieeexplore.ieee.org/document/9103997). The aim of this project is to map rockfalls on Mars.
263 +* [NATO Innovation Challenge](https://medium.com/data-from-the-trenches/object-detection-with-deep-learning-on-aerial-imagery-2465078db8a9). The winning team of the NATO Innovation Challenge used keras-retinanet to detect cars in aerial images ([COWC dataset](https://gdo152.llnl.gov/cowc/)).
264 +* [Microsoft Research for Horovod on Azure](https://blogs.technet.microsoft.com/machinelearning/2018/06/20/how-to-do-distributed-deep-learning-for-object-detection-using-horovod-on-azure/). A research project by Microsoft, using keras-retinanet to distribute training over multiple GPUs using Horovod on Azure.
265 +* [Anno-Mage](https://virajmavani.github.io/saiat/). A tool that helps you annotate images, using input from the keras-retinanet COCO model as suggestions.
266 +* [Telenav.AI](https://github.com/Telenav/Telenav.AI/tree/master/retinanet). For the detection of traffic signs using keras-retinanet.
267 +* [Towards Deep Placental Histology Phenotyping](https://github.com/Nellaker-group/TowardsDeepPhenotyping). This research project uses keras-retinanet for analysing the placenta at a cellular level.
268 +* [4k video example](https://www.youtube.com/watch?v=KYueHEMGRos). This demo shows the use of keras-retinanet on a 4k input video.
269 +* [boring-detector](https://github.com/lexfridman/boring-detector). I suppose not all projects need to solve life's biggest questions. This project detects the "The Boring Company" hats in videos.
270 +* [comet.ml](https://towardsdatascience.com/how-i-monitor-and-track-my-machine-learning-experiments-from-anywhere-described-in-13-tweets-ec3d0870af99). Using keras-retinanet in combination with [comet.ml](https://comet.ml) to interactively inspect and compare experiments.
271 +* [Weights and Biases](https://app.wandb.ai/syllogismos/keras-retinanet/reports?view=carey%2FObject%20Detection%20with%20RetinaNet). Trained keras-retinanet on coco dataset from beginning on resnet50 and resnet101 backends.
272 +* [Google Open Images Challenge 2018 15th place solution](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018). Pretrained weights for keras-retinanet based on ResNet50, ResNet101 and ResNet152 trained on open images dataset.
273 +* [poke.AI](https://github.com/Raghav-B/poke.AI). An experimental AI that attempts to master the 3rd Generation Pokemon games. Using keras-retinanet for in-game mapping and localization.
274 +* [retinanetjs](https://github.com/faustomorales/retinanetjs). A wrapper to run RetinaNet inference in the browser / Node.js. You can also take a look at the [example app](https://faustomorales.github.io/retinanetjs-example-app/).
275 +* [CRFNet](https://github.com/TUMFTM/CameraRadarFusionNet). This network fuses radar and camera data to perform object detection for autonomous driving applications.
276 +* [LogoDet](https://github.com/notAI-tech/LogoDet). Project for detecting company logos in images.
277 +
278 +
279 +If you have a project based on `keras-retinanet` and would like to have it published here, shoot me a message on Slack.
280 +
281 +### Notes
282 +* This repository requires Tensorflow 2.3.0 or higher.
283 +* This repository is [tested](https://github.com/fizyr/keras-retinanet/blob/master/.travis.yml) using OpenCV 3.4.
284 +* This repository is [tested](https://github.com/fizyr/keras-retinanet/blob/master/.travis.yml) using Python 2.7 and 3.6.
285 +
286 +Contributions to this project are welcome.
287 +
288 +### Discussions
289 +Feel free to join the `#keras-retinanet` [Keras Slack](https://keras-slack-autojoin.herokuapp.com/) channel for discussions and questions.
290 +
291 +## FAQ
292 +* **I get the warning `UserWarning: No training configuration found in save file: the model was not compiled. Compile it manually.`, should I be worried?** This warning can safely be ignored during inference.
293 +* **I get the error `ValueError: not enough values to unpack (expected 3, got 2)` during inference, what to do?**. This is because you are using a train model to do inference. See https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model for more information.
294 +* **How do I do transfer learning?** The easiest solution is to use the `--weights` argument when training. Keras will load models, even if the number of classes don't match (it will simply skip loading of weights when there is a mismatch). Run for example `retinanet-train --weights snapshots/some_coco_model.h5 pascal /path/to/pascal` to transfer weights from a COCO model to a PascalVOC training session. If your dataset is small, you can also use the `--freeze-backbone` argument to freeze the backbone layers.
295 +* **How do I change the number / shape of the anchors?** The train tool allows to pass a configuration file, where the anchor parameters can be adjusted. Check [here](https://github.com/fizyr/keras-retinanet-test-data/blob/master/config/config.ini) for an example config file.
296 +* **I get a loss of `0`, what is going on?** This mostly happens when none of the anchors "fit" on your objects, because they are most likely too small or elongated. You can verify this using the [debug](https://github.com/fizyr/keras-retinanet#debugging) tool.
297 +* **I have an older model, can I use it after an update of keras-retinanet?** This depends on what has changed. If it is a change that doesn't affect the weights then you can "update" models by creating a new retinanet model, loading your old weights using `model.load_weights(weights_path, by_name=True)` and saving this model. If the change has been too significant, you should retrain your model (you can try to load in the weights from your old model when starting training, this might be a better starting position than ImageNet).
298 +* **I get the error `ModuleNotFoundError: No module named 'keras_retinanet.utils.compute_overlap'`, how do I fix this?** Most likely you are running the code from the cloned repository. This is fine, but you need to compile some extensions for this to work (`python setup.py build_ext --inplace`).
299 +* **How do I train on my own dataset?** The steps to train on your dataset are roughly as follows:
300 +* 1. Prepare your dataset in the CSV format (a training and validation split is advised).
301 +* 2. Check that your dataset is correct using `retinanet-debug`.
302 +* 3. Train retinanet, preferably using the pretrained COCO weights (this gives a **far** better starting point, making training much quicker and accurate). You can optionally perform evaluation of your validation set during training to keep track of how well it performs (advised).
303 +* 4. Convert your training model to an inference model.
304 +* 5. Evaluate your inference model on your test or validation set.
305 +* 6. Profit!
1 +[anchor_parameters]
2 +# Sizes should correlate to how the network processes an image, it is not advised to change these!
3 +sizes = 64 128 256
4 +# Strides should correlate to how the network strides over an image, it is not advised to change these!
5 +strides = 16 32 64
6 +# The different ratios to use per anchor location.
7 +ratios = 0.5 1 2 3
8 +# The different scaling factors to use per anchor location.
9 +scales = 1 1.2 1.6
10 +
11 +[pyramid_levels]
12 +
13 +levels = 3 4 5
...\ No newline at end of file ...\ No newline at end of file
1 +[.ShellClassInfo]
2 +IconResource=C:\WINDOWS\System32\SHELL32.dll,3
3 +[ViewState]
4 +Mode=
5 +Vid=
6 +FolderType=Generic
This diff could not be displayed because it is too large.
1 +#!/usr/bin/env python
2 +# coding: utf-8
3 +
4 +# Load necessary modules
5 +
6 +import sys
7 +
8 +sys.path.insert(0, "../")
9 +
10 +
11 +# import keras_retinanet
12 +from keras_retinanet import models
13 +from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
14 +from keras_retinanet.utils.visualization import draw_box, draw_caption
15 +from keras_retinanet.utils.colors import label_color
16 +from keras_retinanet.utils.gpu import setup_gpu
17 +
18 +# import miscellaneous modules
19 +import matplotlib.pyplot as plt
20 +import cv2
21 +import os
22 +import numpy as np
23 +import time
24 +
25 +# set tf backend to allow memory to grow, instead of claiming everything
26 +import tensorflow as tf
27 +
28 +# use this to change which GPU to use
29 +gpu = 0
30 +
31 +# set the modified tf session as backend in keras
32 +setup_gpu(gpu)
33 +
34 +
35 +# ## Load RetinaNet model
36 +
37 +# In[ ]:
38 +
39 +
40 +# adjust this to point to your downloaded/trained model
41 +# models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases
42 +model_path = os.path.join("..", "snapshots", "resnet50_coco_best_v2.1.0.h5")
43 +
44 +# load retinanet model
45 +model = models.load_model(model_path, backbone_name="resnet50")
46 +
47 +# if the model is not converted to an inference model, use the line below
48 +# see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model
49 +# model = models.convert_model(model)
50 +
51 +# print(model.summary())
52 +
53 +# load label to names mapping for visualization purposes
54 +labels_to_names = {
55 + 0: "person",
56 + 1: "bicycle",
57 + 2: "car",
58 + 3: "motorcycle",
59 + 4: "airplane",
60 + 5: "bus",
61 + 6: "train",
62 + 7: "truck",
63 + 8: "boat",
64 + 9: "traffic light",
65 + 10: "fire hydrant",
66 + 11: "stop sign",
67 + 12: "parking meter",
68 + 13: "bench",
69 + 14: "bird",
70 + 15: "cat",
71 + 16: "dog",
72 + 17: "horse",
73 + 18: "sheep",
74 + 19: "cow",
75 + 20: "elephant",
76 + 21: "bear",
77 + 22: "zebra",
78 + 23: "giraffe",
79 + 24: "backpack",
80 + 25: "umbrella",
81 + 26: "handbag",
82 + 27: "tie",
83 + 28: "suitcase",
84 + 29: "frisbee",
85 + 30: "skis",
86 + 31: "snowboard",
87 + 32: "sports ball",
88 + 33: "kite",
89 + 34: "baseball bat",
90 + 35: "baseball glove",
91 + 36: "skateboard",
92 + 37: "surfboard",
93 + 38: "tennis racket",
94 + 39: "bottle",
95 + 40: "wine glass",
96 + 41: "cup",
97 + 42: "fork",
98 + 43: "knife",
99 + 44: "spoon",
100 + 45: "bowl",
101 + 46: "banana",
102 + 47: "apple",
103 + 48: "sandwich",
104 + 49: "orange",
105 + 50: "broccoli",
106 + 51: "carrot",
107 + 52: "hot dog",
108 + 53: "pizza",
109 + 54: "donut",
110 + 55: "cake",
111 + 56: "chair",
112 + 57: "couch",
113 + 58: "potted plant",
114 + 59: "bed",
115 + 60: "dining table",
116 + 61: "toilet",
117 + 62: "tv",
118 + 63: "laptop",
119 + 64: "mouse",
120 + 65: "remote",
121 + 66: "keyboard",
122 + 67: "cell phone",
123 + 68: "microwave",
124 + 69: "oven",
125 + 70: "toaster",
126 + 71: "sink",
127 + 72: "refrigerator",
128 + 73: "book",
129 + 74: "clock",
130 + 75: "vase",
131 + 76: "scissors",
132 + 77: "teddy bear",
133 + 78: "hair drier",
134 + 79: "toothbrush",
135 +}
136 +
137 +
138 +# ## Run detection on example
139 +
140 +# In[ ]:
141 +
142 +
143 +# load image
144 +image = read_image_bgr("000000008021.jpg")
145 +
146 +# copy to draw on
147 +draw = image.copy()
148 +draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)
149 +
150 +# preprocess image for network
151 +image = preprocess_image(image)
152 +image, scale = resize_image(image)
153 +
154 +# process image
155 +start = time.time()
156 +boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
157 +print("processing time: ", time.time() - start)
158 +
159 +# correct for image scale
160 +boxes /= scale
161 +
162 +# visualize detections
163 +for box, score, label in zip(boxes[0], scores[0], labels[0]):
164 + # scores are sorted so we can break
165 + if score < 0.5:
166 + break
167 +
168 + color = label_color(label)
169 +
170 + b = box.astype(int)
171 + draw_box(draw, b, color=color)
172 +
173 + caption = "{} {:.3f}".format(labels_to_names[label], score)
174 + draw_caption(draw, b, caption)
175 +
176 +plt.figure(figsize=(15, 15))
177 +plt.axis("off")
178 +plt.imshow(draw)
179 +plt.show()
180 +
181 +
182 +# In[ ]:
183 +
184 +
185 +# In[ ]:
1 +from .backend import * # noqa: F401,F403
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import tensorflow
18 +from tensorflow import keras
19 +
20 +
21 +def bbox_transform_inv(boxes, deltas, mean=None, std=None):
22 + """ Applies deltas (usually regression results) to boxes (usually anchors).
23 +
24 + Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed.
25 + The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes.
26 +
27 + Args
28 + boxes : np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2).
29 + deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height.
30 + mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]).
31 + std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]).
32 +
33 + Returns
34 + A np.array of the same shape as boxes, but with deltas applied to each box.
35 + The mean and std are used during training to normalize the regression values (networks love normalization).
36 + """
37 + if mean is None:
38 + mean = [0, 0, 0, 0]
39 + if std is None:
40 + std = [0.2, 0.2, 0.2, 0.2]
41 +
42 + width = boxes[:, :, 2] - boxes[:, :, 0]
43 + height = boxes[:, :, 3] - boxes[:, :, 1]
44 +
45 + x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width
46 + y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height
47 + x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width
48 + y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height
49 +
50 + pred_boxes = keras.backend.stack([x1, y1, x2, y2], axis=2)
51 +
52 + return pred_boxes
53 +
54 +
55 +def shift(shape, stride, anchors):
56 + """ Produce shifted anchors based on shape of the map and stride size.
57 +
58 + Args
59 + shape : Shape to shift the anchors over.
60 + stride : Stride to shift the anchors with over the shape.
61 + anchors: The anchors to apply at each location.
62 + """
63 + shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride
64 + shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride
65 +
66 + shift_x, shift_y = tensorflow.meshgrid(shift_x, shift_y)
67 + shift_x = keras.backend.reshape(shift_x, [-1])
68 + shift_y = keras.backend.reshape(shift_y, [-1])
69 +
70 + shifts = keras.backend.stack([
71 + shift_x,
72 + shift_y,
73 + shift_x,
74 + shift_y
75 + ], axis=0)
76 +
77 + shifts = keras.backend.transpose(shifts)
78 + number_of_anchors = keras.backend.shape(anchors)[0]
79 +
80 + k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w
81 +
82 + shifted_anchors = keras.backend.reshape(anchors, [1, number_of_anchors, 4]) + keras.backend.cast(keras.backend.reshape(shifts, [k, 1, 4]), keras.backend.floatx())
83 + shifted_anchors = keras.backend.reshape(shifted_anchors, [k * number_of_anchors, 4])
84 +
85 + return shifted_anchors
86 +
87 +
88 +def map_fn(*args, **kwargs):
89 + """ See https://www.tensorflow.org/api_docs/python/tf/map_fn .
90 + """
91 +
92 + if "shapes" in kwargs:
93 + shapes = kwargs.pop("shapes")
94 + dtype = kwargs.pop("dtype")
95 + sig = [tensorflow.TensorSpec(shapes[i], dtype=t) for i, t in
96 + enumerate(dtype)]
97 +
98 + # Try to use the new feature fn_output_signature in TF 2.3, use fallback if this is not available
99 + try:
100 + return tensorflow.map_fn(*args, **kwargs, fn_output_signature=sig)
101 + except TypeError:
102 + kwargs["dtype"] = dtype
103 +
104 + return tensorflow.map_fn(*args, **kwargs)
105 +
106 +
107 +def resize_images(images, size, method='bilinear', align_corners=False):
108 + """ See https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/image/resize_images .
109 +
110 + Args
111 + method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area').
112 + """
113 + methods = {
114 + 'bilinear': tensorflow.image.ResizeMethod.BILINEAR,
115 + 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR,
116 + 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC,
117 + 'area' : tensorflow.image.ResizeMethod.AREA,
118 + }
119 + return tensorflow.compat.v1.image.resize_images(images, size, methods[method], align_corners)
1 +#!/usr/bin/env python
2 +
3 +"""
4 +Copyright 2017-2018 Fizyr (https://fizyr.com)
5 +
6 +Licensed under the Apache License, Version 2.0 (the "License");
7 +you may not use this file except in compliance with the License.
8 +You may obtain a copy of the License at
9 +
10 + http://www.apache.org/licenses/LICENSE-2.0
11 +
12 +Unless required by applicable law or agreed to in writing, software
13 +distributed under the License is distributed on an "AS IS" BASIS,
14 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 +See the License for the specific language governing permissions and
16 +limitations under the License.
17 +"""
18 +
19 +import argparse
20 +import os
21 +import sys
22 +
23 +# Allow relative imports when being executed as script.
24 +if __name__ == "__main__" and __package__ is None:
25 + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
26 + import keras_retinanet.bin # noqa: F401
27 + __package__ = "keras_retinanet.bin"
28 +
29 +# Change these to absolute imports if you copy this script outside the keras_retinanet package.
30 +from .. import models
31 +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels
32 +from ..utils.gpu import setup_gpu
33 +from ..utils.tf_version import check_tf_version
34 +
35 +
36 +def parse_args(args):
37 + parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.')
38 +
39 + parser.add_argument('model_in', help='The model to convert.')
40 + parser.add_argument('model_out', help='Path to save the converted model to.')
41 + parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50')
42 + parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false')
43 + parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false')
44 + parser.add_argument('--config', help='Path to a configuration parameters .ini file.')
45 + parser.add_argument('--nms-threshold', help='Value for non maximum suppression threshold.', type=float, default=0.5)
46 + parser.add_argument('--score-threshold', help='Threshold for prefiltering boxes.', type=float, default=0.05)
47 + parser.add_argument('--max-detections', help='Maximum number of detections to keep.', type=int, default=300)
48 + parser.add_argument('--parallel-iterations', help='Number of batch items to process in parallel.', type=int, default=32)
49 +
50 + return parser.parse_args(args)
51 +
52 +
53 +def main(args=None):
54 + # parse arguments
55 + if args is None:
56 + args = sys.argv[1:]
57 + args = parse_args(args)
58 +
59 + # make sure tensorflow is the minimum required version
60 + check_tf_version()
61 +
62 + # set modified tf session to avoid using the GPUs
63 + setup_gpu('cpu')
64 +
65 + # optionally load config parameters
66 + anchor_parameters = None
67 + pyramid_levels = None
68 + if args.config:
69 + args.config = read_config_file(args.config)
70 + if 'anchor_parameters' in args.config:
71 + anchor_parameters = parse_anchor_parameters(args.config)
72 +
73 + if 'pyramid_levels' in args.config:
74 + pyramid_levels = parse_pyramid_levels(args.config)
75 +
76 + # load the model
77 + model = models.load_model(args.model_in, backbone_name=args.backbone)
78 +
79 + # check if this is indeed a training model
80 + models.check_training_model(model)
81 +
82 + # convert the model
83 + model = models.convert_model(
84 + model,
85 + nms=args.nms,
86 + class_specific_filter=args.class_specific_filter,
87 + anchor_params=anchor_parameters,
88 + pyramid_levels=pyramid_levels,
89 + nms_threshold=args.nms_threshold,
90 + score_threshold=args.score_threshold,
91 + max_detections=args.max_detections,
92 + parallel_iterations=args.parallel_iterations
93 + )
94 +
95 + # save model
96 + model.save(args.model_out)
97 +
98 +
99 +if __name__ == '__main__':
100 + main()
1 +#!/usr/bin/env python
2 +
3 +"""
4 +Copyright 2017-2018 Fizyr (https://fizyr.com)
5 +
6 +Licensed under the Apache License, Version 2.0 (the "License");
7 +you may not use this file except in compliance with the License.
8 +You may obtain a copy of the License at
9 +
10 + http://www.apache.org/licenses/LICENSE-2.0
11 +
12 +Unless required by applicable law or agreed to in writing, software
13 +distributed under the License is distributed on an "AS IS" BASIS,
14 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 +See the License for the specific language governing permissions and
16 +limitations under the License.
17 +"""
18 +
19 +import argparse
20 +import os
21 +import sys
22 +import cv2
23 +
24 +# Set keycodes for changing images
25 +# 81, 83 are left and right arrows on linux in Ascii code (probably not needed)
26 +# 65361, 65363 are left and right arrows in linux
27 +# 2424832, 2555904 are left and right arrows on Windows
28 +# 110, 109 are 'n' and 'm' on mac, windows, linux
29 +# (unfortunately arrow keys not picked up on mac)
30 +leftkeys = (81, 110, 65361, 2424832)
31 +rightkeys = (83, 109, 65363, 2555904)
32 +
33 +# Allow relative imports when being executed as script.
34 +if __name__ == "__main__" and __package__ is None:
35 + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
36 + import keras_retinanet.bin # noqa: F401
37 + __package__ = "keras_retinanet.bin"
38 +
39 +# Change these to absolute imports if you copy this script outside the keras_retinanet package.
40 +from ..preprocessing.pascal_voc import PascalVocGenerator
41 +from ..preprocessing.csv_generator import CSVGenerator
42 +from ..preprocessing.kitti import KittiGenerator
43 +from ..preprocessing.open_images import OpenImagesGenerator
44 +from ..utils.anchors import anchors_for_shape, compute_gt_annotations
45 +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels
46 +from ..utils.image import random_visual_effect_generator
47 +from ..utils.tf_version import check_tf_version
48 +from ..utils.transform import random_transform_generator
49 +from ..utils.visualization import draw_annotations, draw_boxes, draw_caption
50 +
51 +
52 +def create_generator(args):
53 + """ Create the data generators.
54 +
55 + Args:
56 + args: parseargs arguments object.
57 + """
58 + common_args = {
59 + 'config' : args.config,
60 + 'image_min_side' : args.image_min_side,
61 + 'image_max_side' : args.image_max_side,
62 + 'group_method' : args.group_method
63 + }
64 +
65 + # create random transform generator for augmenting training data
66 + transform_generator = random_transform_generator(
67 + min_rotation=-0.1,
68 + max_rotation=0.1,
69 + min_translation=(-0.1, -0.1),
70 + max_translation=(0.1, 0.1),
71 + min_shear=-0.1,
72 + max_shear=0.1,
73 + min_scaling=(0.9, 0.9),
74 + max_scaling=(1.1, 1.1),
75 + flip_x_chance=0.5,
76 + flip_y_chance=0.5,
77 + )
78 +
79 + visual_effect_generator = random_visual_effect_generator(
80 + contrast_range=(0.9, 1.1),
81 + brightness_range=(-.1, .1),
82 + hue_range=(-0.05, 0.05),
83 + saturation_range=(0.95, 1.05)
84 + )
85 +
86 + if args.dataset_type == 'coco':
87 + # import here to prevent unnecessary dependency on cocoapi
88 + from ..preprocessing.coco import CocoGenerator
89 +
90 + generator = CocoGenerator(
91 + args.coco_path,
92 + args.coco_set,
93 + transform_generator=transform_generator,
94 + visual_effect_generator=visual_effect_generator,
95 + **common_args
96 + )
97 + elif args.dataset_type == 'pascal':
98 + generator = PascalVocGenerator(
99 + args.pascal_path,
100 + args.pascal_set,
101 + image_extension=args.image_extension,
102 + transform_generator=transform_generator,
103 + visual_effect_generator=visual_effect_generator,
104 + **common_args
105 + )
106 + elif args.dataset_type == 'csv':
107 + generator = CSVGenerator(
108 + args.annotations,
109 + args.classes,
110 + transform_generator=transform_generator,
111 + visual_effect_generator=visual_effect_generator,
112 + **common_args
113 + )
114 + elif args.dataset_type == 'oid':
115 + generator = OpenImagesGenerator(
116 + args.main_dir,
117 + subset=args.subset,
118 + version=args.version,
119 + labels_filter=args.labels_filter,
120 + parent_label=args.parent_label,
121 + annotation_cache_dir=args.annotation_cache_dir,
122 + transform_generator=transform_generator,
123 + visual_effect_generator=visual_effect_generator,
124 + **common_args
125 + )
126 + elif args.dataset_type == 'kitti':
127 + generator = KittiGenerator(
128 + args.kitti_path,
129 + subset=args.subset,
130 + transform_generator=transform_generator,
131 + visual_effect_generator=visual_effect_generator,
132 + **common_args
133 + )
134 + else:
135 + raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
136 +
137 + return generator
138 +
139 +
140 +def parse_args(args):
141 + """ Parse the arguments.
142 + """
143 + parser = argparse.ArgumentParser(description='Debug script for a RetinaNet network.')
144 + subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
145 + subparsers.required = True
146 +
147 + coco_parser = subparsers.add_parser('coco')
148 + coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
149 + coco_parser.add_argument('--coco-set', help='Name of the set to show (defaults to val2017).', default='val2017')
150 +
151 + pascal_parser = subparsers.add_parser('pascal')
152 + pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
153 + pascal_parser.add_argument('--pascal-set', help='Name of the set to show (defaults to test).', default='test')
154 + pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg')
155 +
156 + kitti_parser = subparsers.add_parser('kitti')
157 + kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).')
158 + kitti_parser.add_argument('subset', help='Argument for loading a subset from train/val.')
159 +
160 + def csv_list(string):
161 + return string.split(',')
162 +
163 + oid_parser = subparsers.add_parser('oid')
164 + oid_parser.add_argument('main_dir', help='Path to dataset directory.')
165 + oid_parser.add_argument('subset', help='Argument for loading a subset from train/validation/test.')
166 + oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4')
167 + oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None)
168 + oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.')
169 + oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None)
170 +
171 + csv_parser = subparsers.add_parser('csv')
172 + csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.')
173 + csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
174 +
175 + parser.add_argument('--no-resize', help='Disable image resizing.', dest='resize', action='store_false')
176 + parser.add_argument('--anchors', help='Show positive anchors on the image.', action='store_true')
177 + parser.add_argument('--display-name', help='Display image name on the bottom left corner.', action='store_true')
178 + parser.add_argument('--show-annotations', help='Show annotations on the image. Green annotations have anchors, red annotations don\'t and therefore don\'t contribute to training.', action='store_true')
179 + parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true')
180 + parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800)
181 + parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333)
182 + parser.add_argument('--config', help='Path to a configuration parameters .ini file.')
183 + parser.add_argument('--no-gui', help='Do not open a GUI window. Save images to an output directory instead.', action='store_true')
184 + parser.add_argument('--output-dir', help='The output directory to save images to if --no-gui is specified.', default='.')
185 + parser.add_argument('--flatten-output', help='Flatten the folder structure of saved output images into a single folder.', action='store_true')
186 + parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio'])
187 +
188 + return parser.parse_args(args)
189 +
190 +
191 +def run(generator, args, anchor_params, pyramid_levels):
192 + """ Main loop.
193 +
194 + Args
195 + generator: The generator to debug.
196 + args: parseargs args object.
197 + """
198 + # display images, one at a time
199 + i = 0
200 + while True:
201 + # load the data
202 + image = generator.load_image(i)
203 + annotations = generator.load_annotations(i)
204 + if len(annotations['labels']) > 0 :
205 + # apply random transformations
206 + if args.random_transform:
207 + image, annotations = generator.random_transform_group_entry(image, annotations)
208 + image, annotations = generator.random_visual_effect_group_entry(image, annotations)
209 +
210 + # resize the image and annotations
211 + if args.resize:
212 + image, image_scale = generator.resize_image(image)
213 + annotations['bboxes'] *= image_scale
214 +
215 + anchors = anchors_for_shape(image.shape, anchor_params=anchor_params, pyramid_levels=pyramid_levels)
216 + positive_indices, _, max_indices = compute_gt_annotations(anchors, annotations['bboxes'])
217 +
218 + # draw anchors on the image
219 + if args.anchors:
220 + draw_boxes(image, anchors[positive_indices], (255, 255, 0), thickness=1)
221 +
222 + # draw annotations on the image
223 + if args.show_annotations:
224 + # draw annotations in red
225 + draw_annotations(image, annotations, color=(0, 0, 255), label_to_name=generator.label_to_name)
226 +
227 + # draw regressed anchors in green to override most red annotations
228 + # result is that annotations without anchors are red, with anchors are green
229 + draw_boxes(image, annotations['bboxes'][max_indices[positive_indices], :], (0, 255, 0))
230 +
231 + # display name on the image
232 + if args.display_name:
233 + draw_caption(image, [0, image.shape[0]], os.path.basename(generator.image_path(i)))
234 +
235 + # write to file and advance if no-gui selected
236 + if args.no_gui:
237 + output_path = make_output_path(args.output_dir, generator.image_path(i), flatten=args.flatten_output)
238 + os.makedirs(os.path.dirname(output_path), exist_ok=True)
239 + cv2.imwrite(output_path, image)
240 + i += 1
241 + if i == generator.size(): # have written all images
242 + break
243 + else:
244 + continue
245 +
246 + # if we are using the GUI, then show an image
247 + cv2.imshow('Image', image)
248 + key = cv2.waitKeyEx()
249 +
250 + # press right for next image and left for previous (linux or windows, doesn't work for macOS)
251 + # if you run macOS, press "n" or "m" (will also work on linux and windows)
252 +
253 + if key in rightkeys:
254 + i = (i + 1) % generator.size()
255 + if key in leftkeys:
256 + i -= 1
257 + if i < 0:
258 + i = generator.size() - 1
259 +
260 + # press q or Esc to quit
261 + if (key == ord('q')) or (key == 27):
262 + return False
263 +
264 + return True
265 +
266 +
267 +def make_output_path(output_dir, image_path, flatten = False):
268 + """ Compute the output path for a debug image. """
269 +
270 + # If the output hierarchy is flattened to a single folder, throw away all leading folders.
271 + if flatten:
272 + path = os.path.basename(image_path)
273 +
274 + # Otherwise, make sure absolute paths are taken relative to the filesystem root.
275 + else:
276 + # Make sure to drop drive letters on Windows, otherwise relpath wil fail.
277 + _, path = os.path.splitdrive(image_path)
278 + if os.path.isabs(path):
279 + path = os.path.relpath(path, '/')
280 +
281 + # In all cases, append "_debug" to the filename, before the extension.
282 + base, extension = os.path.splitext(path)
283 + path = base + "_debug" + extension
284 +
285 + # Finally, join the whole thing to the output directory.
286 + return os.path.join(output_dir, path)
287 +
288 +
289 +def main(args=None):
290 + # parse arguments
291 + if args is None:
292 + args = sys.argv[1:]
293 + args = parse_args(args)
294 +
295 + # make sure tensorflow is the minimum required version
296 + check_tf_version()
297 +
298 + # create the generator
299 + generator = create_generator(args)
300 +
301 + # optionally load config parameters
302 + if args.config:
303 + args.config = read_config_file(args.config)
304 +
305 + # optionally load anchor parameters
306 + anchor_params = None
307 + if args.config and 'anchor_parameters' in args.config:
308 + anchor_params = parse_anchor_parameters(args.config)
309 +
310 + pyramid_levels = None
311 + if args.config and 'pyramid_levels' in args.config:
312 + pyramid_levels = parse_pyramid_levels(args.config)
313 + # create the display window if necessary
314 + if not args.no_gui:
315 + cv2.namedWindow('Image', cv2.WINDOW_NORMAL)
316 +
317 + run(generator, args, anchor_params=anchor_params, pyramid_levels=pyramid_levels)
318 +
319 +
320 +if __name__ == '__main__':
321 + main()
1 +#!/usr/bin/env python
2 +
3 +"""
4 +Copyright 2017-2018 Fizyr (https://fizyr.com)
5 +
6 +Licensed under the Apache License, Version 2.0 (the "License");
7 +you may not use this file except in compliance with the License.
8 +You may obtain a copy of the License at
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import argparse
18 +import os
19 +import sys
20 +
21 +# Allow relative imports when being executed as script.
22 +if __name__ == "__main__" and __package__ is None:
23 + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
24 + import keras_retinanet.bin # noqa: F401
25 + __package__ = "keras_retinanet.bin"
26 +
27 +# Change these to absolute imports if you copy this script outside the keras_retinanet package.
28 +from .. import models
29 +from ..preprocessing.csv_generator import CSVGenerator
30 +from ..preprocessing.pascal_voc import PascalVocGenerator
31 +from ..utils.anchors import make_shapes_callback
32 +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels
33 +from ..utils.eval import evaluate
34 +from ..utils.gpu import setup_gpu
35 +from ..utils.tf_version import check_tf_version
36 +
37 +
38 +def create_generator(args, preprocess_image):
39 + """ Create generators for evaluation.
40 + """
41 + common_args = {
42 + 'config' : args.config,
43 + 'image_min_side' : args.image_min_side,
44 + 'image_max_side' : args.image_max_side,
45 + 'no_resize' : args.no_resize,
46 + 'preprocess_image' : preprocess_image,
47 + 'group_method' : args.group_method
48 + }
49 +
50 + if args.dataset_type == 'coco':
51 + # import here to prevent unnecessary dependency on cocoapi
52 + from ..preprocessing.coco import CocoGenerator
53 +
54 + validation_generator = CocoGenerator(
55 + args.coco_path,
56 + 'val2017',
57 + shuffle_groups=False,
58 + **common_args
59 + )
60 + elif args.dataset_type == 'pascal':
61 + validation_generator = PascalVocGenerator(
62 + args.pascal_path,
63 + 'test',
64 + image_extension=args.image_extension,
65 + shuffle_groups=False,
66 + **common_args
67 + )
68 + elif args.dataset_type == 'csv':
69 + validation_generator = CSVGenerator(
70 + args.annotations,
71 + args.classes,
72 + shuffle_groups=False,
73 + **common_args
74 + )
75 + else:
76 + raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
77 +
78 + return validation_generator
79 +
80 +
81 +def parse_args(args):
82 + """ Parse the arguments.
83 + """
84 + parser = argparse.ArgumentParser(description='Evaluation script for a RetinaNet network.')
85 + subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
86 + subparsers.required = True
87 +
88 + coco_parser = subparsers.add_parser('coco')
89 + coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
90 +
91 + pascal_parser = subparsers.add_parser('pascal')
92 + pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
93 + pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg')
94 +
95 + csv_parser = subparsers.add_parser('csv')
96 + csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.')
97 + csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
98 +
99 + parser.add_argument('model', help='Path to RetinaNet model.')
100 + parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true')
101 + parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50')
102 + parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
103 + parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float)
104 + parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float)
105 + parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int)
106 + parser.add_argument('--save-path', help='Path for saving images with detections (doesn\'t work for COCO).')
107 + parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800)
108 + parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333)
109 + parser.add_argument('--no-resize', help='Don''t rescale the image.', action='store_true')
110 + parser.add_argument('--config', help='Path to a configuration parameters .ini file (only used with --convert-model).')
111 + parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio'])
112 +
113 + return parser.parse_args(args)
114 +
115 +
116 +def main(args=None):
117 + # parse arguments
118 + if args is None:
119 + args = sys.argv[1:]
120 + args = parse_args(args)
121 +
122 + # make sure tensorflow is the minimum required version
123 + check_tf_version()
124 +
125 + # optionally choose specific GPU
126 + if args.gpu:
127 + setup_gpu(args.gpu)
128 +
129 + # make save path if it doesn't exist
130 + if args.save_path is not None and not os.path.exists(args.save_path):
131 + os.makedirs(args.save_path)
132 +
133 + # optionally load config parameters
134 + if args.config:
135 + args.config = read_config_file(args.config)
136 +
137 + # create the generator
138 + backbone = models.backbone(args.backbone)
139 + generator = create_generator(args, backbone.preprocess_image)
140 +
141 + # optionally load anchor parameters
142 + anchor_params = None
143 + pyramid_levels = None
144 + if args.config and 'anchor_parameters' in args.config:
145 + anchor_params = parse_anchor_parameters(args.config)
146 + if args.config and 'pyramid_levels' in args.config:
147 + pyramid_levels = parse_pyramid_levels(args.config)
148 +
149 + # load the model
150 + print('Loading model, this may take a second...')
151 + model = models.load_model(args.model, backbone_name=args.backbone)
152 + generator.compute_shapes = make_shapes_callback(model)
153 +
154 + # optionally convert the model
155 + if args.convert_model:
156 + model = models.convert_model(model, anchor_params=anchor_params, pyramid_levels=pyramid_levels)
157 +
158 + # print model summary
159 + # print(model.summary())
160 +
161 + # start evaluation
162 + if args.dataset_type == 'coco':
163 + from ..utils.coco_eval import evaluate_coco
164 + evaluate_coco(generator, model, args.score_threshold)
165 + else:
166 + average_precisions, inference_time = evaluate(
167 + generator,
168 + model,
169 + iou_threshold=args.iou_threshold,
170 + score_threshold=args.score_threshold,
171 + max_detections=args.max_detections,
172 + save_path=args.save_path
173 + )
174 +
175 + # print evaluation
176 + total_instances = []
177 + precisions = []
178 + for label, (average_precision, num_annotations) in average_precisions.items():
179 + print('{:.0f} instances of class'.format(num_annotations),
180 + generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision))
181 + total_instances.append(num_annotations)
182 + precisions.append(average_precision)
183 +
184 + if sum(total_instances) == 0:
185 + print('No test instances found.')
186 + return
187 +
188 + print('Inference time for {:.0f} images: {:.4f}'.format(generator.size(), inference_time))
189 +
190 + print('mAP using the weighted average of precisions among classes: {:.4f}'.format(sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)))
191 + print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances)))
192 +
193 +
194 +if __name__ == '__main__':
195 + main()
1 +#!/usr/bin/env python
2 +
3 +"""
4 +Copyright 2017-2018 Fizyr (https://fizyr.com)
5 +
6 +Licensed under the Apache License, Version 2.0 (the "License");
7 +you may not use this file except in compliance with the License.
8 +You may obtain a copy of the License at
9 +
10 + http://www.apache.org/licenses/LICENSE-2.0
11 +
12 +Unless required by applicable law or agreed to in writing, software
13 +distributed under the License is distributed on an "AS IS" BASIS,
14 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 +See the License for the specific language governing permissions and
16 +limitations under the License.
17 +"""
18 +
19 +import argparse
20 +import os
21 +import sys
22 +import warnings
23 +
24 +from tensorflow import keras
25 +import tensorflow as tf
26 +
27 +# Allow relative imports when being executed as script.
28 +if __name__ == "__main__" and __package__ is None:
29 + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
30 + import keras_retinanet.bin # noqa: F401
31 +
32 + __package__ = "keras_retinanet.bin"
33 +
34 +# Change these to absolute imports if you copy this script outside the keras_retinanet package.
35 +from .. import layers # noqa: F401
36 +from .. import losses
37 +from .. import models
38 +from ..callbacks import RedirectModel
39 +from ..callbacks.eval import Evaluate
40 +from ..models.retinanet import retinanet_bbox
41 +from ..preprocessing.csv_generator import CSVGenerator
42 +from ..preprocessing.kitti import KittiGenerator
43 +from ..preprocessing.open_images import OpenImagesGenerator
44 +from ..preprocessing.pascal_voc import PascalVocGenerator
45 +from ..utils.anchors import make_shapes_callback
46 +from ..utils.config import (
47 + read_config_file,
48 + parse_anchor_parameters,
49 + parse_pyramid_levels,
50 +)
51 +from ..utils.gpu import setup_gpu
52 +from ..utils.image import random_visual_effect_generator
53 +from ..utils.model import freeze as freeze_model
54 +from ..utils.tf_version import check_tf_version
55 +from ..utils.transform import random_transform_generator
56 +
57 +#######################
58 +
59 +from ..models import submodel
60 +
61 +
62 +def makedirs(path):
63 + # Intended behavior: try to create the directory,
64 + # pass if the directory exists already, fails otherwise.
65 + # Meant for Python 2.7/3.n compatibility.
66 + try:
67 + os.makedirs(path)
68 + except OSError:
69 + if not os.path.isdir(path):
70 + raise
71 +
72 +
73 +def model_with_weights(model, weights, skip_mismatch):
74 + """Load weights for model.
75 +
76 + Args
77 + model : The model to load weights for.
78 + weights : The weights to load.
79 + skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model.
80 + """
81 + if weights is not None:
82 + model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)
83 + return model
84 +
85 +
86 +def create_models(
87 + backbone_retinanet,
88 + num_classes,
89 + weights,
90 + multi_gpu=0,
91 + freeze_backbone=False,
92 + lr=1e-5,
93 + optimizer_clipnorm=0.001,
94 + config=None,
95 + submodels=None,
96 +):
97 + """Creates three models (model, training_model, prediction_model).
98 +
99 + Args
100 + backbone_retinanet : A function to call to create a retinanet model with a given backbone.
101 + num_classes : The number of classes to train.
102 + weights : The weights to load into the model.
103 + multi_gpu : The number of GPUs to use for training.
104 + freeze_backbone : If True, disables learning for the backbone.
105 + config : Config parameters, None indicates the default configuration.
106 +
107 + Returns
108 + model : The base model. This is also the model that is saved in snapshots.
109 + training_model : The training model. If multi_gpu=0, this is identical to model.
110 + prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS).
111 + """
112 +
113 + modifier = freeze_model if freeze_backbone else None
114 +
115 + # load anchor parameters, or pass None (so that defaults will be used)
116 + anchor_params = None
117 + num_anchors = None
118 + pyramid_levels = None
119 + if config and "anchor_parameters" in config:
120 + anchor_params = parse_anchor_parameters(config)
121 + num_anchors = anchor_params.num_anchors()
122 + if config and "pyramid_levels" in config:
123 + pyramid_levels = parse_pyramid_levels(config)
124 +
125 + # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors.
126 + # optionally wrap in a parallel model
127 + if multi_gpu > 1:
128 + from keras.utils import multi_gpu_model
129 +
130 + with tf.device("/cpu:0"):
131 + model = model_with_weights(
132 + backbone_retinanet(
133 + num_classes,
134 + num_anchors=num_anchors,
135 + modifier=modifier,
136 + pyramid_levels=pyramid_levels,
137 + ),
138 + weights=weights,
139 + skip_mismatch=True,
140 + )
141 + training_model = multi_gpu_model(model, gpus=multi_gpu)
142 + else:
143 + model = model_with_weights(
144 + backbone_retinanet(
145 + num_classes,
146 + num_anchors=num_anchors,
147 + modifier=modifier,
148 + pyramid_levels=pyramid_levels,
149 + submodels=submodels,
150 + ),
151 + weights=weights,
152 + skip_mismatch=True,
153 + )
154 + training_model = model
155 +
156 + # make prediction model
157 + prediction_model = retinanet_bbox(
158 + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels
159 + )
160 +
161 + # compile model
162 + training_model.compile(
163 + loss={"regression": losses.smooth_l1(), "classification": losses.focal()},
164 + optimizer=keras.optimizers.Adam(lr=lr, clipnorm=optimizer_clipnorm),
165 + )
166 +
167 + return model, training_model, prediction_model
168 +
169 +
170 +def create_callbacks(
171 + model, training_model, prediction_model, validation_generator, args
172 +):
173 + """Creates the callbacks to use during training.
174 +
175 + Args
176 + model: The base model.
177 + training_model: The model that is used for training.
178 + prediction_model: The model that should be used for validation.
179 + validation_generator: The generator for creating validation data.
180 + args: parseargs args object.
181 +
182 + Returns:
183 + A list of callbacks used for training.
184 + """
185 + callbacks = []
186 +
187 + tensorboard_callback = None
188 +
189 + if args.tensorboard_dir:
190 + makedirs(args.tensorboard_dir)
191 + update_freq = args.tensorboard_freq
192 + if update_freq not in ["epoch", "batch"]:
193 + update_freq = int(update_freq)
194 + tensorboard_callback = keras.callbacks.TensorBoard(
195 + log_dir=args.tensorboard_dir,
196 + histogram_freq=0,
197 + batch_size=args.batch_size,
198 + write_graph=True,
199 + write_grads=False,
200 + write_images=False,
201 + update_freq=update_freq,
202 + embeddings_freq=0,
203 + embeddings_layer_names=None,
204 + embeddings_metadata=None,
205 + )
206 +
207 + if args.evaluation and validation_generator:
208 + if args.dataset_type == "coco":
209 + from ..callbacks.coco import CocoEval
210 +
211 + # use prediction model for evaluation
212 + evaluation = CocoEval(
213 + validation_generator, tensorboard=tensorboard_callback
214 + )
215 + else:
216 + evaluation = Evaluate(
217 + validation_generator,
218 + tensorboard=tensorboard_callback,
219 + weighted_average=args.weighted_average,
220 + )
221 + evaluation = RedirectModel(evaluation, prediction_model)
222 + callbacks.append(evaluation)
223 +
224 + # save the model
225 + if args.snapshots:
226 + # ensure directory created first; otherwise h5py will error after epoch.
227 + makedirs(args.snapshot_path)
228 + checkpoint = keras.callbacks.ModelCheckpoint(
229 + os.path.join(
230 + args.snapshot_path,
231 + "{backbone}_{dataset_type}_{{epoch:02d}}.h5".format(
232 + backbone=args.backbone, dataset_type=args.dataset_type
233 + ),
234 + ),
235 + verbose=1,
236 + # save_best_only=True,
237 + # monitor="mAP",
238 + # mode='max'
239 + )
240 + checkpoint = RedirectModel(checkpoint, model)
241 + callbacks.append(checkpoint)
242 +
243 + callbacks.append(
244 + keras.callbacks.ReduceLROnPlateau(
245 + monitor="loss",
246 + factor=args.reduce_lr_factor,
247 + patience=args.reduce_lr_patience,
248 + verbose=1,
249 + mode="auto",
250 + min_delta=0.0001,
251 + cooldown=0,
252 + min_lr=0,
253 + )
254 + )
255 +
256 + if args.evaluation and validation_generator:
257 + callbacks.append(
258 + keras.callbacks.EarlyStopping(
259 + monitor="mAP", patience=5, mode="max", min_delta=0.01
260 + )
261 + )
262 +
263 + if args.tensorboard_dir:
264 + callbacks.append(tensorboard_callback)
265 +
266 + return callbacks
267 +
268 +
269 +def create_generators(args, preprocess_image):
270 + """Create generators for training and validation.
271 +
272 + Args
273 + args : parseargs object containing configuration for generators.
274 + preprocess_image : Function that preprocesses an image for the network.
275 + """
276 + common_args = {
277 + "batch_size": args.batch_size,
278 + "config": args.config,
279 + "image_min_side": args.image_min_side,
280 + "image_max_side": args.image_max_side,
281 + "no_resize": args.no_resize,
282 + "preprocess_image": preprocess_image,
283 + "group_method": args.group_method,
284 + }
285 +
286 + # create random transform generator for augmenting training data
287 + if args.random_transform:
288 + transform_generator = random_transform_generator(
289 + min_rotation=-0.1,
290 + max_rotation=0.1,
291 + min_translation=(-0.1, -0.1),
292 + max_translation=(0.1, 0.1),
293 + min_shear=-0.1,
294 + max_shear=0.1,
295 + min_scaling=(0.9, 0.9),
296 + max_scaling=(1.1, 1.1),
297 + flip_x_chance=0.5,
298 + flip_y_chance=0.5,
299 + )
300 + visual_effect_generator = random_visual_effect_generator(
301 + contrast_range=(0.9, 1.1),
302 + brightness_range=(-0.1, 0.1),
303 + hue_range=(-0.05, 0.05),
304 + saturation_range=(0.95, 1.05),
305 + )
306 + else:
307 + transform_generator = random_transform_generator(flip_x_chance=0.5)
308 + visual_effect_generator = None
309 +
310 + if args.dataset_type == "coco":
311 + # import here to prevent unnecessary dependency on cocoapi
312 + from ..preprocessing.coco import CocoGenerator
313 +
314 + train_generator = CocoGenerator(
315 + args.coco_path,
316 + "train2017",
317 + transform_generator=transform_generator,
318 + visual_effect_generator=visual_effect_generator,
319 + **common_args
320 + )
321 +
322 + validation_generator = CocoGenerator(
323 + args.coco_path, "val2017", shuffle_groups=False, **common_args
324 + )
325 + elif args.dataset_type == "pascal":
326 + train_generator = PascalVocGenerator(
327 + args.pascal_path,
328 + "train",
329 + image_extension=args.image_extension,
330 + transform_generator=transform_generator,
331 + visual_effect_generator=visual_effect_generator,
332 + **common_args
333 + )
334 +
335 + validation_generator = PascalVocGenerator(
336 + args.pascal_path,
337 + "val",
338 + image_extension=args.image_extension,
339 + shuffle_groups=False,
340 + **common_args
341 + )
342 + elif args.dataset_type == "csv":
343 + train_generator = CSVGenerator(
344 + args.annotations,
345 + args.classes,
346 + transform_generator=transform_generator,
347 + visual_effect_generator=visual_effect_generator,
348 + **common_args
349 + )
350 +
351 + if args.val_annotations:
352 + validation_generator = CSVGenerator(
353 + args.val_annotations, args.classes, shuffle_groups=False, **common_args
354 + )
355 + else:
356 + validation_generator = None
357 + elif args.dataset_type == "oid":
358 + train_generator = OpenImagesGenerator(
359 + args.main_dir,
360 + subset="train",
361 + version=args.version,
362 + labels_filter=args.labels_filter,
363 + annotation_cache_dir=args.annotation_cache_dir,
364 + parent_label=args.parent_label,
365 + transform_generator=transform_generator,
366 + visual_effect_generator=visual_effect_generator,
367 + **common_args
368 + )
369 +
370 + validation_generator = OpenImagesGenerator(
371 + args.main_dir,
372 + subset="validation",
373 + version=args.version,
374 + labels_filter=args.labels_filter,
375 + annotation_cache_dir=args.annotation_cache_dir,
376 + parent_label=args.parent_label,
377 + shuffle_groups=False,
378 + **common_args
379 + )
380 + elif args.dataset_type == "kitti":
381 + train_generator = KittiGenerator(
382 + args.kitti_path,
383 + subset="train",
384 + transform_generator=transform_generator,
385 + visual_effect_generator=visual_effect_generator,
386 + **common_args
387 + )
388 +
389 + validation_generator = KittiGenerator(
390 + args.kitti_path, subset="val", shuffle_groups=False, **common_args
391 + )
392 + else:
393 + raise ValueError("Invalid data type received: {}".format(args.dataset_type))
394 +
395 + return train_generator, validation_generator
396 +
397 +
398 +def check_args(parsed_args):
399 + """Function to check for inherent contradictions within parsed arguments.
400 + For example, batch_size < num_gpus
401 + Intended to raise errors prior to backend initialisation.
402 +
403 + Args
404 + parsed_args: parser.parse_args()
405 +
406 + Returns
407 + parsed_args
408 + """
409 +
410 + if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu:
411 + raise ValueError(
412 + "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(
413 + parsed_args.batch_size, parsed_args.multi_gpu
414 + )
415 + )
416 +
417 + if parsed_args.multi_gpu > 1 and parsed_args.snapshot:
418 + raise ValueError(
419 + "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(
420 + parsed_args.multi_gpu, parsed_args.snapshot
421 + )
422 + )
423 +
424 + if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force:
425 + raise ValueError(
426 + "Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue."
427 + )
428 +
429 + if "resnet" not in parsed_args.backbone:
430 + warnings.warn(
431 + "Using experimental backbone {}. Only resnet50 has been properly tested.".format(
432 + parsed_args.backbone
433 + )
434 + )
435 +
436 + return parsed_args
437 +
438 +
439 +def parse_args(args):
440 + """Parse the arguments."""
441 + parser = argparse.ArgumentParser(
442 + description="Simple training script for training a RetinaNet network."
443 + )
444 + subparsers = parser.add_subparsers(
445 + help="Arguments for specific dataset types.", dest="dataset_type"
446 + )
447 + subparsers.required = True
448 +
449 + coco_parser = subparsers.add_parser("coco")
450 + coco_parser.add_argument(
451 + "coco_path", help="Path to dataset directory (ie. /tmp/COCO)."
452 + )
453 +
454 + pascal_parser = subparsers.add_parser("pascal")
455 + pascal_parser.add_argument(
456 + "pascal_path", help="Path to dataset directory (ie. /tmp/VOCdevkit)."
457 + )
458 + pascal_parser.add_argument(
459 + "--image-extension",
460 + help="Declares the dataset images' extension.",
461 + default=".jpg",
462 + )
463 +
464 + kitti_parser = subparsers.add_parser("kitti")
465 + kitti_parser.add_argument(
466 + "kitti_path", help="Path to dataset directory (ie. /tmp/kitti)."
467 + )
468 +
469 + def csv_list(string):
470 + return string.split(",")
471 +
472 + oid_parser = subparsers.add_parser("oid")
473 + oid_parser.add_argument("main_dir", help="Path to dataset directory.")
474 + oid_parser.add_argument(
475 + "--version", help="The current dataset version is v4.", default="v4"
476 + )
477 + oid_parser.add_argument(
478 + "--labels-filter",
479 + help="A list of labels to filter.",
480 + type=csv_list,
481 + default=None,
482 + )
483 + oid_parser.add_argument(
484 + "--annotation-cache-dir", help="Path to store annotation cache.", default="."
485 + )
486 + oid_parser.add_argument(
487 + "--parent-label", help="Use the hierarchy children of this label.", default=None
488 + )
489 +
490 + csv_parser = subparsers.add_parser("csv")
491 + csv_parser.add_argument(
492 + "annotations", help="Path to CSV file containing annotations for training."
493 + )
494 + csv_parser.add_argument(
495 + "classes", help="Path to a CSV file containing class label mapping."
496 + )
497 + csv_parser.add_argument(
498 + "--val-annotations",
499 + help="Path to CSV file containing annotations for validation (optional).",
500 + )
501 +
502 + group = parser.add_mutually_exclusive_group()
503 + group.add_argument("--snapshot", help="Resume training from a snapshot.")
504 + group.add_argument(
505 + "--imagenet-weights",
506 + help="Initialize the model with pretrained imagenet weights. This is the default behaviour.",
507 + action="store_const",
508 + const=True,
509 + default=True,
510 + )
511 + group.add_argument(
512 + "--weights", help="Initialize the model with weights from a file."
513 + )
514 + group.add_argument(
515 + "--no-weights",
516 + help="Don't initialize the model with any weights.",
517 + dest="imagenet_weights",
518 + action="store_const",
519 + const=False,
520 + )
521 + parser.add_argument(
522 + "--backbone",
523 + help="Backbone model used by retinanet.",
524 + default="resnet50",
525 + type=str,
526 + )
527 + parser.add_argument(
528 + "--batch-size", help="Size of the batches.", default=1, type=int
529 + )
530 + parser.add_argument(
531 + "--gpu", help="Id of the GPU to use (as reported by nvidia-smi)."
532 + )
533 + parser.add_argument(
534 + "--multi-gpu",
535 + help="Number of GPUs to use for parallel processing.",
536 + type=int,
537 + default=0,
538 + )
539 + parser.add_argument(
540 + "--multi-gpu-force",
541 + help="Extra flag needed to enable (experimental) multi-gpu support.",
542 + action="store_true",
543 + )
544 + parser.add_argument(
545 + "--initial-epoch",
546 + help="Epoch from which to begin the train, useful if resuming from snapshot.",
547 + type=int,
548 + default=0,
549 + )
550 + parser.add_argument(
551 + "--epochs", help="Number of epochs to train.", type=int, default=50
552 + )
553 + parser.add_argument(
554 + "--steps", help="Number of steps per epoch.", type=int, default=10000
555 + )
556 + parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-5)
557 + parser.add_argument(
558 + "--optimizer-clipnorm",
559 + help="Clipnorm parameter for optimizer.",
560 + type=float,
561 + default=0.001,
562 + )
563 + parser.add_argument(
564 + "--snapshot-path",
565 + help="Path to store snapshots of models during training (defaults to './snapshots')",
566 + default="./snapshots",
567 + )
568 + parser.add_argument(
569 + "--tensorboard-dir", help="Log directory for Tensorboard output", default=""
570 + ) # default='./logs') => https://github.com/tensorflow/tensorflow/pull/34870
571 + parser.add_argument(
572 + "--tensorboard-freq",
573 + help="Update frequency for Tensorboard output. Values 'epoch', 'batch' or int",
574 + default="epoch",
575 + )
576 + parser.add_argument(
577 + "--no-snapshots",
578 + help="Disable saving snapshots.",
579 + dest="snapshots",
580 + action="store_false",
581 + )
582 + parser.add_argument(
583 + "--no-evaluation",
584 + help="Disable per epoch evaluation.",
585 + dest="evaluation",
586 + action="store_false",
587 + )
588 + parser.add_argument(
589 + "--freeze-backbone",
590 + help="Freeze training of backbone layers.",
591 + action="store_true",
592 + )
593 + parser.add_argument(
594 + "--random-transform",
595 + help="Randomly transform image and annotations.",
596 + action="store_true",
597 + )
598 + parser.add_argument(
599 + "--image-min-side",
600 + help="Rescale the image so the smallest side is min_side.",
601 + type=int,
602 + default=800,
603 + )
604 + parser.add_argument(
605 + "--image-max-side",
606 + help="Rescale the image if the largest side is larger than max_side.",
607 + type=int,
608 + default=1333,
609 + )
610 + parser.add_argument(
611 + "--no-resize", help="Don" "t rescale the image.", action="store_true"
612 + )
613 + parser.add_argument(
614 + "--config", help="Path to a configuration parameters .ini file."
615 + )
616 + parser.add_argument(
617 + "--weighted-average",
618 + help="Compute the mAP using the weighted average of precisions among classes.",
619 + action="store_true",
620 + )
621 + parser.add_argument(
622 + "--compute-val-loss",
623 + help="Compute validation loss during training",
624 + dest="compute_val_loss",
625 + action="store_true",
626 + )
627 + parser.add_argument(
628 + "--reduce-lr-patience",
629 + help="Reduce learning rate after validation loss decreases over reduce_lr_patience epochs",
630 + type=int,
631 + default=2,
632 + )
633 + parser.add_argument(
634 + "--reduce-lr-factor",
635 + help="When learning rate is reduced due to reduce_lr_patience, multiply by reduce_lr_factor",
636 + type=float,
637 + default=0.1,
638 + )
639 + parser.add_argument(
640 + "--group-method",
641 + help="Determines how images are grouped together",
642 + type=str,
643 + default="ratio",
644 + choices=["none", "random", "ratio"],
645 + )
646 +
647 + # Fit generator arguments
648 + parser.add_argument(
649 + "--multiprocessing",
650 + help="Use multiprocessing in fit_generator.",
651 + action="store_true",
652 + )
653 + parser.add_argument(
654 + "--workers", help="Number of generator workers.", type=int, default=1
655 + )
656 + parser.add_argument(
657 + "--max-queue-size",
658 + help="Queue length for multiprocessing workers in fit_generator.",
659 + type=int,
660 + default=10,
661 + )
662 +
663 + return check_args(parser.parse_args(args))
664 +
665 +
666 +def main(args=None):
667 + # parse arguments
668 + if args is None:
669 + args = sys.argv[1:]
670 + args = parse_args(args)
671 +
672 + # create object that stores backbone information
673 + backbone = models.backbone(args.backbone)
674 +
675 + # make sure tensorflow is the minimum required version
676 + check_tf_version()
677 +
678 + # optionally choose specific GPU
679 + if args.gpu is not None:
680 + setup_gpu(args.gpu)
681 +
682 + # optionally load config parameters
683 + if args.config:
684 + args.config = read_config_file(args.config)
685 +
686 + # create the generators
687 + train_generator, validation_generator = create_generators(
688 + args, backbone.preprocess_image
689 + )
690 +
691 + # create the model
692 + if args.snapshot is not None:
693 + print("Loading model, this may take a second...")
694 + model = models.load_model(args.snapshot, backbone_name=args.backbone)
695 + training_model = model
696 + anchor_params = None
697 + pyramid_levels = None
698 + if args.config and "anchor_parameters" in args.config:
699 + anchor_params = parse_anchor_parameters(args.config)
700 + if args.config and "pyramid_levels" in args.config:
701 + pyramid_levels = parse_pyramid_levels(args.config)
702 +
703 + prediction_model = retinanet_bbox(
704 + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels
705 + )
706 + else:
707 + weights = args.weights
708 + # default to imagenet if nothing else is specified
709 + if weights is None and args.imagenet_weights:
710 + weights = backbone.download_imagenet()
711 +
712 + ################
713 + subclass1 = submodel.custom_classification_model(num_classes=51, num_anchors=None, name="classification_submodel1")
714 + subregress1 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel1")
715 +
716 + subclass2 = submodel.custom_classification_model(num_classes=10, num_anchors=None, name="classification_submodel2")
717 + subregress2 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel2")
718 +
719 + subclass3 = submodel.custom_classification_model(num_classes=16, num_anchors=None, name="classification_submodel3")
720 + subregress3 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel3")
721 +
722 + submodels = [
723 + ("regression", subregress1), ("classification", subclass1),
724 + ("regression", subregress2), ("classification", subclass2),
725 + ("regression", subregress3), ("classification", subclass3),
726 + ]
727 +
728 + s1 = submodel.custom_default_submodels(51, None)
729 + s2 = submodel.custom_default_submodels(10, None)
730 + s3 = submodel.custom_default_submodels(16, None)
731 +
732 + submodels = s1 + s2 + s3
733 +
734 + #################
735 + print("Creating model, this may take a second...")
736 + model, training_model, prediction_model = create_models(
737 + backbone_retinanet=backbone.retinanet,
738 + num_classes=train_generator.num_classes(),
739 + weights=weights,
740 + multi_gpu=args.multi_gpu,
741 + freeze_backbone=args.freeze_backbone,
742 + lr=args.lr,
743 + optimizer_clipnorm=args.optimizer_clipnorm,
744 + config=args.config,
745 + submodels=submodels,
746 + )
747 +
748 + # print model summary
749 + print(model.summary())
750 +
751 + # this lets the generator compute backbone layer shapes using the actual backbone model
752 + if "vgg" in args.backbone or "densenet" in args.backbone:
753 + train_generator.compute_shapes = make_shapes_callback(model)
754 + if validation_generator:
755 + validation_generator.compute_shapes = train_generator.compute_shapes
756 +
757 + # create the callbacks
758 + callbacks = create_callbacks(
759 + model,
760 + training_model,
761 + prediction_model,
762 + validation_generator,
763 + args,
764 + )
765 +
766 + if not args.compute_val_loss:
767 + validation_generator = None
768 +
769 + # start training
770 + return training_model.fit_generator(
771 + generator=train_generator,
772 + steps_per_epoch=args.steps,
773 + epochs=args.epochs,
774 + verbose=1,
775 + callbacks=callbacks,
776 + workers=args.workers,
777 + use_multiprocessing=args.multiprocessing,
778 + max_queue_size=args.max_queue_size,
779 + validation_data=validation_generator,
780 + initial_epoch=args.initial_epoch,
781 + )
782 +
783 +
784 +if __name__ == "__main__":
785 + main()
1 +#!/usr/bin/env python
2 +
3 +"""
4 +Copyright 2017-2018 Fizyr (https://fizyr.com)
5 +
6 +Licensed under the Apache License, Version 2.0 (the "License");
7 +you may not use this file except in compliance with the License.
8 +You may obtain a copy of the License at
9 +
10 + http://www.apache.org/licenses/LICENSE-2.0
11 +
12 +Unless required by applicable law or agreed to in writing, software
13 +distributed under the License is distributed on an "AS IS" BASIS,
14 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 +See the License for the specific language governing permissions and
16 +limitations under the License.
17 +"""
18 +
19 +import argparse
20 +import os
21 +import sys
22 +import warnings
23 +
24 +from tensorflow import keras
25 +import tensorflow as tf
26 +
27 +from ../models import submodel
28 +
29 +# Allow relative imports when being executed as script.
30 +if __name__ == "__main__" and __package__ is None:
31 + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
32 + import keras_retinanet.bin # noqa: F401
33 +
34 + __package__ = "keras_retinanet.bin"
35 +
36 +# Change these to absolute imports if you copy this script outside the keras_retinanet package.
37 +from .. import layers # noqa: F401
38 +from .. import losses
39 +from .. import models
40 +from ..callbacks import RedirectModel
41 +from ..callbacks.eval import Evaluate
42 +from ..models.retinanet import retinanet_bbox
43 +from ..preprocessing.csv_generator import CSVGenerator
44 +from ..preprocessing.kitti import KittiGenerator
45 +from ..preprocessing.open_images import OpenImagesGenerator
46 +from ..preprocessing.pascal_voc import PascalVocGenerator
47 +from ..utils.anchors import make_shapes_callback
48 +from ..utils.config import (
49 + read_config_file,
50 + parse_anchor_parameters,
51 + parse_pyramid_levels,
52 +)
53 +from ..utils.gpu import setup_gpu
54 +from ..utils.image import random_visual_effect_generator
55 +from ..utils.model import freeze as freeze_model
56 +from ..utils.tf_version import check_tf_version
57 +from ..utils.transform import random_transform_generator
58 +
59 +#######################
60 +
61 +from ..models import submodel
62 +
63 +
64 +def makedirs(path):
65 + # Intended behavior: try to create the directory,
66 + # pass if the directory exists already, fails otherwise.
67 + # Meant for Python 2.7/3.n compatibility.
68 + try:
69 + os.makedirs(path)
70 + except OSError:
71 + if not os.path.isdir(path):
72 + raise
73 +
74 +
75 +def model_with_weights(model, weights, skip_mismatch):
76 + """Load weights for model.
77 +
78 + Args
79 + model : The model to load weights for.
80 + weights : The weights to load.
81 + skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model.
82 + """
83 + if weights is not None:
84 + model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)
85 + return model
86 +
87 +
88 +def create_models(
89 + backbone_retinanet,
90 + num_classes,
91 + weights,
92 + multi_gpu=0,
93 + freeze_backbone=False,
94 + lr=1e-5,
95 + optimizer_clipnorm=0.001,
96 + config=None,
97 + submodels=None,
98 +):
99 + """Creates three models (model, training_model, prediction_model).
100 +
101 + Args
102 + backbone_retinanet : A function to call to create a retinanet model with a given backbone.
103 + num_classes : The number of classes to train.
104 + weights : The weights to load into the model.
105 + multi_gpu : The number of GPUs to use for training.
106 + freeze_backbone : If True, disables learning for the backbone.
107 + config : Config parameters, None indicates the default configuration.
108 +
109 + Returns
110 + model : The base model. This is also the model that is saved in snapshots.
111 + training_model : The training model. If multi_gpu=0, this is identical to model.
112 + prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS).
113 + """
114 +
115 + modifier = freeze_model if freeze_backbone else None
116 +
117 + # load anchor parameters, or pass None (so that defaults will be used)
118 + anchor_params = None
119 + num_anchors = None
120 + pyramid_levels = None
121 + if config and "anchor_parameters" in config:
122 + anchor_params = parse_anchor_parameters(config)
123 + num_anchors = anchor_params.num_anchors()
124 + if config and "pyramid_levels" in config:
125 + pyramid_levels = parse_pyramid_levels(config)
126 +
127 + # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors.
128 + # optionally wrap in a parallel model
129 + if multi_gpu > 1:
130 + from keras.utils import multi_gpu_model
131 +
132 + with tf.device("/cpu:0"):
133 + model = model_with_weights(
134 + backbone_retinanet(
135 + num_classes,
136 + num_anchors=num_anchors,
137 + modifier=modifier,
138 + pyramid_levels=pyramid_levels,
139 + ),
140 + weights=weights,
141 + skip_mismatch=True,
142 + )
143 + training_model = multi_gpu_model(model, gpus=multi_gpu)
144 + else:
145 + model = model_with_weights(
146 + backbone_retinanet(
147 + num_classes,
148 + num_anchors=num_anchors,
149 + modifier=modifier,
150 + pyramid_levels=pyramid_levels,
151 + submodels=submodels,
152 + ),
153 + weights=weights,
154 + skip_mismatch=True,
155 + )
156 + training_model = model
157 +
158 + # make prediction model
159 + prediction_model = retinanet_bbox(
160 + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels
161 + )
162 +
163 + # compile model
164 + training_model.compile(
165 + loss={"regression": losses.smooth_l1(), "classification": losses.focal()},
166 + optimizer=keras.optimizers.Adam(lr=lr, clipnorm=optimizer_clipnorm),
167 + )
168 +
169 + return model, training_model, prediction_model
170 +
171 +
172 +def create_callbacks(
173 + model, training_model, prediction_model, validation_generator, args
174 +):
175 + """Creates the callbacks to use during training.
176 +
177 + Args
178 + model: The base model.
179 + training_model: The model that is used for training.
180 + prediction_model: The model that should be used for validation.
181 + validation_generator: The generator for creating validation data.
182 + args: parseargs args object.
183 +
184 + Returns:
185 + A list of callbacks used for training.
186 + """
187 + callbacks = []
188 +
189 + tensorboard_callback = None
190 +
191 + if args.tensorboard_dir:
192 + makedirs(args.tensorboard_dir)
193 + update_freq = args.tensorboard_freq
194 + if update_freq not in ["epoch", "batch"]:
195 + update_freq = int(update_freq)
196 + tensorboard_callback = keras.callbacks.TensorBoard(
197 + log_dir=args.tensorboard_dir,
198 + histogram_freq=0,
199 + batch_size=args.batch_size,
200 + write_graph=True,
201 + write_grads=False,
202 + write_images=False,
203 + update_freq=update_freq,
204 + embeddings_freq=0,
205 + embeddings_layer_names=None,
206 + embeddings_metadata=None,
207 + )
208 +
209 + if args.evaluation and validation_generator:
210 + if args.dataset_type == "coco":
211 + from ..callbacks.coco import CocoEval
212 +
213 + # use prediction model for evaluation
214 + evaluation = CocoEval(
215 + validation_generator, tensorboard=tensorboard_callback
216 + )
217 + else:
218 + evaluation = Evaluate(
219 + validation_generator,
220 + tensorboard=tensorboard_callback,
221 + weighted_average=args.weighted_average,
222 + )
223 + evaluation = RedirectModel(evaluation, prediction_model)
224 + callbacks.append(evaluation)
225 +
226 + # save the model
227 + if args.snapshots:
228 + # ensure directory created first; otherwise h5py will error after epoch.
229 + makedirs(args.snapshot_path)
230 + checkpoint = keras.callbacks.ModelCheckpoint(
231 + os.path.join(
232 + args.snapshot_path,
233 + "{backbone}_{dataset_type}_{{epoch:02d}}.h5".format(
234 + backbone=args.backbone, dataset_type=args.dataset_type
235 + ),
236 + ),
237 + verbose=1,
238 + # save_best_only=True,
239 + # monitor="mAP",
240 + # mode='max'
241 + )
242 + checkpoint = RedirectModel(checkpoint, model)
243 + callbacks.append(checkpoint)
244 +
245 + callbacks.append(
246 + keras.callbacks.ReduceLROnPlateau(
247 + monitor="loss",
248 + factor=args.reduce_lr_factor,
249 + patience=args.reduce_lr_patience,
250 + verbose=1,
251 + mode="auto",
252 + min_delta=0.0001,
253 + cooldown=0,
254 + min_lr=0,
255 + )
256 + )
257 +
258 + if args.evaluation and validation_generator:
259 + callbacks.append(
260 + keras.callbacks.EarlyStopping(
261 + monitor="mAP", patience=5, mode="max", min_delta=0.01
262 + )
263 + )
264 +
265 + if args.tensorboard_dir:
266 + callbacks.append(tensorboard_callback)
267 +
268 + return callbacks
269 +
270 +
271 +def create_generators(args, preprocess_image):
272 + """Create generators for training and validation.
273 +
274 + Args
275 + args : parseargs object containing configuration for generators.
276 + preprocess_image : Function that preprocesses an image for the network.
277 + """
278 + common_args = {
279 + "batch_size": args.batch_size,
280 + "config": args.config,
281 + "image_min_side": args.image_min_side,
282 + "image_max_side": args.image_max_side,
283 + "no_resize": args.no_resize,
284 + "preprocess_image": preprocess_image,
285 + "group_method": args.group_method,
286 + }
287 +
288 + # create random transform generator for augmenting training data
289 + if args.random_transform:
290 + transform_generator = random_transform_generator(
291 + min_rotation=-0.1,
292 + max_rotation=0.1,
293 + min_translation=(-0.1, -0.1),
294 + max_translation=(0.1, 0.1),
295 + min_shear=-0.1,
296 + max_shear=0.1,
297 + min_scaling=(0.9, 0.9),
298 + max_scaling=(1.1, 1.1),
299 + flip_x_chance=0.5,
300 + flip_y_chance=0.5,
301 + )
302 + visual_effect_generator = random_visual_effect_generator(
303 + contrast_range=(0.9, 1.1),
304 + brightness_range=(-0.1, 0.1),
305 + hue_range=(-0.05, 0.05),
306 + saturation_range=(0.95, 1.05),
307 + )
308 + else:
309 + transform_generator = random_transform_generator(flip_x_chance=0.5)
310 + visual_effect_generator = None
311 +
312 + if args.dataset_type == "coco":
313 + # import here to prevent unnecessary dependency on cocoapi
314 + from ..preprocessing.coco import CocoGenerator
315 +
316 + train_generator = CocoGenerator(
317 + args.coco_path,
318 + "train2017",
319 + transform_generator=transform_generator,
320 + visual_effect_generator=visual_effect_generator,
321 + **common_args
322 + )
323 +
324 + validation_generator = CocoGenerator(
325 + args.coco_path, "val2017", shuffle_groups=False, **common_args
326 + )
327 + elif args.dataset_type == "pascal":
328 + train_generator = PascalVocGenerator(
329 + args.pascal_path,
330 + "train",
331 + image_extension=args.image_extension,
332 + transform_generator=transform_generator,
333 + visual_effect_generator=visual_effect_generator,
334 + **common_args
335 + )
336 +
337 + validation_generator = PascalVocGenerator(
338 + args.pascal_path,
339 + "val",
340 + image_extension=args.image_extension,
341 + shuffle_groups=False,
342 + **common_args
343 + )
344 + elif args.dataset_type == "csv":
345 + train_generator = CSVGenerator(
346 + args.annotations,
347 + args.classes,
348 + transform_generator=transform_generator,
349 + visual_effect_generator=visual_effect_generator,
350 + **common_args
351 + )
352 +
353 + if args.val_annotations:
354 + validation_generator = CSVGenerator(
355 + args.val_annotations, args.classes, shuffle_groups=False, **common_args
356 + )
357 + else:
358 + validation_generator = None
359 + elif args.dataset_type == "oid":
360 + train_generator = OpenImagesGenerator(
361 + args.main_dir,
362 + subset="train",
363 + version=args.version,
364 + labels_filter=args.labels_filter,
365 + annotation_cache_dir=args.annotation_cache_dir,
366 + parent_label=args.parent_label,
367 + transform_generator=transform_generator,
368 + visual_effect_generator=visual_effect_generator,
369 + **common_args
370 + )
371 +
372 + validation_generator = OpenImagesGenerator(
373 + args.main_dir,
374 + subset="validation",
375 + version=args.version,
376 + labels_filter=args.labels_filter,
377 + annotation_cache_dir=args.annotation_cache_dir,
378 + parent_label=args.parent_label,
379 + shuffle_groups=False,
380 + **common_args
381 + )
382 + elif args.dataset_type == "kitti":
383 + train_generator = KittiGenerator(
384 + args.kitti_path,
385 + subset="train",
386 + transform_generator=transform_generator,
387 + visual_effect_generator=visual_effect_generator,
388 + **common_args
389 + )
390 +
391 + validation_generator = KittiGenerator(
392 + args.kitti_path, subset="val", shuffle_groups=False, **common_args
393 + )
394 + else:
395 + raise ValueError("Invalid data type received: {}".format(args.dataset_type))
396 +
397 + return train_generator, validation_generator
398 +
399 +
400 +def check_args(parsed_args):
401 + """Function to check for inherent contradictions within parsed arguments.
402 + For example, batch_size < num_gpus
403 + Intended to raise errors prior to backend initialisation.
404 +
405 + Args
406 + parsed_args: parser.parse_args()
407 +
408 + Returns
409 + parsed_args
410 + """
411 +
412 + if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu:
413 + raise ValueError(
414 + "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(
415 + parsed_args.batch_size, parsed_args.multi_gpu
416 + )
417 + )
418 +
419 + if parsed_args.multi_gpu > 1 and parsed_args.snapshot:
420 + raise ValueError(
421 + "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(
422 + parsed_args.multi_gpu, parsed_args.snapshot
423 + )
424 + )
425 +
426 + if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force:
427 + raise ValueError(
428 + "Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue."
429 + )
430 +
431 + if "resnet" not in parsed_args.backbone:
432 + warnings.warn(
433 + "Using experimental backbone {}. Only resnet50 has been properly tested.".format(
434 + parsed_args.backbone
435 + )
436 + )
437 +
438 + return parsed_args
439 +
440 +
441 +def parse_args(args):
442 + """Parse the arguments."""
443 + parser = argparse.ArgumentParser(
444 + description="Simple training script for training a RetinaNet network."
445 + )
446 + subparsers = parser.add_subparsers(
447 + help="Arguments for specific dataset types.", dest="dataset_type"
448 + )
449 + subparsers.required = True
450 +
451 + coco_parser = subparsers.add_parser("coco")
452 + coco_parser.add_argument(
453 + "coco_path", help="Path to dataset directory (ie. /tmp/COCO)."
454 + )
455 +
456 + pascal_parser = subparsers.add_parser("pascal")
457 + pascal_parser.add_argument(
458 + "pascal_path", help="Path to dataset directory (ie. /tmp/VOCdevkit)."
459 + )
460 + pascal_parser.add_argument(
461 + "--image-extension",
462 + help="Declares the dataset images' extension.",
463 + default=".jpg",
464 + )
465 +
466 + kitti_parser = subparsers.add_parser("kitti")
467 + kitti_parser.add_argument(
468 + "kitti_path", help="Path to dataset directory (ie. /tmp/kitti)."
469 + )
470 +
471 + def csv_list(string):
472 + return string.split(",")
473 +
474 + oid_parser = subparsers.add_parser("oid")
475 + oid_parser.add_argument("main_dir", help="Path to dataset directory.")
476 + oid_parser.add_argument(
477 + "--version", help="The current dataset version is v4.", default="v4"
478 + )
479 + oid_parser.add_argument(
480 + "--labels-filter",
481 + help="A list of labels to filter.",
482 + type=csv_list,
483 + default=None,
484 + )
485 + oid_parser.add_argument(
486 + "--annotation-cache-dir", help="Path to store annotation cache.", default="."
487 + )
488 + oid_parser.add_argument(
489 + "--parent-label", help="Use the hierarchy children of this label.", default=None
490 + )
491 +
492 + csv_parser = subparsers.add_parser("csv")
493 + csv_parser.add_argument(
494 + "annotations", help="Path to CSV file containing annotations for training."
495 + )
496 + csv_parser.add_argument(
497 + "classes", help="Path to a CSV file containing class label mapping."
498 + )
499 + csv_parser.add_argument(
500 + "--val-annotations",
501 + help="Path to CSV file containing annotations for validation (optional).",
502 + )
503 +
504 + group = parser.add_mutually_exclusive_group()
505 + group.add_argument("--snapshot", help="Resume training from a snapshot.")
506 + group.add_argument(
507 + "--imagenet-weights",
508 + help="Initialize the model with pretrained imagenet weights. This is the default behaviour.",
509 + action="store_const",
510 + const=True,
511 + default=True,
512 + )
513 + group.add_argument(
514 + "--weights", help="Initialize the model with weights from a file."
515 + )
516 + group.add_argument(
517 + "--no-weights",
518 + help="Don't initialize the model with any weights.",
519 + dest="imagenet_weights",
520 + action="store_const",
521 + const=False,
522 + )
523 + parser.add_argument(
524 + "--backbone",
525 + help="Backbone model used by retinanet.",
526 + default="resnet50",
527 + type=str,
528 + )
529 + parser.add_argument(
530 + "--batch-size", help="Size of the batches.", default=1, type=int
531 + )
532 + parser.add_argument(
533 + "--gpu", help="Id of the GPU to use (as reported by nvidia-smi)."
534 + )
535 + parser.add_argument(
536 + "--multi-gpu",
537 + help="Number of GPUs to use for parallel processing.",
538 + type=int,
539 + default=0,
540 + )
541 + parser.add_argument(
542 + "--multi-gpu-force",
543 + help="Extra flag needed to enable (experimental) multi-gpu support.",
544 + action="store_true",
545 + )
546 + parser.add_argument(
547 + "--initial-epoch",
548 + help="Epoch from which to begin the train, useful if resuming from snapshot.",
549 + type=int,
550 + default=0,
551 + )
552 + parser.add_argument(
553 + "--epochs", help="Number of epochs to train.", type=int, default=50
554 + )
555 + parser.add_argument(
556 + "--steps", help="Number of steps per epoch.", type=int, default=10000
557 + )
558 + parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-5)
559 + parser.add_argument(
560 + "--optimizer-clipnorm",
561 + help="Clipnorm parameter for optimizer.",
562 + type=float,
563 + default=0.001,
564 + )
565 + parser.add_argument(
566 + "--snapshot-path",
567 + help="Path to store snapshots of models during training (defaults to './snapshots')",
568 + default="./snapshots",
569 + )
570 + parser.add_argument(
571 + "--tensorboard-dir", help="Log directory for Tensorboard output", default=""
572 + ) # default='./logs') => https://github.com/tensorflow/tensorflow/pull/34870
573 + parser.add_argument(
574 + "--tensorboard-freq",
575 + help="Update frequency for Tensorboard output. Values 'epoch', 'batch' or int",
576 + default="epoch",
577 + )
578 + parser.add_argument(
579 + "--no-snapshots",
580 + help="Disable saving snapshots.",
581 + dest="snapshots",
582 + action="store_false",
583 + )
584 + parser.add_argument(
585 + "--no-evaluation",
586 + help="Disable per epoch evaluation.",
587 + dest="evaluation",
588 + action="store_false",
589 + )
590 + parser.add_argument(
591 + "--freeze-backbone",
592 + help="Freeze training of backbone layers.",
593 + action="store_true",
594 + )
595 + parser.add_argument(
596 + "--random-transform",
597 + help="Randomly transform image and annotations.",
598 + action="store_true",
599 + )
600 + parser.add_argument(
601 + "--image-min-side",
602 + help="Rescale the image so the smallest side is min_side.",
603 + type=int,
604 + default=800,
605 + )
606 + parser.add_argument(
607 + "--image-max-side",
608 + help="Rescale the image if the largest side is larger than max_side.",
609 + type=int,
610 + default=1333,
611 + )
612 + parser.add_argument(
613 + "--no-resize", help="Don" "t rescale the image.", action="store_true"
614 + )
615 + parser.add_argument(
616 + "--config", help="Path to a configuration parameters .ini file."
617 + )
618 + parser.add_argument(
619 + "--weighted-average",
620 + help="Compute the mAP using the weighted average of precisions among classes.",
621 + action="store_true",
622 + )
623 + parser.add_argument(
624 + "--compute-val-loss",
625 + help="Compute validation loss during training",
626 + dest="compute_val_loss",
627 + action="store_true",
628 + )
629 + parser.add_argument(
630 + "--reduce-lr-patience",
631 + help="Reduce learning rate after validation loss decreases over reduce_lr_patience epochs",
632 + type=int,
633 + default=2,
634 + )
635 + parser.add_argument(
636 + "--reduce-lr-factor",
637 + help="When learning rate is reduced due to reduce_lr_patience, multiply by reduce_lr_factor",
638 + type=float,
639 + default=0.1,
640 + )
641 + parser.add_argument(
642 + "--group-method",
643 + help="Determines how images are grouped together",
644 + type=str,
645 + default="ratio",
646 + choices=["none", "random", "ratio"],
647 + )
648 +
649 + # Fit generator arguments
650 + parser.add_argument(
651 + "--multiprocessing",
652 + help="Use multiprocessing in fit_generator.",
653 + action="store_true",
654 + )
655 + parser.add_argument(
656 + "--workers", help="Number of generator workers.", type=int, default=1
657 + )
658 + parser.add_argument(
659 + "--max-queue-size",
660 + help="Queue length for multiprocessing workers in fit_generator.",
661 + type=int,
662 + default=10,
663 + )
664 +
665 + return check_args(parser.parse_args(args))
666 +
667 +
668 +def main(args=None):
669 + # parse arguments
670 + if args is None:
671 + args = sys.argv[1:]
672 + args = parse_args(args)
673 +
674 + # create object that stores backbone information
675 + backbone = models.backbone(args.backbone)
676 +
677 + # make sure tensorflow is the minimum required version
678 + check_tf_version()
679 +
680 + # optionally choose specific GPU
681 + if args.gpu is not None:
682 + setup_gpu(args.gpu)
683 +
684 + # optionally load config parameters
685 + if args.config:
686 + args.config = read_config_file(args.config)
687 +
688 + # create the generators
689 + train_generator, validation_generator = create_generators(
690 + args, backbone.preprocess_image
691 + )
692 +
693 + # create the model
694 + if args.snapshot is not None:
695 + print("Loading model, this may take a second...")
696 + model = models.load_model(args.snapshot, backbone_name=args.backbone)
697 + training_model = model
698 + anchor_params = None
699 + pyramid_levels = None
700 + if args.config and "anchor_parameters" in args.config:
701 + anchor_params = parse_anchor_parameters(args.config)
702 + if args.config and "pyramid_levels" in args.config:
703 + pyramid_levels = parse_pyramid_levels(args.config)
704 +
705 + prediction_model = retinanet_bbox(
706 + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels
707 + )
708 + else:
709 + weights = args.weights
710 + # default to imagenet if nothing else is specified
711 + if weights is None and args.imagenet_weights:
712 + weights = backbone.download_imagenet()
713 +
714 + #################
715 + # subclass1 = submodel.custom_classification_model(num_classes=51, num_anchors=None, name="classification_submodel1")
716 + # subregress1 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel1")
717 +
718 + # subclass2 = submodel.custom_classification_model(num_classes=10, num_anchors=None, name="classification_submodel2")
719 + # subregress2 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel2")
720 +
721 + # subclass3 = submodel.custom_classification_model(num_classes=16, num_anchors=None, name="classification_submodel3")
722 + # subregress3 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel3")
723 +
724 + # submodels = [
725 + # ("regression", subregress1), ("classification", subclass1),
726 + # ("regression", subregress2), ("classification", subclass2),
727 + # ("regression", subregress3), ("classification", subclass3),
728 + # ]
729 +
730 + # s1 = submodel.custom_default_submodels(51, None)
731 + # s2 = submodel.custom_default_submodels(10, None)
732 + # s3 = submodel.custom_default_submodels(16, None)
733 +
734 + # submodels = s1 + s2 + s3
735 +
736 + #################
737 + print("Creating model, this may take a second...")
738 + model, training_model, prediction_model = create_models(
739 + backbone_retinanet=backbone.retinanet,
740 + num_classes=train_generator.num_classes(),
741 + weights=weights,
742 + multi_gpu=args.multi_gpu,
743 + freeze_backbone=args.freeze_backbone,
744 + lr=args.lr,
745 + optimizer_clipnorm=args.optimizer_clipnorm,
746 + config=args.config,
747 + submodels=submodel.custom_classification_model(76,),
748 + )
749 +
750 + # print model summary
751 + print(model.summary())
752 +
753 + # this lets the generator compute backbone layer shapes using the actual backbone model
754 + if "vgg" in args.backbone or "densenet" in args.backbone:
755 + train_generator.compute_shapes = make_shapes_callback(model)
756 + if validation_generator:
757 + validation_generator.compute_shapes = train_generator.compute_shapes
758 +
759 + # create the callbacks
760 + callbacks = create_callbacks(
761 + model,
762 + training_model,
763 + prediction_model,
764 + validation_generator,
765 + args,
766 + )
767 +
768 + if not args.compute_val_loss:
769 + validation_generator = None
770 +
771 + # start training
772 + return training_model.fit_generator(
773 + generator=train_generator,
774 + steps_per_epoch=args.steps,
775 + epochs=args.epochs,
776 + verbose=1,
777 + callbacks=callbacks,
778 + workers=args.workers,
779 + use_multiprocessing=args.multiprocessing,
780 + max_queue_size=args.max_queue_size,
781 + validation_data=validation_generator,
782 + initial_epoch=args.initial_epoch,
783 + )
784 +
785 +
786 +if __name__ == "__main__":
787 + main()
1 +from .common import * # noqa: F401,F403
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +from ..utils.coco_eval import evaluate_coco
19 +
20 +
21 +class CocoEval(keras.callbacks.Callback):
22 + """ Performs COCO evaluation on each epoch.
23 + """
24 + def __init__(self, generator, tensorboard=None, threshold=0.05):
25 + """ CocoEval callback intializer.
26 +
27 + Args
28 + generator : The generator used for creating validation data.
29 + tensorboard : If given, the results will be written to tensorboard.
30 + threshold : The score threshold to use.
31 + """
32 + self.generator = generator
33 + self.threshold = threshold
34 + self.tensorboard = tensorboard
35 +
36 + super(CocoEval, self).__init__()
37 +
38 + def on_epoch_end(self, epoch, logs=None):
39 + logs = logs or {}
40 +
41 + coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
42 + 'AP @[ IoU=0.50 | area= all | maxDets=100 ]',
43 + 'AP @[ IoU=0.75 | area= all | maxDets=100 ]',
44 + 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
45 + 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
46 + 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
47 + 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]',
48 + 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]',
49 + 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
50 + 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
51 + 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
52 + 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]']
53 + coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold)
54 +
55 + if coco_eval_stats is not None:
56 + for index, result in enumerate(coco_eval_stats):
57 + logs[coco_tag[index]] = result
58 +
59 + if self.tensorboard:
60 + import tensorflow as tf
61 + writer = tf.summary.create_file_writer(self.tensorboard.log_dir)
62 + with writer.as_default():
63 + for index, result in enumerate(coco_eval_stats):
64 + tf.summary.scalar('{}. {}'.format(index + 1, coco_tag[index]), result, step=epoch)
65 + writer.flush()
1 +from tensorflow import keras
2 +
3 +
4 +class RedirectModel(keras.callbacks.Callback):
5 + """Callback which wraps another callback, but executed on a different model.
6 +
7 + ```python
8 + model = keras.models.load_model('model.h5')
9 + model_checkpoint = ModelCheckpoint(filepath='snapshot.h5')
10 + parallel_model = multi_gpu_model(model, gpus=2)
11 + parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)])
12 + ```
13 +
14 + Args
15 + callback : callback to wrap.
16 + model : model to use when executing callbacks.
17 + """
18 +
19 + def __init__(self,
20 + callback,
21 + model):
22 + super(RedirectModel, self).__init__()
23 +
24 + self.callback = callback
25 + self.redirect_model = model
26 +
27 + def on_epoch_begin(self, epoch, logs=None):
28 + self.callback.on_epoch_begin(epoch, logs=logs)
29 +
30 + def on_epoch_end(self, epoch, logs=None):
31 + self.callback.on_epoch_end(epoch, logs=logs)
32 +
33 + def on_batch_begin(self, batch, logs=None):
34 + self.callback.on_batch_begin(batch, logs=logs)
35 +
36 + def on_batch_end(self, batch, logs=None):
37 + self.callback.on_batch_end(batch, logs=logs)
38 +
39 + def on_train_begin(self, logs=None):
40 + # overwrite the model with our custom model
41 + self.callback.set_model(self.redirect_model)
42 +
43 + self.callback.on_train_begin(logs=logs)
44 +
45 + def on_train_end(self, logs=None):
46 + self.callback.on_train_end(logs=logs)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +from ..utils.eval import evaluate
19 +
20 +
21 +class Evaluate(keras.callbacks.Callback):
22 + """ Evaluation callback for arbitrary datasets.
23 + """
24 +
25 + def __init__(
26 + self,
27 + generator,
28 + iou_threshold=0.5,
29 + score_threshold=0.05,
30 + max_detections=100,
31 + save_path=None,
32 + tensorboard=None,
33 + weighted_average=False,
34 + verbose=1
35 + ):
36 + """ Evaluate a given dataset using a given model at the end of every epoch during training.
37 +
38 + # Arguments
39 + generator : The generator that represents the dataset to evaluate.
40 + iou_threshold : The threshold used to consider when a detection is positive or negative.
41 + score_threshold : The score confidence threshold to use for detections.
42 + max_detections : The maximum number of detections to use per image.
43 + save_path : The path to save images with visualized detections to.
44 + tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value.
45 + weighted_average : Compute the mAP using the weighted average of precisions among classes.
46 + verbose : Set the verbosity level, by default this is set to 1.
47 + """
48 + self.generator = generator
49 + self.iou_threshold = iou_threshold
50 + self.score_threshold = score_threshold
51 + self.max_detections = max_detections
52 + self.save_path = save_path
53 + self.tensorboard = tensorboard
54 + self.weighted_average = weighted_average
55 + self.verbose = verbose
56 +
57 + super(Evaluate, self).__init__()
58 +
59 + def on_epoch_end(self, epoch, logs=None):
60 + logs = logs or {}
61 +
62 + # run evaluation
63 + average_precisions, _ = evaluate(
64 + self.generator,
65 + self.model,
66 + iou_threshold=self.iou_threshold,
67 + score_threshold=self.score_threshold,
68 + max_detections=self.max_detections,
69 + save_path=self.save_path
70 + )
71 +
72 + # compute per class average precision
73 + total_instances = []
74 + precisions = []
75 + for label, (average_precision, num_annotations) in average_precisions.items():
76 + if self.verbose == 1:
77 + print('{:.0f} instances of class'.format(num_annotations),
78 + self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision))
79 + total_instances.append(num_annotations)
80 + precisions.append(average_precision)
81 + if self.weighted_average:
82 + self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)
83 + else:
84 + self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances)
85 +
86 + if self.tensorboard:
87 + import tensorflow as tf
88 + writer = tf.summary.create_file_writer(self.tensorboard.log_dir)
89 + with writer.as_default():
90 + tf.summary.scalar("mAP", self.mean_ap, step=epoch)
91 + if self.verbose == 1:
92 + for label, (average_precision, num_annotations) in average_precisions.items():
93 + tf.summary.scalar("AP_" + self.generator.label_to_name(label), average_precision, step=epoch)
94 + writer.flush()
95 +
96 + logs['mAP'] = self.mean_ap
97 +
98 + if self.verbose == 1:
99 + print('mAP: {:.4f}'.format(self.mean_ap))
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +
19 +import math
20 +
21 +
22 +class PriorProbability(keras.initializers.Initializer):
23 + """ Apply a prior probability to the weights.
24 + """
25 +
26 + def __init__(self, probability=0.01):
27 + self.probability = probability
28 +
29 + def get_config(self):
30 + return {
31 + 'probability': self.probability
32 + }
33 +
34 + def __call__(self, shape, dtype=None):
35 + # set bias to -log((1 - p)/p) for foreground
36 + result = keras.backend.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability)
37 +
38 + return result
1 +from ._misc import RegressBoxes, UpsampleLike, Anchors, ClipBoxes # noqa: F401
2 +from .filter_detections import FilterDetections # noqa: F401
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import tensorflow
18 +from tensorflow import keras
19 +from .. import backend
20 +from ..utils import anchors as utils_anchors
21 +
22 +import numpy as np
23 +
24 +
25 +class Anchors(keras.layers.Layer):
26 + """ Keras layer for generating achors for a given shape.
27 + """
28 +
29 + def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs):
30 + """ Initializer for an Anchors layer.
31 +
32 + Args
33 + size: The base size of the anchors to generate.
34 + stride: The stride of the anchors to generate.
35 + ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios).
36 + scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales).
37 + """
38 + self.size = size
39 + self.stride = stride
40 + self.ratios = ratios
41 + self.scales = scales
42 +
43 + if ratios is None:
44 + self.ratios = utils_anchors.AnchorParameters.default.ratios
45 + elif isinstance(ratios, list):
46 + self.ratios = np.array(ratios)
47 + if scales is None:
48 + self.scales = utils_anchors.AnchorParameters.default.scales
49 + elif isinstance(scales, list):
50 + self.scales = np.array(scales)
51 +
52 + self.num_anchors = len(self.ratios) * len(self.scales)
53 + self.anchors = utils_anchors.generate_anchors(
54 + base_size=self.size,
55 + ratios=self.ratios,
56 + scales=self.scales,
57 + ).astype(np.float32)
58 +
59 + super(Anchors, self).__init__(*args, **kwargs)
60 +
61 + def call(self, inputs, **kwargs):
62 + features = inputs
63 + features_shape = keras.backend.shape(features)
64 +
65 + # generate proposals from bbox deltas and shifted anchors
66 + if keras.backend.image_data_format() == 'channels_first':
67 + anchors = backend.shift(features_shape[2:4], self.stride, self.anchors)
68 + else:
69 + anchors = backend.shift(features_shape[1:3], self.stride, self.anchors)
70 + anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1))
71 +
72 + return anchors
73 +
74 + def compute_output_shape(self, input_shape):
75 + if None not in input_shape[1:]:
76 + if keras.backend.image_data_format() == 'channels_first':
77 + total = np.prod(input_shape[2:4]) * self.num_anchors
78 + else:
79 + total = np.prod(input_shape[1:3]) * self.num_anchors
80 +
81 + return (input_shape[0], total, 4)
82 + else:
83 + return (input_shape[0], None, 4)
84 +
85 + def get_config(self):
86 + config = super(Anchors, self).get_config()
87 + config.update({
88 + 'size' : self.size,
89 + 'stride' : self.stride,
90 + 'ratios' : self.ratios.tolist(),
91 + 'scales' : self.scales.tolist(),
92 + })
93 +
94 + return config
95 +
96 +
97 +class UpsampleLike(keras.layers.Layer):
98 + """ Keras layer for upsampling a Tensor to be the same shape as another Tensor.
99 + """
100 +
101 + def call(self, inputs, **kwargs):
102 + source, target = inputs
103 + target_shape = keras.backend.shape(target)
104 + if keras.backend.image_data_format() == 'channels_first':
105 + source = tensorflow.transpose(source, (0, 2, 3, 1))
106 + output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest')
107 + output = tensorflow.transpose(output, (0, 3, 1, 2))
108 + return output
109 + else:
110 + return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest')
111 +
112 + def compute_output_shape(self, input_shape):
113 + if keras.backend.image_data_format() == 'channels_first':
114 + return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4]
115 + else:
116 + return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],)
117 +
118 +
119 +class RegressBoxes(keras.layers.Layer):
120 + """ Keras layer for applying regression values to boxes.
121 + """
122 +
123 + def __init__(self, mean=None, std=None, *args, **kwargs):
124 + """ Initializer for the RegressBoxes layer.
125 +
126 + Args
127 + mean: The mean value of the regression values which was used for normalization.
128 + std: The standard value of the regression values which was used for normalization.
129 + """
130 + if mean is None:
131 + mean = np.array([0, 0, 0, 0])
132 + if std is None:
133 + std = np.array([0.2, 0.2, 0.2, 0.2])
134 +
135 + if isinstance(mean, (list, tuple)):
136 + mean = np.array(mean)
137 + elif not isinstance(mean, np.ndarray):
138 + raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean)))
139 +
140 + if isinstance(std, (list, tuple)):
141 + std = np.array(std)
142 + elif not isinstance(std, np.ndarray):
143 + raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std)))
144 +
145 + self.mean = mean
146 + self.std = std
147 + super(RegressBoxes, self).__init__(*args, **kwargs)
148 +
149 + def call(self, inputs, **kwargs):
150 + anchors, regression = inputs
151 + return backend.bbox_transform_inv(anchors, regression, mean=self.mean, std=self.std)
152 +
153 + def compute_output_shape(self, input_shape):
154 + return input_shape[0]
155 +
156 + def get_config(self):
157 + config = super(RegressBoxes, self).get_config()
158 + config.update({
159 + 'mean': self.mean.tolist(),
160 + 'std' : self.std.tolist(),
161 + })
162 +
163 + return config
164 +
165 +
166 +class ClipBoxes(keras.layers.Layer):
167 + """ Keras layer to clip box values to lie inside a given shape.
168 + """
169 + def call(self, inputs, **kwargs):
170 + image, boxes = inputs
171 + shape = keras.backend.cast(keras.backend.shape(image), keras.backend.floatx())
172 + if keras.backend.image_data_format() == 'channels_first':
173 + _, _, height, width = tensorflow.unstack(shape, axis=0)
174 + else:
175 + _, height, width, _ = tensorflow.unstack(shape, axis=0)
176 +
177 + x1, y1, x2, y2 = tensorflow.unstack(boxes, axis=-1)
178 + x1 = tensorflow.clip_by_value(x1, 0, width - 1)
179 + y1 = tensorflow.clip_by_value(y1, 0, height - 1)
180 + x2 = tensorflow.clip_by_value(x2, 0, width - 1)
181 + y2 = tensorflow.clip_by_value(y2, 0, height - 1)
182 +
183 + return keras.backend.stack([x1, y1, x2, y2], axis=2)
184 +
185 + def compute_output_shape(self, input_shape):
186 + return input_shape[1]
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import tensorflow
18 +from tensorflow import keras
19 +from .. import backend
20 +
21 +
22 +def filter_detections(
23 + boxes,
24 + classification,
25 + other = [],
26 + class_specific_filter = True,
27 + nms = True,
28 + score_threshold = 0.05,
29 + max_detections = 300,
30 + nms_threshold = 0.5
31 +):
32 + """ Filter detections using the boxes and classification values.
33 +
34 + Args
35 + boxes : Tensor of shape (num_boxes, 4) containing the boxes in (x1, y1, x2, y2) format.
36 + classification : Tensor of shape (num_boxes, num_classes) containing the classification scores.
37 + other : List of tensors of shape (num_boxes, ...) to filter along with the boxes and classification scores.
38 + class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those.
39 + nms : Flag to enable/disable non maximum suppression.
40 + score_threshold : Threshold used to prefilter the boxes with.
41 + max_detections : Maximum number of detections to keep.
42 + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed.
43 +
44 + Returns
45 + A list of [boxes, scores, labels, other[0], other[1], ...].
46 + boxes is shaped (max_detections, 4) and contains the (x1, y1, x2, y2) of the non-suppressed boxes.
47 + scores is shaped (max_detections,) and contains the scores of the predicted class.
48 + labels is shaped (max_detections,) and contains the predicted label.
49 + other[i] is shaped (max_detections, ...) and contains the filtered other[i] data.
50 + In case there are less than max_detections detections, the tensors are padded with -1's.
51 + """
52 + def _filter_detections(scores, labels):
53 + # threshold based on score
54 + indices = tensorflow.where(keras.backend.greater(scores, score_threshold))
55 +
56 + if nms:
57 + filtered_boxes = tensorflow.gather_nd(boxes, indices)
58 + filtered_scores = keras.backend.gather(scores, indices)[:, 0]
59 +
60 + # perform NMS
61 + nms_indices = tensorflow.image.non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold)
62 +
63 + # filter indices based on NMS
64 + indices = keras.backend.gather(indices, nms_indices)
65 +
66 + # add indices to list of all indices
67 + labels = tensorflow.gather_nd(labels, indices)
68 + indices = keras.backend.stack([indices[:, 0], labels], axis=1)
69 +
70 + return indices
71 +
72 + if class_specific_filter:
73 + all_indices = []
74 + # perform per class filtering
75 + for c in range(int(classification.shape[1])):
76 + scores = classification[:, c]
77 + labels = c * tensorflow.ones((keras.backend.shape(scores)[0],), dtype='int64')
78 + all_indices.append(_filter_detections(scores, labels))
79 +
80 + # concatenate indices to single tensor
81 + indices = keras.backend.concatenate(all_indices, axis=0)
82 + else:
83 + scores = keras.backend.max(classification, axis = 1)
84 + labels = keras.backend.argmax(classification, axis = 1)
85 + indices = _filter_detections(scores, labels)
86 +
87 + # select top k
88 + scores = tensorflow.gather_nd(classification, indices)
89 + labels = indices[:, 1]
90 + scores, top_indices = tensorflow.nn.top_k(scores, k=keras.backend.minimum(max_detections, keras.backend.shape(scores)[0]))
91 +
92 + # filter input using the final set of indices
93 + indices = keras.backend.gather(indices[:, 0], top_indices)
94 + boxes = keras.backend.gather(boxes, indices)
95 + labels = keras.backend.gather(labels, top_indices)
96 + other_ = [keras.backend.gather(o, indices) for o in other]
97 +
98 + # zero pad the outputs
99 + pad_size = keras.backend.maximum(0, max_detections - keras.backend.shape(scores)[0])
100 + boxes = tensorflow.pad(boxes, [[0, pad_size], [0, 0]], constant_values=-1)
101 + scores = tensorflow.pad(scores, [[0, pad_size]], constant_values=-1)
102 + labels = tensorflow.pad(labels, [[0, pad_size]], constant_values=-1)
103 + labels = keras.backend.cast(labels, 'int32')
104 + other_ = [tensorflow.pad(o, [[0, pad_size]] + [[0, 0] for _ in range(1, len(o.shape))], constant_values=-1) for o in other_]
105 +
106 + # set shapes, since we know what they are
107 + boxes.set_shape([max_detections, 4])
108 + scores.set_shape([max_detections])
109 + labels.set_shape([max_detections])
110 + for o, s in zip(other_, [list(keras.backend.int_shape(o)) for o in other]):
111 + o.set_shape([max_detections] + s[1:])
112 +
113 + return [boxes, scores, labels] + other_
114 +
115 +
116 +class FilterDetections(keras.layers.Layer):
117 + """ Keras layer for filtering detections using score threshold and NMS.
118 + """
119 +
120 + def __init__(
121 + self,
122 + nms = True,
123 + class_specific_filter = True,
124 + nms_threshold = 0.5,
125 + score_threshold = 0.05,
126 + max_detections = 300,
127 + parallel_iterations = 32,
128 + **kwargs
129 + ):
130 + """ Filters detections using score threshold, NMS and selecting the top-k detections.
131 +
132 + Args
133 + nms : Flag to enable/disable NMS.
134 + class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those.
135 + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed.
136 + score_threshold : Threshold used to prefilter the boxes with.
137 + max_detections : Maximum number of detections to keep.
138 + parallel_iterations : Number of batch items to process in parallel.
139 + """
140 + self.nms = nms
141 + self.class_specific_filter = class_specific_filter
142 + self.nms_threshold = nms_threshold
143 + self.score_threshold = score_threshold
144 + self.max_detections = max_detections
145 + self.parallel_iterations = parallel_iterations
146 + super(FilterDetections, self).__init__(**kwargs)
147 +
148 + def call(self, inputs, **kwargs):
149 + """ Constructs the NMS graph.
150 +
151 + Args
152 + inputs : List of [boxes, classification, other[0], other[1], ...] tensors.
153 + """
154 + boxes = inputs[0]
155 + classification = inputs[1]
156 + other = inputs[2:]
157 +
158 + # wrap nms with our parameters
159 + def _filter_detections(args):
160 + boxes = args[0]
161 + classification = args[1]
162 + other = args[2]
163 +
164 + return filter_detections(
165 + boxes,
166 + classification,
167 + other,
168 + nms = self.nms,
169 + class_specific_filter = self.class_specific_filter,
170 + score_threshold = self.score_threshold,
171 + max_detections = self.max_detections,
172 + nms_threshold = self.nms_threshold,
173 + )
174 +
175 + # call filter_detections on each batch
176 + dtypes = [keras.backend.floatx(), keras.backend.floatx(), 'int32'] + [o.dtype for o in other]
177 + shapes = [(self.max_detections, 4), (self.max_detections,), (self.max_detections,)]
178 + shapes.extend([(self.max_detections,) + o.shape[2:] for o in other])
179 + outputs = backend.map_fn(
180 + _filter_detections,
181 + elems=[boxes, classification, other],
182 + dtype=dtypes,
183 + shapes=shapes,
184 + parallel_iterations=self.parallel_iterations,
185 + )
186 +
187 + return outputs
188 +
189 + def compute_output_shape(self, input_shape):
190 + """ Computes the output shapes given the input shapes.
191 +
192 + Args
193 + input_shape : List of input shapes [boxes, classification, other[0], other[1], ...].
194 +
195 + Returns
196 + List of tuples representing the output shapes:
197 + [filtered_boxes.shape, filtered_scores.shape, filtered_labels.shape, filtered_other[0].shape, filtered_other[1].shape, ...]
198 + """
199 + return [
200 + (input_shape[0][0], self.max_detections, 4),
201 + (input_shape[1][0], self.max_detections),
202 + (input_shape[1][0], self.max_detections),
203 + ] + [
204 + tuple([input_shape[i][0], self.max_detections] + list(input_shape[i][2:])) for i in range(2, len(input_shape))
205 + ]
206 +
207 + def compute_mask(self, inputs, mask=None):
208 + """ This is required in Keras when there is more than 1 output.
209 + """
210 + return (len(inputs) + 1) * [None]
211 +
212 + def get_config(self):
213 + """ Gets the configuration of this layer.
214 +
215 + Returns
216 + Dictionary containing the parameters of this layer.
217 + """
218 + config = super(FilterDetections, self).get_config()
219 + config.update({
220 + 'nms' : self.nms,
221 + 'class_specific_filter' : self.class_specific_filter,
222 + 'nms_threshold' : self.nms_threshold,
223 + 'score_threshold' : self.score_threshold,
224 + 'max_detections' : self.max_detections,
225 + 'parallel_iterations' : self.parallel_iterations,
226 + })
227 +
228 + return config
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import tensorflow
18 +from tensorflow import keras
19 +
20 +
21 +def focal(alpha=0.25, gamma=2.0, cutoff=0.5):
22 + """ Create a functor for computing the focal loss.
23 +
24 + Args
25 + alpha: Scale the focal weight with alpha.
26 + gamma: Take the power of the focal weight with gamma.
27 + cutoff: Positive prediction cutoff for soft targets
28 +
29 + Returns
30 + A functor that computes the focal loss using the alpha and gamma.
31 + """
32 + def _focal(y_true, y_pred):
33 + """ Compute the focal loss given the target tensor and the predicted tensor.
34 +
35 + As defined in https://arxiv.org/abs/1708.02002
36 +
37 + Args
38 + y_true: Tensor of target data from the generator with shape (B, N, num_classes).
39 + y_pred: Tensor of predicted data from the network with shape (B, N, num_classes).
40 +
41 + Returns
42 + The focal loss of y_pred w.r.t. y_true.
43 + """
44 + labels = y_true[:, :, :-1]
45 + anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object
46 + classification = y_pred
47 +
48 + # filter out "ignore" anchors
49 + indices = tensorflow.where(keras.backend.not_equal(anchor_state, -1))
50 + labels = tensorflow.gather_nd(labels, indices)
51 + classification = tensorflow.gather_nd(classification, indices)
52 +
53 + # compute the focal loss
54 + alpha_factor = keras.backend.ones_like(labels) * alpha
55 + alpha_factor = tensorflow.where(keras.backend.greater(labels, cutoff), alpha_factor, 1 - alpha_factor)
56 + focal_weight = tensorflow.where(keras.backend.greater(labels, cutoff), 1 - classification, classification)
57 + focal_weight = alpha_factor * focal_weight ** gamma
58 +
59 + cls_loss = focal_weight * keras.backend.binary_crossentropy(labels, classification)
60 +
61 + # compute the normalizer: the number of positive anchors
62 + normalizer = tensorflow.where(keras.backend.equal(anchor_state, 1))
63 + normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())
64 + normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer)
65 +
66 + return keras.backend.sum(cls_loss) / normalizer
67 +
68 + return _focal
69 +
70 +
71 +def smooth_l1(sigma=3.0):
72 + """ Create a smooth L1 loss functor.
73 +
74 + Args
75 + sigma: This argument defines the point where the loss changes from L2 to L1.
76 +
77 + Returns
78 + A functor for computing the smooth L1 loss given target data and predicted data.
79 + """
80 + sigma_squared = sigma ** 2
81 +
82 + def _smooth_l1(y_true, y_pred):
83 + """ Compute the smooth L1 loss of y_pred w.r.t. y_true.
84 +
85 + Args
86 + y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive).
87 + y_pred: Tensor from the network of shape (B, N, 4).
88 +
89 + Returns
90 + The smooth L1 loss of y_pred w.r.t. y_true.
91 + """
92 + # separate target and state
93 + regression = y_pred
94 + regression_target = y_true[:, :, :-1]
95 + anchor_state = y_true[:, :, -1]
96 +
97 + # filter out "ignore" anchors
98 + indices = tensorflow.where(keras.backend.equal(anchor_state, 1))
99 + regression = tensorflow.gather_nd(regression, indices)
100 + regression_target = tensorflow.gather_nd(regression_target, indices)
101 +
102 + # compute smooth L1 loss
103 + # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
104 + # |x| - 0.5 / sigma / sigma otherwise
105 + regression_diff = regression - regression_target
106 + regression_diff = keras.backend.abs(regression_diff)
107 + regression_loss = tensorflow.where(
108 + keras.backend.less(regression_diff, 1.0 / sigma_squared),
109 + 0.5 * sigma_squared * keras.backend.pow(regression_diff, 2),
110 + regression_diff - 0.5 / sigma_squared
111 + )
112 +
113 + # compute the normalizer: the number of positive anchors
114 + normalizer = keras.backend.maximum(1, keras.backend.shape(indices)[0])
115 + normalizer = keras.backend.cast(normalizer, dtype=keras.backend.floatx())
116 + return keras.backend.sum(regression_loss) / normalizer
117 +
118 + return _smooth_l1
1 +from __future__ import print_function
2 +import sys
3 +
4 +
5 +class Backbone(object):
6 + """ This class stores additional information on backbones.
7 + """
8 + def __init__(self, backbone):
9 + # a dictionary mapping custom layer names to the correct classes
10 + from .. import layers
11 + from .. import losses
12 + from .. import initializers
13 + self.custom_objects = {
14 + 'UpsampleLike' : layers.UpsampleLike,
15 + 'PriorProbability' : initializers.PriorProbability,
16 + 'RegressBoxes' : layers.RegressBoxes,
17 + 'FilterDetections' : layers.FilterDetections,
18 + 'Anchors' : layers.Anchors,
19 + 'ClipBoxes' : layers.ClipBoxes,
20 + '_smooth_l1' : losses.smooth_l1(),
21 + '_focal' : losses.focal(),
22 + }
23 +
24 + self.backbone = backbone
25 + self.validate()
26 +
27 + def retinanet(self, *args, **kwargs):
28 + """ Returns a retinanet model using the correct backbone.
29 + """
30 + raise NotImplementedError('retinanet method not implemented.')
31 +
32 + def download_imagenet(self):
33 + """ Downloads ImageNet weights and returns path to weights file.
34 + """
35 + raise NotImplementedError('download_imagenet method not implemented.')
36 +
37 + def validate(self):
38 + """ Checks whether the backbone string is correct.
39 + """
40 + raise NotImplementedError('validate method not implemented.')
41 +
42 + def preprocess_image(self, inputs):
43 + """ Takes as input an image and prepares it for being passed through the network.
44 + Having this function in Backbone allows other backbones to define a specific preprocessing step.
45 + """
46 + raise NotImplementedError('preprocess_image method not implemented.')
47 +
48 +
49 +def backbone(backbone_name):
50 + """ Returns a backbone object for the given backbone.
51 + """
52 + if 'densenet' in backbone_name:
53 + from .densenet import DenseNetBackbone as b
54 + elif 'seresnext' in backbone_name or 'seresnet' in backbone_name or 'senet' in backbone_name:
55 + from .senet import SeBackbone as b
56 + elif 'resnet' in backbone_name:
57 + from .resnet import ResNetBackbone as b
58 + elif 'mobilenet' in backbone_name:
59 + from .mobilenet import MobileNetBackbone as b
60 + elif 'vgg' in backbone_name:
61 + from .vgg import VGGBackbone as b
62 + elif 'EfficientNet' in backbone_name:
63 + from .effnet import EfficientNetBackbone as b
64 + else:
65 + raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone))
66 +
67 + return b(backbone_name)
68 +
69 +
70 +def load_model(filepath, backbone_name='resnet50'):
71 + """ Loads a retinanet model using the correct custom objects.
72 +
73 + Args
74 + filepath: one of the following:
75 + - string, path to the saved model, or
76 + - h5py.File object from which to load the model
77 + backbone_name : Backbone with which the model was trained.
78 +
79 + Returns
80 + A keras.models.Model object.
81 +
82 + Raises
83 + ImportError: if h5py is not available.
84 + ValueError: In case of an invalid savefile.
85 + """
86 + from tensorflow import keras
87 + return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects)
88 +
89 +
90 +def convert_model(model, nms=True, class_specific_filter=True, anchor_params=None, **kwargs):
91 + """ Converts a training model to an inference model.
92 +
93 + Args
94 + model : A retinanet training model.
95 + nms : Boolean, whether to add NMS filtering to the converted model.
96 + class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only.
97 + anchor_params : Anchor parameters object. If omitted, default values are used.
98 + **kwargs : Inference and minimal retinanet model settings.
99 +
100 + Returns
101 + A keras.models.Model object.
102 +
103 + Raises
104 + ImportError: if h5py is not available.
105 + ValueError: In case of an invalid savefile.
106 + """
107 + from .retinanet import retinanet_bbox
108 + return retinanet_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter, anchor_params=anchor_params, **kwargs)
109 +
110 +
111 +def assert_training_model(model):
112 + """ Assert that the model is a training model.
113 + """
114 + assert(all(output in model.output_names for output in ['regression', 'classification'])), \
115 + "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(model.output_names)
116 +
117 +
118 +def check_training_model(model):
119 + """ Check that model is a training model and exit otherwise.
120 + """
121 + try:
122 + assert_training_model(model)
123 + except AssertionError as e:
124 + print(e, file=sys.stderr)
125 + sys.exit(1)
1 +"""
2 +Copyright 2018 vidosits (https://github.com/vidosits/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +
19 +from . import retinanet
20 +from . import Backbone
21 +from ..utils.image import preprocess_image
22 +
23 +
24 +allowed_backbones = {
25 + 'densenet121': ([6, 12, 24, 16], keras.applications.densenet.DenseNet121),
26 + 'densenet169': ([6, 12, 32, 32], keras.applications.densenet.DenseNet169),
27 + 'densenet201': ([6, 12, 48, 32], keras.applications.densenet.DenseNet201),
28 +}
29 +
30 +
31 +class DenseNetBackbone(Backbone):
32 + """ Describes backbone information and provides utility functions.
33 + """
34 +
35 + def retinanet(self, *args, **kwargs):
36 + """ Returns a retinanet model using the correct backbone.
37 + """
38 + return densenet_retinanet(*args, backbone=self.backbone, **kwargs)
39 +
40 + def download_imagenet(self):
41 + """ Download pre-trained weights for the specified backbone name.
42 + This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop
43 + where backbone is the densenet + number of layers (e.g. densenet121).
44 + For more info check the explanation from the keras densenet script itself:
45 + https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py
46 + """
47 + origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/'
48 + file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'
49 +
50 + # load weights
51 + if keras.backend.image_data_format() == 'channels_first':
52 + raise ValueError('Weights for "channels_first" format are not available.')
53 +
54 + weights_url = origin + file_name.format(self.backbone)
55 + return keras.utils.get_file(file_name.format(self.backbone), weights_url, cache_subdir='models')
56 +
57 + def validate(self):
58 + """ Checks whether the backbone string is correct.
59 + """
60 + backbone = self.backbone.split('_')[0]
61 +
62 + if backbone not in allowed_backbones:
63 + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys()))
64 +
65 + def preprocess_image(self, inputs):
66 + """ Takes as input an image and prepares it for being passed through the network.
67 + """
68 + return preprocess_image(inputs, mode='tf')
69 +
70 +
71 +def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs):
72 + """ Constructs a retinanet model using a densenet backbone.
73 +
74 + Args
75 + num_classes: Number of classes to predict.
76 + backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')).
77 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
78 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
79 +
80 + Returns
81 + RetinaNet model with a DenseNet backbone.
82 + """
83 + # choose default input
84 + if inputs is None:
85 + inputs = keras.layers.Input((None, None, 3))
86 +
87 + blocks, creator = allowed_backbones[backbone]
88 + model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None)
89 +
90 + # get last conv layer from the end of each dense block
91 + layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)]
92 +
93 + # create the densenet backbone
94 + # layer_outputs contains 4 layers
95 + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name)
96 +
97 + # invoke modifier if given
98 + if modifier:
99 + model = modifier(model)
100 +
101 + # create the full model
102 + backbone_layers = {
103 + 'C2': model.outputs[0],
104 + 'C3': model.outputs[1],
105 + 'C4': model.outputs[2],
106 + 'C5': model.outputs[3]
107 + }
108 +
109 + model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs)
110 +
111 + return model
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +
19 +from . import retinanet
20 +from . import Backbone
21 +import efficientnet.keras as efn
22 +
23 +
24 +class EfficientNetBackbone(Backbone):
25 + """ Describes backbone information and provides utility functions.
26 + """
27 +
28 + def __init__(self, backbone):
29 + super(EfficientNetBackbone, self).__init__(backbone)
30 + self.preprocess_image_func = None
31 +
32 + def retinanet(self, *args, **kwargs):
33 + """ Returns a retinanet model using the correct backbone.
34 + """
35 + return effnet_retinanet(*args, backbone=self.backbone, **kwargs)
36 +
37 + def download_imagenet(self):
38 + """ Downloads ImageNet weights and returns path to weights file.
39 + """
40 + from efficientnet.weights import IMAGENET_WEIGHTS_PATH
41 + from efficientnet.weights import IMAGENET_WEIGHTS_HASHES
42 +
43 + model_name = 'efficientnet-b' + self.backbone[-1]
44 + file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
45 + file_hash = IMAGENET_WEIGHTS_HASHES[model_name][1]
46 + weights_path = keras.utils.get_file(file_name, IMAGENET_WEIGHTS_PATH + file_name, cache_subdir='models', file_hash=file_hash)
47 + return weights_path
48 +
49 + def validate(self):
50 + """ Checks whether the backbone string is correct.
51 + """
52 + allowed_backbones = ['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4',
53 + 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7']
54 + backbone = self.backbone.split('_')[0]
55 +
56 + if backbone not in allowed_backbones:
57 + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones))
58 +
59 + def preprocess_image(self, inputs):
60 + """ Takes as input an image and prepares it for being passed through the network.
61 + """
62 + return efn.preprocess_input(inputs)
63 +
64 +
65 +def effnet_retinanet(num_classes, backbone='EfficientNetB0', inputs=None, modifier=None, **kwargs):
66 + """ Constructs a retinanet model using a resnet backbone.
67 +
68 + Args
69 + num_classes: Number of classes to predict.
70 + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')).
71 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
72 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
73 +
74 + Returns
75 + RetinaNet model with a ResNet backbone.
76 + """
77 + # choose default input
78 + if inputs is None:
79 + if keras.backend.image_data_format() == 'channels_first':
80 + inputs = keras.layers.Input(shape=(3, None, None))
81 + else:
82 + # inputs = keras.layers.Input(shape=(224, 224, 3))
83 + inputs = keras.layers.Input(shape=(None, None, 3))
84 +
85 + # get last conv layer from the end of each block [28x28, 14x14, 7x7]
86 + if backbone == 'EfficientNetB0':
87 + model = efn.EfficientNetB0(input_tensor=inputs, include_top=False, weights=None)
88 + elif backbone == 'EfficientNetB1':
89 + model = efn.EfficientNetB1(input_tensor=inputs, include_top=False, weights=None)
90 + elif backbone == 'EfficientNetB2':
91 + model = efn.EfficientNetB2(input_tensor=inputs, include_top=False, weights=None)
92 + elif backbone == 'EfficientNetB3':
93 + model = efn.EfficientNetB3(input_tensor=inputs, include_top=False, weights=None)
94 + elif backbone == 'EfficientNetB4':
95 + model = efn.EfficientNetB4(input_tensor=inputs, include_top=False, weights=None)
96 + elif backbone == 'EfficientNetB5':
97 + model = efn.EfficientNetB5(input_tensor=inputs, include_top=False, weights=None)
98 + elif backbone == 'EfficientNetB6':
99 + model = efn.EfficientNetB6(input_tensor=inputs, include_top=False, weights=None)
100 + elif backbone == 'EfficientNetB7':
101 + model = efn.EfficientNetB7(input_tensor=inputs, include_top=False, weights=None)
102 + else:
103 + raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone))
104 +
105 + layer_outputs = ['block4a_expand_activation', 'block6a_expand_activation', 'top_activation']
106 +
107 + layer_outputs = [
108 + model.get_layer(name=layer_outputs[0]).output, # 28x28
109 + model.get_layer(name=layer_outputs[1]).output, # 14x14
110 + model.get_layer(name=layer_outputs[2]).output, # 7x7
111 + ]
112 + # create the densenet backbone
113 + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name)
114 +
115 + # invoke modifier if given
116 + if modifier:
117 + model = modifier(model)
118 +
119 + # C2 not provided
120 + backbone_layers = {
121 + 'C3': model.outputs[0],
122 + 'C4': model.outputs[1],
123 + 'C5': model.outputs[2]
124 + }
125 +
126 + # create the full model
127 + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs)
128 +
129 +
130 +def EfficientNetB0_retinanet(num_classes, inputs=None, **kwargs):
131 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB0', inputs=inputs, **kwargs)
132 +
133 +
134 +def EfficientNetB1_retinanet(num_classes, inputs=None, **kwargs):
135 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB1', inputs=inputs, **kwargs)
136 +
137 +
138 +def EfficientNetB2_retinanet(num_classes, inputs=None, **kwargs):
139 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB2', inputs=inputs, **kwargs)
140 +
141 +
142 +def EfficientNetB3_retinanet(num_classes, inputs=None, **kwargs):
143 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB3', inputs=inputs, **kwargs)
144 +
145 +
146 +def EfficientNetB4_retinanet(num_classes, inputs=None, **kwargs):
147 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB4', inputs=inputs, **kwargs)
148 +
149 +
150 +def EfficientNetB5_retinanet(num_classes, inputs=None, **kwargs):
151 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB5', inputs=inputs, **kwargs)
152 +
153 +
154 +def EfficientNetB6_retinanet(num_classes, inputs=None, **kwargs):
155 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB6', inputs=inputs, **kwargs)
156 +
157 +
158 +def EfficientNetB7_retinanet(num_classes, inputs=None, **kwargs):
159 + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB7', inputs=inputs, **kwargs)
1 +"""
2 +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +from ..utils.image import preprocess_image
19 +
20 +from . import retinanet
21 +from . import Backbone
22 +
23 +
24 +class MobileNetBackbone(Backbone):
25 + """ Describes backbone information and provides utility functions.
26 + """
27 +
28 + allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224']
29 +
30 + def retinanet(self, *args, **kwargs):
31 + """ Returns a retinanet model using the correct backbone.
32 + """
33 + return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs)
34 +
35 + def download_imagenet(self):
36 + """ Download pre-trained weights for the specified backbone name.
37 + This name is in the format mobilenet{rows}_{alpha} where rows is the
38 + imagenet shape dimension and 'alpha' controls the width of the network.
39 + For more info check the explanation from the keras mobilenet script itself.
40 + """
41 +
42 + alpha = float(self.backbone.split('_')[1])
43 + rows = int(self.backbone.split('_')[0].replace('mobilenet', ''))
44 +
45 + # load weights
46 + if keras.backend.image_data_format() == 'channels_first':
47 + raise ValueError('Weights for "channels_last" format '
48 + 'are not available.')
49 + if alpha == 1.0:
50 + alpha_text = '1_0'
51 + elif alpha == 0.75:
52 + alpha_text = '7_5'
53 + elif alpha == 0.50:
54 + alpha_text = '5_0'
55 + else:
56 + alpha_text = '2_5'
57 +
58 + model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows)
59 + weights_url = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' + model_name
60 + weights_path = keras.utils.get_file(model_name, weights_url, cache_subdir='models')
61 +
62 + return weights_path
63 +
64 + def validate(self):
65 + """ Checks whether the backbone string is correct.
66 + """
67 + backbone = self.backbone.split('_')[0]
68 +
69 + if backbone not in MobileNetBackbone.allowed_backbones:
70 + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones))
71 +
72 + def preprocess_image(self, inputs):
73 + """ Takes as input an image and prepares it for being passed through the network.
74 + """
75 + return preprocess_image(inputs, mode='tf')
76 +
77 +
78 +def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs):
79 + """ Constructs a retinanet model using a mobilenet backbone.
80 +
81 + Args
82 + num_classes: Number of classes to predict.
83 + backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')).
84 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
85 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
86 +
87 + Returns
88 + RetinaNet model with a MobileNet backbone.
89 + """
90 + alpha = float(backbone.split('_')[1])
91 +
92 + # choose default input
93 + if inputs is None:
94 + inputs = keras.layers.Input((None, None, 3))
95 +
96 + backbone = keras.applications.mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None)
97 +
98 + # create the full model
99 + layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu']
100 + layer_outputs = [backbone.get_layer(name).output for name in layer_names]
101 + backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name)
102 +
103 + # invoke modifier if given
104 + if modifier:
105 + backbone = modifier(backbone)
106 +
107 + # C2 not provided
108 + backbone_layers = {
109 + 'C3': backbone.outputs[0],
110 + 'C4': backbone.outputs[1],
111 + 'C5': backbone.outputs[2]
112 + }
113 +
114 + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +import keras_resnet
19 +import keras_resnet.models
20 +
21 +from . import retinanet
22 +from . import Backbone
23 +from ..utils.image import preprocess_image
24 +
25 +
26 +class ResNetBackbone(Backbone):
27 + """Describes backbone information and provides utility functions."""
28 +
29 + def __init__(self, backbone):
30 + super(ResNetBackbone, self).__init__(backbone)
31 + self.custom_objects.update(keras_resnet.custom_objects)
32 +
33 + def retinanet(self, *args, **kwargs):
34 + """Returns a retinanet model using the correct backbone."""
35 + return resnet_retinanet(*args, backbone=self.backbone, **kwargs)
36 +
37 + def download_imagenet(self):
38 + """Downloads ImageNet weights and returns path to weights file."""
39 + resnet_filename = "ResNet-{}-model.keras.h5"
40 + resnet_resource = (
41 + "https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}".format(
42 + resnet_filename
43 + )
44 + )
45 + depth = int(self.backbone.replace("resnet", ""))
46 +
47 + filename = resnet_filename.format(depth)
48 + resource = resnet_resource.format(depth)
49 + if depth == 50:
50 + checksum = "3e9f4e4f77bbe2c9bec13b53ee1c2319"
51 + elif depth == 101:
52 + checksum = "05dc86924389e5b401a9ea0348a3213c"
53 + elif depth == 152:
54 + checksum = "6ee11ef2b135592f8031058820bb9e71"
55 +
56 + return keras.utils.get_file(
57 + filename, resource, cache_subdir="models", md5_hash=checksum
58 + )
59 +
60 + def validate(self):
61 + """Checks whether the backbone string is correct."""
62 + allowed_backbones = ["resnet50", "resnet101", "resnet152"]
63 + backbone = self.backbone.split("_")[0]
64 +
65 + if backbone not in allowed_backbones:
66 + raise ValueError(
67 + "Backbone ('{}') not in allowed backbones ({}).".format(
68 + backbone, allowed_backbones
69 + )
70 + )
71 +
72 + def preprocess_image(self, inputs):
73 + """Takes as input an image and prepares it for being passed through the network."""
74 + return preprocess_image(inputs, mode="caffe")
75 +
76 +
77 +def resnet_retinanet(
78 + num_classes, backbone="resnet50", inputs=None, modifier=None, **kwargs
79 +):
80 + """Constructs a retinanet model using a resnet backbone.
81 +
82 + Args
83 + num_classes: Number of classes to predict.
84 + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')).
85 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
86 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
87 +
88 + Returns
89 + RetinaNet model with a ResNet backbone.
90 + """
91 + # choose default input
92 + if inputs is None:
93 + if keras.backend.image_data_format() == "channels_first":
94 + inputs = keras.layers.Input(shape=(3, None, None))
95 + else:
96 + inputs = keras.layers.Input(shape=(None, None, 3))
97 +
98 + # create the resnet backbone
99 + if backbone == "resnet50":
100 + resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True)
101 + elif backbone == "resnet101":
102 + resnet = keras_resnet.models.ResNet101(
103 + inputs, include_top=False, freeze_bn=True
104 + )
105 + elif backbone == "resnet152":
106 + resnet = keras_resnet.models.ResNet152(
107 + inputs, include_top=False, freeze_bn=True
108 + )
109 + else:
110 + raise ValueError("Backbone ('{}') is invalid.".format(backbone))
111 +
112 + # invoke modifier if given
113 + if modifier:
114 + resnet = modifier(resnet)
115 +
116 + # create the full model
117 + # resnet.outputs contains 4 layers
118 + backbone_layers = {
119 + "C2": resnet.outputs[0],
120 + "C3": resnet.outputs[1],
121 + "C4": resnet.outputs[2],
122 + "C5": resnet.outputs[3],
123 + }
124 +
125 + return retinanet.retinanet(
126 + inputs=inputs,
127 + num_classes=num_classes,
128 + backbone_layers=backbone_layers,
129 + **kwargs
130 + )
131 +
132 +
133 +def resnet50_retinanet(num_classes, inputs=None, **kwargs):
134 + return resnet_retinanet(
135 + num_classes=num_classes, backbone="resnet50", inputs=inputs, **kwargs
136 + )
137 +
138 +
139 +def resnet101_retinanet(num_classes, inputs=None, **kwargs):
140 + return resnet_retinanet(
141 + num_classes=num_classes, backbone="resnet101", inputs=inputs, **kwargs
142 + )
143 +
144 +
145 +def resnet152_retinanet(num_classes, inputs=None, **kwargs):
146 + return resnet_retinanet(
147 + num_classes=num_classes, backbone="resnet152", inputs=inputs, **kwargs
148 + )
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +from .. import initializers
19 +from .. import layers
20 +from ..utils.anchors import AnchorParameters
21 +from . import assert_training_model
22 +
23 +
24 +def default_classification_model(
25 + num_classes,
26 + num_anchors,
27 + pyramid_feature_size=256,
28 + prior_probability=0.01,
29 + classification_feature_size=256,
30 + name="classification_submodel",
31 +):
32 + """Creates the default classification submodel.
33 +
34 + Args
35 + num_classes : Number of classes to predict a score for at each feature level.
36 + num_anchors : Number of anchors to predict classification scores for at each feature level.
37 + pyramid_feature_size : The number of filters to expect from the feature pyramid levels.
38 + classification_feature_size : The number of filters to use in the layers in the classification submodel.
39 + name : The name of the submodel.
40 +
41 + Returns
42 + A keras.models.Model that predicts classes for each anchor.
43 + """
44 + options = {
45 + "kernel_size": 3,
46 + "strides": 1,
47 + "padding": "same",
48 + }
49 +
50 + # set input
51 + if keras.backend.image_data_format() == "channels_first":
52 + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None))
53 + else:
54 + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size))
55 +
56 + outputs = inputs
57 +
58 + # 4 layer
59 + for i in range(4):
60 +
61 + # 각 층의 output
62 + outputs = keras.layers.Conv2D(
63 + filters=classification_feature_size,
64 + activation="relu",
65 + name="pyramid_classification_{}".format(i),
66 + kernel_initializer=keras.initializers.RandomNormal(
67 + mean=0.0, stddev=0.01, seed=None
68 + ), # 정규분포에 따라 텐서를 생성하는 초기값 설정
69 + bias_initializer="zeros",
70 + **options
71 + )(outputs)
72 +
73 + # 마지막 layer는 다른 필터로 다른 conv layer를 생성
74 + outputs = keras.layers.Conv2D(
75 + filters=num_classes * num_anchors,
76 + kernel_initializer=keras.initializers.RandomNormal(
77 + mean=0.0, stddev=0.01, seed=None
78 + ),
79 + bias_initializer=initializers.PriorProbability(probability=prior_probability),
80 + name="pyramid_classification",
81 + **options
82 + )(outputs)
83 +
84 + # reshape output and apply sigmoid
85 + if keras.backend.image_data_format() == "channels_first":
86 + outputs = keras.layers.Permute(
87 + (2, 3, 1), name="pyramid_classification_permute"
88 + )(outputs)
89 +
90 + # reshape : 2차원 > 1차원
91 + outputs = keras.layers.Reshape(
92 + (-1, num_classes), name="pyramid_classification_reshape"
93 + )(outputs)
94 +
95 + # output layer activation : sigmoid
96 + outputs = keras.layers.Activation("sigmoid", name="pyramid_classification_sigmoid")(
97 + outputs
98 + )
99 +
100 + return keras.models.Model(inputs=inputs, outputs=outputs, name=name)
101 +
102 +
103 +def default_regression_model(
104 + num_values,
105 + num_anchors,
106 + pyramid_feature_size=256,
107 + regression_feature_size=256,
108 + name="regression_submodel",
109 +):
110 + """Creates the default regression submodel.
111 +
112 + Args
113 + num_values : Number of values to regress.
114 + num_anchors : Number of anchors to regress for each feature level.
115 + pyramid_feature_size : The number of filters to expect from the feature pyramid levels.
116 + regression_feature_size : The number of filters to use in the layers in the regression submodel.
117 + name : The name of the submodel.
118 +
119 + Returns
120 + A keras.models.Model that predicts regression values for each anchor.
121 + """
122 + # All new conv layers except the final one in the
123 + # RetinaNet (classification) subnets are initialized
124 + # with bias b = 0 and a Gaussian weight fill with stddev = 0.01.
125 + options = {
126 + "kernel_size": 3,
127 + "strides": 1,
128 + "padding": "same",
129 + "kernel_initializer": keras.initializers.RandomNormal(
130 + mean=0.0, stddev=0.01, seed=None
131 + ),
132 + "bias_initializer": "zeros",
133 + }
134 +
135 + if keras.backend.image_data_format() == "channels_first":
136 + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None))
137 + else:
138 + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size))
139 + outputs = inputs
140 + for i in range(4):
141 + outputs = keras.layers.Conv2D(
142 + filters=regression_feature_size,
143 + activation="relu",
144 + name="pyramid_regression_{}".format(i),
145 + **options
146 + )(outputs)
147 +
148 + outputs = keras.layers.Conv2D(
149 + num_anchors * num_values, name="pyramid_regression", **options
150 + )(outputs)
151 + if keras.backend.image_data_format() == "channels_first":
152 + outputs = keras.layers.Permute((2, 3, 1), name="pyramid_regression_permute")(
153 + outputs
154 + )
155 + outputs = keras.layers.Reshape((-1, num_values), name="pyramid_regression_reshape")(
156 + outputs
157 + )
158 +
159 + return keras.models.Model(inputs=inputs, outputs=outputs, name=name)
160 +
161 +
162 +def __create_pyramid_features(backbone_layers, pyramid_levels, feature_size=256):
163 + """Creates the FPN layers on top of the backbone features.
164 +
165 + Args
166 + backbone_layers: a dictionary containing feature stages C3, C4, C5 from the backbone. Also contains C2 if provided.
167 + pyramid_levels: Pyramid levels in use.
168 + feature_size : The feature size to use for the resulting feature levels.
169 +
170 + Returns
171 + output_layers : A dict of feature levels. P3, P4, P5, P6 are always included. P2, P6, P7 included if in use.
172 + """
173 +
174 + output_layers = {}
175 +
176 + # upsample C5 to get P5 from the FPN paper
177 + P5 = keras.layers.Conv2D(
178 + feature_size, kernel_size=1, strides=1, padding="same", name="C5_reduced"
179 + )(backbone_layers["C5"])
180 + P5_upsampled = layers.UpsampleLike(name="P5_upsampled")([P5, backbone_layers["C4"]])
181 + P5 = keras.layers.Conv2D(
182 + feature_size, kernel_size=3, strides=1, padding="same", name="P5"
183 + )(P5)
184 + output_layers["P5"] = P5
185 +
186 + # add P5 elementwise to C4
187 + P4 = keras.layers.Conv2D(
188 + feature_size, kernel_size=1, strides=1, padding="same", name="C4_reduced"
189 + )(backbone_layers["C4"])
190 + P4 = keras.layers.Add(name="P4_merged")([P5_upsampled, P4])
191 + P4_upsampled = layers.UpsampleLike(name="P4_upsampled")([P4, backbone_layers["C3"]])
192 + P4 = keras.layers.Conv2D(
193 + feature_size, kernel_size=3, strides=1, padding="same", name="P4"
194 + )(P4)
195 + output_layers["P4"] = P4
196 +
197 + # add P4 elementwise to C3
198 + P3 = keras.layers.Conv2D(
199 + feature_size, kernel_size=1, strides=1, padding="same", name="C3_reduced"
200 + )(backbone_layers["C3"])
201 + P3 = keras.layers.Add(name="P3_merged")([P4_upsampled, P3])
202 + if "C2" in backbone_layers and 2 in pyramid_levels:
203 + P3_upsampled = layers.UpsampleLike(name="P3_upsampled")(
204 + [P3, backbone_layers["C2"]]
205 + )
206 + P3 = keras.layers.Conv2D(
207 + feature_size, kernel_size=3, strides=1, padding="same", name="P3"
208 + )(P3)
209 + output_layers["P3"] = P3
210 +
211 + if "C2" in backbone_layers and 2 in pyramid_levels:
212 + P2 = keras.layers.Conv2D(
213 + feature_size, kernel_size=1, strides=1, padding="same", name="C2_reduced"
214 + )(backbone_layers["C2"])
215 + P2 = keras.layers.Add(name="P2_merged")([P3_upsampled, P2])
216 + P2 = keras.layers.Conv2D(
217 + feature_size, kernel_size=3, strides=1, padding="same", name="P2"
218 + )(P2)
219 + output_layers["P2"] = P2
220 +
221 + # "P6 is obtained via a 3x3 stride-2 conv on C5"
222 + if 6 in pyramid_levels:
223 + P6 = keras.layers.Conv2D(
224 + feature_size, kernel_size=3, strides=2, padding="same", name="P6"
225 + )(backbone_layers["C5"])
226 + output_layers["P6"] = P6
227 +
228 + # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
229 + if 7 in pyramid_levels:
230 + if 6 not in pyramid_levels:
231 + raise ValueError("P6 is required to use P7")
232 + P7 = keras.layers.Activation("relu", name="C6_relu")(P6)
233 + P7 = keras.layers.Conv2D(
234 + feature_size, kernel_size=3, strides=2, padding="same", name="P7"
235 + )(P7)
236 + output_layers["P7"] = P7
237 +
238 + return output_layers
239 +
240 +
241 +def default_submodels(num_classes, num_anchors):
242 + """Create a list of default submodels used for object detection.
243 +
244 + The default submodels contains a regression submodel and a classification submodel.
245 +
246 + Args
247 + num_classes : Number of classes to use.
248 + num_anchors : Number of base anchors.
249 +
250 + Returns
251 + A list of tuple, where the first element is the name of the submodel and the second element is the submodel itself.
252 + """
253 + return [
254 + ("regression", default_regression_model(4, num_anchors)),
255 + ("classification", default_classification_model(num_classes, num_anchors)),
256 + ]
257 +
258 +
259 +def __build_model_pyramid(name, model, features):
260 + """Applies a single submodel to each FPN level.
261 +
262 + Args
263 + name : Name of the submodel.
264 + model : The submodel to evaluate.
265 + features : The FPN features.
266 +
267 + Returns
268 + A tensor containing the response from the submodel on the FPN features.
269 + """
270 + return keras.layers.Concatenate(axis=1, name=name)([model(f) for f in features])
271 +
272 +
273 +def __build_pyramid(models, features):
274 + """Applies all submodels to each FPN level.
275 +
276 + Args
277 + models : List of submodels to run on each pyramid level (by default only regression, classifcation).
278 + features : The FPN features.
279 +
280 + Returns
281 + A list of tensors, one for each submodel.
282 + """
283 + return [__build_model_pyramid(n, m, features) for n, m in models]
284 +
285 +
286 +def __build_anchors(anchor_parameters, features):
287 + """Builds anchors for the shape of the features from FPN.
288 +
289 + Args
290 + anchor_parameters : Parameteres that determine how anchors are generated.
291 + features : The FPN features.
292 +
293 + Returns
294 + A tensor containing the anchors for the FPN features.
295 +
296 + The shape is:
297 + ```
298 + (batch_size, num_anchors, 4)
299 + ```
300 + """
301 + anchors = [
302 + layers.Anchors(
303 + size=anchor_parameters.sizes[i],
304 + stride=anchor_parameters.strides[i],
305 + ratios=anchor_parameters.ratios,
306 + scales=anchor_parameters.scales,
307 + name="anchors_{}".format(i),
308 + )(f)
309 + for i, f in enumerate(features)
310 + ]
311 +
312 + return keras.layers.Concatenate(axis=1, name="anchors")(anchors)
313 +
314 +
315 +def retinanet(
316 + inputs,
317 + backbone_layers,
318 + num_classes,
319 + num_anchors=None,
320 + create_pyramid_features=__create_pyramid_features,
321 + pyramid_levels=None,
322 + submodels=None,
323 + name="retinanet",
324 +):
325 + """Construct a RetinaNet model on top of a backbone.
326 +
327 + This model is the minimum model necessary for training (with the unfortunate exception of anchors as output).
328 +
329 + Args
330 + inputs : keras.layers.Input (or list of) for the input to the model.
331 + num_classes : Number of classes to classify.
332 + num_anchors : Number of base anchors.
333 + create_pyramid_features : Functor for creating pyramid features given the features C3, C4, C5, and possibly C2 from the backbone.
334 + pyramid_levels : pyramid levels to use.
335 + submodels : Submodels to run on each feature map (default is regression and classification submodels).
336 + name : Name of the model.
337 +
338 + Returns
339 + A keras.models.Model which takes an image as input and outputs generated anchors and the result from each submodel on every pyramid level.
340 +
341 + The order of the outputs is as defined in submodels:
342 + ```
343 + [
344 + regression, classification, other[0], other[1], ...
345 + ]
346 + ```
347 + """
348 +
349 + if num_anchors is None:
350 + num_anchors = AnchorParameters.default.num_anchors()
351 +
352 + if submodels is None:
353 + submodels = default_submodels(num_classes, num_anchors)
354 +
355 + if pyramid_levels is None:
356 + pyramid_levels = [3, 4, 5, 6, 7]
357 +
358 + if 2 in pyramid_levels and "C2" not in backbone_layers:
359 + raise ValueError("C2 not provided by backbone model. Cannot create P2 layers.")
360 +
361 + if 3 not in pyramid_levels or 4 not in pyramid_levels or 5 not in pyramid_levels:
362 + raise ValueError("pyramid levels 3, 4, and 5 required for functionality")
363 +
364 + # compute pyramid features as per https://arxiv.org/abs/1708.02002
365 + features = create_pyramid_features(backbone_layers, pyramid_levels)
366 + feature_list = [features["P{}".format(p)] for p in pyramid_levels]
367 +
368 + # for all pyramid levels, run available submodels
369 + pyramids = __build_pyramid(submodels, feature_list)
370 +
371 + return keras.models.Model(inputs=inputs, outputs=pyramids, name=name)
372 +
373 +
374 +def retinanet_bbox(
375 + model=None,
376 + nms=True,
377 + class_specific_filter=True,
378 + name="retinanet-bbox",
379 + anchor_params=None,
380 + pyramid_levels=None,
381 + nms_threshold=0.5,
382 + score_threshold=0.05,
383 + max_detections=300,
384 + parallel_iterations=32,
385 + **kwargs
386 +):
387 + """Construct a RetinaNet model on top of a backbone and adds convenience functions to output boxes directly.
388 +
389 + This model uses the minimum retinanet model and appends a few layers to compute boxes within the graph.
390 + These layers include applying the regression values to the anchors and performing NMS.
391 +
392 + Args
393 + model : RetinaNet model to append bbox layers to. If None, it will create a RetinaNet model using **kwargs.
394 + nms : Whether to use non-maximum suppression for the filtering step.
395 + class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only.
396 + name : Name of the model.
397 + anchor_params : Struct containing anchor parameters. If None, default values are used.
398 + pyramid_levels : pyramid levels to use.
399 + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed.
400 + score_threshold : Threshold used to prefilter the boxes with.
401 + max_detections : Maximum number of detections to keep.
402 + parallel_iterations : Number of batch items to process in parallel.
403 + **kwargs : Additional kwargs to pass to the minimal retinanet model.
404 +
405 + Returns
406 + A keras.models.Model which takes an image as input and outputs the detections on the image.
407 +
408 + The order is defined as follows:
409 + ```
410 + [
411 + boxes, scores, labels, other[0], other[1], ...
412 + ]
413 + ```
414 + """
415 +
416 + # if no anchor parameters are passed, use default values
417 + if anchor_params is None:
418 + anchor_params = AnchorParameters.default
419 +
420 + # create RetinaNet model
421 + if model is None:
422 + model = retinanet(num_anchors=anchor_params.num_anchors(), **kwargs)
423 + else:
424 + assert_training_model(model)
425 +
426 + if pyramid_levels is None:
427 + pyramid_levels = [3, 4, 5, 6, 7]
428 +
429 + assert len(pyramid_levels) == len(
430 + anchor_params.sizes
431 + ), "number of pyramid levels {} should match number of anchor parameter sizes {}".format(
432 + len(pyramid_levels), len(anchor_params.sizes)
433 + )
434 +
435 + pyramid_layer_names = ["P{}".format(p) for p in pyramid_levels]
436 + # compute the anchors
437 + features = [model.get_layer(p_name).output for p_name in pyramid_layer_names]
438 + anchors = __build_anchors(anchor_params, features)
439 +
440 + # we expect the anchors, regression and classification values as first output
441 + regression = model.outputs[0]
442 + classification = model.outputs[1]
443 +
444 + # "other" can be any additional output from custom submodels, by default this will be []
445 + other = model.outputs[2:]
446 +
447 + # apply predicted regression to anchors
448 + boxes = layers.RegressBoxes(name="boxes")([anchors, regression])
449 + boxes = layers.ClipBoxes(name="clipped_boxes")([model.inputs[0], boxes])
450 +
451 + # filter detections (apply NMS / score threshold / select top-k)
452 + detections = layers.FilterDetections(
453 + nms=nms,
454 + class_specific_filter=class_specific_filter,
455 + name="filtered_detections",
456 + nms_threshold=nms_threshold,
457 + score_threshold=score_threshold,
458 + max_detections=max_detections,
459 + parallel_iterations=parallel_iterations,
460 + )([boxes, classification] + other)
461 +
462 + # construct the model
463 + return keras.models.Model(inputs=model.inputs, outputs=detections, name=name)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +
19 +from . import retinanet
20 +from . import Backbone
21 +from classification_models.keras import Classifiers
22 +
23 +
24 +class SeBackbone(Backbone):
25 + """ Describes backbone information and provides utility functions.
26 + """
27 +
28 + def __init__(self, backbone):
29 + super(SeBackbone, self).__init__(backbone)
30 + _, self.preprocess_image_func = Classifiers.get(self.backbone)
31 +
32 + def retinanet(self, *args, **kwargs):
33 + """ Returns a retinanet model using the correct backbone.
34 + """
35 + return senet_retinanet(*args, backbone=self.backbone, **kwargs)
36 +
37 + def download_imagenet(self):
38 + """ Downloads ImageNet weights and returns path to weights file.
39 + """
40 + from classification_models.weights import WEIGHTS_COLLECTION
41 +
42 + weights_path = None
43 + for el in WEIGHTS_COLLECTION:
44 + if el['model'] == self.backbone and not el['include_top']:
45 + weights_path = keras.utils.get_file(el['name'], el['url'], cache_subdir='models', file_hash=el['md5'])
46 +
47 + if weights_path is None:
48 + raise ValueError('Unable to find imagenet weights for backbone {}!'.format(self.backbone))
49 +
50 + return weights_path
51 +
52 + def validate(self):
53 + """ Checks whether the backbone string is correct.
54 + """
55 + allowed_backbones = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152',
56 + 'seresnext50', 'seresnext101', 'senet154']
57 + backbone = self.backbone.split('_')[0]
58 +
59 + if backbone not in allowed_backbones:
60 + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones))
61 +
62 + def preprocess_image(self, inputs):
63 + """ Takes as input an image and prepares it for being passed through the network.
64 + """
65 + return self.preprocess_image_func(inputs)
66 +
67 +
68 +def senet_retinanet(num_classes, backbone='seresnext50', inputs=None, modifier=None, **kwargs):
69 + """ Constructs a retinanet model using a resnet backbone.
70 +
71 + Args
72 + num_classes: Number of classes to predict.
73 + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')).
74 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
75 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
76 +
77 + Returns
78 + RetinaNet model with a ResNet backbone.
79 + """
80 + # choose default input
81 + if inputs is None:
82 + if keras.backend.image_data_format() == 'channels_first':
83 + inputs = keras.layers.Input(shape=(3, None, None))
84 + else:
85 + # inputs = keras.layers.Input(shape=(224, 224, 3))
86 + inputs = keras.layers.Input(shape=(None, None, 3))
87 +
88 + classifier, _ = Classifiers.get(backbone)
89 + model = classifier(input_tensor=inputs, include_top=False, weights=None)
90 +
91 + # get last conv layer from the end of each block [28x28, 14x14, 7x7]
92 + if backbone == 'seresnet18' or backbone == 'seresnet34':
93 + layer_outputs = ['stage3_unit1_relu1', 'stage4_unit1_relu1', 'relu1']
94 + elif backbone == 'seresnet50':
95 + layer_outputs = ['activation_36', 'activation_66', 'activation_81']
96 + elif backbone == 'seresnet101':
97 + layer_outputs = ['activation_36', 'activation_151', 'activation_166']
98 + elif backbone == 'seresnet152':
99 + layer_outputs = ['activation_56', 'activation_236', 'activation_251']
100 + elif backbone == 'seresnext50':
101 + layer_outputs = ['activation_37', 'activation_67', 'activation_81']
102 + elif backbone == 'seresnext101':
103 + layer_outputs = ['activation_37', 'activation_152', 'activation_166']
104 + elif backbone == 'senet154':
105 + layer_outputs = ['activation_59', 'activation_239', 'activation_253']
106 + else:
107 + raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone))
108 +
109 + layer_outputs = [
110 + model.get_layer(name=layer_outputs[0]).output, # 28x28
111 + model.get_layer(name=layer_outputs[1]).output, # 14x14
112 + model.get_layer(name=layer_outputs[2]).output, # 7x7
113 + ]
114 + # create the densenet backbone
115 + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name)
116 +
117 + # invoke modifier if given
118 + if modifier:
119 + model = modifier(model)
120 +
121 + # C2 not provided
122 + backbone_layers = {
123 + 'C3': model.outputs[0],
124 + 'C4': model.outputs[1],
125 + 'C5': model.outputs[2]
126 + }
127 +
128 + # create the full model
129 + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs)
130 +
131 +
132 +def seresnet18_retinanet(num_classes, inputs=None, **kwargs):
133 + return senet_retinanet(num_classes=num_classes, backbone='seresnet18', inputs=inputs, **kwargs)
134 +
135 +
136 +def seresnet34_retinanet(num_classes, inputs=None, **kwargs):
137 + return senet_retinanet(num_classes=num_classes, backbone='seresnet34', inputs=inputs, **kwargs)
138 +
139 +
140 +def seresnet50_retinanet(num_classes, inputs=None, **kwargs):
141 + return senet_retinanet(num_classes=num_classes, backbone='seresnet50', inputs=inputs, **kwargs)
142 +
143 +
144 +def seresnet101_retinanet(num_classes, inputs=None, **kwargs):
145 + return senet_retinanet(num_classes=num_classes, backbone='seresnet101', inputs=inputs, **kwargs)
146 +
147 +
148 +def seresnet152_retinanet(num_classes, inputs=None, **kwargs):
149 + return senet_retinanet(num_classes=num_classes, backbone='seresnet152', inputs=inputs, **kwargs)
150 +
151 +
152 +def seresnext50_retinanet(num_classes, inputs=None, **kwargs):
153 + return senet_retinanet(num_classes=num_classes, backbone='seresnext50', inputs=inputs, **kwargs)
154 +
155 +
156 +def seresnext101_retinanet(num_classes, inputs=None, **kwargs):
157 + return senet_retinanet(num_classes=num_classes, backbone='seresnext101', inputs=inputs, **kwargs)
158 +
159 +
160 +def senet154_retinanet(num_classes, inputs=None, **kwargs):
161 + return senet_retinanet(num_classes=num_classes, backbone='senet154', inputs=inputs, **kwargs)
1 +from tensorflow import keras
2 +from .. import initializers
3 +from .. import layers
4 +from ..utils.anchors import AnchorParameters
5 +from . import assert_training_model
6 +from . import retinanet
7 +
8 +def custom_classification_model(
9 + num_classes,
10 + num_anchors,
11 + pyramid_feature_size=256,
12 + prior_probability=0.01,
13 + classification_feature_size=256,
14 + name='classification_submodel'
15 +):
16 + # set input
17 + if keras.backend.image_data_format() == "channels_first":
18 + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None))
19 + else:
20 + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size))
21 +
22 + outputs = inputs
23 +
24 + # 3 layer
25 + for i in range(3):
26 +
27 + # 각 층의 output
28 + outputs = keras.layers.Conv2D(
29 + filters=classification_feature_size,
30 + activation="relu",
31 + name="pyramid_classification_{}".format(i),
32 + kernel_initializer=keras.initializers.RandomNormal(
33 + mean=0.0, stddev=0.01, seed=None
34 + ), # 정규분포에 따라 텐서를 생성하는 초기값 설정
35 + bias_initializer="zeros",
36 + **options
37 + )(outputs)
38 +
39 + # 마지막 layer는 다른 필터로 다른 conv layer를 생성
40 + outputs = keras.layers.Conv2D(
41 + filters=num_classes * num_anchors,
42 + kernel_initializer=keras.initializers.RandomNormal(
43 + mean=0.0, stddev=0.01, seed=None
44 + ),
45 + bias_initializer=initializers.PriorProbability(probability=prior_probability),
46 + name="pyramid_classification",
47 + **options
48 + )(outputs)
49 +
50 + # reshape output and apply sigmoid
51 + if keras.backend.image_data_format() == "channels_first":
52 + outputs = keras.layers.Permute(
53 + (2, 3, 1), name="pyramid_classification_permute"
54 + )(outputs)
55 +
56 + # reshape : 2차원 > 1차원
57 + outputs = keras.layers.Reshape(
58 + (-1, num_classes), name="pyramid_classification_reshape"
59 + )(outputs)
60 +
61 + # output layer activation : sigmoid
62 + outputs = keras.layers.Activation("sigmoid", name="pyramid_classification_sigmoid")(
63 + outputs
64 + )
65 +
66 + return keras.models.Model(inputs=inputs, outputs=outputs, name=name)
67 +
68 +def custom_regression_model(num_values, num_anchors, pyramid_feature_size=256, regression_feature_size=256, name='regression_submodel'):
69 + if num_anchors is None:
70 + num_anchors = AnchorParameters.default.num_anchors()
71 + model = retinanet.default_regression_model(num_values, num_anchors, pyramid_feature_size, regression_feature_size, name)
72 + return model
73 +
74 +
75 +def custom_submodels(num_classes, num_anchors):
76 + if num_anchors is None:
77 + num_anchors = AnchorParameters.default.num_anchors()
78 + return [
79 + ("regression", custom_regression_model(4, num_anchors)),
80 + ("classification", custom_classification_model(num_classes, num_anchors)),
81 + ]
82 +
83 +
84 +
1 +"""
2 +Copyright 2017-2018 cgratie (https://github.com/cgratie/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +
18 +from tensorflow import keras
19 +
20 +from . import retinanet
21 +from . import Backbone
22 +from ..utils.image import preprocess_image
23 +
24 +
25 +class VGGBackbone(Backbone):
26 + """ Describes backbone information and provides utility functions.
27 + """
28 +
29 + def retinanet(self, *args, **kwargs):
30 + """ Returns a retinanet model using the correct backbone.
31 + """
32 + return vgg_retinanet(*args, backbone=self.backbone, **kwargs)
33 +
34 + def download_imagenet(self):
35 + """ Downloads ImageNet weights and returns path to weights file.
36 + Weights can be downloaded at https://github.com/fizyr/keras-models/releases .
37 + """
38 + if self.backbone == 'vgg16':
39 + resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP
40 + checksum = '6d6bbae143d832006294945121d1f1fc'
41 + elif self.backbone == 'vgg19':
42 + resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP
43 + checksum = '253f8cb515780f3b799900260a226db6'
44 + else:
45 + raise ValueError("Backbone '{}' not recognized.".format(self.backbone))
46 +
47 + return keras.utils.get_file(
48 + '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone),
49 + resource,
50 + cache_subdir='models',
51 + file_hash=checksum
52 + )
53 +
54 + def validate(self):
55 + """ Checks whether the backbone string is correct.
56 + """
57 + allowed_backbones = ['vgg16', 'vgg19']
58 +
59 + if self.backbone not in allowed_backbones:
60 + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones))
61 +
62 + def preprocess_image(self, inputs):
63 + """ Takes as input an image and prepares it for being passed through the network.
64 + """
65 + return preprocess_image(inputs, mode='caffe')
66 +
67 +
68 +def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs):
69 + """ Constructs a retinanet model using a vgg backbone.
70 +
71 + Args
72 + num_classes: Number of classes to predict.
73 + backbone: Which backbone to use (one of ('vgg16', 'vgg19')).
74 + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
75 + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
76 +
77 + Returns
78 + RetinaNet model with a VGG backbone.
79 + """
80 + # choose default input
81 + if inputs is None:
82 + inputs = keras.layers.Input(shape=(None, None, 3))
83 +
84 + # create the vgg backbone
85 + if backbone == 'vgg16':
86 + vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None)
87 + elif backbone == 'vgg19':
88 + vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None)
89 + else:
90 + raise ValueError("Backbone '{}' not recognized.".format(backbone))
91 +
92 + if modifier:
93 + vgg = modifier(vgg)
94 +
95 + # create the full model
96 + layer_names = ["block3_pool", "block4_pool", "block5_pool"]
97 + layer_outputs = [vgg.get_layer(name).output for name in layer_names]
98 +
99 + # C2 not provided
100 + backbone_layers = {
101 + 'C3': layer_outputs[0],
102 + 'C4': layer_outputs[1],
103 + 'C5': layer_outputs[2]
104 + }
105 +
106 + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from ..preprocessing.generator import Generator
18 +from ..utils.image import read_image_bgr
19 +
20 +import os
21 +import numpy as np
22 +
23 +from pycocotools.coco import COCO
24 +
25 +
26 +class CocoGenerator(Generator):
27 + """ Generate data from the COCO dataset.
28 +
29 + See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information.
30 + """
31 +
32 + def __init__(self, data_dir, set_name, **kwargs):
33 + """ Initialize a COCO data generator.
34 +
35 + Args
36 + data_dir: Path to where the COCO dataset is stored.
37 + set_name: Name of the set to parse.
38 + """
39 + self.data_dir = data_dir
40 + self.set_name = set_name
41 + self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json'))
42 + self.image_ids = self.coco.getImgIds()
43 +
44 + self.load_classes()
45 +
46 + super(CocoGenerator, self).__init__(**kwargs)
47 +
48 + def load_classes(self):
49 + """ Loads the class to label mapping (and inverse) for COCO.
50 + """
51 + # load class names (name -> label)
52 + categories = self.coco.loadCats(self.coco.getCatIds())
53 + categories.sort(key=lambda x: x['id'])
54 +
55 + self.classes = {}
56 + self.coco_labels = {}
57 + self.coco_labels_inverse = {}
58 + for c in categories:
59 + self.coco_labels[len(self.classes)] = c['id']
60 + self.coco_labels_inverse[c['id']] = len(self.classes)
61 + self.classes[c['name']] = len(self.classes)
62 +
63 + # also load the reverse (label -> name)
64 + self.labels = {}
65 + for key, value in self.classes.items():
66 + self.labels[value] = key
67 +
68 + def size(self):
69 + """ Size of the COCO dataset.
70 + """
71 + return len(self.image_ids)
72 +
73 + def num_classes(self):
74 + """ Number of classes in the dataset. For COCO this is 80.
75 + """
76 + return len(self.classes)
77 +
78 + def has_label(self, label):
79 + """ Return True if label is a known label.
80 + """
81 + return label in self.labels
82 +
83 + def has_name(self, name):
84 + """ Returns True if name is a known class.
85 + """
86 + return name in self.classes
87 +
88 + def name_to_label(self, name):
89 + """ Map name to label.
90 + """
91 + return self.classes[name]
92 +
93 + def label_to_name(self, label):
94 + """ Map label to name.
95 + """
96 + return self.labels[label]
97 +
98 + def coco_label_to_label(self, coco_label):
99 + """ Map COCO label to the label as used in the network.
100 + COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes.
101 + """
102 + return self.coco_labels_inverse[coco_label]
103 +
104 + def coco_label_to_name(self, coco_label):
105 + """ Map COCO label to name.
106 + """
107 + return self.label_to_name(self.coco_label_to_label(coco_label))
108 +
109 + def label_to_coco_label(self, label):
110 + """ Map label as used by the network to labels as used by COCO.
111 + """
112 + return self.coco_labels[label]
113 +
114 + def image_path(self, image_index):
115 + """ Returns the image path for image_index.
116 + """
117 + image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
118 + path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name'])
119 + return path
120 +
121 + def image_aspect_ratio(self, image_index):
122 + """ Compute the aspect ratio for an image with image_index.
123 + """
124 + image = self.coco.loadImgs(self.image_ids[image_index])[0]
125 + return float(image['width']) / float(image['height'])
126 +
127 + def load_image(self, image_index):
128 + """ Load an image at the image_index.
129 + """
130 + path = self.image_path(image_index)
131 + return read_image_bgr(path)
132 +
133 + def load_annotations(self, image_index):
134 + """ Load annotations for an image_index.
135 + """
136 + # get ground truth annotations
137 + annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
138 + annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))}
139 +
140 + # some images appear to miss annotations (like image with id 257034)
141 + if len(annotations_ids) == 0:
142 + return annotations
143 +
144 + # parse annotations
145 + coco_annotations = self.coco.loadAnns(annotations_ids)
146 + for idx, a in enumerate(coco_annotations):
147 + # some annotations have basically no width / height, skip them
148 + if a['bbox'][2] < 1 or a['bbox'][3] < 1:
149 + continue
150 +
151 + annotations['labels'] = np.concatenate([annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0)
152 + annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[
153 + a['bbox'][0],
154 + a['bbox'][1],
155 + a['bbox'][0] + a['bbox'][2],
156 + a['bbox'][1] + a['bbox'][3],
157 + ]]], axis=0)
158 +
159 + return annotations
1 +"""
2 +Copyright 2017-2018 yhenon (https://github.com/yhenon/)
3 +Copyright 2017-2018 Fizyr (https://fizyr.com)
4 +
5 +Licensed under the Apache License, Version 2.0 (the "License");
6 +you may not use this file except in compliance with the License.
7 +You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 +Unless required by applicable law or agreed to in writing, software
12 +distributed under the License is distributed on an "AS IS" BASIS,
13 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 +See the License for the specific language governing permissions and
15 +limitations under the License.
16 +"""
17 +
18 +from .generator import Generator
19 +from ..utils.image import read_image_bgr
20 +
21 +import numpy as np
22 +from PIL import Image
23 +from six import raise_from
24 +
25 +import csv
26 +import sys
27 +import os.path
28 +from collections import OrderedDict
29 +
30 +
31 +def _parse(value, function, fmt):
32 + """
33 + Parse a string into a value, and format a nice ValueError if it fails.
34 +
35 + Returns `function(value)`.
36 + Any `ValueError` raised is catched and a new `ValueError` is raised
37 + with message `fmt.format(e)`, where `e` is the caught `ValueError`.
38 + """
39 + try:
40 + return function(value)
41 + except ValueError as e:
42 + raise_from(ValueError(fmt.format(e)), None)
43 +
44 +
45 +def _read_classes(csv_reader):
46 + """ Parse the classes file given by csv_reader.
47 + """
48 + result = OrderedDict()
49 + for line, row in enumerate(csv_reader):
50 + line += 1
51 +
52 + try:
53 + class_name, class_id = row
54 + except ValueError:
55 + raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None)
56 + class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line))
57 +
58 + if class_name in result:
59 + raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
60 + result[class_name] = class_id
61 + return result
62 +
63 +
64 +def _read_annotations(csv_reader, classes):
65 + """ Read annotations from the csv_reader.
66 + """
67 + result = OrderedDict()
68 + for line, row in enumerate(csv_reader):
69 + line += 1
70 +
71 + try:
72 + img_file, x1, y1, x2, y2, class_name = row[:6]
73 + except ValueError:
74 + raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None)
75 +
76 + if img_file not in result:
77 + result[img_file] = []
78 +
79 + # If a row contains only an image path, it's an image without annotations.
80 + if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''):
81 + continue
82 +
83 + x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line))
84 + y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line))
85 + x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line))
86 + y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line))
87 +
88 + # Check that the bounding box is valid.
89 + if x2 <= x1:
90 + raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
91 + if y2 <= y1:
92 + raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
93 +
94 + # check if the current class name is correctly present
95 + if class_name not in classes:
96 + raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes))
97 +
98 + result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name})
99 + return result
100 +
101 +
102 +def _open_for_csv(path):
103 + """ Open a file with flags suitable for csv.reader.
104 +
105 + This is different for python2 it means with mode 'rb',
106 + for python3 this means 'r' with "universal newlines".
107 + """
108 + if sys.version_info[0] < 3:
109 + return open(path, 'rb')
110 + else:
111 + return open(path, 'r', newline='')
112 +
113 +
114 +class CSVGenerator(Generator):
115 + """ Generate data for a custom CSV dataset.
116 +
117 + See https://github.com/fizyr/keras-retinanet#csv-datasets for more information.
118 + """
119 +
120 + def __init__(
121 + self,
122 + csv_data_file,
123 + csv_class_file,
124 + base_dir=None,
125 + **kwargs
126 + ):
127 + """ Initialize a CSV data generator.
128 +
129 + Args
130 + csv_data_file: Path to the CSV annotations file.
131 + csv_class_file: Path to the CSV classes file.
132 + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
133 + """
134 + self.image_names = []
135 + self.image_data = {}
136 + self.base_dir = base_dir
137 +
138 + # Take base_dir from annotations file if not explicitly specified.
139 + if self.base_dir is None:
140 + self.base_dir = os.path.dirname(csv_data_file)
141 +
142 + # parse the provided class file
143 + try:
144 + with _open_for_csv(csv_class_file) as file:
145 + self.classes = _read_classes(csv.reader(file, delimiter=','))
146 + except ValueError as e:
147 + raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None)
148 +
149 + self.labels = {}
150 + for key, value in self.classes.items():
151 + self.labels[value] = key
152 +
153 + # csv with img_path, x1, y1, x2, y2, class_name
154 + try:
155 + with _open_for_csv(csv_data_file) as file:
156 + self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes)
157 + except ValueError as e:
158 + raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None)
159 + self.image_names = list(self.image_data.keys())
160 +
161 + super(CSVGenerator, self).__init__(**kwargs)
162 +
163 + def size(self):
164 + """ Size of the dataset.
165 + """
166 + return len(self.image_names)
167 +
168 + def num_classes(self):
169 + """ Number of classes in the dataset.
170 + """
171 + return max(self.classes.values()) + 1
172 +
173 + def has_label(self, label):
174 + """ Return True if label is a known label.
175 + """
176 + return label in self.labels
177 +
178 + def has_name(self, name):
179 + """ Returns True if name is a known class.
180 + """
181 + return name in self.classes
182 +
183 + def name_to_label(self, name):
184 + """ Map name to label.
185 + """
186 + return self.classes[name]
187 +
188 + def label_to_name(self, label):
189 + """ Map label to name.
190 + """
191 + return self.labels[label]
192 +
193 + def image_path(self, image_index):
194 + """ Returns the image path for image_index.
195 + """
196 + return os.path.join(self.base_dir, self.image_names[image_index])
197 +
198 + def image_aspect_ratio(self, image_index):
199 + """ Compute the aspect ratio for an image with image_index.
200 + """
201 + # PIL is fast for metadata
202 + image = Image.open(self.image_path(image_index))
203 + return float(image.width) / float(image.height)
204 +
205 + def load_image(self, image_index):
206 + """ Load an image at the image_index.
207 + """
208 + return read_image_bgr(self.image_path(image_index))
209 +
210 + def load_annotations(self, image_index):
211 + """ Load annotations for an image_index.
212 + """
213 + path = self.image_names[image_index]
214 + annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))}
215 +
216 + for idx, annot in enumerate(self.image_data[path]):
217 + annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])]))
218 + annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[
219 + float(annot['x1']),
220 + float(annot['y1']),
221 + float(annot['x2']),
222 + float(annot['y2']),
223 + ]]))
224 +
225 + return annotations
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import numpy as np
18 +import random
19 +import warnings
20 +
21 +from tensorflow import keras
22 +
23 +from ..utils.anchors import (
24 + anchor_targets_bbox,
25 + anchors_for_shape,
26 + guess_shapes
27 +)
28 +from ..utils.config import parse_anchor_parameters, parse_pyramid_levels
29 +from ..utils.image import (
30 + TransformParameters,
31 + adjust_transform_for_image,
32 + apply_transform,
33 + preprocess_image,
34 + resize_image,
35 +)
36 +from ..utils.transform import transform_aabb
37 +
38 +
39 +class Generator(keras.utils.Sequence):
40 + """ Abstract generator class.
41 + """
42 +
43 + def __init__(
44 + self,
45 + transform_generator = None,
46 + visual_effect_generator=None,
47 + batch_size=1,
48 + group_method='ratio', # one of 'none', 'random', 'ratio'
49 + shuffle_groups=True,
50 + image_min_side=800,
51 + image_max_side=1333,
52 + no_resize=False,
53 + transform_parameters=None,
54 + compute_anchor_targets=anchor_targets_bbox,
55 + compute_shapes=guess_shapes,
56 + preprocess_image=preprocess_image,
57 + config=None
58 + ):
59 + """ Initialize Generator object.
60 +
61 + Args
62 + transform_generator : A generator used to randomly transform images and annotations.
63 + batch_size : The size of the batches to generate.
64 + group_method : Determines how images are grouped together (defaults to 'ratio', one of ('none', 'random', 'ratio')).
65 + shuffle_groups : If True, shuffles the groups each epoch.
66 + image_min_side : After resizing the minimum side of an image is equal to image_min_side.
67 + image_max_side : If after resizing the maximum side is larger than image_max_side, scales down further so that the max side is equal to image_max_side.
68 + no_resize : If True, no image/annotation resizing is performed.
69 + transform_parameters : The transform parameters used for data augmentation.
70 + compute_anchor_targets : Function handler for computing the targets of anchors for an image and its annotations.
71 + compute_shapes : Function handler for computing the shapes of the pyramid for a given input.
72 + preprocess_image : Function handler for preprocessing an image (scaling / normalizing) for passing through a network.
73 + """
74 + self.transform_generator = transform_generator
75 + self.visual_effect_generator = visual_effect_generator
76 + self.batch_size = int(batch_size)
77 + self.group_method = group_method
78 + self.shuffle_groups = shuffle_groups
79 + self.image_min_side = image_min_side
80 + self.image_max_side = image_max_side
81 + self.no_resize = no_resize
82 + self.transform_parameters = transform_parameters or TransformParameters()
83 + self.compute_anchor_targets = compute_anchor_targets
84 + self.compute_shapes = compute_shapes
85 + self.preprocess_image = preprocess_image
86 + self.config = config
87 +
88 + # Define groups
89 + self.group_images()
90 +
91 + # Shuffle when initializing
92 + if self.shuffle_groups:
93 + self.on_epoch_end()
94 +
95 + def on_epoch_end(self):
96 + if self.shuffle_groups:
97 + random.shuffle(self.groups)
98 +
99 + def size(self):
100 + """ Size of the dataset.
101 + """
102 + raise NotImplementedError('size method not implemented')
103 +
104 + def num_classes(self):
105 + """ Number of classes in the dataset.
106 + """
107 + raise NotImplementedError('num_classes method not implemented')
108 +
109 + def has_label(self, label):
110 + """ Returns True if label is a known label.
111 + """
112 + raise NotImplementedError('has_label method not implemented')
113 +
114 + def has_name(self, name):
115 + """ Returns True if name is a known class.
116 + """
117 + raise NotImplementedError('has_name method not implemented')
118 +
119 + def name_to_label(self, name):
120 + """ Map name to label.
121 + """
122 + raise NotImplementedError('name_to_label method not implemented')
123 +
124 + def label_to_name(self, label):
125 + """ Map label to name.
126 + """
127 + raise NotImplementedError('label_to_name method not implemented')
128 +
129 + def image_aspect_ratio(self, image_index):
130 + """ Compute the aspect ratio for an image with image_index.
131 + """
132 + raise NotImplementedError('image_aspect_ratio method not implemented')
133 +
134 + def image_path(self, image_index):
135 + """ Get the path to an image.
136 + """
137 + raise NotImplementedError('image_path method not implemented')
138 +
139 + def load_image(self, image_index):
140 + """ Load an image at the image_index.
141 + """
142 + raise NotImplementedError('load_image method not implemented')
143 +
144 + def load_annotations(self, image_index):
145 + """ Load annotations for an image_index.
146 + """
147 + raise NotImplementedError('load_annotations method not implemented')
148 +
149 + def load_annotations_group(self, group):
150 + """ Load annotations for all images in group.
151 + """
152 + annotations_group = [self.load_annotations(image_index) for image_index in group]
153 + for annotations in annotations_group:
154 + assert(isinstance(annotations, dict)), '\'load_annotations\' should return a list of dictionaries, received: {}'.format(type(annotations))
155 + assert('labels' in annotations), '\'load_annotations\' should return a list of dictionaries that contain \'labels\' and \'bboxes\'.'
156 + assert('bboxes' in annotations), '\'load_annotations\' should return a list of dictionaries that contain \'labels\' and \'bboxes\'.'
157 +
158 + return annotations_group
159 +
160 + def filter_annotations(self, image_group, annotations_group, group):
161 + """ Filter annotations by removing those that are outside of the image bounds or whose width/height < 0.
162 + """
163 + # test all annotations
164 + for index, (image, annotations) in enumerate(zip(image_group, annotations_group)):
165 + # test x2 < x1 | y2 < y1 | x1 < 0 | y1 < 0 | x2 <= 0 | y2 <= 0 | x2 >= image.shape[1] | y2 >= image.shape[0]
166 + invalid_indices = np.where(
167 + (annotations['bboxes'][:, 2] <= annotations['bboxes'][:, 0]) |
168 + (annotations['bboxes'][:, 3] <= annotations['bboxes'][:, 1]) |
169 + (annotations['bboxes'][:, 0] < 0) |
170 + (annotations['bboxes'][:, 1] < 0) |
171 + (annotations['bboxes'][:, 2] > image.shape[1]) |
172 + (annotations['bboxes'][:, 3] > image.shape[0])
173 + )[0]
174 +
175 + # delete invalid indices
176 + if len(invalid_indices):
177 + warnings.warn('Image {} with id {} (shape {}) contains the following invalid boxes: {}.'.format(
178 + self.image_path(group[index]),
179 + group[index],
180 + image.shape,
181 + annotations['bboxes'][invalid_indices, :]
182 + ))
183 + for k in annotations_group[index].keys():
184 + annotations_group[index][k] = np.delete(annotations[k], invalid_indices, axis=0)
185 + return image_group, annotations_group
186 +
187 + def load_image_group(self, group):
188 + """ Load images for all images in a group.
189 + """
190 + return [self.load_image(image_index) for image_index in group]
191 +
192 + def random_visual_effect_group_entry(self, image, annotations):
193 + """ Randomly transforms image and annotation.
194 + """
195 + visual_effect = next(self.visual_effect_generator)
196 + # apply visual effect
197 + image = visual_effect(image)
198 + return image, annotations
199 +
200 + def random_visual_effect_group(self, image_group, annotations_group):
201 + """ Randomly apply visual effect on each image.
202 + """
203 + assert(len(image_group) == len(annotations_group))
204 +
205 + if self.visual_effect_generator is None:
206 + # do nothing
207 + return image_group, annotations_group
208 +
209 + for index in range(len(image_group)):
210 + # apply effect on a single group entry
211 + image_group[index], annotations_group[index] = self.random_visual_effect_group_entry(
212 + image_group[index], annotations_group[index]
213 + )
214 +
215 + return image_group, annotations_group
216 +
217 + def random_transform_group_entry(self, image, annotations, transform=None):
218 + """ Randomly transforms image and annotation.
219 + """
220 + # randomly transform both image and annotations
221 + if transform is not None or self.transform_generator:
222 + if transform is None:
223 + transform = adjust_transform_for_image(next(self.transform_generator), image, self.transform_parameters.relative_translation)
224 +
225 + # apply transformation to image
226 + image = apply_transform(transform, image, self.transform_parameters)
227 +
228 + # Transform the bounding boxes in the annotations.
229 + annotations['bboxes'] = annotations['bboxes'].copy()
230 + for index in range(annotations['bboxes'].shape[0]):
231 + annotations['bboxes'][index, :] = transform_aabb(transform, annotations['bboxes'][index, :])
232 +
233 + return image, annotations
234 +
235 + def random_transform_group(self, image_group, annotations_group):
236 + """ Randomly transforms each image and its annotations.
237 + """
238 +
239 + assert(len(image_group) == len(annotations_group))
240 +
241 + for index in range(len(image_group)):
242 + # transform a single group entry
243 + image_group[index], annotations_group[index] = self.random_transform_group_entry(image_group[index], annotations_group[index])
244 +
245 + return image_group, annotations_group
246 +
247 + def resize_image(self, image):
248 + """ Resize an image using image_min_side and image_max_side.
249 + """
250 + if self.no_resize:
251 + return image, 1
252 + else:
253 + return resize_image(image, min_side=self.image_min_side, max_side=self.image_max_side)
254 +
255 + def preprocess_group_entry(self, image, annotations):
256 + """ Preprocess image and its annotations.
257 + """
258 + # resize image
259 + image, image_scale = self.resize_image(image)
260 +
261 + # preprocess the image
262 + image = self.preprocess_image(image)
263 +
264 + # apply resizing to annotations too
265 + annotations['bboxes'] *= image_scale
266 +
267 + # convert to the wanted keras floatx
268 + image = keras.backend.cast_to_floatx(image)
269 +
270 + return image, annotations
271 +
272 + def preprocess_group(self, image_group, annotations_group):
273 + """ Preprocess each image and its annotations in its group.
274 + """
275 + assert(len(image_group) == len(annotations_group))
276 +
277 + for index in range(len(image_group)):
278 + # preprocess a single group entry
279 + image_group[index], annotations_group[index] = self.preprocess_group_entry(image_group[index], annotations_group[index])
280 +
281 + return image_group, annotations_group
282 +
283 + def group_images(self):
284 + """ Order the images according to self.order and makes groups of self.batch_size.
285 + """
286 + # determine the order of the images
287 + order = list(range(self.size()))
288 + if self.group_method == 'random':
289 + random.shuffle(order)
290 + elif self.group_method == 'ratio':
291 + order.sort(key=lambda x: self.image_aspect_ratio(x))
292 +
293 + # divide into groups, one group = one batch
294 + self.groups = [[order[x % len(order)] for x in range(i, i + self.batch_size)] for i in range(0, len(order), self.batch_size)]
295 +
296 + def compute_inputs(self, image_group):
297 + """ Compute inputs for the network using an image_group.
298 + """
299 + # get the max image shape
300 + max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3))
301 +
302 + # construct an image batch object
303 + image_batch = np.zeros((self.batch_size,) + max_shape, dtype=keras.backend.floatx())
304 +
305 + # copy all images to the upper left part of the image batch object
306 + for image_index, image in enumerate(image_group):
307 + image_batch[image_index, :image.shape[0], :image.shape[1], :image.shape[2]] = image
308 +
309 + if keras.backend.image_data_format() == 'channels_first':
310 + image_batch = image_batch.transpose((0, 3, 1, 2))
311 +
312 + return image_batch
313 +
314 + def generate_anchors(self, image_shape):
315 + anchor_params = None
316 + pyramid_levels = None
317 + if self.config and 'anchor_parameters' in self.config:
318 + anchor_params = parse_anchor_parameters(self.config)
319 + if self.config and 'pyramid_levels' in self.config:
320 + pyramid_levels = parse_pyramid_levels(self.config)
321 +
322 + return anchors_for_shape(image_shape, anchor_params=anchor_params, pyramid_levels=pyramid_levels, shapes_callback=self.compute_shapes)
323 +
324 + def compute_targets(self, image_group, annotations_group):
325 + """ Compute target outputs for the network using images and their annotations.
326 + """
327 + # get the max image shape
328 + max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3))
329 + anchors = self.generate_anchors(max_shape)
330 +
331 + batches = self.compute_anchor_targets(
332 + anchors,
333 + image_group,
334 + annotations_group,
335 + self.num_classes()
336 + )
337 +
338 + return list(batches)
339 +
340 + def compute_input_output(self, group):
341 + """ Compute inputs and target outputs for the network.
342 + """
343 + # load images and annotations
344 + image_group = self.load_image_group(group)
345 + annotations_group = self.load_annotations_group(group)
346 +
347 + # check validity of annotations
348 + image_group, annotations_group = self.filter_annotations(image_group, annotations_group, group)
349 +
350 + # randomly apply visual effect
351 + image_group, annotations_group = self.random_visual_effect_group(image_group, annotations_group)
352 +
353 + # randomly transform data
354 + image_group, annotations_group = self.random_transform_group(image_group, annotations_group)
355 +
356 + # perform preprocessing steps
357 + image_group, annotations_group = self.preprocess_group(image_group, annotations_group)
358 +
359 + # compute network inputs
360 + inputs = self.compute_inputs(image_group)
361 +
362 + # compute network targets
363 + targets = self.compute_targets(image_group, annotations_group)
364 +
365 + return inputs, targets
366 +
367 + def __len__(self):
368 + """
369 + Number of batches for generator.
370 + """
371 +
372 + return len(self.groups)
373 +
374 + def __getitem__(self, index):
375 + """
376 + Keras sequence method for generating batches.
377 + """
378 + group = self.groups[index]
379 + inputs, targets = self.compute_input_output(group)
380 +
381 + return inputs, targets
1 +"""
2 +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import csv
18 +import os.path
19 +
20 +import numpy as np
21 +from PIL import Image
22 +
23 +from .generator import Generator
24 +from ..utils.image import read_image_bgr
25 +
26 +kitti_classes = {
27 + 'Car': 0,
28 + 'Van': 1,
29 + 'Truck': 2,
30 + 'Pedestrian': 3,
31 + 'Person_sitting': 4,
32 + 'Cyclist': 5,
33 + 'Tram': 6,
34 + 'Misc': 7,
35 + 'DontCare': 7
36 +}
37 +
38 +
39 +class KittiGenerator(Generator):
40 + """ Generate data for a KITTI dataset.
41 +
42 + See http://www.cvlibs.net/datasets/kitti/ for more information.
43 + """
44 +
45 + def __init__(
46 + self,
47 + base_dir,
48 + subset='train',
49 + **kwargs
50 + ):
51 + """ Initialize a KITTI data generator.
52 +
53 + Args
54 + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
55 + subset: The subset to generate data for (defaults to 'train').
56 + """
57 + self.base_dir = base_dir
58 +
59 + label_dir = os.path.join(self.base_dir, subset, 'labels')
60 + image_dir = os.path.join(self.base_dir, subset, 'images')
61 +
62 + """
63 + 1 type Describes the type of object: 'Car', 'Van', 'Truck',
64 + 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
65 + 'Misc' or 'DontCare'
66 + 1 truncated Float from 0 (non-truncated) to 1 (truncated), where
67 + truncated refers to the object leaving image boundaries
68 + 1 occluded Integer (0,1,2,3) indicating occlusion state:
69 + 0 = fully visible, 1 = partly occluded
70 + 2 = largely occluded, 3 = unknown
71 + 1 alpha Observation angle of object, ranging [-pi..pi]
72 + 4 bbox 2D bounding box of object in the image (0-based index):
73 + contains left, top, right, bottom pixel coordinates
74 + 3 dimensions 3D object dimensions: height, width, length (in meters)
75 + 3 location 3D object location x,y,z in camera coordinates (in meters)
76 + 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi]
77 + """
78 +
79 + self.labels = {}
80 + self.classes = kitti_classes
81 + for name, label in self.classes.items():
82 + self.labels[label] = name
83 +
84 + self.image_data = dict()
85 + self.images = []
86 + for i, fn in enumerate(os.listdir(label_dir)):
87 + label_fp = os.path.join(label_dir, fn)
88 + image_fp = os.path.join(image_dir, fn.replace('.txt', '.png'))
89 +
90 + self.images.append(image_fp)
91 +
92 + fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl',
93 + 'lx', 'ly', 'lz', 'ry']
94 + with open(label_fp, 'r') as csv_file:
95 + reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames)
96 + boxes = []
97 + for line, row in enumerate(reader):
98 + label = row['type']
99 + cls_id = kitti_classes[label]
100 +
101 + annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']}
102 + boxes.append(annotation)
103 +
104 + self.image_data[i] = boxes
105 +
106 + super(KittiGenerator, self).__init__(**kwargs)
107 +
108 + def size(self):
109 + """ Size of the dataset.
110 + """
111 + return len(self.images)
112 +
113 + def num_classes(self):
114 + """ Number of classes in the dataset.
115 + """
116 + return max(self.classes.values()) + 1
117 +
118 + def has_label(self, label):
119 + """ Return True if label is a known label.
120 + """
121 + return label in self.labels
122 +
123 + def has_name(self, name):
124 + """ Returns True if name is a known class.
125 + """
126 + return name in self.classes
127 +
128 + def name_to_label(self, name):
129 + """ Map name to label.
130 + """
131 + raise NotImplementedError()
132 +
133 + def label_to_name(self, label):
134 + """ Map label to name.
135 + """
136 + return self.labels[label]
137 +
138 + def image_aspect_ratio(self, image_index):
139 + """ Compute the aspect ratio for an image with image_index.
140 + """
141 + # PIL is fast for metadata
142 + image = Image.open(self.images[image_index])
143 + return float(image.width) / float(image.height)
144 +
145 + def image_path(self, image_index):
146 + """ Get the path to an image.
147 + """
148 + return self.images[image_index]
149 +
150 + def load_image(self, image_index):
151 + """ Load an image at the image_index.
152 + """
153 + return read_image_bgr(self.image_path(image_index))
154 +
155 + def load_annotations(self, image_index):
156 + """ Load annotations for an image_index.
157 + """
158 + image_data = self.image_data[image_index]
159 + annotations = {'labels': np.empty((len(image_data),)), 'bboxes': np.empty((len(image_data), 4))}
160 +
161 + for idx, ann in enumerate(image_data):
162 + annotations['bboxes'][idx, 0] = float(ann['x1'])
163 + annotations['bboxes'][idx, 1] = float(ann['y1'])
164 + annotations['bboxes'][idx, 2] = float(ann['x2'])
165 + annotations['bboxes'][idx, 3] = float(ann['y2'])
166 + annotations['labels'][idx] = int(ann['cls_id'])
167 +
168 + return annotations
1 +"""
2 +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import csv
18 +import json
19 +import os
20 +import warnings
21 +
22 +import numpy as np
23 +from PIL import Image
24 +
25 +from .generator import Generator
26 +from ..utils.image import read_image_bgr
27 +
28 +
29 +def load_hierarchy(metadata_dir, version='v4'):
30 + hierarchy = None
31 + if version == 'challenge2018':
32 + hierarchy = 'bbox_labels_500_hierarchy.json'
33 + elif version == 'v4':
34 + hierarchy = 'bbox_labels_600_hierarchy.json'
35 + elif version == 'v3':
36 + hierarchy = 'bbox_labels_600_hierarchy.json'
37 +
38 + hierarchy_json = os.path.join(metadata_dir, hierarchy)
39 + with open(hierarchy_json) as f:
40 + hierarchy_data = json.loads(f.read())
41 +
42 + return hierarchy_data
43 +
44 +
45 +def load_hierarchy_children(hierarchy):
46 + res = [hierarchy['LabelName']]
47 +
48 + if 'Subcategory' in hierarchy:
49 + for subcategory in hierarchy['Subcategory']:
50 + children = load_hierarchy_children(subcategory)
51 +
52 + for c in children:
53 + res.append(c)
54 +
55 + return res
56 +
57 +
58 +def find_hierarchy_parent(hierarchy, parent_cls):
59 + if hierarchy['LabelName'] == parent_cls:
60 + return hierarchy
61 + elif 'Subcategory' in hierarchy:
62 + for child in hierarchy['Subcategory']:
63 + res = find_hierarchy_parent(child, parent_cls)
64 + if res is not None:
65 + return res
66 +
67 + return None
68 +
69 +
70 +def get_labels(metadata_dir, version='v4'):
71 + if version == 'v4' or version == 'challenge2018':
72 + csv_file = 'class-descriptions-boxable.csv' if version == 'v4' else 'challenge-2018-class-descriptions-500.csv'
73 +
74 + boxable_classes_descriptions = os.path.join(metadata_dir, csv_file)
75 + id_to_labels = {}
76 + cls_index = {}
77 +
78 + i = 0
79 + with open(boxable_classes_descriptions) as f:
80 + for row in csv.reader(f):
81 + # make sure the csv row is not empty (usually the last one)
82 + if len(row):
83 + label = row[0]
84 + description = row[1].replace("\"", "").replace("'", "").replace('`', '')
85 +
86 + id_to_labels[i] = description
87 + cls_index[label] = i
88 +
89 + i += 1
90 + else:
91 + trainable_classes_path = os.path.join(metadata_dir, 'classes-bbox-trainable.txt')
92 + description_path = os.path.join(metadata_dir, 'class-descriptions.csv')
93 +
94 + description_table = {}
95 + with open(description_path) as f:
96 + for row in csv.reader(f):
97 + # make sure the csv row is not empty (usually the last one)
98 + if len(row):
99 + description_table[row[0]] = row[1].replace("\"", "").replace("'", "").replace('`', '')
100 +
101 + with open(trainable_classes_path, 'rb') as f:
102 + trainable_classes = f.read().split('\n')
103 +
104 + id_to_labels = dict([(i, description_table[c]) for i, c in enumerate(trainable_classes)])
105 + cls_index = dict([(c, i) for i, c in enumerate(trainable_classes)])
106 +
107 + return id_to_labels, cls_index
108 +
109 +
110 +def generate_images_annotations_json(main_dir, metadata_dir, subset, cls_index, version='v4'):
111 + validation_image_ids = {}
112 +
113 + if version == 'v4':
114 + annotations_path = os.path.join(metadata_dir, subset, '{}-annotations-bbox.csv'.format(subset))
115 + elif version == 'challenge2018':
116 + validation_image_ids_path = os.path.join(metadata_dir, 'challenge-2018-image-ids-valset-od.csv')
117 +
118 + with open(validation_image_ids_path, 'r') as csv_file:
119 + reader = csv.DictReader(csv_file, fieldnames=['ImageID'])
120 + next(reader)
121 + for line, row in enumerate(reader):
122 + image_id = row['ImageID']
123 + validation_image_ids[image_id] = True
124 +
125 + annotations_path = os.path.join(metadata_dir, 'challenge-2018-train-annotations-bbox.csv')
126 + else:
127 + annotations_path = os.path.join(metadata_dir, subset, 'annotations-human-bbox.csv')
128 +
129 + fieldnames = ['ImageID', 'Source', 'LabelName', 'Confidence',
130 + 'XMin', 'XMax', 'YMin', 'YMax',
131 + 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction', 'IsInside']
132 +
133 + id_annotations = dict()
134 + with open(annotations_path, 'r') as csv_file:
135 + reader = csv.DictReader(csv_file, fieldnames=fieldnames)
136 + next(reader)
137 +
138 + images_sizes = {}
139 + for line, row in enumerate(reader):
140 + frame = row['ImageID']
141 +
142 + if version == 'challenge2018':
143 + if subset == 'train':
144 + if frame in validation_image_ids:
145 + continue
146 + elif subset == 'validation':
147 + if frame not in validation_image_ids:
148 + continue
149 + else:
150 + raise NotImplementedError('This generator handles only the train and validation subsets')
151 +
152 + class_name = row['LabelName']
153 +
154 + if class_name not in cls_index:
155 + continue
156 +
157 + cls_id = cls_index[class_name]
158 +
159 + if version == 'challenge2018':
160 + # We recommend participants to use the provided subset of the training set as a validation set.
161 + # This is preferable over using the V4 val/test sets, as the training set is more densely annotated.
162 + img_path = os.path.join(main_dir, 'images', 'train', frame + '.jpg')
163 + else:
164 + img_path = os.path.join(main_dir, 'images', subset, frame + '.jpg')
165 +
166 + if frame in images_sizes:
167 + width, height = images_sizes[frame]
168 + else:
169 + try:
170 + with Image.open(img_path) as img:
171 + width, height = img.width, img.height
172 + images_sizes[frame] = (width, height)
173 + except Exception as ex:
174 + if version == 'challenge2018':
175 + raise ex
176 + continue
177 +
178 + x1 = float(row['XMin'])
179 + x2 = float(row['XMax'])
180 + y1 = float(row['YMin'])
181 + y2 = float(row['YMax'])
182 +
183 + x1_int = int(round(x1 * width))
184 + x2_int = int(round(x2 * width))
185 + y1_int = int(round(y1 * height))
186 + y2_int = int(round(y2 * height))
187 +
188 + # Check that the bounding box is valid.
189 + if x2 <= x1:
190 + raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
191 + if y2 <= y1:
192 + raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
193 +
194 + if y2_int == y1_int:
195 + warnings.warn('filtering line {}: rounding y2 ({}) and y1 ({}) makes them equal'.format(line, y2, y1))
196 + continue
197 +
198 + if x2_int == x1_int:
199 + warnings.warn('filtering line {}: rounding x2 ({}) and x1 ({}) makes them equal'.format(line, x2, x1))
200 + continue
201 +
202 + img_id = row['ImageID']
203 + annotation = {'cls_id': cls_id, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2}
204 +
205 + if img_id in id_annotations:
206 + annotations = id_annotations[img_id]
207 + annotations['boxes'].append(annotation)
208 + else:
209 + id_annotations[img_id] = {'w': width, 'h': height, 'boxes': [annotation]}
210 + return id_annotations
211 +
212 +
213 +class OpenImagesGenerator(Generator):
214 + def __init__(
215 + self, main_dir, subset, version='v4',
216 + labels_filter=None, annotation_cache_dir='.',
217 + parent_label=None,
218 + **kwargs
219 + ):
220 + if version == 'challenge2018':
221 + metadata = 'challenge2018'
222 + elif version == 'v4':
223 + metadata = '2018_04'
224 + elif version == 'v3':
225 + metadata = '2017_11'
226 + else:
227 + raise NotImplementedError('There is currently no implementation for versions older than v3')
228 +
229 + if version == 'challenge2018':
230 + self.base_dir = os.path.join(main_dir, 'images', 'train')
231 + else:
232 + self.base_dir = os.path.join(main_dir, 'images', subset)
233 +
234 + metadata_dir = os.path.join(main_dir, metadata)
235 + annotation_cache_json = os.path.join(annotation_cache_dir, subset + '.json')
236 +
237 + self.hierarchy = load_hierarchy(metadata_dir, version=version)
238 + id_to_labels, cls_index = get_labels(metadata_dir, version=version)
239 +
240 + if os.path.exists(annotation_cache_json):
241 + with open(annotation_cache_json, 'r') as f:
242 + self.annotations = json.loads(f.read())
243 + else:
244 + self.annotations = generate_images_annotations_json(main_dir, metadata_dir, subset, cls_index, version=version)
245 + json.dump(self.annotations, open(annotation_cache_json, "w"))
246 +
247 + if labels_filter is not None or parent_label is not None:
248 + self.id_to_labels, self.annotations = self.__filter_data(id_to_labels, cls_index, labels_filter, parent_label)
249 + else:
250 + self.id_to_labels = id_to_labels
251 +
252 + self.id_to_image_id = dict([(i, k) for i, k in enumerate(self.annotations)])
253 +
254 + super(OpenImagesGenerator, self).__init__(**kwargs)
255 +
256 + def __filter_data(self, id_to_labels, cls_index, labels_filter=None, parent_label=None):
257 + """
258 + If you want to work with a subset of the labels just set a list with trainable labels
259 + :param labels_filter: Ex: labels_filter = ['Helmet', 'Hat', 'Analog television']
260 + :param parent_label: If parent_label is set this will bring you the parent label
261 + but also its children in the semantic hierarchy as defined in OID, ex: Animal
262 + hierarchical tree
263 + :return:
264 + """
265 +
266 + children_id_to_labels = {}
267 +
268 + if parent_label is None:
269 + # there is/are no other sublabel(s) other than the labels itself
270 +
271 + for label in labels_filter:
272 + for i, lb in id_to_labels.items():
273 + if lb == label:
274 + children_id_to_labels[i] = label
275 + break
276 + else:
277 + parent_cls = None
278 + for i, lb in id_to_labels.items():
279 + if lb == parent_label:
280 + parent_id = i
281 + for c, index in cls_index.items():
282 + if index == parent_id:
283 + parent_cls = c
284 + break
285 +
286 + if parent_cls is None:
287 + raise Exception('Couldnt find label {}'.format(parent_label))
288 +
289 + parent_tree = find_hierarchy_parent(self.hierarchy, parent_cls)
290 +
291 + if parent_tree is None:
292 + raise Exception('Couldnt find parent {} in the semantic hierarchical tree'.format(parent_label))
293 +
294 + children = load_hierarchy_children(parent_tree)
295 +
296 + for cls in children:
297 + index = cls_index[cls]
298 + label = id_to_labels[index]
299 + children_id_to_labels[index] = label
300 +
301 + id_map = dict([(ind, i) for i, ind in enumerate(children_id_to_labels.keys())])
302 +
303 + filtered_annotations = {}
304 + for k in self.annotations:
305 + img_ann = self.annotations[k]
306 +
307 + filtered_boxes = []
308 + for ann in img_ann['boxes']:
309 + cls_id = ann['cls_id']
310 + if cls_id in children_id_to_labels:
311 + ann['cls_id'] = id_map[cls_id]
312 + filtered_boxes.append(ann)
313 +
314 + if len(filtered_boxes) > 0:
315 + filtered_annotations[k] = {'w': img_ann['w'], 'h': img_ann['h'], 'boxes': filtered_boxes}
316 +
317 + children_id_to_labels = dict([(id_map[i], l) for (i, l) in children_id_to_labels.items()])
318 +
319 + return children_id_to_labels, filtered_annotations
320 +
321 + def size(self):
322 + return len(self.annotations)
323 +
324 + def num_classes(self):
325 + return len(self.id_to_labels)
326 +
327 + def has_label(self, label):
328 + """ Return True if label is a known label.
329 + """
330 + return label in self.id_to_labels
331 +
332 + def has_name(self, name):
333 + """ Returns True if name is a known class.
334 + """
335 + raise NotImplementedError()
336 +
337 + def name_to_label(self, name):
338 + raise NotImplementedError()
339 +
340 + def label_to_name(self, label):
341 + return self.id_to_labels[label]
342 +
343 + def image_aspect_ratio(self, image_index):
344 + img_annotations = self.annotations[self.id_to_image_id[image_index]]
345 + height, width = img_annotations['h'], img_annotations['w']
346 + return float(width) / float(height)
347 +
348 + def image_path(self, image_index):
349 + path = os.path.join(self.base_dir, self.id_to_image_id[image_index] + '.jpg')
350 + return path
351 +
352 + def load_image(self, image_index):
353 + return read_image_bgr(self.image_path(image_index))
354 +
355 + def load_annotations(self, image_index):
356 + image_annotations = self.annotations[self.id_to_image_id[image_index]]
357 +
358 + labels = image_annotations['boxes']
359 + height, width = image_annotations['h'], image_annotations['w']
360 +
361 + annotations = {'labels': np.empty((len(labels),)), 'bboxes': np.empty((len(labels), 4))}
362 + for idx, ann in enumerate(labels):
363 + cls_id = ann['cls_id']
364 + x1 = ann['x1'] * width
365 + x2 = ann['x2'] * width
366 + y1 = ann['y1'] * height
367 + y2 = ann['y2'] * height
368 +
369 + annotations['bboxes'][idx, 0] = x1
370 + annotations['bboxes'][idx, 1] = y1
371 + annotations['bboxes'][idx, 2] = x2
372 + annotations['bboxes'][idx, 3] = y2
373 + annotations['labels'][idx] = cls_id
374 +
375 + return annotations
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from ..preprocessing.generator import Generator
18 +from ..utils.image import read_image_bgr
19 +
20 +import os
21 +import numpy as np
22 +from six import raise_from
23 +from PIL import Image
24 +
25 +try:
26 + import xml.etree.cElementTree as ET
27 +except ImportError:
28 + import xml.etree.ElementTree as ET
29 +
30 +voc_classes = {
31 + 'aeroplane' : 0,
32 + 'bicycle' : 1,
33 + 'bird' : 2,
34 + 'boat' : 3,
35 + 'bottle' : 4,
36 + 'bus' : 5,
37 + 'car' : 6,
38 + 'cat' : 7,
39 + 'chair' : 8,
40 + 'cow' : 9,
41 + 'diningtable' : 10,
42 + 'dog' : 11,
43 + 'horse' : 12,
44 + 'motorbike' : 13,
45 + 'person' : 14,
46 + 'pottedplant' : 15,
47 + 'sheep' : 16,
48 + 'sofa' : 17,
49 + 'train' : 18,
50 + 'tvmonitor' : 19
51 +}
52 +
53 +
54 +def _findNode(parent, name, debug_name=None, parse=None):
55 + if debug_name is None:
56 + debug_name = name
57 +
58 + result = parent.find(name)
59 + if result is None:
60 + raise ValueError('missing element \'{}\''.format(debug_name))
61 + if parse is not None:
62 + try:
63 + return parse(result.text)
64 + except ValueError as e:
65 + raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None)
66 + return result
67 +
68 +
69 +class PascalVocGenerator(Generator):
70 + """ Generate data for a Pascal VOC dataset.
71 +
72 + See http://host.robots.ox.ac.uk/pascal/VOC/ for more information.
73 + """
74 +
75 + def __init__(
76 + self,
77 + data_dir,
78 + set_name,
79 + classes=voc_classes,
80 + image_extension='.jpg',
81 + skip_truncated=False,
82 + skip_difficult=False,
83 + **kwargs
84 + ):
85 + """ Initialize a Pascal VOC data generator.
86 +
87 + Args
88 + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
89 + csv_class_file: Path to the CSV classes file.
90 + """
91 + self.data_dir = data_dir
92 + self.set_name = set_name
93 + self.classes = classes
94 + self.image_names = [line.strip().split(None, 1)[0] for line in open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()]
95 + self.image_extension = image_extension
96 + self.skip_truncated = skip_truncated
97 + self.skip_difficult = skip_difficult
98 +
99 + self.labels = {}
100 + for key, value in self.classes.items():
101 + self.labels[value] = key
102 +
103 + super(PascalVocGenerator, self).__init__(**kwargs)
104 +
105 + def size(self):
106 + """ Size of the dataset.
107 + """
108 + return len(self.image_names)
109 +
110 + def num_classes(self):
111 + """ Number of classes in the dataset.
112 + """
113 + return len(self.classes)
114 +
115 + def has_label(self, label):
116 + """ Return True if label is a known label.
117 + """
118 + return label in self.labels
119 +
120 + def has_name(self, name):
121 + """ Returns True if name is a known class.
122 + """
123 + return name in self.classes
124 +
125 + def name_to_label(self, name):
126 + """ Map name to label.
127 + """
128 + return self.classes[name]
129 +
130 + def label_to_name(self, label):
131 + """ Map label to name.
132 + """
133 + return self.labels[label]
134 +
135 + def image_aspect_ratio(self, image_index):
136 + """ Compute the aspect ratio for an image with image_index.
137 + """
138 + path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
139 + image = Image.open(path)
140 + return float(image.width) / float(image.height)
141 +
142 + def image_path(self, image_index):
143 + """ Get the path to an image.
144 + """
145 + return os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
146 +
147 + def load_image(self, image_index):
148 + """ Load an image at the image_index.
149 + """
150 + return read_image_bgr(self.image_path(image_index))
151 +
152 + def __parse_annotation(self, element):
153 + """ Parse an annotation given an XML element.
154 + """
155 + truncated = _findNode(element, 'truncated', parse=int)
156 + difficult = _findNode(element, 'difficult', parse=int)
157 +
158 + class_name = _findNode(element, 'name').text
159 + if class_name not in self.classes:
160 + raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys())))
161 +
162 + box = np.zeros((4,))
163 + label = self.name_to_label(class_name)
164 +
165 + bndbox = _findNode(element, 'bndbox')
166 + box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1
167 + box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1
168 + box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1
169 + box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1
170 +
171 + return truncated, difficult, box, label
172 +
173 + def __parse_annotations(self, xml_root):
174 + """ Parse all annotations under the xml_root.
175 + """
176 + annotations = {'labels': np.empty((len(xml_root.findall('object')),)), 'bboxes': np.empty((len(xml_root.findall('object')), 4))}
177 + for i, element in enumerate(xml_root.iter('object')):
178 + try:
179 + truncated, difficult, box, label = self.__parse_annotation(element)
180 + except ValueError as e:
181 + raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None)
182 +
183 + if truncated and self.skip_truncated:
184 + continue
185 + if difficult and self.skip_difficult:
186 + continue
187 +
188 + annotations['bboxes'][i, :] = box
189 + annotations['labels'][i] = label
190 +
191 + return annotations
192 +
193 + def load_annotations(self, image_index):
194 + """ Load annotations for an image_index.
195 + """
196 + filename = self.image_names[image_index] + '.xml'
197 + try:
198 + tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename))
199 + return self.__parse_annotations(tree.getroot())
200 + except ET.ParseError as e:
201 + raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
202 + except ValueError as e:
203 + raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import numpy as np
18 +from tensorflow import keras
19 +
20 +from ..utils.compute_overlap import compute_overlap
21 +
22 +
23 +class AnchorParameters:
24 + """ The parameteres that define how anchors are generated.
25 +
26 + Args
27 + sizes : List of sizes to use. Each size corresponds to one feature level.
28 + strides : List of strides to use. Each stride correspond to one feature level.
29 + ratios : List of ratios to use per location in a feature map.
30 + scales : List of scales to use per location in a feature map.
31 + """
32 + def __init__(self, sizes, strides, ratios, scales):
33 + self.sizes = sizes
34 + self.strides = strides
35 + self.ratios = ratios
36 + self.scales = scales
37 +
38 + def num_anchors(self):
39 + return len(self.ratios) * len(self.scales)
40 +
41 +
42 +"""
43 +The default anchor parameters.
44 +"""
45 +AnchorParameters.default = AnchorParameters(
46 + sizes = [32, 64, 128, 256, 512],
47 + strides = [8, 16, 32, 64, 128],
48 + ratios = np.array([0.5, 1, 2], keras.backend.floatx()),
49 + scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)], keras.backend.floatx()),
50 +)
51 +
52 +
53 +def anchor_targets_bbox(
54 + anchors,
55 + image_group,
56 + annotations_group,
57 + num_classes,
58 + negative_overlap=0.4,
59 + positive_overlap=0.5
60 +):
61 + """ Generate anchor targets for bbox detection.
62 +
63 + Args
64 + anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2).
65 + image_group: List of BGR images.
66 + annotations_group: List of annotation dictionaries with each annotation containing 'labels' and 'bboxes' of an image.
67 + num_classes: Number of classes to predict.
68 + mask_shape: If the image is padded with zeros, mask_shape can be used to mark the relevant part of the image.
69 + negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative).
70 + positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive).
71 +
72 + Returns
73 + labels_batch: batch that contains labels & anchor states (np.array of shape (batch_size, N, num_classes + 1),
74 + where N is the number of anchors for an image and the last column defines the anchor state (-1 for ignore, 0 for bg, 1 for fg).
75 + regression_batch: batch that contains bounding-box regression targets for an image & anchor states (np.array of shape (batch_size, N, 4 + 1),
76 + where N is the number of anchors for an image, the first 4 columns define regression targets for (x1, y1, x2, y2) and the
77 + last column defines anchor states (-1 for ignore, 0 for bg, 1 for fg).
78 + """
79 +
80 + assert(len(image_group) == len(annotations_group)), "The length of the images and annotations need to be equal."
81 + assert(len(annotations_group) > 0), "No data received to compute anchor targets for."
82 + for annotations in annotations_group:
83 + assert('bboxes' in annotations), "Annotations should contain bboxes."
84 + assert('labels' in annotations), "Annotations should contain labels."
85 +
86 + batch_size = len(image_group)
87 +
88 + regression_batch = np.zeros((batch_size, anchors.shape[0], 4 + 1), dtype=keras.backend.floatx())
89 + labels_batch = np.zeros((batch_size, anchors.shape[0], num_classes + 1), dtype=keras.backend.floatx())
90 +
91 + # compute labels and regression targets
92 + for index, (image, annotations) in enumerate(zip(image_group, annotations_group)):
93 + if annotations['bboxes'].shape[0]:
94 + # obtain indices of gt annotations with the greatest overlap
95 + positive_indices, ignore_indices, argmax_overlaps_inds = compute_gt_annotations(anchors, annotations['bboxes'], negative_overlap, positive_overlap)
96 +
97 + labels_batch[index, ignore_indices, -1] = -1
98 + labels_batch[index, positive_indices, -1] = 1
99 +
100 + regression_batch[index, ignore_indices, -1] = -1
101 + regression_batch[index, positive_indices, -1] = 1
102 +
103 + # compute target class labels
104 + labels_batch[index, positive_indices, annotations['labels'][argmax_overlaps_inds[positive_indices]].astype(int)] = 1
105 +
106 + regression_batch[index, :, :-1] = bbox_transform(anchors, annotations['bboxes'][argmax_overlaps_inds, :])
107 +
108 + # ignore annotations outside of image
109 + if image.shape:
110 + anchors_centers = np.vstack([(anchors[:, 0] + anchors[:, 2]) / 2, (anchors[:, 1] + anchors[:, 3]) / 2]).T
111 + indices = np.logical_or(anchors_centers[:, 0] >= image.shape[1], anchors_centers[:, 1] >= image.shape[0])
112 +
113 + labels_batch[index, indices, -1] = -1
114 + regression_batch[index, indices, -1] = -1
115 +
116 + return regression_batch, labels_batch
117 +
118 +
119 +def compute_gt_annotations(
120 + anchors,
121 + annotations,
122 + negative_overlap=0.4,
123 + positive_overlap=0.5
124 +):
125 + """ Obtain indices of gt annotations with the greatest overlap.
126 +
127 + Args
128 + anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2).
129 + annotations: np.array of shape (N, 5) for (x1, y1, x2, y2, label).
130 + negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative).
131 + positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive).
132 +
133 + Returns
134 + positive_indices: indices of positive anchors
135 + ignore_indices: indices of ignored anchors
136 + argmax_overlaps_inds: ordered overlaps indices
137 + """
138 +
139 + overlaps = compute_overlap(anchors.astype(np.float64), annotations.astype(np.float64))
140 + argmax_overlaps_inds = np.argmax(overlaps, axis=1)
141 + max_overlaps = overlaps[np.arange(overlaps.shape[0]), argmax_overlaps_inds]
142 +
143 + # assign "dont care" labels
144 + positive_indices = max_overlaps >= positive_overlap
145 + ignore_indices = (max_overlaps > negative_overlap) & ~positive_indices
146 +
147 + return positive_indices, ignore_indices, argmax_overlaps_inds
148 +
149 +
150 +def layer_shapes(image_shape, model):
151 + """Compute layer shapes given input image shape and the model.
152 +
153 + Args
154 + image_shape: The shape of the image.
155 + model: The model to use for computing how the image shape is transformed in the pyramid.
156 +
157 + Returns
158 + A dictionary mapping layer names to image shapes.
159 + """
160 + shape = {
161 + model.layers[0].name: (None,) + image_shape,
162 + }
163 +
164 + for layer in model.layers[1:]:
165 + nodes = layer._inbound_nodes
166 + for node in nodes:
167 + if isinstance(node.inbound_layers, keras.layers.Layer):
168 + inputs = [shape[node.inbound_layers.name]]
169 + else:
170 + inputs = [shape[lr.name] for lr in node.inbound_layers]
171 + if not inputs:
172 + continue
173 + shape[layer.name] = layer.compute_output_shape(inputs[0] if len(inputs) == 1 else inputs)
174 +
175 + return shape
176 +
177 +
178 +def make_shapes_callback(model):
179 + """ Make a function for getting the shape of the pyramid levels.
180 + """
181 + def get_shapes(image_shape, pyramid_levels):
182 + shape = layer_shapes(image_shape, model)
183 + image_shapes = [shape["P{}".format(level)][1:3] for level in pyramid_levels]
184 + return image_shapes
185 +
186 + return get_shapes
187 +
188 +
189 +def guess_shapes(image_shape, pyramid_levels):
190 + """Guess shapes based on pyramid levels.
191 +
192 + Args
193 + image_shape: The shape of the image.
194 + pyramid_levels: A list of what pyramid levels are used.
195 +
196 + Returns
197 + A list of image shapes at each pyramid level.
198 + """
199 + image_shape = np.array(image_shape[:2])
200 + image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
201 + return image_shapes
202 +
203 +
204 +def anchors_for_shape(
205 + image_shape,
206 + pyramid_levels=None,
207 + anchor_params=None,
208 + shapes_callback=None,
209 +):
210 + """ Generators anchors for a given shape.
211 +
212 + Args
213 + image_shape: The shape of the image.
214 + pyramid_levels: List of ints representing which pyramids to use (defaults to [3, 4, 5, 6, 7]).
215 + anchor_params: Struct containing anchor parameters. If None, default values are used.
216 + shapes_callback: Function to call for getting the shape of the image at different pyramid levels.
217 +
218 + Returns
219 + np.array of shape (N, 4) containing the (x1, y1, x2, y2) coordinates for the anchors.
220 + """
221 +
222 + if pyramid_levels is None:
223 + pyramid_levels = [3, 4, 5, 6, 7]
224 +
225 + if anchor_params is None:
226 + anchor_params = AnchorParameters.default
227 +
228 + if shapes_callback is None:
229 + shapes_callback = guess_shapes
230 + image_shapes = shapes_callback(image_shape, pyramid_levels)
231 +
232 + # compute anchors over all pyramid levels
233 + all_anchors = np.zeros((0, 4))
234 + for idx, p in enumerate(pyramid_levels):
235 + anchors = generate_anchors(
236 + base_size=anchor_params.sizes[idx],
237 + ratios=anchor_params.ratios,
238 + scales=anchor_params.scales
239 + )
240 + shifted_anchors = shift(image_shapes[idx], anchor_params.strides[idx], anchors)
241 + all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
242 +
243 + return all_anchors
244 +
245 +
246 +def shift(shape, stride, anchors):
247 + """ Produce shifted anchors based on shape of the map and stride size.
248 +
249 + Args
250 + shape : Shape to shift the anchors over.
251 + stride : Stride to shift the anchors with over the shape.
252 + anchors: The anchors to apply at each location.
253 + """
254 +
255 + # create a grid starting from half stride from the top left corner
256 + shift_x = (np.arange(0, shape[1]) + 0.5) * stride
257 + shift_y = (np.arange(0, shape[0]) + 0.5) * stride
258 +
259 + shift_x, shift_y = np.meshgrid(shift_x, shift_y)
260 +
261 + shifts = np.vstack((
262 + shift_x.ravel(), shift_y.ravel(),
263 + shift_x.ravel(), shift_y.ravel()
264 + )).transpose()
265 +
266 + # add A anchors (1, A, 4) to
267 + # cell K shifts (K, 1, 4) to get
268 + # shift anchors (K, A, 4)
269 + # reshape to (K*A, 4) shifted anchors
270 + A = anchors.shape[0]
271 + K = shifts.shape[0]
272 + all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
273 + all_anchors = all_anchors.reshape((K * A, 4))
274 +
275 + return all_anchors
276 +
277 +
278 +def generate_anchors(base_size=16, ratios=None, scales=None):
279 + """
280 + Generate anchor (reference) windows by enumerating aspect ratios X
281 + scales w.r.t. a reference window.
282 + """
283 +
284 + if ratios is None:
285 + ratios = AnchorParameters.default.ratios
286 +
287 + if scales is None:
288 + scales = AnchorParameters.default.scales
289 +
290 + num_anchors = len(ratios) * len(scales)
291 +
292 + # initialize output anchors
293 + anchors = np.zeros((num_anchors, 4))
294 +
295 + # scale base_size
296 + anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T
297 +
298 + # compute areas of anchors
299 + areas = anchors[:, 2] * anchors[:, 3]
300 +
301 + # correct for ratios
302 + anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
303 + anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))
304 +
305 + # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
306 + anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
307 + anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T
308 +
309 + return anchors
310 +
311 +
312 +def bbox_transform(anchors, gt_boxes, mean=None, std=None):
313 + """Compute bounding-box regression targets for an image."""
314 +
315 + # The Mean and std are calculated from COCO dataset.
316 + # Bounding box normalization was firstly introduced in the Fast R-CNN paper.
317 + # See https://github.com/fizyr/keras-retinanet/issues/1273#issuecomment-585828825 for more details
318 + if mean is None:
319 + mean = np.array([0, 0, 0, 0])
320 + if std is None:
321 + std = np.array([0.2, 0.2, 0.2, 0.2])
322 +
323 + if isinstance(mean, (list, tuple)):
324 + mean = np.array(mean)
325 + elif not isinstance(mean, np.ndarray):
326 + raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean)))
327 +
328 + if isinstance(std, (list, tuple)):
329 + std = np.array(std)
330 + elif not isinstance(std, np.ndarray):
331 + raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std)))
332 +
333 + anchor_widths = anchors[:, 2] - anchors[:, 0]
334 + anchor_heights = anchors[:, 3] - anchors[:, 1]
335 +
336 + # According to the information provided by a keras-retinanet author, they got marginally better results using
337 + # the following way of bounding box parametrization.
338 + # See https://github.com/fizyr/keras-retinanet/issues/1273#issuecomment-585828825 for more details
339 + targets_dx1 = (gt_boxes[:, 0] - anchors[:, 0]) / anchor_widths
340 + targets_dy1 = (gt_boxes[:, 1] - anchors[:, 1]) / anchor_heights
341 + targets_dx2 = (gt_boxes[:, 2] - anchors[:, 2]) / anchor_widths
342 + targets_dy2 = (gt_boxes[:, 3] - anchors[:, 3]) / anchor_heights
343 +
344 + targets = np.stack((targets_dx1, targets_dy1, targets_dx2, targets_dy2))
345 + targets = targets.T
346 +
347 + targets = (targets - mean) / std
348 +
349 + return targets
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from pycocotools.cocoeval import COCOeval
18 +
19 +from tensorflow import keras
20 +import numpy as np
21 +import json
22 +
23 +import progressbar
24 +assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead."
25 +
26 +
27 +def evaluate_coco(generator, model, threshold=0.05):
28 + """ Use the pycocotools to evaluate a COCO model on a dataset.
29 +
30 + Args
31 + generator : The generator for generating the evaluation data.
32 + model : The model to evaluate.
33 + threshold : The score threshold to use.
34 + """
35 + # start collecting results
36 + results = []
37 + image_ids = []
38 + for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '):
39 + image = generator.load_image(index)
40 + image = generator.preprocess_image(image)
41 + image, scale = generator.resize_image(image)
42 +
43 + if keras.backend.image_data_format() == 'channels_first':
44 + image = image.transpose((2, 0, 1))
45 +
46 + # run network
47 + boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
48 +
49 + # correct boxes for image scale
50 + boxes /= scale
51 +
52 + # change to (x, y, w, h) (MS COCO standard)
53 + boxes[:, :, 2] -= boxes[:, :, 0]
54 + boxes[:, :, 3] -= boxes[:, :, 1]
55 +
56 + # compute predicted labels and scores
57 + for box, score, label in zip(boxes[0], scores[0], labels[0]):
58 + # scores are sorted, so we can break
59 + if score < threshold:
60 + break
61 +
62 + # append detection for each positively labeled class
63 + image_result = {
64 + 'image_id' : generator.image_ids[index],
65 + 'category_id' : generator.label_to_coco_label(label),
66 + 'score' : float(score),
67 + 'bbox' : box.tolist(),
68 + }
69 +
70 + # append detection to results
71 + results.append(image_result)
72 +
73 + # append image to list of processed images
74 + image_ids.append(generator.image_ids[index])
75 +
76 + if not len(results):
77 + return
78 +
79 + # write output
80 + json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4)
81 + json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4)
82 +
83 + # load results in COCO evaluation tool
84 + coco_true = generator.coco
85 + coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name))
86 +
87 + # run COCO evaluation
88 + coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
89 + coco_eval.params.imgIds = image_ids
90 + coco_eval.evaluate()
91 + coco_eval.accumulate()
92 + coco_eval.summarize()
93 + return coco_eval.stats
1 +import warnings
2 +
3 +
4 +def label_color(label):
5 + """ Return a color from a set of predefined colors. Contains 80 colors in total.
6 +
7 + Args
8 + label: The label to get the color for.
9 +
10 + Returns
11 + A list of three values representing a RGB color.
12 +
13 + If no color is defined for a certain label, the color green is returned and a warning is printed.
14 + """
15 + if label < len(colors):
16 + return colors[label]
17 + else:
18 + warnings.warn('Label {} has no color, returning default.'.format(label))
19 + return (0, 255, 0)
20 +
21 +
22 +"""
23 +Generated using:
24 +
25 +```
26 +colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)]
27 +shuffle(colors)
28 +pprint(colors)
29 +```
30 +"""
31 +colors = [
32 + [31 , 0 , 255] ,
33 + [0 , 159 , 255] ,
34 + [255 , 95 , 0] ,
35 + [255 , 19 , 0] ,
36 + [255 , 0 , 0] ,
37 + [255 , 38 , 0] ,
38 + [0 , 255 , 25] ,
39 + [255 , 0 , 133] ,
40 + [255 , 172 , 0] ,
41 + [108 , 0 , 255] ,
42 + [0 , 82 , 255] ,
43 + [0 , 255 , 6] ,
44 + [255 , 0 , 152] ,
45 + [223 , 0 , 255] ,
46 + [12 , 0 , 255] ,
47 + [0 , 255 , 178] ,
48 + [108 , 255 , 0] ,
49 + [184 , 0 , 255] ,
50 + [255 , 0 , 76] ,
51 + [146 , 255 , 0] ,
52 + [51 , 0 , 255] ,
53 + [0 , 197 , 255] ,
54 + [255 , 248 , 0] ,
55 + [255 , 0 , 19] ,
56 + [255 , 0 , 38] ,
57 + [89 , 255 , 0] ,
58 + [127 , 255 , 0] ,
59 + [255 , 153 , 0] ,
60 + [0 , 255 , 255] ,
61 + [0 , 255 , 216] ,
62 + [0 , 255 , 121] ,
63 + [255 , 0 , 248] ,
64 + [70 , 0 , 255] ,
65 + [0 , 255 , 159] ,
66 + [0 , 216 , 255] ,
67 + [0 , 6 , 255] ,
68 + [0 , 63 , 255] ,
69 + [31 , 255 , 0] ,
70 + [255 , 57 , 0] ,
71 + [255 , 0 , 210] ,
72 + [0 , 255 , 102] ,
73 + [242 , 255 , 0] ,
74 + [255 , 191 , 0] ,
75 + [0 , 255 , 63] ,
76 + [255 , 0 , 95] ,
77 + [146 , 0 , 255] ,
78 + [184 , 255 , 0] ,
79 + [255 , 114 , 0] ,
80 + [0 , 255 , 235] ,
81 + [255 , 229 , 0] ,
82 + [0 , 178 , 255] ,
83 + [255 , 0 , 114] ,
84 + [255 , 0 , 57] ,
85 + [0 , 140 , 255] ,
86 + [0 , 121 , 255] ,
87 + [12 , 255 , 0] ,
88 + [255 , 210 , 0] ,
89 + [0 , 255 , 44] ,
90 + [165 , 255 , 0] ,
91 + [0 , 25 , 255] ,
92 + [0 , 255 , 140] ,
93 + [0 , 101 , 255] ,
94 + [0 , 255 , 82] ,
95 + [223 , 255 , 0] ,
96 + [242 , 0 , 255] ,
97 + [89 , 0 , 255] ,
98 + [165 , 0 , 255] ,
99 + [70 , 255 , 0] ,
100 + [255 , 0 , 172] ,
101 + [255 , 76 , 0] ,
102 + [203 , 255 , 0] ,
103 + [204 , 0 , 255] ,
104 + [255 , 0 , 229] ,
105 + [255 , 133 , 0] ,
106 + [127 , 0 , 255] ,
107 + [0 , 235 , 255] ,
108 + [0 , 255 , 197] ,
109 + [255 , 0 , 191] ,
110 + [0 , 44 , 255] ,
111 + [50 , 255 , 0]
112 +]
This diff could not be displayed because it is too large.
1 +# --------------------------------------------------------
2 +# Fast R-CNN
3 +# Copyright (c) 2015 Microsoft
4 +# Licensed under The MIT License [see LICENSE for details]
5 +# Written by Sergey Karayev
6 +# --------------------------------------------------------
7 +
8 +cimport cython
9 +import numpy as np
10 +cimport numpy as np
11 +
12 +
13 +def compute_overlap(
14 + np.ndarray[double, ndim=2] boxes,
15 + np.ndarray[double, ndim=2] query_boxes
16 +):
17 + """
18 + Args
19 + a: (N, 4) ndarray of float
20 + b: (K, 4) ndarray of float
21 +
22 + Returns
23 + overlaps: (N, K) ndarray of overlap between boxes and query_boxes
24 + """
25 + cdef unsigned int N = boxes.shape[0]
26 + cdef unsigned int K = query_boxes.shape[0]
27 + cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64)
28 + cdef double iw, ih, box_area
29 + cdef double ua
30 + cdef unsigned int k, n
31 + for k in range(K):
32 + box_area = (
33 + (query_boxes[k, 2] - query_boxes[k, 0]) *
34 + (query_boxes[k, 3] - query_boxes[k, 1])
35 + )
36 + for n in range(N):
37 + iw = (
38 + min(boxes[n, 2], query_boxes[k, 2]) -
39 + max(boxes[n, 0], query_boxes[k, 0])
40 + )
41 + if iw > 0:
42 + ih = (
43 + min(boxes[n, 3], query_boxes[k, 3]) -
44 + max(boxes[n, 1], query_boxes[k, 1])
45 + )
46 + if ih > 0:
47 + ua = np.float64(
48 + (boxes[n, 2] - boxes[n, 0]) *
49 + (boxes[n, 3] - boxes[n, 1]) +
50 + box_area - iw * ih
51 + )
52 + overlaps[n, k] = iw * ih / ua
53 + return overlaps
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import configparser
18 +import numpy as np
19 +from tensorflow import keras
20 +from ..utils.anchors import AnchorParameters
21 +
22 +
23 +def read_config_file(config_path):
24 + config = configparser.ConfigParser()
25 +
26 + with open(config_path, 'r') as file:
27 + config.read_file(file)
28 +
29 + assert 'anchor_parameters' in config, \
30 + "Malformed config file. Verify that it contains the anchor_parameters section."
31 +
32 + config_keys = set(config['anchor_parameters'])
33 + default_keys = set(AnchorParameters.default.__dict__.keys())
34 +
35 + assert config_keys <= default_keys, \
36 + "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys)
37 +
38 + if 'pyramid_levels' in config:
39 + assert('levels' in config['pyramid_levels']), "pyramid levels specified by levels key"
40 +
41 + return config
42 +
43 +
44 +def parse_anchor_parameters(config):
45 + ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx())
46 + scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx())
47 + sizes = list(map(int, config['anchor_parameters']['sizes'].split(' ')))
48 + strides = list(map(int, config['anchor_parameters']['strides'].split(' ')))
49 + assert (len(sizes) == len(strides)), "sizes and strides should have an equal number of values"
50 +
51 + return AnchorParameters(sizes, strides, ratios, scales)
52 +
53 +
54 +def parse_pyramid_levels(config):
55 + levels = list(map(int, config['pyramid_levels']['levels'].split(' ')))
56 +
57 + return levels
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from .anchors import compute_overlap
18 +from .visualization import draw_detections, draw_annotations
19 +
20 +from tensorflow import keras
21 +import numpy as np
22 +import os
23 +import time
24 +
25 +import cv2
26 +import progressbar
27 +assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead."
28 +
29 +
30 +def _compute_ap(recall, precision):
31 + """ Compute the average precision, given the recall and precision curves.
32 +
33 + Code originally from https://github.com/rbgirshick/py-faster-rcnn.
34 +
35 + # Arguments
36 + recall: The recall curve (list).
37 + precision: The precision curve (list).
38 + # Returns
39 + The average precision as computed in py-faster-rcnn.
40 + """
41 + # correct AP calculation
42 + # first append sentinel values at the end
43 + mrec = np.concatenate(([0.], recall, [1.]))
44 + mpre = np.concatenate(([0.], precision, [0.]))
45 +
46 + # compute the precision envelope
47 + for i in range(mpre.size - 1, 0, -1):
48 + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
49 +
50 + # to calculate area under PR curve, look for points
51 + # where X axis (recall) changes value
52 + i = np.where(mrec[1:] != mrec[:-1])[0]
53 +
54 + # and sum (\Delta recall) * prec
55 + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
56 + return ap
57 +
58 +
59 +def _get_detections(generator, model, score_threshold=0.05, max_detections=100, save_path=None):
60 + """ Get the detections from the model using the generator.
61 +
62 + The result is a list of lists such that the size is:
63 + all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes]
64 +
65 + # Arguments
66 + generator : The generator used to run images through the model.
67 + model : The model to run on the images.
68 + score_threshold : The score confidence threshold to use.
69 + max_detections : The maximum number of detections to use per image.
70 + save_path : The path to save the images with visualized detections to.
71 + # Returns
72 + A list of lists containing the detections for each image in the generator.
73 + """
74 + all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())]
75 + all_inferences = [None for i in range(generator.size())]
76 +
77 + for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '):
78 + raw_image = generator.load_image(i)
79 + image, scale = generator.resize_image(raw_image.copy())
80 + image = generator.preprocess_image(image)
81 +
82 + if keras.backend.image_data_format() == 'channels_first':
83 + image = image.transpose((2, 0, 1))
84 +
85 + # run network
86 + start = time.time()
87 + boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3]
88 + inference_time = time.time() - start
89 +
90 + # correct boxes for image scale
91 + boxes /= scale
92 +
93 + # select indices which have a score above the threshold
94 + indices = np.where(scores[0, :] > score_threshold)[0]
95 +
96 + # select those scores
97 + scores = scores[0][indices]
98 +
99 + # find the order with which to sort the scores
100 + scores_sort = np.argsort(-scores)[:max_detections]
101 +
102 + # select detections
103 + image_boxes = boxes[0, indices[scores_sort], :]
104 + image_scores = scores[scores_sort]
105 + image_labels = labels[0, indices[scores_sort]]
106 + image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1)
107 +
108 + if save_path is not None:
109 + draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name)
110 + draw_detections(raw_image, image_boxes, image_scores, image_labels, label_to_name=generator.label_to_name, score_threshold=score_threshold)
111 +
112 + cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image)
113 +
114 + # copy detections to all_detections
115 + for label in range(generator.num_classes()):
116 + if not generator.has_label(label):
117 + continue
118 +
119 + all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1]
120 +
121 + all_inferences[i] = inference_time
122 +
123 + return all_detections, all_inferences
124 +
125 +
126 +def _get_annotations(generator):
127 + """ Get the ground truth annotations from the generator.
128 +
129 + The result is a list of lists such that the size is:
130 + all_detections[num_images][num_classes] = annotations[num_detections, 5]
131 +
132 + # Arguments
133 + generator : The generator used to retrieve ground truth annotations.
134 + # Returns
135 + A list of lists containing the annotations for each image in the generator.
136 + """
137 + all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())]
138 +
139 + for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '):
140 + # load the annotations
141 + annotations = generator.load_annotations(i)
142 +
143 + # copy detections to all_annotations
144 + for label in range(generator.num_classes()):
145 + if not generator.has_label(label):
146 + continue
147 +
148 + all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy()
149 +
150 + return all_annotations
151 +
152 +
153 +def evaluate(
154 + generator,
155 + model,
156 + iou_threshold=0.5,
157 + score_threshold=0.05,
158 + max_detections=100,
159 + save_path=None
160 +):
161 + """ Evaluate a given dataset using a given model.
162 +
163 + # Arguments
164 + generator : The generator that represents the dataset to evaluate.
165 + model : The model to evaluate.
166 + iou_threshold : The threshold used to consider when a detection is positive or negative.
167 + score_threshold : The score confidence threshold to use for detections.
168 + max_detections : The maximum number of detections to use per image.
169 + save_path : The path to save images with visualized detections to.
170 + # Returns
171 + A dict mapping class names to mAP scores.
172 + """
173 + # gather all detections and annotations
174 + all_detections, all_inferences = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path)
175 + all_annotations = _get_annotations(generator)
176 + average_precisions = {}
177 +
178 + # all_detections = pickle.load(open('all_detections.pkl', 'rb'))
179 + # all_annotations = pickle.load(open('all_annotations.pkl', 'rb'))
180 + # pickle.dump(all_detections, open('all_detections.pkl', 'wb'))
181 + # pickle.dump(all_annotations, open('all_annotations.pkl', 'wb'))
182 +
183 + # process detections and annotations
184 + for label in range(generator.num_classes()):
185 + if not generator.has_label(label):
186 + continue
187 +
188 + false_positives = np.zeros((0,))
189 + true_positives = np.zeros((0,))
190 + scores = np.zeros((0,))
191 + num_annotations = 0.0
192 +
193 + for i in range(generator.size()):
194 + detections = all_detections[i][label]
195 + annotations = all_annotations[i][label]
196 + num_annotations += annotations.shape[0]
197 + detected_annotations = []
198 +
199 + for d in detections:
200 + scores = np.append(scores, d[4])
201 +
202 + if annotations.shape[0] == 0:
203 + false_positives = np.append(false_positives, 1)
204 + true_positives = np.append(true_positives, 0)
205 + continue
206 +
207 + overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)
208 + assigned_annotation = np.argmax(overlaps, axis=1)
209 + max_overlap = overlaps[0, assigned_annotation]
210 +
211 + if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:
212 + false_positives = np.append(false_positives, 0)
213 + true_positives = np.append(true_positives, 1)
214 + detected_annotations.append(assigned_annotation)
215 + else:
216 + false_positives = np.append(false_positives, 1)
217 + true_positives = np.append(true_positives, 0)
218 +
219 + # no annotations -> AP for this class is 0 (is this correct?)
220 + if num_annotations == 0:
221 + average_precisions[label] = 0, 0
222 + continue
223 +
224 + # sort by score
225 + indices = np.argsort(-scores)
226 + false_positives = false_positives[indices]
227 + true_positives = true_positives[indices]
228 +
229 + # compute false positives and true positives
230 + false_positives = np.cumsum(false_positives)
231 + true_positives = np.cumsum(true_positives)
232 +
233 + # compute recall and precision
234 + recall = true_positives / num_annotations
235 + precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
236 +
237 + # compute average precision
238 + average_precision = _compute_ap(recall, precision)
239 + average_precisions[label] = average_precision, num_annotations
240 +
241 + # inference time
242 + inference_time = np.sum(all_inferences) / generator.size()
243 +
244 + return average_precisions, inference_time
1 +"""
2 +Copyright 2017-2019 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import tensorflow as tf
18 +
19 +
20 +def setup_gpu(gpu_id):
21 + try:
22 + visible_gpu_indices = [int(id) for id in gpu_id.split(',')]
23 + available_gpus = tf.config.list_physical_devices('GPU')
24 + visible_gpus = [gpu for idx, gpu in enumerate(available_gpus) if idx in visible_gpu_indices]
25 +
26 + if visible_gpus:
27 + try:
28 + # Currently, memory growth needs to be the same across GPUs.
29 + for gpu in available_gpus:
30 + tf.config.experimental.set_memory_growth(gpu, True)
31 +
32 + # Use only the selcted gpu.
33 + tf.config.set_visible_devices(visible_gpus, 'GPU')
34 + except RuntimeError as e:
35 + # Visible devices must be set before GPUs have been initialized.
36 + print(e)
37 +
38 + logical_gpus = tf.config.list_logical_devices('GPU')
39 + print(len(available_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
40 + else:
41 + tf.config.set_visible_devices([], 'GPU')
42 + except ValueError:
43 + tf.config.set_visible_devices([], 'GPU')
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from __future__ import division
18 +import numpy as np
19 +import cv2
20 +from PIL import Image
21 +
22 +from .transform import change_transform_origin
23 +
24 +
25 +def read_image_bgr(path):
26 + """ Read an image in BGR format.
27 +
28 + Args
29 + path: Path to the image.
30 + """
31 + # We deliberately don't use cv2.imread here, since it gives no feedback on errors while reading the image.
32 + image = np.ascontiguousarray(Image.open(path).convert('RGB'))
33 + return image[:, :, ::-1]
34 +
35 +
36 +def preprocess_image(x, mode='caffe'):
37 + """ Preprocess an image by subtracting the ImageNet mean.
38 +
39 + Args
40 + x: np.array of shape (None, None, 3) or (3, None, None).
41 + mode: One of "caffe" or "tf".
42 + - caffe: will zero-center each color channel with
43 + respect to the ImageNet dataset, without scaling.
44 + - tf: will scale pixels between -1 and 1, sample-wise.
45 +
46 + Returns
47 + The input with the ImageNet mean subtracted.
48 + """
49 + # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py"
50 + # except for converting RGB -> BGR since we assume BGR already
51 +
52 + # covert always to float32 to keep compatibility with opencv
53 + x = x.astype(np.float32)
54 +
55 + if mode == 'tf':
56 + x /= 127.5
57 + x -= 1.
58 + elif mode == 'caffe':
59 + x -= [103.939, 116.779, 123.68]
60 +
61 + return x
62 +
63 +
64 +def adjust_transform_for_image(transform, image, relative_translation):
65 + """ Adjust a transformation for a specific image.
66 +
67 + The translation of the matrix will be scaled with the size of the image.
68 + The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image.
69 + """
70 + height, width, channels = image.shape
71 +
72 + result = transform
73 +
74 + # Scale the translation with the image size if specified.
75 + if relative_translation:
76 + result[0:2, 2] *= [width, height]
77 +
78 + # Move the origin of transformation.
79 + result = change_transform_origin(transform, (0.5 * width, 0.5 * height))
80 +
81 + return result
82 +
83 +
84 +class TransformParameters:
85 + """ Struct holding parameters determining how to apply a transformation to an image.
86 +
87 + Args
88 + fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap'
89 + interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4'
90 + cval: Fill value to use with fill_mode='constant'
91 + relative_translation: If true (the default), interpret translation as a factor of the image size.
92 + If false, interpret it as absolute pixels.
93 + """
94 + def __init__(
95 + self,
96 + fill_mode = 'nearest',
97 + interpolation = 'linear',
98 + cval = 0,
99 + relative_translation = True,
100 + ):
101 + self.fill_mode = fill_mode
102 + self.cval = cval
103 + self.interpolation = interpolation
104 + self.relative_translation = relative_translation
105 +
106 + def cvBorderMode(self):
107 + if self.fill_mode == 'constant':
108 + return cv2.BORDER_CONSTANT
109 + if self.fill_mode == 'nearest':
110 + return cv2.BORDER_REPLICATE
111 + if self.fill_mode == 'reflect':
112 + return cv2.BORDER_REFLECT_101
113 + if self.fill_mode == 'wrap':
114 + return cv2.BORDER_WRAP
115 +
116 + def cvInterpolation(self):
117 + if self.interpolation == 'nearest':
118 + return cv2.INTER_NEAREST
119 + if self.interpolation == 'linear':
120 + return cv2.INTER_LINEAR
121 + if self.interpolation == 'cubic':
122 + return cv2.INTER_CUBIC
123 + if self.interpolation == 'area':
124 + return cv2.INTER_AREA
125 + if self.interpolation == 'lanczos4':
126 + return cv2.INTER_LANCZOS4
127 +
128 +
129 +def apply_transform(matrix, image, params):
130 + """
131 + Apply a transformation to an image.
132 +
133 + The origin of transformation is at the top left corner of the image.
134 +
135 + The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image.
136 + Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space.
137 +
138 + Args
139 + matrix: A homogeneous 3 by 3 matrix holding representing the transformation to apply.
140 + image: The image to transform.
141 + params: The transform parameters (see TransformParameters)
142 + """
143 + output = cv2.warpAffine(
144 + image,
145 + matrix[:2, :],
146 + dsize = (image.shape[1], image.shape[0]),
147 + flags = params.cvInterpolation(),
148 + borderMode = params.cvBorderMode(),
149 + borderValue = params.cval,
150 + )
151 + return output
152 +
153 +
154 +def compute_resize_scale(image_shape, min_side=800, max_side=1333):
155 + """ Compute an image scale such that the image size is constrained to min_side and max_side.
156 +
157 + Args
158 + min_side: The image's min side will be equal to min_side after resizing.
159 + max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side.
160 +
161 + Returns
162 + A resizing scale.
163 + """
164 + (rows, cols, _) = image_shape
165 +
166 + smallest_side = min(rows, cols)
167 +
168 + # rescale the image so the smallest side is min_side
169 + scale = min_side / smallest_side
170 +
171 + # check if the largest side is now greater than max_side, which can happen
172 + # when images have a large aspect ratio
173 + largest_side = max(rows, cols)
174 + if largest_side * scale > max_side:
175 + scale = max_side / largest_side
176 +
177 + return scale
178 +
179 +
180 +def resize_image(img, min_side=800, max_side=1333):
181 + """ Resize an image such that the size is constrained to min_side and max_side.
182 +
183 + Args
184 + min_side: The image's min side will be equal to min_side after resizing.
185 + max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side.
186 +
187 + Returns
188 + A resized image.
189 + """
190 + # compute scale to resize the image
191 + scale = compute_resize_scale(img.shape, min_side=min_side, max_side=max_side)
192 +
193 + # resize the image with the computed scale
194 + img = cv2.resize(img, None, fx=scale, fy=scale)
195 +
196 + return img, scale
197 +
198 +
199 +def _uniform(val_range):
200 + """ Uniformly sample from the given range.
201 +
202 + Args
203 + val_range: A pair of lower and upper bound.
204 + """
205 + return np.random.uniform(val_range[0], val_range[1])
206 +
207 +
208 +def _check_range(val_range, min_val=None, max_val=None):
209 + """ Check whether the range is a valid range.
210 +
211 + Args
212 + val_range: A pair of lower and upper bound.
213 + min_val: Minimal value for the lower bound.
214 + max_val: Maximal value for the upper bound.
215 + """
216 + if val_range[0] > val_range[1]:
217 + raise ValueError('interval lower bound > upper bound')
218 + if min_val is not None and val_range[0] < min_val:
219 + raise ValueError('invalid interval lower bound')
220 + if max_val is not None and val_range[1] > max_val:
221 + raise ValueError('invalid interval upper bound')
222 +
223 +
224 +def _clip(image):
225 + """
226 + Clip and convert an image to np.uint8.
227 +
228 + Args
229 + image: Image to clip.
230 + """
231 + return np.clip(image, 0, 255).astype(np.uint8)
232 +
233 +
234 +class VisualEffect:
235 + """ Struct holding parameters and applying image color transformation.
236 +
237 + Args
238 + contrast_factor: A factor for adjusting contrast. Should be between 0 and 3.
239 + brightness_delta: Brightness offset between -1 and 1 added to the pixel values.
240 + hue_delta: Hue offset between -1 and 1 added to the hue channel.
241 + saturation_factor: A factor multiplying the saturation values of each pixel.
242 + """
243 +
244 + def __init__(
245 + self,
246 + contrast_factor,
247 + brightness_delta,
248 + hue_delta,
249 + saturation_factor,
250 + ):
251 + self.contrast_factor = contrast_factor
252 + self.brightness_delta = brightness_delta
253 + self.hue_delta = hue_delta
254 + self.saturation_factor = saturation_factor
255 +
256 + def __call__(self, image):
257 + """ Apply a visual effect on the image.
258 +
259 + Args
260 + image: Image to adjust
261 + """
262 +
263 + if self.contrast_factor:
264 + image = adjust_contrast(image, self.contrast_factor)
265 + if self.brightness_delta:
266 + image = adjust_brightness(image, self.brightness_delta)
267 +
268 + if self.hue_delta or self.saturation_factor:
269 +
270 + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
271 +
272 + if self.hue_delta:
273 + image = adjust_hue(image, self.hue_delta)
274 + if self.saturation_factor:
275 + image = adjust_saturation(image, self.saturation_factor)
276 +
277 + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
278 +
279 + return image
280 +
281 +
282 +def random_visual_effect_generator(
283 + contrast_range=(0.9, 1.1),
284 + brightness_range=(-.1, .1),
285 + hue_range=(-0.05, 0.05),
286 + saturation_range=(0.95, 1.05)
287 +):
288 + """ Generate visual effect parameters uniformly sampled from the given intervals.
289 +
290 + Args
291 + contrast_factor: A factor interval for adjusting contrast. Should be between 0 and 3.
292 + brightness_delta: An interval between -1 and 1 for the amount added to the pixels.
293 + hue_delta: An interval between -1 and 1 for the amount added to the hue channel.
294 + The values are rotated if they exceed 180.
295 + saturation_factor: An interval for the factor multiplying the saturation values of each
296 + pixel.
297 + """
298 + _check_range(contrast_range, 0)
299 + _check_range(brightness_range, -1, 1)
300 + _check_range(hue_range, -1, 1)
301 + _check_range(saturation_range, 0)
302 +
303 + def _generate():
304 + while True:
305 + yield VisualEffect(
306 + contrast_factor=_uniform(contrast_range),
307 + brightness_delta=_uniform(brightness_range),
308 + hue_delta=_uniform(hue_range),
309 + saturation_factor=_uniform(saturation_range),
310 + )
311 +
312 + return _generate()
313 +
314 +
315 +def adjust_contrast(image, factor):
316 + """ Adjust contrast of an image.
317 +
318 + Args
319 + image: Image to adjust.
320 + factor: A factor for adjusting contrast.
321 + """
322 + mean = image.mean(axis=0).mean(axis=0)
323 + return _clip((image - mean) * factor + mean)
324 +
325 +
326 +def adjust_brightness(image, delta):
327 + """ Adjust brightness of an image
328 +
329 + Args
330 + image: Image to adjust.
331 + delta: Brightness offset between -1 and 1 added to the pixel values.
332 + """
333 + return _clip(image + delta * 255)
334 +
335 +
336 +def adjust_hue(image, delta):
337 + """ Adjust hue of an image.
338 +
339 + Args
340 + image: Image to adjust.
341 + delta: An interval between -1 and 1 for the amount added to the hue channel.
342 + The values are rotated if they exceed 180.
343 + """
344 + image[..., 0] = np.mod(image[..., 0] + delta * 180, 180)
345 + return image
346 +
347 +
348 +def adjust_saturation(image, factor):
349 + """ Adjust saturation of an image.
350 +
351 + Args
352 + image: Image to adjust.
353 + factor: An interval for the factor multiplying the saturation values of each pixel.
354 + """
355 + image[..., 1] = np.clip(image[..., 1] * factor, 0 , 255)
356 + return image
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +
18 +def freeze(model):
19 + """ Set all layers in a model to non-trainable.
20 +
21 + The weights for these layers will not be updated during training.
22 +
23 + This function modifies the given model in-place,
24 + but it also returns the modified model to allow easy chaining with other functions.
25 + """
26 + for layer in model.layers:
27 + layer.trainable = False
28 + return model
1 +"""
2 +Copyright 2017-2019 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from __future__ import print_function
18 +
19 +import tensorflow as tf
20 +import sys
21 +
22 +MINIMUM_TF_VERSION = 2, 3, 0
23 +BLACKLISTED_TF_VERSIONS = []
24 +
25 +
26 +def tf_version():
27 + """ Get the Tensorflow version.
28 + Returns
29 + tuple of (major, minor, patch).
30 + """
31 + return tuple(map(int, tf.version.VERSION.split('-')[0].split('.')))
32 +
33 +
34 +def tf_version_ok(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS):
35 + """ Check if the current Tensorflow version is higher than the minimum version.
36 + """
37 + return tf_version() >= minimum_tf_version and tf_version() not in blacklisted
38 +
39 +
40 +def assert_tf_version(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS):
41 + """ Assert that the Tensorflow version is up to date.
42 + """
43 + detected = tf.version.VERSION
44 + required = '.'.join(map(str, minimum_tf_version))
45 + assert(tf_version_ok(minimum_tf_version, blacklisted)), 'You are using tensorflow version {}. The minimum required version is {} (blacklisted: {}).'.format(detected, required, blacklisted)
46 +
47 +
48 +def check_tf_version():
49 + """ Check that the Tensorflow version is up to date. If it isn't, print an error message and exit the script.
50 + """
51 + try:
52 + assert_tf_version()
53 + except AssertionError as e:
54 + print(e, file=sys.stderr)
55 + sys.exit(1)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import numpy as np
18 +
19 +DEFAULT_PRNG = np.random
20 +
21 +
22 +def colvec(*args):
23 + """ Create a numpy array representing a column vector. """
24 + return np.array([args]).T
25 +
26 +
27 +def transform_aabb(transform, aabb):
28 + """ Apply a transformation to an axis aligned bounding box.
29 +
30 + The result is a new AABB in the same coordinate system as the original AABB.
31 + The new AABB contains all corner points of the original AABB after applying the given transformation.
32 +
33 + Args
34 + transform: The transformation to apply.
35 + x1: The minimum x value of the AABB.
36 + y1: The minimum y value of the AABB.
37 + x2: The maximum x value of the AABB.
38 + y2: The maximum y value of the AABB.
39 + Returns
40 + The new AABB as tuple (x1, y1, x2, y2)
41 + """
42 + x1, y1, x2, y2 = aabb
43 + # Transform all 4 corners of the AABB.
44 + points = transform.dot([
45 + [x1, x2, x1, x2],
46 + [y1, y2, y2, y1],
47 + [1, 1, 1, 1 ],
48 + ])
49 +
50 + # Extract the min and max corners again.
51 + min_corner = points.min(axis=1)
52 + max_corner = points.max(axis=1)
53 +
54 + return [min_corner[0], min_corner[1], max_corner[0], max_corner[1]]
55 +
56 +
57 +def _random_vector(min, max, prng=DEFAULT_PRNG):
58 + """ Construct a random vector between min and max.
59 + Args
60 + min: the minimum value for each component
61 + max: the maximum value for each component
62 + """
63 + min = np.array(min)
64 + max = np.array(max)
65 + assert min.shape == max.shape
66 + assert len(min.shape) == 1
67 + return prng.uniform(min, max)
68 +
69 +
70 +def rotation(angle):
71 + """ Construct a homogeneous 2D rotation matrix.
72 + Args
73 + angle: the angle in radians
74 + Returns
75 + the rotation matrix as 3 by 3 numpy array
76 + """
77 + return np.array([
78 + [np.cos(angle), -np.sin(angle), 0],
79 + [np.sin(angle), np.cos(angle), 0],
80 + [0, 0, 1]
81 + ])
82 +
83 +
84 +def random_rotation(min, max, prng=DEFAULT_PRNG):
85 + """ Construct a random rotation between -max and max.
86 + Args
87 + min: a scalar for the minimum absolute angle in radians
88 + max: a scalar for the maximum absolute angle in radians
89 + prng: the pseudo-random number generator to use.
90 + Returns
91 + a homogeneous 3 by 3 rotation matrix
92 + """
93 + return rotation(prng.uniform(min, max))
94 +
95 +
96 +def translation(translation):
97 + """ Construct a homogeneous 2D translation matrix.
98 + # Arguments
99 + translation: the translation 2D vector
100 + # Returns
101 + the translation matrix as 3 by 3 numpy array
102 + """
103 + return np.array([
104 + [1, 0, translation[0]],
105 + [0, 1, translation[1]],
106 + [0, 0, 1]
107 + ])
108 +
109 +
110 +def random_translation(min, max, prng=DEFAULT_PRNG):
111 + """ Construct a random 2D translation between min and max.
112 + Args
113 + min: a 2D vector with the minimum translation for each dimension
114 + max: a 2D vector with the maximum translation for each dimension
115 + prng: the pseudo-random number generator to use.
116 + Returns
117 + a homogeneous 3 by 3 translation matrix
118 + """
119 + return translation(_random_vector(min, max, prng))
120 +
121 +
122 +def shear(angle):
123 + """ Construct a homogeneous 2D shear matrix.
124 + Args
125 + angle: the shear angle in radians
126 + Returns
127 + the shear matrix as 3 by 3 numpy array
128 + """
129 + return np.array([
130 + [1, -np.sin(angle), 0],
131 + [0, np.cos(angle), 0],
132 + [0, 0, 1]
133 + ])
134 +
135 +
136 +def random_shear(min, max, prng=DEFAULT_PRNG):
137 + """ Construct a random 2D shear matrix with shear angle between -max and max.
138 + Args
139 + min: the minimum shear angle in radians.
140 + max: the maximum shear angle in radians.
141 + prng: the pseudo-random number generator to use.
142 + Returns
143 + a homogeneous 3 by 3 shear matrix
144 + """
145 + return shear(prng.uniform(min, max))
146 +
147 +
148 +def scaling(factor):
149 + """ Construct a homogeneous 2D scaling matrix.
150 + Args
151 + factor: a 2D vector for X and Y scaling
152 + Returns
153 + the zoom matrix as 3 by 3 numpy array
154 + """
155 + return np.array([
156 + [factor[0], 0, 0],
157 + [0, factor[1], 0],
158 + [0, 0, 1]
159 + ])
160 +
161 +
162 +def random_scaling(min, max, prng=DEFAULT_PRNG):
163 + """ Construct a random 2D scale matrix between -max and max.
164 + Args
165 + min: a 2D vector containing the minimum scaling factor for X and Y.
166 + min: a 2D vector containing The maximum scaling factor for X and Y.
167 + prng: the pseudo-random number generator to use.
168 + Returns
169 + a homogeneous 3 by 3 scaling matrix
170 + """
171 + return scaling(_random_vector(min, max, prng))
172 +
173 +
174 +def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG):
175 + """ Construct a transformation randomly containing X/Y flips (or not).
176 + Args
177 + flip_x_chance: The chance that the result will contain a flip along the X axis.
178 + flip_y_chance: The chance that the result will contain a flip along the Y axis.
179 + prng: The pseudo-random number generator to use.
180 + Returns
181 + a homogeneous 3 by 3 transformation matrix
182 + """
183 + flip_x = prng.uniform(0, 1) < flip_x_chance
184 + flip_y = prng.uniform(0, 1) < flip_y_chance
185 + # 1 - 2 * bool gives 1 for False and -1 for True.
186 + return scaling((1 - 2 * flip_x, 1 - 2 * flip_y))
187 +
188 +
189 +def change_transform_origin(transform, center):
190 + """ Create a new transform representing the same transformation,
191 + only with the origin of the linear part changed.
192 + Args
193 + transform: the transformation matrix
194 + center: the new origin of the transformation
195 + Returns
196 + translate(center) * transform * translate(-center)
197 + """
198 + center = np.array(center)
199 + return np.linalg.multi_dot([translation(center), transform, translation(-center)])
200 +
201 +
202 +def random_transform(
203 + min_rotation=0,
204 + max_rotation=0,
205 + min_translation=(0, 0),
206 + max_translation=(0, 0),
207 + min_shear=0,
208 + max_shear=0,
209 + min_scaling=(1, 1),
210 + max_scaling=(1, 1),
211 + flip_x_chance=0,
212 + flip_y_chance=0,
213 + prng=DEFAULT_PRNG
214 +):
215 + """ Create a random transformation.
216 +
217 + The transformation consists of the following operations in this order (from left to right):
218 + * rotation
219 + * translation
220 + * shear
221 + * scaling
222 + * flip x (if applied)
223 + * flip y (if applied)
224 +
225 + Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation
226 + as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width.
227 + Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret
228 + the translation directly as pixel distances instead.
229 +
230 + Args
231 + min_rotation: The minimum rotation in radians for the transform as scalar.
232 + max_rotation: The maximum rotation in radians for the transform as scalar.
233 + min_translation: The minimum translation for the transform as 2D column vector.
234 + max_translation: The maximum translation for the transform as 2D column vector.
235 + min_shear: The minimum shear angle for the transform in radians.
236 + max_shear: The maximum shear angle for the transform in radians.
237 + min_scaling: The minimum scaling for the transform as 2D column vector.
238 + max_scaling: The maximum scaling for the transform as 2D column vector.
239 + flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction.
240 + flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction.
241 + prng: The pseudo-random number generator to use.
242 + """
243 + return np.linalg.multi_dot([
244 + random_rotation(min_rotation, max_rotation, prng),
245 + random_translation(min_translation, max_translation, prng),
246 + random_shear(min_shear, max_shear, prng),
247 + random_scaling(min_scaling, max_scaling, prng),
248 + random_flip(flip_x_chance, flip_y_chance, prng)
249 + ])
250 +
251 +
252 +def random_transform_generator(prng=None, **kwargs):
253 + """ Create a random transform generator.
254 +
255 + Uses a dedicated, newly created, properly seeded PRNG by default instead of the global DEFAULT_PRNG.
256 +
257 + The transformation consists of the following operations in this order (from left to right):
258 + * rotation
259 + * translation
260 + * shear
261 + * scaling
262 + * flip x (if applied)
263 + * flip y (if applied)
264 +
265 + Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation
266 + as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width.
267 + Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret
268 + the translation directly as pixel distances instead.
269 +
270 + Args
271 + min_rotation: The minimum rotation in radians for the transform as scalar.
272 + max_rotation: The maximum rotation in radians for the transform as scalar.
273 + min_translation: The minimum translation for the transform as 2D column vector.
274 + max_translation: The maximum translation for the transform as 2D column vector.
275 + min_shear: The minimum shear angle for the transform in radians.
276 + max_shear: The maximum shear angle for the transform in radians.
277 + min_scaling: The minimum scaling for the transform as 2D column vector.
278 + max_scaling: The maximum scaling for the transform as 2D column vector.
279 + flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction.
280 + flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction.
281 + prng: The pseudo-random number generator to use.
282 + """
283 +
284 + if prng is None:
285 + # RandomState automatically seeds using the best available method.
286 + prng = np.random.RandomState()
287 +
288 + while True:
289 + yield random_transform(prng=prng, **kwargs)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import cv2
18 +import numpy as np
19 +
20 +from .colors import label_color
21 +
22 +
23 +def draw_box(image, box, color, thickness=2):
24 + """ Draws a box on an image with a given color.
25 +
26 + # Arguments
27 + image : The image to draw on.
28 + box : A list of 4 elements (x1, y1, x2, y2).
29 + color : The color of the box.
30 + thickness : The thickness of the lines to draw a box with.
31 + """
32 + b = np.array(box).astype(int)
33 + cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA)
34 +
35 +
36 +def draw_caption(image, box, caption):
37 + """ Draws a caption above the box in an image.
38 +
39 + # Arguments
40 + image : The image to draw on.
41 + box : A list of 4 elements (x1, y1, x2, y2).
42 + caption : String containing the text to draw.
43 + """
44 + b = np.array(box).astype(int)
45 + cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2)
46 + cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1)
47 +
48 +
49 +def draw_boxes(image, boxes, color, thickness=2):
50 + """ Draws boxes on an image with a given color.
51 +
52 + # Arguments
53 + image : The image to draw on.
54 + boxes : A [N, 4] matrix (x1, y1, x2, y2).
55 + color : The color of the boxes.
56 + thickness : The thickness of the lines to draw boxes with.
57 + """
58 + for b in boxes:
59 + draw_box(image, b, color, thickness=thickness)
60 +
61 +
62 +def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5):
63 + """ Draws detections in an image.
64 +
65 + # Arguments
66 + image : The image to draw on.
67 + boxes : A [N, 4] matrix (x1, y1, x2, y2).
68 + scores : A list of N classification scores.
69 + labels : A list of N labels.
70 + color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used.
71 + label_to_name : (optional) Functor for mapping a label to a name.
72 + score_threshold : Threshold used for determining what detections to draw.
73 + """
74 + selection = np.where(scores > score_threshold)[0]
75 +
76 + for i in selection:
77 + c = color if color is not None else label_color(labels[i])
78 + draw_box(image, boxes[i, :], color=c)
79 +
80 + # draw labels
81 + caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i])
82 + draw_caption(image, boxes[i, :], caption)
83 +
84 +
85 +def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None):
86 + """ Draws annotations in an image.
87 +
88 + # Arguments
89 + image : The image to draw on.
90 + annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]).
91 + color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used.
92 + label_to_name : (optional) Functor for mapping a label to a name.
93 + """
94 + if isinstance(annotations, np.ndarray):
95 + annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]}
96 +
97 + assert('bboxes' in annotations)
98 + assert('labels' in annotations)
99 + assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0])
100 +
101 + for i in range(annotations['bboxes'].shape[0]):
102 + label = annotations['labels'][i]
103 + c = color if color is not None else label_color(label)
104 + caption = '{}'.format(label_to_name(label) if label_to_name else label)
105 + draw_caption(image, annotations['bboxes'][i], caption)
106 + draw_box(image, annotations['bboxes'][i], color=c)
1 +cython
2 +keras-resnet==0.2.0
3 +git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
4 +h5py
5 +keras
6 +matplotlib
7 +numpy>=1.14
8 +opencv-python>=3.3.0
9 +pillow
10 +progressbar2
11 +tensorflow>=2.3.0
1 +# ignore:
2 +# E201 whitespace after '['
3 +# E202 whitespace before ']'
4 +# E203 whitespace before ':'
5 +# E221 multiple spaces before operator
6 +# E241 multiple spaces after ','
7 +# E251 unexpected spaces around keyword / parameter equals
8 +# E501 line too long (85 > 79 characters)
9 +# W504 line break after binary operator
10 +[tool:pytest]
11 +flake8-max-line-length = 100
12 +flake8-ignore = E201 E202 E203 E221 E241 E251 E402 E501 W504
1 +import setuptools
2 +from setuptools.extension import Extension
3 +from distutils.command.build_ext import build_ext as DistUtilsBuildExt
4 +
5 +
6 +class BuildExtension(setuptools.Command):
7 + description = DistUtilsBuildExt.description
8 + user_options = DistUtilsBuildExt.user_options
9 + boolean_options = DistUtilsBuildExt.boolean_options
10 + help_options = DistUtilsBuildExt.help_options
11 +
12 + def __init__(self, *args, **kwargs):
13 + from setuptools.command.build_ext import build_ext as SetupToolsBuildExt
14 +
15 + # Bypass __setatrr__ to avoid infinite recursion.
16 + self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs)
17 +
18 + def __getattr__(self, name):
19 + return getattr(self._command, name)
20 +
21 + def __setattr__(self, name, value):
22 + setattr(self._command, name, value)
23 +
24 + def initialize_options(self, *args, **kwargs):
25 + return self._command.initialize_options(*args, **kwargs)
26 +
27 + def finalize_options(self, *args, **kwargs):
28 + ret = self._command.finalize_options(*args, **kwargs)
29 + import numpy
30 + self.include_dirs.append(numpy.get_include())
31 + return ret
32 +
33 + def run(self, *args, **kwargs):
34 + return self._command.run(*args, **kwargs)
35 +
36 +
37 +extensions = [
38 + Extension(
39 + 'keras_retinanet.utils.compute_overlap',
40 + ['keras_retinanet/utils/compute_overlap.pyx']
41 + ),
42 +]
43 +
44 +
45 +setuptools.setup(
46 + name = 'keras-retinanet',
47 + version = '1.0.0',
48 + description = 'Keras implementation of RetinaNet object detection.',
49 + url = 'https://github.com/fizyr/keras-retinanet',
50 + author = 'Hans Gaiser',
51 + author_email = 'h.gaiser@fizyr.com',
52 + maintainer = 'Hans Gaiser',
53 + maintainer_email = 'h.gaiser@fizyr.com',
54 + cmdclass = {'build_ext': BuildExtension},
55 + packages = setuptools.find_packages(),
56 + install_requires = ['keras-resnet==0.2.0', 'six', 'numpy', 'cython', 'Pillow', 'opencv-python', 'progressbar2'],
57 + entry_points = {
58 + 'console_scripts': [
59 + 'retinanet-train=keras_retinanet.bin.train:main',
60 + 'retinanet-evaluate=keras_retinanet.bin.evaluate:main',
61 + 'retinanet-debug=keras_retinanet.bin.debug:main',
62 + 'retinanet-convert-model=keras_retinanet.bin.convert_model:main',
63 + ],
64 + },
65 + ext_modules = extensions,
66 + setup_requires = ["cython>=0.28", "numpy>=1.14.0"]
67 +)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import numpy as np
18 +from tensorflow import keras
19 +import keras_retinanet.backend
20 +
21 +
22 +def test_bbox_transform_inv():
23 + boxes = np.array([[
24 + [100, 100, 200, 200],
25 + [100, 100, 300, 300],
26 + [100, 100, 200, 300],
27 + [100, 100, 300, 200],
28 + [80, 120, 200, 200],
29 + [80, 120, 300, 300],
30 + [80, 120, 200, 300],
31 + [80, 120, 300, 200],
32 + ]])
33 + boxes = keras.backend.variable(boxes)
34 +
35 + deltas = np.array([[
36 + [0 , 0 , 0 , 0 ],
37 + [0 , 0.1, 0 , 0 ],
38 + [-0.3, 0 , 0 , 0 ],
39 + [0.2 , 0.2, 0 , 0 ],
40 + [0 , 0 , 0.1 , 0 ],
41 + [0 , 0 , 0 , -0.3],
42 + [0 , 0 , 0.2 , 0.2 ],
43 + [0.1 , 0.2, -0.3, 0.4 ],
44 + ]])
45 + deltas = keras.backend.variable(deltas)
46 +
47 + expected = np.array([[
48 + [100 , 100 , 200 , 200 ],
49 + [100 , 104 , 300 , 300 ],
50 + [ 94 , 100 , 200 , 300 ],
51 + [108 , 104 , 300 , 200 ],
52 + [ 80 , 120 , 202.4 , 200 ],
53 + [ 80 , 120 , 300 , 289.2],
54 + [ 80 , 120 , 204.8 , 307.2],
55 + [ 84.4, 123.2, 286.8 , 206.4]
56 + ]])
57 +
58 + result = keras_retinanet.backend.bbox_transform_inv(boxes, deltas)
59 + result = keras.backend.eval(result)
60 +
61 + np.testing.assert_array_almost_equal(result, expected, decimal=2)
62 +
63 +
64 +def test_shift():
65 + shape = (2, 3)
66 + stride = 8
67 +
68 + anchors = np.array([
69 + [-8, -8, 8, 8],
70 + [-16, -16, 16, 16],
71 + [-12, -12, 12, 12],
72 + [-12, -16, 12, 16],
73 + [-16, -12, 16, 12]
74 + ], dtype=keras.backend.floatx())
75 +
76 + expected = [
77 + # anchors for (0, 0)
78 + [4 - 8, 4 - 8, 4 + 8, 4 + 8],
79 + [4 - 16, 4 - 16, 4 + 16, 4 + 16],
80 + [4 - 12, 4 - 12, 4 + 12, 4 + 12],
81 + [4 - 12, 4 - 16, 4 + 12, 4 + 16],
82 + [4 - 16, 4 - 12, 4 + 16, 4 + 12],
83 +
84 + # anchors for (0, 1)
85 + [12 - 8, 4 - 8, 12 + 8, 4 + 8],
86 + [12 - 16, 4 - 16, 12 + 16, 4 + 16],
87 + [12 - 12, 4 - 12, 12 + 12, 4 + 12],
88 + [12 - 12, 4 - 16, 12 + 12, 4 + 16],
89 + [12 - 16, 4 - 12, 12 + 16, 4 + 12],
90 +
91 + # anchors for (0, 2)
92 + [20 - 8, 4 - 8, 20 + 8, 4 + 8],
93 + [20 - 16, 4 - 16, 20 + 16, 4 + 16],
94 + [20 - 12, 4 - 12, 20 + 12, 4 + 12],
95 + [20 - 12, 4 - 16, 20 + 12, 4 + 16],
96 + [20 - 16, 4 - 12, 20 + 16, 4 + 12],
97 +
98 + # anchors for (1, 0)
99 + [4 - 8, 12 - 8, 4 + 8, 12 + 8],
100 + [4 - 16, 12 - 16, 4 + 16, 12 + 16],
101 + [4 - 12, 12 - 12, 4 + 12, 12 + 12],
102 + [4 - 12, 12 - 16, 4 + 12, 12 + 16],
103 + [4 - 16, 12 - 12, 4 + 16, 12 + 12],
104 +
105 + # anchors for (1, 1)
106 + [12 - 8, 12 - 8, 12 + 8, 12 + 8],
107 + [12 - 16, 12 - 16, 12 + 16, 12 + 16],
108 + [12 - 12, 12 - 12, 12 + 12, 12 + 12],
109 + [12 - 12, 12 - 16, 12 + 12, 12 + 16],
110 + [12 - 16, 12 - 12, 12 + 16, 12 + 12],
111 +
112 + # anchors for (1, 2)
113 + [20 - 8, 12 - 8, 20 + 8, 12 + 8],
114 + [20 - 16, 12 - 16, 20 + 16, 12 + 16],
115 + [20 - 12, 12 - 12, 20 + 12, 12 + 12],
116 + [20 - 12, 12 - 16, 20 + 12, 12 + 16],
117 + [20 - 16, 12 - 12, 20 + 16, 12 + 12],
118 + ]
119 +
120 + result = keras_retinanet.backend.shift(shape, stride, anchors)
121 + result = keras.backend.eval(result)
122 +
123 + np.testing.assert_array_equal(result, expected)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import keras_retinanet.backend
18 +import keras_retinanet.bin.train
19 +from tensorflow import keras
20 +
21 +import warnings
22 +
23 +import pytest
24 +
25 +
26 +@pytest.fixture(autouse=True)
27 +def clear_session():
28 + # run before test (do nothing)
29 + yield
30 + # run after test, clear keras session
31 + keras.backend.clear_session()
32 +
33 +
34 +def test_coco():
35 + # ignore warnings in this test
36 + warnings.simplefilter('ignore')
37 +
38 + # run training / evaluation
39 + keras_retinanet.bin.train.main([
40 + '--epochs=1',
41 + '--steps=1',
42 + '--no-weights',
43 + '--no-snapshots',
44 + 'coco',
45 + 'tests/test-data/coco',
46 + ])
47 +
48 +
49 +def test_pascal():
50 + # ignore warnings in this test
51 + warnings.simplefilter('ignore')
52 +
53 + # run training / evaluation
54 + keras_retinanet.bin.train.main([
55 + '--epochs=1',
56 + '--steps=1',
57 + '--no-weights',
58 + '--no-snapshots',
59 + 'pascal',
60 + 'tests/test-data/pascal',
61 + ])
62 +
63 +
64 +def test_csv():
65 + # ignore warnings in this test
66 + warnings.simplefilter('ignore')
67 +
68 + # run training / evaluation
69 + keras_retinanet.bin.train.main([
70 + '--epochs=1',
71 + '--steps=1',
72 + '--no-weights',
73 + '--no-snapshots',
74 + 'csv',
75 + 'tests/test-data/csv/annotations.csv',
76 + 'tests/test-data/csv/classes.csv',
77 + ])
78 +
79 +
80 +def test_vgg():
81 + # ignore warnings in this test
82 + warnings.simplefilter('ignore')
83 +
84 + # run training / evaluation
85 + keras_retinanet.bin.train.main([
86 + '--backbone=vgg16',
87 + '--epochs=1',
88 + '--steps=1',
89 + '--no-weights',
90 + '--no-snapshots',
91 + '--freeze-backbone',
92 + 'coco',
93 + 'tests/test-data/coco',
94 + ])
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +import keras_retinanet.backend
19 +import keras_retinanet.layers
20 +
21 +import numpy as np
22 +
23 +
24 +class TestFilterDetections(object):
25 + def test_simple(self):
26 + # create simple FilterDetections layer
27 + filter_detections_layer = keras_retinanet.layers.FilterDetections()
28 +
29 + # create simple input
30 + boxes = np.array([[
31 + [0, 0, 10, 10],
32 + [0, 0, 10, 10], # this will be suppressed
33 + ]], dtype=keras.backend.floatx())
34 + boxes = keras.backend.constant(boxes)
35 +
36 + classification = np.array([[
37 + [0, 0.9], # this will be suppressed
38 + [0, 1],
39 + ]], dtype=keras.backend.floatx())
40 + classification = keras.backend.constant(classification)
41 +
42 + # compute output
43 + actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification])
44 + actual_boxes = keras.backend.eval(actual_boxes)
45 + actual_scores = keras.backend.eval(actual_scores)
46 + actual_labels = keras.backend.eval(actual_labels)
47 +
48 + # define expected output
49 + expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx())
50 + expected_boxes[0, 0, :] = [0, 0, 10, 10]
51 +
52 + expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx())
53 + expected_scores[0, 0] = 1
54 +
55 + expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx())
56 + expected_labels[0, 0] = 1
57 +
58 + # assert actual and expected are equal
59 + np.testing.assert_array_equal(actual_boxes, expected_boxes)
60 + np.testing.assert_array_equal(actual_scores, expected_scores)
61 + np.testing.assert_array_equal(actual_labels, expected_labels)
62 +
63 + def test_simple_with_other(self):
64 + # create simple FilterDetections layer
65 + filter_detections_layer = keras_retinanet.layers.FilterDetections()
66 +
67 + # create simple input
68 + boxes = np.array([[
69 + [0, 0, 10, 10],
70 + [0, 0, 10, 10], # this will be suppressed
71 + ]], dtype=keras.backend.floatx())
72 + boxes = keras.backend.constant(boxes)
73 +
74 + classification = np.array([[
75 + [0, 0.9], # this will be suppressed
76 + [0, 1],
77 + ]], dtype=keras.backend.floatx())
78 + classification = keras.backend.constant(classification)
79 +
80 + other = []
81 + other.append(np.array([[
82 + [0, 1234], # this will be suppressed
83 + [0, 5678],
84 + ]], dtype=keras.backend.floatx()))
85 + other.append(np.array([[
86 + 5678, # this will be suppressed
87 + 1234,
88 + ]], dtype=keras.backend.floatx()))
89 + other = [keras.backend.constant(o) for o in other]
90 +
91 + # compute output
92 + actual = filter_detections_layer.call([boxes, classification] + other)
93 + actual_boxes = keras.backend.eval(actual[0])
94 + actual_scores = keras.backend.eval(actual[1])
95 + actual_labels = keras.backend.eval(actual[2])
96 + actual_other = [keras.backend.eval(a) for a in actual[3:]]
97 +
98 + # define expected output
99 + expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx())
100 + expected_boxes[0, 0, :] = [0, 0, 10, 10]
101 +
102 + expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx())
103 + expected_scores[0, 0] = 1
104 +
105 + expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx())
106 + expected_labels[0, 0] = 1
107 +
108 + expected_other = []
109 + expected_other.append(-1 * np.ones((1, 300, 2), dtype=keras.backend.floatx()))
110 + expected_other[-1][0, 0, :] = [0, 5678]
111 + expected_other.append(-1 * np.ones((1, 300), dtype=keras.backend.floatx()))
112 + expected_other[-1][0, 0] = 1234
113 +
114 + # assert actual and expected are equal
115 + np.testing.assert_array_equal(actual_boxes, expected_boxes)
116 + np.testing.assert_array_equal(actual_scores, expected_scores)
117 + np.testing.assert_array_equal(actual_labels, expected_labels)
118 +
119 + for a, e in zip(actual_other, expected_other):
120 + np.testing.assert_array_equal(a, e)
121 +
122 + def test_mini_batch(self):
123 + # create simple FilterDetections layer
124 + filter_detections_layer = keras_retinanet.layers.FilterDetections()
125 +
126 + # create input with batch_size=2
127 + boxes = np.array([
128 + [
129 + [0, 0, 10, 10], # this will be suppressed
130 + [0, 0, 10, 10],
131 + ],
132 + [
133 + [100, 100, 150, 150],
134 + [100, 100, 150, 150], # this will be suppressed
135 + ],
136 + ], dtype=keras.backend.floatx())
137 + boxes = keras.backend.constant(boxes)
138 +
139 + classification = np.array([
140 + [
141 + [0, 0.9], # this will be suppressed
142 + [0, 1],
143 + ],
144 + [
145 + [1, 0],
146 + [0.9, 0], # this will be suppressed
147 + ],
148 + ], dtype=keras.backend.floatx())
149 + classification = keras.backend.constant(classification)
150 +
151 + # compute output
152 + actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification])
153 + actual_boxes = keras.backend.eval(actual_boxes)
154 + actual_scores = keras.backend.eval(actual_scores)
155 + actual_labels = keras.backend.eval(actual_labels)
156 +
157 + # define expected output
158 + expected_boxes = -1 * np.ones((2, 300, 4), dtype=keras.backend.floatx())
159 + expected_boxes[0, 0, :] = [0, 0, 10, 10]
160 + expected_boxes[1, 0, :] = [100, 100, 150, 150]
161 +
162 + expected_scores = -1 * np.ones((2, 300), dtype=keras.backend.floatx())
163 + expected_scores[0, 0] = 1
164 + expected_scores[1, 0] = 1
165 +
166 + expected_labels = -1 * np.ones((2, 300), dtype=keras.backend.floatx())
167 + expected_labels[0, 0] = 1
168 + expected_labels[1, 0] = 0
169 +
170 + # assert actual and expected are equal
171 + np.testing.assert_array_equal(actual_boxes, expected_boxes)
172 + np.testing.assert_array_equal(actual_scores, expected_scores)
173 + np.testing.assert_array_equal(actual_labels, expected_labels)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from tensorflow import keras
18 +import keras_retinanet.backend
19 +import keras_retinanet.layers
20 +
21 +import numpy as np
22 +
23 +
24 +class TestAnchors(object):
25 + def test_simple(self):
26 + # create simple Anchors layer
27 + anchors_layer = keras_retinanet.layers.Anchors(
28 + size=32,
29 + stride=8,
30 + ratios=np.array([1], keras.backend.floatx()),
31 + scales=np.array([1], keras.backend.floatx()),
32 + )
33 +
34 + # create fake features input (only shape is used anyway)
35 + features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx())
36 + features = keras.backend.variable(features)
37 +
38 + # call the Anchors layer
39 + anchors = anchors_layer.call(features)
40 + anchors = keras.backend.eval(anchors)
41 +
42 + # expected anchor values
43 + expected = np.array([[
44 + [-12, -12, 20, 20],
45 + [-4 , -12, 28, 20],
46 + [-12, -4 , 20, 28],
47 + [-4 , -4 , 28, 28],
48 + ]], dtype=keras.backend.floatx())
49 +
50 + # test anchor values
51 + np.testing.assert_array_equal(anchors, expected)
52 +
53 + # mark test to fail
54 + def test_mini_batch(self):
55 + # create simple Anchors layer
56 + anchors_layer = keras_retinanet.layers.Anchors(
57 + size=32,
58 + stride=8,
59 + ratios=np.array([1], dtype=keras.backend.floatx()),
60 + scales=np.array([1], dtype=keras.backend.floatx()),
61 + )
62 +
63 + # create fake features input with batch_size=2
64 + features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx())
65 + features = keras.backend.variable(features)
66 +
67 + # call the Anchors layer
68 + anchors = anchors_layer.call(features)
69 + anchors = keras.backend.eval(anchors)
70 +
71 + # expected anchor values
72 + expected = np.array([[
73 + [-12, -12, 20, 20],
74 + [-4 , -12, 28, 20],
75 + [-12, -4 , 20, 28],
76 + [-4 , -4 , 28, 28],
77 + ]], dtype=keras.backend.floatx())
78 + expected = np.tile(expected, (2, 1, 1))
79 +
80 + # test anchor values
81 + np.testing.assert_array_equal(anchors, expected)
82 +
83 +
84 +class TestUpsampleLike(object):
85 + def test_simple(self):
86 + # create simple UpsampleLike layer
87 + upsample_like_layer = keras_retinanet.layers.UpsampleLike()
88 +
89 + # create input source
90 + source = np.zeros((1, 2, 2, 1), dtype=keras.backend.floatx())
91 + source = keras.backend.variable(source)
92 + target = np.zeros((1, 5, 5, 1), dtype=keras.backend.floatx())
93 + expected = target
94 + target = keras.backend.variable(target)
95 +
96 + # compute output
97 + actual = upsample_like_layer.call([source, target])
98 + actual = keras.backend.eval(actual)
99 +
100 + np.testing.assert_array_equal(actual, expected)
101 +
102 + def test_mini_batch(self):
103 + # create simple UpsampleLike layer
104 + upsample_like_layer = keras_retinanet.layers.UpsampleLike()
105 +
106 + # create input source
107 + source = np.zeros((2, 2, 2, 1), dtype=keras.backend.floatx())
108 + source = keras.backend.variable(source)
109 +
110 + target = np.zeros((2, 5, 5, 1), dtype=keras.backend.floatx())
111 + expected = target
112 + target = keras.backend.variable(target)
113 +
114 + # compute output
115 + actual = upsample_like_layer.call([source, target])
116 + actual = keras.backend.eval(actual)
117 +
118 + np.testing.assert_array_equal(actual, expected)
119 +
120 +
121 +class TestRegressBoxes(object):
122 + def test_simple(self):
123 + mean = [0, 0, 0, 0]
124 + std = [0.2, 0.2, 0.2, 0.2]
125 +
126 + # create simple RegressBoxes layer
127 + regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std)
128 +
129 + # create input
130 + anchors = np.array([[
131 + [0 , 0 , 10 , 10 ],
132 + [50, 50, 100, 100],
133 + [20, 20, 40 , 40 ],
134 + ]], dtype=keras.backend.floatx())
135 + anchors = keras.backend.variable(anchors)
136 +
137 + regression = np.array([[
138 + [0 , 0 , 0 , 0 ],
139 + [0.1, 0.1, 0 , 0 ],
140 + [0 , 0 , 0.1, 0.1],
141 + ]], dtype=keras.backend.floatx())
142 + regression = keras.backend.variable(regression)
143 +
144 + # compute output
145 + actual = regress_boxes_layer.call([anchors, regression])
146 + actual = keras.backend.eval(actual)
147 +
148 + # compute expected output
149 + expected = np.array([[
150 + [0 , 0 , 10 , 10 ],
151 + [51, 51, 100 , 100 ],
152 + [20, 20, 40.4, 40.4],
153 + ]], dtype=keras.backend.floatx())
154 +
155 + np.testing.assert_array_almost_equal(actual, expected, decimal=2)
156 +
157 + # mark test to fail
158 + def test_mini_batch(self):
159 + mean = [0, 0, 0, 0]
160 + std = [0.2, 0.2, 0.2, 0.2]
161 +
162 + # create simple RegressBoxes layer
163 + regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std)
164 +
165 + # create input
166 + anchors = np.array([
167 + [
168 + [0 , 0 , 10 , 10 ], # 1
169 + [50, 50, 100, 100], # 2
170 + [20, 20, 40 , 40 ], # 3
171 + ],
172 + [
173 + [20, 20, 40 , 40 ], # 3
174 + [0 , 0 , 10 , 10 ], # 1
175 + [50, 50, 100, 100], # 2
176 + ],
177 + ], dtype=keras.backend.floatx())
178 + anchors = keras.backend.variable(anchors)
179 +
180 + regression = np.array([
181 + [
182 + [0 , 0 , 0 , 0 ], # 1
183 + [0.1, 0.1, 0 , 0 ], # 2
184 + [0 , 0 , 0.1, 0.1], # 3
185 + ],
186 + [
187 + [0 , 0 , 0.1, 0.1], # 3
188 + [0 , 0 , 0 , 0 ], # 1
189 + [0.1, 0.1, 0 , 0 ], # 2
190 + ],
191 + ], dtype=keras.backend.floatx())
192 + regression = keras.backend.variable(regression)
193 +
194 + # compute output
195 + actual = regress_boxes_layer.call([anchors, regression])
196 + actual = keras.backend.eval(actual)
197 +
198 + # compute expected output
199 + expected = np.array([
200 + [
201 + [0 , 0 , 10 , 10 ], # 1
202 + [51, 51, 100 , 100 ], # 2
203 + [20, 20, 40.4, 40.4], # 3
204 + ],
205 + [
206 + [20, 20, 40.4, 40.4], # 3
207 + [0 , 0 , 10 , 10 ], # 1
208 + [51, 51, 100 , 100 ], # 2
209 + ],
210 + ], dtype=keras.backend.floatx())
211 +
212 + np.testing.assert_array_almost_equal(actual, expected, decimal=2)
1 +"""
2 +Copyright 2018 vidosits (https://github.com/vidosits/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import warnings
18 +import pytest
19 +import numpy as np
20 +from tensorflow import keras
21 +from keras_retinanet import losses
22 +from keras_retinanet.models.densenet import DenseNetBackbone
23 +
24 +parameters = ['densenet121']
25 +
26 +
27 +@pytest.mark.parametrize("backbone", parameters)
28 +def test_backbone(backbone):
29 + # ignore warnings in this test
30 + warnings.simplefilter('ignore')
31 +
32 + num_classes = 10
33 +
34 + inputs = np.zeros((1, 200, 400, 3), dtype=np.float32)
35 + targets = [np.zeros((1, 14814, 5), dtype=np.float32), np.zeros((1, 14814, num_classes + 1))]
36 +
37 + inp = keras.layers.Input(inputs[0].shape)
38 +
39 + densenet_backbone = DenseNetBackbone(backbone)
40 + model = densenet_backbone.retinanet(num_classes=num_classes, inputs=inp)
41 + model.summary()
42 +
43 + # compile model
44 + model.compile(
45 + loss={
46 + 'regression': losses.smooth_l1(),
47 + 'classification': losses.focal()
48 + },
49 + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001))
50 +
51 + model.fit(inputs, targets, batch_size=1)
1 +"""
2 +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import warnings
18 +import pytest
19 +import numpy as np
20 +from tensorflow import keras
21 +from keras_retinanet import losses
22 +from keras_retinanet.models.mobilenet import MobileNetBackbone
23 +
24 +
25 +alphas = ['1.0']
26 +parameters = []
27 +
28 +for backbone in MobileNetBackbone.allowed_backbones:
29 + for alpha in alphas:
30 + parameters.append((backbone, alpha))
31 +
32 +
33 +@pytest.mark.parametrize("backbone, alpha", parameters)
34 +def test_backbone(backbone, alpha):
35 + # ignore warnings in this test
36 + warnings.simplefilter('ignore')
37 +
38 + num_classes = 10
39 +
40 + inputs = np.zeros((1, 1024, 363, 3), dtype=np.float32)
41 + targets = [np.zeros((1, 68760, 5), dtype=np.float32), np.zeros((1, 68760, num_classes + 1))]
42 +
43 + inp = keras.layers.Input(inputs[0].shape)
44 +
45 + mobilenet_backbone = MobileNetBackbone(backbone='{}_{}'.format(backbone, format(alpha)))
46 + training_model = mobilenet_backbone.retinanet(num_classes=num_classes, inputs=inp)
47 + training_model.summary()
48 +
49 + # compile model
50 + training_model.compile(
51 + loss={
52 + 'regression': losses.smooth_l1(),
53 + 'classification': losses.focal()
54 + },
55 + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001))
56 +
57 + training_model.fit(inputs, targets, batch_size=1)
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +import csv
18 +import pytest
19 +try:
20 + from io import StringIO
21 +except ImportError:
22 + from stringio import StringIO
23 +
24 +from keras_retinanet.preprocessing import csv_generator
25 +
26 +
27 +def csv_str(string):
28 + if str == bytes:
29 + string = string.decode('utf-8')
30 + return csv.reader(StringIO(string))
31 +
32 +
33 +def annotation(x1, y1, x2, y2, class_name):
34 + return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'class': class_name}
35 +
36 +
37 +def test_read_classes():
38 + assert csv_generator._read_classes(csv_str('')) == {}
39 + assert csv_generator._read_classes(csv_str('a,1')) == {'a': 1}
40 + assert csv_generator._read_classes(csv_str('a,1\nb,2')) == {'a': 1, 'b': 2}
41 +
42 +
43 +def test_read_classes_wrong_format():
44 + with pytest.raises(ValueError):
45 + try:
46 + csv_generator._read_classes(csv_str('a,b,c'))
47 + except ValueError as e:
48 + assert str(e).startswith('line 1: format should be')
49 + raise
50 + with pytest.raises(ValueError):
51 + try:
52 + csv_generator._read_classes(csv_str('a,1\nb,c,d'))
53 + except ValueError as e:
54 + assert str(e).startswith('line 2: format should be')
55 + raise
56 +
57 +
58 +def test_read_classes_malformed_class_id():
59 + with pytest.raises(ValueError):
60 + try:
61 + csv_generator._read_classes(csv_str('a,b'))
62 + except ValueError as e:
63 + assert str(e).startswith("line 1: malformed class ID:")
64 + raise
65 +
66 + with pytest.raises(ValueError):
67 + try:
68 + csv_generator._read_classes(csv_str('a,1\nb,c'))
69 + except ValueError as e:
70 + assert str(e).startswith('line 2: malformed class ID:')
71 + raise
72 +
73 +
74 +def test_read_classes_duplicate_name():
75 + with pytest.raises(ValueError):
76 + try:
77 + csv_generator._read_classes(csv_str('a,1\nb,2\na,3'))
78 + except ValueError as e:
79 + assert str(e).startswith('line 3: duplicate class name')
80 + raise
81 +
82 +
83 +def test_read_annotations():
84 + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10}
85 + annotations = csv_generator._read_annotations(csv_str(
86 + 'a.png,0,1,2,3,a' '\n'
87 + 'b.png,4,5,6,7,b' '\n'
88 + 'c.png,8,9,10,11,c' '\n'
89 + 'd.png,12,13,14,15,d' '\n'
90 + ), classes)
91 + assert annotations == {
92 + 'a.png': [annotation( 0, 1, 2, 3, 'a')],
93 + 'b.png': [annotation( 4, 5, 6, 7, 'b')],
94 + 'c.png': [annotation( 8, 9, 10, 11, 'c')],
95 + 'd.png': [annotation(12, 13, 14, 15, 'd')],
96 + }
97 +
98 +
99 +def test_read_annotations_multiple():
100 + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10}
101 + annotations = csv_generator._read_annotations(csv_str(
102 + 'a.png,0,1,2,3,a' '\n'
103 + 'b.png,4,5,6,7,b' '\n'
104 + 'a.png,8,9,10,11,c' '\n'
105 + ), classes)
106 + assert annotations == {
107 + 'a.png': [
108 + annotation(0, 1, 2, 3, 'a'),
109 + annotation(8, 9, 10, 11, 'c'),
110 + ],
111 + 'b.png': [annotation(4, 5, 6, 7, 'b')],
112 + }
113 +
114 +
115 +def test_read_annotations_wrong_format():
116 + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10}
117 + with pytest.raises(ValueError):
118 + try:
119 + csv_generator._read_annotations(csv_str('a.png,1,2,3,a'), classes)
120 + except ValueError as e:
121 + assert str(e).startswith("line 1: format should be")
122 + raise
123 +
124 + with pytest.raises(ValueError):
125 + try:
126 + csv_generator._read_annotations(csv_str(
127 + 'a.png,0,1,2,3,a' '\n'
128 + 'a.png,1,2,3,a' '\n'
129 + ), classes)
130 + except ValueError as e:
131 + assert str(e).startswith("line 2: format should be")
132 + raise
133 +
134 +
135 +def test_read_annotations_wrong_x1():
136 + with pytest.raises(ValueError):
137 + try:
138 + csv_generator._read_annotations(csv_str('a.png,a,0,1,2,a'), {'a': 1})
139 + except ValueError as e:
140 + assert str(e).startswith("line 1: malformed x1:")
141 + raise
142 +
143 +
144 +def test_read_annotations_wrong_y1():
145 + with pytest.raises(ValueError):
146 + try:
147 + csv_generator._read_annotations(csv_str('a.png,0,a,1,2,a'), {'a': 1})
148 + except ValueError as e:
149 + assert str(e).startswith("line 1: malformed y1:")
150 + raise
151 +
152 +
153 +def test_read_annotations_wrong_x2():
154 + with pytest.raises(ValueError):
155 + try:
156 + csv_generator._read_annotations(csv_str('a.png,0,1,a,2,a'), {'a': 1})
157 + except ValueError as e:
158 + assert str(e).startswith("line 1: malformed x2:")
159 + raise
160 +
161 +
162 +def test_read_annotations_wrong_y2():
163 + with pytest.raises(ValueError):
164 + try:
165 + csv_generator._read_annotations(csv_str('a.png,0,1,2,a,a'), {'a': 1})
166 + except ValueError as e:
167 + assert str(e).startswith("line 1: malformed y2:")
168 + raise
169 +
170 +
171 +def test_read_annotations_wrong_class():
172 + with pytest.raises(ValueError):
173 + try:
174 + csv_generator._read_annotations(csv_str('a.png,0,1,2,3,g'), {'a': 1})
175 + except ValueError as e:
176 + assert str(e).startswith("line 1: unknown class name:")
177 + raise
178 +
179 +
180 +def test_read_annotations_invalid_bb_x():
181 + with pytest.raises(ValueError):
182 + try:
183 + csv_generator._read_annotations(csv_str('a.png,1,2,1,3,g'), {'a': 1})
184 + except ValueError as e:
185 + assert str(e).startswith("line 1: x2 (1) must be higher than x1 (1)")
186 + raise
187 + with pytest.raises(ValueError):
188 + try:
189 + csv_generator._read_annotations(csv_str('a.png,9,2,5,3,g'), {'a': 1})
190 + except ValueError as e:
191 + assert str(e).startswith("line 1: x2 (5) must be higher than x1 (9)")
192 + raise
193 +
194 +
195 +def test_read_annotations_invalid_bb_y():
196 + with pytest.raises(ValueError):
197 + try:
198 + csv_generator._read_annotations(csv_str('a.png,1,2,3,2,a'), {'a': 1})
199 + except ValueError as e:
200 + assert str(e).startswith("line 1: y2 (2) must be higher than y1 (2)")
201 + raise
202 + with pytest.raises(ValueError):
203 + try:
204 + csv_generator._read_annotations(csv_str('a.png,1,8,3,5,a'), {'a': 1})
205 + except ValueError as e:
206 + assert str(e).startswith("line 1: y2 (5) must be higher than y1 (8)")
207 + raise
208 +
209 +
210 +def test_read_annotations_empty_image():
211 + # Check that images without annotations are parsed.
212 + assert csv_generator._read_annotations(csv_str('a.png,,,,,\nb.png,,,,,'), {'a': 1}) == {'a.png': [], 'b.png': []}
213 +
214 + # Check that lines without annotations don't clear earlier annotations.
215 + assert csv_generator._read_annotations(csv_str('a.png,0,1,2,3,a\na.png,,,,,'), {'a': 1}) == {'a.png': [annotation(0, 1, 2, 3, 'a')]}
1 +"""
2 +Copyright 2017-2018 Fizyr (https://fizyr.com)
3 +
4 +Licensed under the Apache License, Version 2.0 (the "License");
5 +you may not use this file except in compliance with the License.
6 +You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 +Unless required by applicable law or agreed to in writing, software
11 +distributed under the License is distributed on an "AS IS" BASIS,
12 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +See the License for the specific language governing permissions and
14 +limitations under the License.
15 +"""
16 +
17 +from keras_retinanet.preprocessing.generator import Generator
18 +
19 +import numpy as np
20 +import pytest
21 +
22 +
23 +class SimpleGenerator(Generator):
24 + def __init__(self, bboxes, labels, num_classes=0, image=None):
25 + assert(len(bboxes) == len(labels))
26 + self.bboxes = bboxes
27 + self.labels = labels
28 + self.num_classes_ = num_classes
29 + self.image = image
30 + super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False)
31 +
32 + def num_classes(self):
33 + return self.num_classes_
34 +
35 + def load_image(self, image_index):
36 + return self.image
37 +
38 + def image_path(self, image_index):
39 + return ''
40 +
41 + def size(self):
42 + return len(self.bboxes)
43 +
44 + def load_annotations(self, image_index):
45 + annotations = {'labels': self.labels[image_index], 'bboxes': self.bboxes[image_index]}
46 + return annotations
47 +
48 +
49 +class TestLoadAnnotationsGroup(object):
50 + def test_simple(self):
51 + input_bboxes_group = [
52 + np.array([
53 + [ 0, 0, 10, 10],
54 + [150, 150, 350, 350]
55 + ]),
56 + ]
57 + input_labels_group = [
58 + np.array([
59 + 1,
60 + 3
61 + ]),
62 + ]
63 + expected_bboxes_group = input_bboxes_group
64 + expected_labels_group = input_labels_group
65 +
66 + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group)
67 + annotations = simple_generator.load_annotations_group(simple_generator.groups[0])
68 +
69 + assert('bboxes' in annotations[0])
70 + assert('labels' in annotations[0])
71 + np.testing.assert_equal(expected_bboxes_group[0], annotations[0]['bboxes'])
72 + np.testing.assert_equal(expected_labels_group[0], annotations[0]['labels'])
73 +
74 + def test_multiple(self):
75 + input_bboxes_group = [
76 + np.array([
77 + [ 0, 0, 10, 10],
78 + [150, 150, 350, 350]
79 + ]),
80 + np.array([
81 + [0, 0, 50, 50],
82 + ]),
83 + ]
84 + input_labels_group = [
85 + np.array([
86 + 1,
87 + 0
88 + ]),
89 + np.array([
90 + 3
91 + ])
92 + ]
93 + expected_bboxes_group = input_bboxes_group
94 + expected_labels_group = input_labels_group
95 +
96 + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group)
97 + annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0])
98 + annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1])
99 +
100 + assert('bboxes' in annotations_group_0[0])
101 + assert('bboxes' in annotations_group_1[0])
102 + assert('labels' in annotations_group_0[0])
103 + assert('labels' in annotations_group_1[0])
104 + np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes'])
105 + np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels'])
106 + np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes'])
107 + np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels'])
108 +
109 +
110 +class TestFilterAnnotations(object):
111 + def test_simple_filter(self):
112 + input_bboxes_group = [
113 + np.array([
114 + [ 0, 0, 10, 10],
115 + [150, 150, 50, 50]
116 + ]),
117 + ]
118 + input_labels_group = [
119 + np.array([
120 + 3,
121 + 1
122 + ]),
123 + ]
124 +
125 + input_image = np.zeros((500, 500, 3))
126 +
127 + expected_bboxes_group = [
128 + np.array([
129 + [0, 0, 10, 10],
130 + ]),
131 + ]
132 + expected_labels_group = [
133 + np.array([
134 + 3,
135 + ]),
136 + ]
137 +
138 + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group)
139 + annotations = simple_generator.load_annotations_group(simple_generator.groups[0])
140 + # expect a UserWarning
141 + with pytest.warns(UserWarning):
142 + image_group, annotations_group = simple_generator.filter_annotations([input_image], annotations, simple_generator.groups[0])
143 +
144 + np.testing.assert_equal(expected_bboxes_group[0], annotations_group[0]['bboxes'])
145 + np.testing.assert_equal(expected_labels_group[0], annotations_group[0]['labels'])
146 +
147 + def test_multiple_filter(self):
148 + input_bboxes_group = [
149 + np.array([
150 + [ 0, 0, 10, 10],
151 + [150, 150, 50, 50],
152 + [150, 150, 350, 350],
153 + [350, 350, 150, 150],
154 + [ 1, 1, 2, 2],
155 + [ 2, 2, 1, 1]
156 + ]),
157 + np.array([
158 + [0, 0, -1, -1]
159 + ]),
160 + np.array([
161 + [-10, -10, 0, 0],
162 + [-10, -10, -100, -100],
163 + [ 10, 10, 100, 100]
164 + ]),
165 + np.array([
166 + [ 10, 10, 100, 100],
167 + [ 10, 10, 600, 600]
168 + ]),
169 + ]
170 +
171 + input_labels_group = [
172 + np.array([
173 + 6,
174 + 5,
175 + 4,
176 + 3,
177 + 2,
178 + 1
179 + ]),
180 + np.array([
181 + 0
182 + ]),
183 + np.array([
184 + 10,
185 + 11,
186 + 12
187 + ]),
188 + np.array([
189 + 105,
190 + 107
191 + ]),
192 + ]
193 +
194 + input_image = np.zeros((500, 500, 3))
195 +
196 + expected_bboxes_group = [
197 + np.array([
198 + [ 0, 0, 10, 10],
199 + [150, 150, 350, 350],
200 + [ 1, 1, 2, 2]
201 + ]),
202 + np.zeros((0, 4)),
203 + np.array([
204 + [10, 10, 100, 100]
205 + ]),
206 + np.array([
207 + [ 10, 10, 100, 100]
208 + ]),
209 + ]
210 + expected_labels_group = [
211 + np.array([
212 + 6,
213 + 4,
214 + 2
215 + ]),
216 + np.zeros((0,)),
217 + np.array([
218 + 12
219 + ]),
220 + np.array([
221 + 105
222 + ]),
223 + ]
224 +
225 + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group)
226 + # expect a UserWarning
227 + annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0])
228 + with pytest.warns(UserWarning):
229 + image_group, annotations_group_0 = simple_generator.filter_annotations([input_image], annotations_group_0, simple_generator.groups[0])
230 +
231 + annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1])
232 + with pytest.warns(UserWarning):
233 + image_group, annotations_group_1 = simple_generator.filter_annotations([input_image], annotations_group_1, simple_generator.groups[1])
234 +
235 + annotations_group_2 = simple_generator.load_annotations_group(simple_generator.groups[2])
236 + with pytest.warns(UserWarning):
237 + image_group, annotations_group_2 = simple_generator.filter_annotations([input_image], annotations_group_2, simple_generator.groups[2])
238 +
239 + np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes'])
240 + np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels'])
241 +
242 + np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes'])
243 + np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels'])
244 +
245 + np.testing.assert_equal(expected_bboxes_group[2], annotations_group_2[0]['bboxes'])
246 + np.testing.assert_equal(expected_labels_group[2], annotations_group_2[0]['labels'])
247 +
248 + def test_complete(self):
249 + input_bboxes_group = [
250 + np.array([
251 + [ 0, 0, 50, 50],
252 + [150, 150, 50, 50], # invalid bbox
253 + ], dtype=float)
254 + ]
255 +
256 + input_labels_group = [
257 + np.array([
258 + 5, # one object of class 5
259 + 3, # one object of class 3 with an invalid box
260 + ], dtype=float)
261 + ]
262 +
263 + input_image = np.zeros((500, 500, 3), dtype=np.uint8)
264 +
265 + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6)
266 + # expect a UserWarning
267 + with pytest.warns(UserWarning):
268 + _, [_, labels_batch] = simple_generator[0]
269 +
270 + # test that only object with class 5 is present in labels_batch
271 + labels = np.unique(np.argmax(labels_batch == 5, axis=2))
272 + assert(len(labels) == 1 and labels[0] == 0), 'Expected only class 0 to be present, but got classes {}'.format(labels)
1 +import os
2 +import pytest
3 +from PIL import Image
4 +from keras_retinanet.utils import image
5 +import numpy as np
6 +
7 +_STUB_IMG_FNAME = 'stub-image.jpg'
8 +
9 +
10 +@pytest.fixture(autouse=True)
11 +def run_around_tests(tmp_path):
12 + """Create a temp image for test"""
13 + rand_img = np.random.randint(0, 255, (3, 3, 3), dtype='uint8')
14 + Image.fromarray(rand_img).save(os.path.join(tmp_path, _STUB_IMG_FNAME))
15 + yield
16 +
17 +
18 +def test_read_image_bgr(tmp_path):
19 + stub_image_path = os.path.join(tmp_path, _STUB_IMG_FNAME)
20 +
21 + original_img = np.asarray(Image.open(
22 + stub_image_path).convert('RGB'))[:, :, ::-1]
23 + loaded_image = image.read_image_bgr(stub_image_path)
24 +
25 + # Assert images are equal
26 + np.testing.assert_array_equal(original_img, loaded_image)
1 +check-manifest
2 +image-classifiers
3 +efficientnet
4 +# pytest
5 +pytest-xdist
6 +pytest-cov
7 +pytest-flake8
8 +# flake8
9 +coverage
10 +codecov
1 +import keras_retinanet.losses
2 +from tensorflow import keras
3 +
4 +import numpy as np
5 +
6 +import pytest
7 +
8 +
9 +def test_smooth_l1():
10 + regression = np.array([
11 + [
12 + [0, 0, 0, 0],
13 + [0, 0, 0, 0],
14 + [0, 0, 0, 0],
15 + [0, 0, 0, 0],
16 + ]
17 + ], dtype=keras.backend.floatx())
18 + regression = keras.backend.variable(regression)
19 +
20 + regression_target = np.array([
21 + [
22 + [0, 0, 0, 1, 1],
23 + [0, 0, 1, 0, 1],
24 + [0, 0, 0.05, 0, 1],
25 + [0, 0, 1, 0, 0],
26 + ]
27 + ], dtype=keras.backend.floatx())
28 + regression_target = keras.backend.variable(regression_target)
29 +
30 + loss = keras_retinanet.losses.smooth_l1()(regression_target, regression)
31 + loss = keras.backend.eval(loss)
32 +
33 + assert loss == pytest.approx((((1 - 0.5 / 9) * 2 + (0.5 * 9 * 0.05 ** 2)) / 3))
1 +import numpy as np
2 +import configparser
3 +from tensorflow import keras
4 +
5 +from keras_retinanet.utils.anchors import anchors_for_shape, AnchorParameters
6 +from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters
7 +
8 +
9 +def test_config_read():
10 + config = read_config_file('tests/test-data/config/config.ini')
11 + assert 'anchor_parameters' in config
12 + assert 'sizes' in config['anchor_parameters']
13 + assert 'strides' in config['anchor_parameters']
14 + assert 'ratios' in config['anchor_parameters']
15 + assert 'scales' in config['anchor_parameters']
16 + assert config['anchor_parameters']['sizes'] == '32 64 128 256 512'
17 + assert config['anchor_parameters']['strides'] == '8 16 32 64 128'
18 + assert config['anchor_parameters']['ratios'] == '0.5 1 2 3'
19 + assert config['anchor_parameters']['scales'] == '1 1.2 1.6'
20 +
21 +
22 +def create_anchor_params_config():
23 + config = configparser.ConfigParser()
24 + config['anchor_parameters'] = {}
25 + config['anchor_parameters']['sizes'] = '32 64 128 256 512'
26 + config['anchor_parameters']['strides'] = '8 16 32 64 128'
27 + config['anchor_parameters']['ratios'] = '0.5 1'
28 + config['anchor_parameters']['scales'] = '1 1.2 1.6'
29 +
30 + return config
31 +
32 +
33 +def test_parse_anchor_parameters():
34 + config = create_anchor_params_config()
35 + anchor_params_parsed = parse_anchor_parameters(config)
36 +
37 + sizes = [32, 64, 128, 256, 512]
38 + strides = [8, 16, 32, 64, 128]
39 + ratios = np.array([0.5, 1], keras.backend.floatx())
40 + scales = np.array([1, 1.2, 1.6], keras.backend.floatx())
41 +
42 + assert sizes == anchor_params_parsed.sizes
43 + assert strides == anchor_params_parsed.strides
44 + np.testing.assert_equal(ratios, anchor_params_parsed.ratios)
45 + np.testing.assert_equal(scales, anchor_params_parsed.scales)
46 +
47 +
48 +def test_anchors_for_shape_dimensions():
49 + sizes = [32, 64, 128]
50 + strides = [8, 16, 32]
51 + ratios = np.array([0.5, 1, 2, 3], keras.backend.floatx())
52 + scales = np.array([1, 1.2, 1.6], keras.backend.floatx())
53 + anchor_params = AnchorParameters(sizes, strides, ratios, scales)
54 +
55 + pyramid_levels = [3, 4, 5]
56 + image_shape = (64, 64)
57 + all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params)
58 +
59 + assert all_anchors.shape == (1008, 4)
60 +
61 +
62 +def test_anchors_for_shape_values():
63 + sizes = [12]
64 + strides = [8]
65 + ratios = np.array([1, 2], keras.backend.floatx())
66 + scales = np.array([1, 2], keras.backend.floatx())
67 + anchor_params = AnchorParameters(sizes, strides, ratios, scales)
68 +
69 + pyramid_levels = [3]
70 + image_shape = (16, 16)
71 + all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params)
72 +
73 + # using almost_equal for floating point imprecisions
74 + np.testing.assert_almost_equal(all_anchors[0, :], [
75 + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
76 + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
77 + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
78 + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
79 + ], decimal=6)
80 + np.testing.assert_almost_equal(all_anchors[1, :], [
81 + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
82 + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
83 + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
84 + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
85 + ], decimal=6)
86 + np.testing.assert_almost_equal(all_anchors[2, :], [
87 + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
88 + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
89 + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
90 + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
91 + ], decimal=6)
92 + np.testing.assert_almost_equal(all_anchors[3, :], [
93 + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
94 + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
95 + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
96 + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
97 + ], decimal=6)
98 + np.testing.assert_almost_equal(all_anchors[4, :], [
99 + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
100 + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
101 + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
102 + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
103 + ], decimal=6)
104 + np.testing.assert_almost_equal(all_anchors[5, :], [
105 + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
106 + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
107 + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
108 + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
109 + ], decimal=6)
110 + np.testing.assert_almost_equal(all_anchors[6, :], [
111 + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
112 + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
113 + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
114 + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
115 + ], decimal=6)
116 + np.testing.assert_almost_equal(all_anchors[7, :], [
117 + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
118 + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
119 + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
120 + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
121 + ], decimal=6)
122 + np.testing.assert_almost_equal(all_anchors[8, :], [
123 + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
124 + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
125 + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
126 + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
127 + ], decimal=6)
128 + np.testing.assert_almost_equal(all_anchors[9, :], [
129 + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
130 + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
131 + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
132 + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
133 + ], decimal=6)
134 + np.testing.assert_almost_equal(all_anchors[10, :], [
135 + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
136 + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
137 + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
138 + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
139 + ], decimal=6)
140 + np.testing.assert_almost_equal(all_anchors[11, :], [
141 + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
142 + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
143 + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
144 + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
145 + ], decimal=6)
146 + np.testing.assert_almost_equal(all_anchors[12, :], [
147 + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
148 + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
149 + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2,
150 + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2,
151 + ], decimal=6)
152 + np.testing.assert_almost_equal(all_anchors[13, :], [
153 + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
154 + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
155 + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2,
156 + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2,
157 + ], decimal=6)
158 + np.testing.assert_almost_equal(all_anchors[14, :], [
159 + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
160 + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
161 + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2,
162 + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2,
163 + ], decimal=6)
164 + np.testing.assert_almost_equal(all_anchors[15, :], [
165 + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
166 + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
167 + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2,
168 + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2,
169 + ], decimal=6)
1 +import numpy as np
2 +from numpy.testing import assert_almost_equal
3 +from math import pi
4 +
5 +from keras_retinanet.utils.transform import (
6 + colvec,
7 + transform_aabb,
8 + rotation, random_rotation,
9 + translation, random_translation,
10 + scaling, random_scaling,
11 + shear, random_shear,
12 + random_flip,
13 + random_transform,
14 + random_transform_generator,
15 + change_transform_origin,
16 +)
17 +
18 +
19 +def test_colvec():
20 + assert np.array_equal(colvec(0), np.array([[0]]))
21 + assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]]))
22 + assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]]))
23 +
24 +
25 +def test_rotation():
26 + assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1)))
27 + assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1)))
28 + assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1)))
29 + assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1)))
30 + assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1)))
31 +
32 + assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1)))
33 + assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1)))
34 + assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1)))
35 + assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1)))
36 + assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1)))
37 +
38 +
39 +def test_random_rotation():
40 + prng = np.random.RandomState(0)
41 + for i in range(100):
42 + assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng)))
43 +
44 +
45 +def test_translation():
46 + assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1)))
47 + assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1)))
48 + assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1)))
49 +
50 +
51 +def assert_is_translation(transform, min, max):
52 + assert transform.shape == (3, 3)
53 + assert np.array_equal(transform[:, 0:2], np.eye(3, 2))
54 + assert transform[2, 2] == 1
55 + assert np.greater_equal(transform[0:2, 2], min).all()
56 + assert np.less( transform[0:2, 2], max).all()
57 +
58 +
59 +def test_random_translation():
60 + prng = np.random.RandomState(0)
61 + min = (-10, -20)
62 + max = (20, 10)
63 + for i in range(100):
64 + assert_is_translation(random_translation(min, max, prng), min, max)
65 +
66 +
67 +def test_shear():
68 + assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1)))
69 + assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1)))
70 + assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1)))
71 + assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1)))
72 + assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1)))
73 +
74 +
75 +def assert_is_shear(transform):
76 + assert transform.shape == (3, 3)
77 + assert np.array_equal(transform[:, 0], [1, 0, 0])
78 + assert np.array_equal(transform[:, 2], [0, 0, 1])
79 + assert transform[2, 1] == 0
80 + # sin^2 + cos^2 == 1
81 + assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2)
82 +
83 +
84 +def test_random_shear():
85 + prng = np.random.RandomState(0)
86 + for i in range(100):
87 + assert_is_shear(random_shear(-pi, pi, prng))
88 +
89 +
90 +def test_scaling():
91 + assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1)))
92 + assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1)))
93 + assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1)))
94 + assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1)))
95 +
96 +
97 +def assert_is_scaling(transform, min, max):
98 + assert transform.shape == (3, 3)
99 + assert np.array_equal(transform[2, :], [0, 0, 1])
100 + assert np.array_equal(transform[:, 2], [0, 0, 1])
101 + assert transform[1, 0] == 0
102 + assert transform[0, 1] == 0
103 + assert np.greater_equal(np.diagonal(transform)[:2], min).all()
104 + assert np.less( np.diagonal(transform)[:2], max).all()
105 +
106 +
107 +def test_random_scaling():
108 + prng = np.random.RandomState(0)
109 + min = (0.1, 0.2)
110 + max = (20, 10)
111 + for i in range(100):
112 + assert_is_scaling(random_scaling(min, max, prng), min, max)
113 +
114 +
115 +def assert_is_flip(transform):
116 + assert transform.shape == (3, 3)
117 + assert np.array_equal(transform[2, :], [0, 0, 1])
118 + assert np.array_equal(transform[:, 2], [0, 0, 1])
119 + assert transform[1, 0] == 0
120 + assert transform[0, 1] == 0
121 + assert abs(transform[0, 0]) == 1
122 + assert abs(transform[1, 1]) == 1
123 +
124 +
125 +def test_random_flip():
126 + prng = np.random.RandomState(0)
127 + for i in range(100):
128 + assert_is_flip(random_flip(0.5, 0.5, prng))
129 +
130 +
131 +def test_random_transform():
132 + prng = np.random.RandomState(0)
133 + for i in range(100):
134 + transform = random_transform(prng=prng)
135 + assert np.array_equal(transform, np.identity(3))
136 +
137 + for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())):
138 + assert np.array_equal(transform, np.identity(3))
139 +
140 +
141 +def test_transform_aabb():
142 + assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4]))
143 + assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4]))
144 + assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4]))
145 +
146 +
147 +def test_change_transform_origin():
148 + assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4]))
149 + assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1)))
150 + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1)))
151 + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1)))