Showing
86 changed files
with
10409 additions
and
0 deletions
retinaNet/.codecov.yml
0 → 100644
1 | +#see https://github.com/codecov/support/wiki/Codecov-Yaml | ||
2 | +codecov: | ||
3 | + notify: | ||
4 | + require_ci_to_pass: yes | ||
5 | + | ||
6 | +coverage: | ||
7 | + precision: 0 # 2 = xx.xx%, 0 = xx% | ||
8 | + round: nearest # how coverage is rounded: down/up/nearest | ||
9 | + range: 40...100 # custom range of coverage colors from red -> yellow -> green | ||
10 | + status: | ||
11 | + # https://codecov.readme.io/v1.0/docs/commit-status | ||
12 | + project: | ||
13 | + default: | ||
14 | + against: auto | ||
15 | + target: 90% # specify the target coverage for each commit status | ||
16 | + threshold: 20% # allow this little decrease on project | ||
17 | + # https://github.com/codecov/support/wiki/Filtering-Branches | ||
18 | + # branches: master | ||
19 | + if_ci_failed: error | ||
20 | + # https://github.com/codecov/support/wiki/Patch-Status | ||
21 | + patch: | ||
22 | + default: | ||
23 | + against: auto | ||
24 | + target: 40% # specify the target "X%" coverage to hit | ||
25 | + # threshold: 50% # allow this much decrease on patch | ||
26 | + changes: false | ||
27 | + | ||
28 | +parsers: | ||
29 | + gcov: | ||
30 | + branch_detection: | ||
31 | + conditional: true | ||
32 | + loop: true | ||
33 | + macro: false | ||
34 | + method: false | ||
35 | + javascript: | ||
36 | + enable_partials: false | ||
37 | + | ||
38 | +comment: | ||
39 | + layout: header, diff | ||
40 | + require_changes: false | ||
41 | + behavior: default # update if exists else create new | ||
42 | + # branches: * | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
retinaNet/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# Distribution / packaging | ||
7 | +.Python | ||
8 | +/build/ | ||
9 | +/dist/ | ||
10 | +/eggs/ | ||
11 | +/*-eggs/ | ||
12 | +.eggs/ | ||
13 | +/sdist/ | ||
14 | +/wheels/ | ||
15 | +/*.egg-info/ | ||
16 | +.installed.cfg | ||
17 | +*.egg | ||
18 | + | ||
19 | +# Unit test / coverage reports | ||
20 | +.coverage | ||
21 | +.coverage.* | ||
22 | +coverage.xml | ||
23 | +*.cover | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
retinaNet/.gitmodules
0 → 100644
retinaNet/.travis.yml
0 → 100644
1 | +language: python | ||
2 | + | ||
3 | +sudo: required | ||
4 | + | ||
5 | +python: | ||
6 | + - '3.6' | ||
7 | + - '3.7' | ||
8 | + | ||
9 | +install: | ||
10 | + - pip install -r requirements.txt | ||
11 | + - pip install -r tests/requirements.txt | ||
12 | + | ||
13 | +cache: pip | ||
14 | + | ||
15 | +script: | ||
16 | + - python setup.py check -m -s | ||
17 | + - python setup.py build_ext --inplace | ||
18 | + - coverage run --source keras_retinanet -m py.test keras_retinanet tests --doctest-modules --forked --flake8 | ||
19 | + | ||
20 | +after_success: | ||
21 | + - coverage xml | ||
22 | + - coverage report | ||
23 | + - codecov |
retinaNet/CONTRIBUTORS.md
0 → 100644
1 | +# Contributors | ||
2 | + | ||
3 | +This is a list of people who contributed patches to keras-retinanet. | ||
4 | + | ||
5 | +If you feel you should be listed here or if you have any other questions/comments on your listing here, | ||
6 | +please create an issue or pull request at https://github.com/fizyr/keras-retinanet/ | ||
7 | + | ||
8 | +* Hans Gaiser <h.gaiser@fizyr.com> | ||
9 | +* Maarten de Vries <maarten@de-vri.es> | ||
10 | +* Valerio Carpani | ||
11 | +* Ashley Williamson | ||
12 | +* Yann Henon | ||
13 | +* Valeriu Lacatusu | ||
14 | +* András Vidosits | ||
15 | +* Cristian Gratie | ||
16 | +* jjiunlin | ||
17 | +* Sorin Panduru | ||
18 | +* Rodrigo Meira de Andrade | ||
19 | +* Enrico Liscio <e.liscio@fizyr.com> | ||
20 | +* Mihai Morariu | ||
21 | +* pedroconceicao | ||
22 | +* jjiun | ||
23 | +* Wudi Fang | ||
24 | +* Mike Clark | ||
25 | +* hannesedvartsen | ||
26 | +* Max Van Sande | ||
27 | +* Pierre Dérian | ||
28 | +* ori | ||
29 | +* mxvs | ||
30 | +* mwilder | ||
31 | +* Muhammed Kocabas | ||
32 | +* Koen Vijverberg | ||
33 | +* iver56 | ||
34 | +* hnsywangxin | ||
35 | +* Guillaume Erhard | ||
36 | +* Eduardo Ramos | ||
37 | +* DiegoAgher | ||
38 | +* Alexander Pacha | ||
39 | +* Agastya Kalra | ||
40 | +* Jiri BOROVEC | ||
41 | +* ntsagko | ||
42 | +* charlie / tianqi | ||
43 | +* jsemric | ||
44 | +* Martin Zlocha | ||
45 | +* Raghav Bhardwaj | ||
46 | +* bw4sz | ||
47 | +* Morten Back Nielsen | ||
48 | +* dshahrokhian | ||
49 | +* Alex / adreo00 | ||
50 | +* simone.merello | ||
51 | +* Matt Wilder | ||
52 | +* Jinwoo Baek | ||
53 | +* Etienne Meunier | ||
54 | +* Denis Dowling | ||
55 | +* cclauss | ||
56 | +* Andrew Grigorev | ||
57 | +* ZFTurbo | ||
58 | +* UgoLouche | ||
59 | +* Richard Higgins | ||
60 | +* Rajat / rajat.goel | ||
61 | +* philipp.marquardt | ||
62 | +* peacherwu | ||
63 | +* Paul / pauldesigaud | ||
64 | +* Martin Genet | ||
65 | +* Leo / leonardvandriel | ||
66 | +* Laurens Hagendoorn | ||
67 | +* Julius / juliussimonelli | ||
68 | +* HolyGuacamole | ||
69 | +* Fausto Morales | ||
70 | +* borakrc | ||
71 | +* Ben Weinstein | ||
72 | +* Anil Karaka | ||
73 | +* Andrea Panizza | ||
74 | +* Bruno Santos | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
retinaNet/LICENSE
0 → 100644
1 | + Apache License | ||
2 | + Version 2.0, January 2004 | ||
3 | + http://www.apache.org/licenses/ | ||
4 | + | ||
5 | + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||
6 | + | ||
7 | + 1. Definitions. | ||
8 | + | ||
9 | + "License" shall mean the terms and conditions for use, reproduction, | ||
10 | + and distribution as defined by Sections 1 through 9 of this document. | ||
11 | + | ||
12 | + "Licensor" shall mean the copyright owner or entity authorized by | ||
13 | + the copyright owner that is granting the License. | ||
14 | + | ||
15 | + "Legal Entity" shall mean the union of the acting entity and all | ||
16 | + other entities that control, are controlled by, or are under common | ||
17 | + control with that entity. For the purposes of this definition, | ||
18 | + "control" means (i) the power, direct or indirect, to cause the | ||
19 | + direction or management of such entity, whether by contract or | ||
20 | + otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||
21 | + outstanding shares, or (iii) beneficial ownership of such entity. | ||
22 | + | ||
23 | + "You" (or "Your") shall mean an individual or Legal Entity | ||
24 | + exercising permissions granted by this License. | ||
25 | + | ||
26 | + "Source" form shall mean the preferred form for making modifications, | ||
27 | + including but not limited to software source code, documentation | ||
28 | + source, and configuration files. | ||
29 | + | ||
30 | + "Object" form shall mean any form resulting from mechanical | ||
31 | + transformation or translation of a Source form, including but | ||
32 | + not limited to compiled object code, generated documentation, | ||
33 | + and conversions to other media types. | ||
34 | + | ||
35 | + "Work" shall mean the work of authorship, whether in Source or | ||
36 | + Object form, made available under the License, as indicated by a | ||
37 | + copyright notice that is included in or attached to the work | ||
38 | + (an example is provided in the Appendix below). | ||
39 | + | ||
40 | + "Derivative Works" shall mean any work, whether in Source or Object | ||
41 | + form, that is based on (or derived from) the Work and for which the | ||
42 | + editorial revisions, annotations, elaborations, or other modifications | ||
43 | + represent, as a whole, an original work of authorship. For the purposes | ||
44 | + of this License, Derivative Works shall not include works that remain | ||
45 | + separable from, or merely link (or bind by name) to the interfaces of, | ||
46 | + the Work and Derivative Works thereof. | ||
47 | + | ||
48 | + "Contribution" shall mean any work of authorship, including | ||
49 | + the original version of the Work and any modifications or additions | ||
50 | + to that Work or Derivative Works thereof, that is intentionally | ||
51 | + submitted to Licensor for inclusion in the Work by the copyright owner | ||
52 | + or by an individual or Legal Entity authorized to submit on behalf of | ||
53 | + the copyright owner. For the purposes of this definition, "submitted" | ||
54 | + means any form of electronic, verbal, or written communication sent | ||
55 | + to the Licensor or its representatives, including but not limited to | ||
56 | + communication on electronic mailing lists, source code control systems, | ||
57 | + and issue tracking systems that are managed by, or on behalf of, the | ||
58 | + Licensor for the purpose of discussing and improving the Work, but | ||
59 | + excluding communication that is conspicuously marked or otherwise | ||
60 | + designated in writing by the copyright owner as "Not a Contribution." | ||
61 | + | ||
62 | + "Contributor" shall mean Licensor and any individual or Legal Entity | ||
63 | + on behalf of whom a Contribution has been received by Licensor and | ||
64 | + subsequently incorporated within the Work. | ||
65 | + | ||
66 | + 2. Grant of Copyright License. Subject to the terms and conditions of | ||
67 | + this License, each Contributor hereby grants to You a perpetual, | ||
68 | + worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
69 | + copyright license to reproduce, prepare Derivative Works of, | ||
70 | + publicly display, publicly perform, sublicense, and distribute the | ||
71 | + Work and such Derivative Works in Source or Object form. | ||
72 | + | ||
73 | + 3. Grant of Patent License. Subject to the terms and conditions of | ||
74 | + this License, each Contributor hereby grants to You a perpetual, | ||
75 | + worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
76 | + (except as stated in this section) patent license to make, have made, | ||
77 | + use, offer to sell, sell, import, and otherwise transfer the Work, | ||
78 | + where such license applies only to those patent claims licensable | ||
79 | + by such Contributor that are necessarily infringed by their | ||
80 | + Contribution(s) alone or by combination of their Contribution(s) | ||
81 | + with the Work to which such Contribution(s) was submitted. If You | ||
82 | + institute patent litigation against any entity (including a | ||
83 | + cross-claim or counterclaim in a lawsuit) alleging that the Work | ||
84 | + or a Contribution incorporated within the Work constitutes direct | ||
85 | + or contributory patent infringement, then any patent licenses | ||
86 | + granted to You under this License for that Work shall terminate | ||
87 | + as of the date such litigation is filed. | ||
88 | + | ||
89 | + 4. Redistribution. You may reproduce and distribute copies of the | ||
90 | + Work or Derivative Works thereof in any medium, with or without | ||
91 | + modifications, and in Source or Object form, provided that You | ||
92 | + meet the following conditions: | ||
93 | + | ||
94 | + (a) You must give any other recipients of the Work or | ||
95 | + Derivative Works a copy of this License; and | ||
96 | + | ||
97 | + (b) You must cause any modified files to carry prominent notices | ||
98 | + stating that You changed the files; and | ||
99 | + | ||
100 | + (c) You must retain, in the Source form of any Derivative Works | ||
101 | + that You distribute, all copyright, patent, trademark, and | ||
102 | + attribution notices from the Source form of the Work, | ||
103 | + excluding those notices that do not pertain to any part of | ||
104 | + the Derivative Works; and | ||
105 | + | ||
106 | + (d) If the Work includes a "NOTICE" text file as part of its | ||
107 | + distribution, then any Derivative Works that You distribute must | ||
108 | + include a readable copy of the attribution notices contained | ||
109 | + within such NOTICE file, excluding those notices that do not | ||
110 | + pertain to any part of the Derivative Works, in at least one | ||
111 | + of the following places: within a NOTICE text file distributed | ||
112 | + as part of the Derivative Works; within the Source form or | ||
113 | + documentation, if provided along with the Derivative Works; or, | ||
114 | + within a display generated by the Derivative Works, if and | ||
115 | + wherever such third-party notices normally appear. The contents | ||
116 | + of the NOTICE file are for informational purposes only and | ||
117 | + do not modify the License. You may add Your own attribution | ||
118 | + notices within Derivative Works that You distribute, alongside | ||
119 | + or as an addendum to the NOTICE text from the Work, provided | ||
120 | + that such additional attribution notices cannot be construed | ||
121 | + as modifying the License. | ||
122 | + | ||
123 | + You may add Your own copyright statement to Your modifications and | ||
124 | + may provide additional or different license terms and conditions | ||
125 | + for use, reproduction, or distribution of Your modifications, or | ||
126 | + for any such Derivative Works as a whole, provided Your use, | ||
127 | + reproduction, and distribution of the Work otherwise complies with | ||
128 | + the conditions stated in this License. | ||
129 | + | ||
130 | + 5. Submission of Contributions. Unless You explicitly state otherwise, | ||
131 | + any Contribution intentionally submitted for inclusion in the Work | ||
132 | + by You to the Licensor shall be under the terms and conditions of | ||
133 | + this License, without any additional terms or conditions. | ||
134 | + Notwithstanding the above, nothing herein shall supersede or modify | ||
135 | + the terms of any separate license agreement you may have executed | ||
136 | + with Licensor regarding such Contributions. | ||
137 | + | ||
138 | + 6. Trademarks. This License does not grant permission to use the trade | ||
139 | + names, trademarks, service marks, or product names of the Licensor, | ||
140 | + except as required for reasonable and customary use in describing the | ||
141 | + origin of the Work and reproducing the content of the NOTICE file. | ||
142 | + | ||
143 | + 7. Disclaimer of Warranty. Unless required by applicable law or | ||
144 | + agreed to in writing, Licensor provides the Work (and each | ||
145 | + Contributor provides its Contributions) on an "AS IS" BASIS, | ||
146 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
147 | + implied, including, without limitation, any warranties or conditions | ||
148 | + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||
149 | + PARTICULAR PURPOSE. You are solely responsible for determining the | ||
150 | + appropriateness of using or redistributing the Work and assume any | ||
151 | + risks associated with Your exercise of permissions under this License. | ||
152 | + | ||
153 | + 8. Limitation of Liability. In no event and under no legal theory, | ||
154 | + whether in tort (including negligence), contract, or otherwise, | ||
155 | + unless required by applicable law (such as deliberate and grossly | ||
156 | + negligent acts) or agreed to in writing, shall any Contributor be | ||
157 | + liable to You for damages, including any direct, indirect, special, | ||
158 | + incidental, or consequential damages of any character arising as a | ||
159 | + result of this License or out of the use or inability to use the | ||
160 | + Work (including but not limited to damages for loss of goodwill, | ||
161 | + work stoppage, computer failure or malfunction, or any and all | ||
162 | + other commercial damages or losses), even if such Contributor | ||
163 | + has been advised of the possibility of such damages. | ||
164 | + | ||
165 | + 9. Accepting Warranty or Additional Liability. While redistributing | ||
166 | + the Work or Derivative Works thereof, You may choose to offer, | ||
167 | + and charge a fee for, acceptance of support, warranty, indemnity, | ||
168 | + or other liability obligations and/or rights consistent with this | ||
169 | + License. However, in accepting such obligations, You may act only | ||
170 | + on Your own behalf and on Your sole responsibility, not on behalf | ||
171 | + of any other Contributor, and only if You agree to indemnify, | ||
172 | + defend, and hold each Contributor harmless for any liability | ||
173 | + incurred by, or claims asserted against, such Contributor by reason | ||
174 | + of your accepting any such warranty or additional liability. | ||
175 | + | ||
176 | + END OF TERMS AND CONDITIONS | ||
177 | + | ||
178 | + APPENDIX: How to apply the Apache License to your work. | ||
179 | + | ||
180 | + To apply the Apache License to your work, attach the following | ||
181 | + boilerplate notice, with the fields enclosed by brackets "{}" | ||
182 | + replaced with your own identifying information. (Don't include | ||
183 | + the brackets!) The text should be enclosed in the appropriate | ||
184 | + comment syntax for the file format. We also recommend that a | ||
185 | + file or class name and description of purpose be included on the | ||
186 | + same "printed page" as the copyright notice for easier | ||
187 | + identification within third-party archives. | ||
188 | + | ||
189 | + Copyright {yyyy} {name of copyright owner} | ||
190 | + | ||
191 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
192 | + you may not use this file except in compliance with the License. | ||
193 | + You may obtain a copy of the License at | ||
194 | + | ||
195 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
196 | + | ||
197 | + Unless required by applicable law or agreed to in writing, software | ||
198 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
199 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
200 | + See the License for the specific language governing permissions and | ||
201 | + limitations under the License. |
retinaNet/README.md
0 → 100644
1 | +# Keras RetinaNet [![Build Status](https://travis-ci.org/fizyr/keras-retinanet.svg?branch=master)](https://travis-ci.org/fizyr/keras-retinanet) [![DOI](https://zenodo.org/badge/100249425.svg)](https://zenodo.org/badge/latestdoi/100249425) | ||
2 | + | ||
3 | +Keras implementation of RetinaNet object detection as described in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | ||
4 | +by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollár. | ||
5 | + | ||
6 | +## :warning: Deprecated | ||
7 | + | ||
8 | +This repository is deprecated in favor of the [torchvision](https://github.com/pytorch/vision/) module. | ||
9 | +This project should work with keras 2.4 and tensorflow 2.3.0, newer versions might break support. | ||
10 | +For more information, check [here](https://github.com/fizyr/keras-retinanet/issues/1471#issuecomment-704187205). | ||
11 | + | ||
12 | +## Installation | ||
13 | + | ||
14 | +1) Clone this repository. | ||
15 | +2) In the repository, execute `pip install . --user`. | ||
16 | + Note that due to inconsistencies with how `tensorflow` should be installed, | ||
17 | + this package does not define a dependency on `tensorflow` as it will try to install that (which at least on Arch Linux results in an incorrect installation). | ||
18 | + Please make sure `tensorflow` is installed as per your systems requirements. | ||
19 | +3) Alternatively, you can run the code directly from the cloned repository, however you need to run `python setup.py build_ext --inplace` to compile Cython code first. | ||
20 | +4) Optionally, install `pycocotools` if you want to train / test on the MS COCO dataset by running `pip install --user git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI`. | ||
21 | + | ||
22 | +## Testing | ||
23 | +An example of testing the network can be seen in [this Notebook](https://github.com/delftrobotics/keras-retinanet/blob/master/examples/ResNet50RetinaNet.ipynb). | ||
24 | +In general, inference of the network works as follows: | ||
25 | +```python | ||
26 | +boxes, scores, labels = model.predict_on_batch(inputs) | ||
27 | +``` | ||
28 | + | ||
29 | +Where `boxes` are shaped `(None, None, 4)` (for `(x1, y1, x2, y2)`), scores is shaped `(None, None)` (classification score) and labels is shaped `(None, None)` (label corresponding to the score). In all three outputs, the first dimension represents the shape and the second dimension indexes the list of detections. | ||
30 | + | ||
31 | +Loading models can be done in the following manner: | ||
32 | +```python | ||
33 | +from keras_retinanet.models import load_model | ||
34 | +model = load_model('/path/to/model.h5', backbone_name='resnet50') | ||
35 | +``` | ||
36 | + | ||
37 | +Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x800x3`. | ||
38 | + | ||
39 | +### Converting a training model to inference model | ||
40 | +The training procedure of `keras-retinanet` works with *training models*. These are stripped down versions compared to the *inference model* and only contains the layers necessary for training (regression and classification values). If you wish to do inference on a model (perform object detection on an image), you need to convert the trained model to an inference model. This is done as follows: | ||
41 | + | ||
42 | +```shell | ||
43 | +# Running directly from the repository: | ||
44 | +keras_retinanet/bin/convert_model.py /path/to/training/model.h5 /path/to/save/inference/model.h5 | ||
45 | + | ||
46 | +# Using the installed script: | ||
47 | +retinanet-convert-model /path/to/training/model.h5 /path/to/save/inference/model.h5 | ||
48 | +``` | ||
49 | + | ||
50 | +Most scripts (like `retinanet-evaluate`) also support converting on the fly, using the `--convert-model` argument. | ||
51 | + | ||
52 | + | ||
53 | +## Training | ||
54 | +`keras-retinanet` can be trained using [this](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/bin/train.py) script. | ||
55 | +Note that the train script uses relative imports since it is inside the `keras_retinanet` package. | ||
56 | +If you want to adjust the script for your own use outside of this repository, | ||
57 | +you will need to switch it to use absolute imports. | ||
58 | + | ||
59 | +If you installed `keras-retinanet` correctly, the train script will be installed as `retinanet-train`. | ||
60 | +However, if you make local modifications to the `keras-retinanet` repository, you should run the script directly from the repository. | ||
61 | +That will ensure that your local changes will be used by the train script. | ||
62 | + | ||
63 | +The default backbone is `resnet50`. You can change this using the `--backbone=xxx` argument in the running script. | ||
64 | +`xxx` can be one of the backbones in resnet models (`resnet50`, `resnet101`, `resnet152`), mobilenet models (`mobilenet128_1.0`, `mobilenet128_0.75`, `mobilenet160_1.0`, etc), densenet models or vgg models. The different options are defined by each model in their corresponding python scripts (`resnet.py`, `mobilenet.py`, etc). | ||
65 | + | ||
66 | +Trained models can't be used directly for inference. To convert a trained model to an inference model, check [here](https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model). | ||
67 | + | ||
68 | +### Usage | ||
69 | +For training on [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/), run: | ||
70 | +```shell | ||
71 | +# Running directly from the repository: | ||
72 | +keras_retinanet/bin/train.py pascal /path/to/VOCdevkit/VOC2007 | ||
73 | + | ||
74 | +# Using the installed script: | ||
75 | +retinanet-train pascal /path/to/VOCdevkit/VOC2007 | ||
76 | +``` | ||
77 | + | ||
78 | +For training on [MS COCO](http://cocodataset.org/#home), run: | ||
79 | +```shell | ||
80 | +# Running directly from the repository: | ||
81 | +keras_retinanet/bin/train.py coco /path/to/MS/COCO | ||
82 | + | ||
83 | +# Using the installed script: | ||
84 | +retinanet-train coco /path/to/MS/COCO | ||
85 | +``` | ||
86 | + | ||
87 | +For training on Open Images Dataset [OID](https://storage.googleapis.com/openimages/web/index.html) | ||
88 | +or taking place to the [OID challenges](https://storage.googleapis.com/openimages/web/challenge.html), run: | ||
89 | +```shell | ||
90 | +# Running directly from the repository: | ||
91 | +keras_retinanet/bin/train.py oid /path/to/OID | ||
92 | + | ||
93 | +# Using the installed script: | ||
94 | +retinanet-train oid /path/to/OID | ||
95 | + | ||
96 | +# You can also specify a list of labels if you want to train on a subset | ||
97 | +# by adding the argument 'labels_filter': | ||
98 | +keras_retinanet/bin/train.py oid /path/to/OID --labels-filter=Helmet,Tree | ||
99 | + | ||
100 | +# You can also specify a parent label if you want to train on a branch | ||
101 | +# from the semantic hierarchical tree (i.e a parent and all children) | ||
102 | +(https://storage.googleapis.com/openimages/challenge_2018/bbox_labels_500_hierarchy_visualizer/circle.html) | ||
103 | +# by adding the argument 'parent-label': | ||
104 | +keras_retinanet/bin/train.py oid /path/to/OID --parent-label=Boat | ||
105 | +``` | ||
106 | + | ||
107 | + | ||
108 | +For training on [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php), run: | ||
109 | +```shell | ||
110 | +# Running directly from the repository: | ||
111 | +keras_retinanet/bin/train.py kitti /path/to/KITTI | ||
112 | + | ||
113 | +# Using the installed script: | ||
114 | +retinanet-train kitti /path/to/KITTI | ||
115 | + | ||
116 | +If you want to prepare the dataset you can use the following script: | ||
117 | +https://github.com/NVIDIA/DIGITS/blob/master/examples/object-detection/prepare_kitti_data.py | ||
118 | +``` | ||
119 | + | ||
120 | + | ||
121 | +For training on a [custom dataset], a CSV file can be used as a way to pass the data. | ||
122 | +See below for more details on the format of these CSV files. | ||
123 | +To train using your CSV, run: | ||
124 | +```shell | ||
125 | +# Running directly from the repository: | ||
126 | +keras_retinanet/bin/train.py csv /path/to/csv/file/containing/annotations /path/to/csv/file/containing/classes | ||
127 | + | ||
128 | +# Using the installed script: | ||
129 | +retinanet-train csv /path/to/csv/file/containing/annotations /path/to/csv/file/containing/classes | ||
130 | +``` | ||
131 | + | ||
132 | +In general, the steps to train on your own datasets are: | ||
133 | +1) Create a model by calling for instance `keras_retinanet.models.backbone('resnet50').retinanet(num_classes=80)` and compile it. | ||
134 | + Empirically, the following compile arguments have been found to work well: | ||
135 | +```python | ||
136 | +model.compile( | ||
137 | + loss={ | ||
138 | + 'regression' : keras_retinanet.losses.smooth_l1(), | ||
139 | + 'classification': keras_retinanet.losses.focal() | ||
140 | + }, | ||
141 | + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001) | ||
142 | +) | ||
143 | +``` | ||
144 | +2) Create generators for training and testing data (an example is show in [`keras_retinanet.preprocessing.pascal_voc.PascalVocGenerator`](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/preprocessing/pascal_voc.py)). | ||
145 | +3) Use `model.fit_generator` to start training. | ||
146 | + | ||
147 | +## Pretrained models | ||
148 | + | ||
149 | +All models can be downloaded from the [releases page](https://github.com/fizyr/keras-retinanet/releases). | ||
150 | + | ||
151 | +### MS COCO | ||
152 | + | ||
153 | +Results using the `cocoapi` are shown below (note: according to the paper, this configuration should achieve a mAP of 0.357). | ||
154 | + | ||
155 | +``` | ||
156 | + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 | ||
157 | + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.537 | ||
158 | + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.374 | ||
159 | + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.191 | ||
160 | + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.383 | ||
161 | + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.472 | ||
162 | + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.306 | ||
163 | + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.491 | ||
164 | + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533 | ||
165 | + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.345 | ||
166 | + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.577 | ||
167 | + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.681 | ||
168 | +``` | ||
169 | + | ||
170 | +### Open Images Dataset | ||
171 | +There are 3 RetinaNet models based on ResNet50, ResNet101 and ResNet152 trained on all [500 classes](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018/blob/master/a00_utils_and_constants.py#L130) of the Open Images Dataset (thanks to @ZFTurbo). | ||
172 | + | ||
173 | +| Backbone | Image Size (px) | Small validation mAP | LB (Public) | | ||
174 | +| --------- | --------------- | -------------------- | ----------- | | ||
175 | +| ResNet50 | 768 - 1024 | 0.4594 | 0.4223 | | ||
176 | +| ResNet101 | 768 - 1024 | 0.4986 | 0.4520 | | ||
177 | +| ResNet152 | 600 - 800 | 0.4991 | 0.4651 | | ||
178 | + | ||
179 | +For more information, check [@ZFTurbo's](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018) repository. | ||
180 | + | ||
181 | +## CSV datasets | ||
182 | +The `CSVGenerator` provides an easy way to define your own datasets. | ||
183 | +It uses two CSV files: one file containing annotations and one file containing a class name to ID mapping. | ||
184 | + | ||
185 | +### Annotations format | ||
186 | +The CSV file with annotations should contain one annotation per line. | ||
187 | +Images with multiple bounding boxes should use one row per bounding box. | ||
188 | +Note that indexing for pixel values starts at 0. | ||
189 | +The expected format of each line is: | ||
190 | +``` | ||
191 | +path/to/image.jpg,x1,y1,x2,y2,class_name | ||
192 | +``` | ||
193 | +By default the CSV generator will look for images relative to the directory of the annotations file. | ||
194 | + | ||
195 | +Some images may not contain any labeled objects. | ||
196 | +To add these images to the dataset as negative examples, | ||
197 | +add an annotation where `x1`, `y1`, `x2`, `y2` and `class_name` are all empty: | ||
198 | +``` | ||
199 | +path/to/image.jpg,,,,, | ||
200 | +``` | ||
201 | + | ||
202 | +A full example: | ||
203 | +``` | ||
204 | +/data/imgs/img_001.jpg,837,346,981,456,cow | ||
205 | +/data/imgs/img_002.jpg,215,312,279,391,cat | ||
206 | +/data/imgs/img_002.jpg,22,5,89,84,bird | ||
207 | +/data/imgs/img_003.jpg,,,,, | ||
208 | +``` | ||
209 | + | ||
210 | +This defines a dataset with 3 images. | ||
211 | +`img_001.jpg` contains a cow. | ||
212 | +`img_002.jpg` contains a cat and a bird. | ||
213 | +`img_003.jpg` contains no interesting objects/animals. | ||
214 | + | ||
215 | + | ||
216 | +### Class mapping format | ||
217 | +The class name to ID mapping file should contain one mapping per line. | ||
218 | +Each line should use the following format: | ||
219 | +``` | ||
220 | +class_name,id | ||
221 | +``` | ||
222 | + | ||
223 | +Indexing for classes starts at 0. | ||
224 | +Do not include a background class as it is implicit. | ||
225 | + | ||
226 | +For example: | ||
227 | +``` | ||
228 | +cow,0 | ||
229 | +cat,1 | ||
230 | +bird,2 | ||
231 | +``` | ||
232 | + | ||
233 | +## Anchor optimization | ||
234 | + | ||
235 | +In some cases, the default anchor configuration is not suitable for detecting objects in your dataset, for example, if your objects are smaller than the 32x32px (size of the smallest anchors). In this case, it might be suitable to modify the anchor configuration, this can be done automatically by following the steps in the [anchor-optimization](https://github.com/martinzlocha/anchor-optimization/) repository. To use the generated configuration check [here](https://github.com/fizyr/keras-retinanet-test-data/blob/master/config/config.ini) for an example config file and then pass it to `train.py` using the `--config` parameter. | ||
236 | + | ||
237 | +## Debugging | ||
238 | +Creating your own dataset does not always work out of the box. There is a [`debug.py`](https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/bin/debug.py) tool to help find the most common mistakes. | ||
239 | + | ||
240 | +Particularly helpful is the `--annotations` flag which displays your annotations on the images from your dataset. Annotations are colored in green when there are anchors available and colored in red when there are no anchors available. If an annotation doesn't have anchors available, it means it won't contribute to training. It is normal for a small amount of annotations to show up in red, but if most or all annotations are red there is cause for concern. The most common issues are that the annotations are too small or too oddly shaped (stretched out). | ||
241 | + | ||
242 | +## Results | ||
243 | + | ||
244 | +### MS COCO | ||
245 | + | ||
246 | +## Status | ||
247 | +Example output images using `keras-retinanet` are shown below. | ||
248 | + | ||
249 | +<p align="center"> | ||
250 | + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco1.png" alt="Example result of RetinaNet on MS COCO"/> | ||
251 | + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco2.png" alt="Example result of RetinaNet on MS COCO"/> | ||
252 | + <img src="https://github.com/delftrobotics/keras-retinanet/blob/master/images/coco3.png" alt="Example result of RetinaNet on MS COCO"/> | ||
253 | +</p> | ||
254 | + | ||
255 | +### Projects using keras-retinanet | ||
256 | +* [Improving Apple Detection and Counting Using RetinaNet](https://github.com/nikostsagk/Apple-detection). This work aims to investigate the apple detection problem through the deployment of the Keras RetinaNet. | ||
257 | +* [Improving RetinaNet for CT Lesion Detection with Dense Masks from Weak RECIST Labels](https://arxiv.org/abs/1906.02283). Research project for detecting lesions in CT using keras-retinanet. | ||
258 | +* [NudeNet](https://github.com/bedapudi6788/NudeNet). Project that focuses on detecting and censoring of nudity. | ||
259 | +* [Individual tree-crown detection in RGB imagery using self-supervised deep learning neural networks](https://www.biorxiv.org/content/10.1101/532952v1). Research project focused on improving the performance of remotely sensed tree surveys. | ||
260 | +* [ESRI Object Detection Challenge 2019](https://github.com/kunwar31/ESRI_Object_Detection). Winning implementation of the ESRI Object Detection Challenge 2019. | ||
261 | +* [Lunar Rockfall Detector Project](https://ieeexplore.ieee.org/document/8587120). The aim of this project is to [map lunar rockfalls on a global scale](https://www.nature.com/articles/s41467-020-16653-3) using the available > 2 million satellite images. | ||
262 | +* [Mars Rockfall Detector Project](https://ieeexplore.ieee.org/document/9103997). The aim of this project is to map rockfalls on Mars. | ||
263 | +* [NATO Innovation Challenge](https://medium.com/data-from-the-trenches/object-detection-with-deep-learning-on-aerial-imagery-2465078db8a9). The winning team of the NATO Innovation Challenge used keras-retinanet to detect cars in aerial images ([COWC dataset](https://gdo152.llnl.gov/cowc/)). | ||
264 | +* [Microsoft Research for Horovod on Azure](https://blogs.technet.microsoft.com/machinelearning/2018/06/20/how-to-do-distributed-deep-learning-for-object-detection-using-horovod-on-azure/). A research project by Microsoft, using keras-retinanet to distribute training over multiple GPUs using Horovod on Azure. | ||
265 | +* [Anno-Mage](https://virajmavani.github.io/saiat/). A tool that helps you annotate images, using input from the keras-retinanet COCO model as suggestions. | ||
266 | +* [Telenav.AI](https://github.com/Telenav/Telenav.AI/tree/master/retinanet). For the detection of traffic signs using keras-retinanet. | ||
267 | +* [Towards Deep Placental Histology Phenotyping](https://github.com/Nellaker-group/TowardsDeepPhenotyping). This research project uses keras-retinanet for analysing the placenta at a cellular level. | ||
268 | +* [4k video example](https://www.youtube.com/watch?v=KYueHEMGRos). This demo shows the use of keras-retinanet on a 4k input video. | ||
269 | +* [boring-detector](https://github.com/lexfridman/boring-detector). I suppose not all projects need to solve life's biggest questions. This project detects the "The Boring Company" hats in videos. | ||
270 | +* [comet.ml](https://towardsdatascience.com/how-i-monitor-and-track-my-machine-learning-experiments-from-anywhere-described-in-13-tweets-ec3d0870af99). Using keras-retinanet in combination with [comet.ml](https://comet.ml) to interactively inspect and compare experiments. | ||
271 | +* [Weights and Biases](https://app.wandb.ai/syllogismos/keras-retinanet/reports?view=carey%2FObject%20Detection%20with%20RetinaNet). Trained keras-retinanet on coco dataset from beginning on resnet50 and resnet101 backends. | ||
272 | +* [Google Open Images Challenge 2018 15th place solution](https://github.com/ZFTurbo/Keras-RetinaNet-for-Open-Images-Challenge-2018). Pretrained weights for keras-retinanet based on ResNet50, ResNet101 and ResNet152 trained on open images dataset. | ||
273 | +* [poke.AI](https://github.com/Raghav-B/poke.AI). An experimental AI that attempts to master the 3rd Generation Pokemon games. Using keras-retinanet for in-game mapping and localization. | ||
274 | +* [retinanetjs](https://github.com/faustomorales/retinanetjs). A wrapper to run RetinaNet inference in the browser / Node.js. You can also take a look at the [example app](https://faustomorales.github.io/retinanetjs-example-app/). | ||
275 | +* [CRFNet](https://github.com/TUMFTM/CameraRadarFusionNet). This network fuses radar and camera data to perform object detection for autonomous driving applications. | ||
276 | +* [LogoDet](https://github.com/notAI-tech/LogoDet). Project for detecting company logos in images. | ||
277 | + | ||
278 | + | ||
279 | +If you have a project based on `keras-retinanet` and would like to have it published here, shoot me a message on Slack. | ||
280 | + | ||
281 | +### Notes | ||
282 | +* This repository requires Tensorflow 2.3.0 or higher. | ||
283 | +* This repository is [tested](https://github.com/fizyr/keras-retinanet/blob/master/.travis.yml) using OpenCV 3.4. | ||
284 | +* This repository is [tested](https://github.com/fizyr/keras-retinanet/blob/master/.travis.yml) using Python 2.7 and 3.6. | ||
285 | + | ||
286 | +Contributions to this project are welcome. | ||
287 | + | ||
288 | +### Discussions | ||
289 | +Feel free to join the `#keras-retinanet` [Keras Slack](https://keras-slack-autojoin.herokuapp.com/) channel for discussions and questions. | ||
290 | + | ||
291 | +## FAQ | ||
292 | +* **I get the warning `UserWarning: No training configuration found in save file: the model was not compiled. Compile it manually.`, should I be worried?** This warning can safely be ignored during inference. | ||
293 | +* **I get the error `ValueError: not enough values to unpack (expected 3, got 2)` during inference, what to do?**. This is because you are using a train model to do inference. See https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model for more information. | ||
294 | +* **How do I do transfer learning?** The easiest solution is to use the `--weights` argument when training. Keras will load models, even if the number of classes don't match (it will simply skip loading of weights when there is a mismatch). Run for example `retinanet-train --weights snapshots/some_coco_model.h5 pascal /path/to/pascal` to transfer weights from a COCO model to a PascalVOC training session. If your dataset is small, you can also use the `--freeze-backbone` argument to freeze the backbone layers. | ||
295 | +* **How do I change the number / shape of the anchors?** The train tool allows to pass a configuration file, where the anchor parameters can be adjusted. Check [here](https://github.com/fizyr/keras-retinanet-test-data/blob/master/config/config.ini) for an example config file. | ||
296 | +* **I get a loss of `0`, what is going on?** This mostly happens when none of the anchors "fit" on your objects, because they are most likely too small or elongated. You can verify this using the [debug](https://github.com/fizyr/keras-retinanet#debugging) tool. | ||
297 | +* **I have an older model, can I use it after an update of keras-retinanet?** This depends on what has changed. If it is a change that doesn't affect the weights then you can "update" models by creating a new retinanet model, loading your old weights using `model.load_weights(weights_path, by_name=True)` and saving this model. If the change has been too significant, you should retrain your model (you can try to load in the weights from your old model when starting training, this might be a better starting position than ImageNet). | ||
298 | +* **I get the error `ModuleNotFoundError: No module named 'keras_retinanet.utils.compute_overlap'`, how do I fix this?** Most likely you are running the code from the cloned repository. This is fine, but you need to compile some extensions for this to work (`python setup.py build_ext --inplace`). | ||
299 | +* **How do I train on my own dataset?** The steps to train on your dataset are roughly as follows: | ||
300 | +* 1. Prepare your dataset in the CSV format (a training and validation split is advised). | ||
301 | +* 2. Check that your dataset is correct using `retinanet-debug`. | ||
302 | +* 3. Train retinanet, preferably using the pretrained COCO weights (this gives a **far** better starting point, making training much quicker and accurate). You can optionally perform evaluation of your validation set during training to keep track of how well it performs (advised). | ||
303 | +* 4. Convert your training model to an inference model. | ||
304 | +* 5. Evaluate your inference model on your test or validation set. | ||
305 | +* 6. Profit! |
retinaNet/config/config1.ini
0 → 100644
1 | +[anchor_parameters] | ||
2 | +# Sizes should correlate to how the network processes an image, it is not advised to change these! | ||
3 | +sizes = 64 128 256 | ||
4 | +# Strides should correlate to how the network strides over an image, it is not advised to change these! | ||
5 | +strides = 16 32 64 | ||
6 | +# The different ratios to use per anchor location. | ||
7 | +ratios = 0.5 1 2 3 | ||
8 | +# The different scaling factors to use per anchor location. | ||
9 | +scales = 1 1.2 1.6 | ||
10 | + | ||
11 | +[pyramid_levels] | ||
12 | + | ||
13 | +levels = 3 4 5 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
retinaNet/desktop.ini
0 → 100644
retinaNet/examples/000000008021.jpg
0 → 100644
176 KB
retinaNet/examples/ResNet50RetinaNet.ipynb
0 → 100644
This diff could not be displayed because it is too large.
retinaNet/examples/resnet50_retinanet.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | +# coding: utf-8 | ||
3 | + | ||
4 | +# Load necessary modules | ||
5 | + | ||
6 | +import sys | ||
7 | + | ||
8 | +sys.path.insert(0, "../") | ||
9 | + | ||
10 | + | ||
11 | +# import keras_retinanet | ||
12 | +from keras_retinanet import models | ||
13 | +from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image | ||
14 | +from keras_retinanet.utils.visualization import draw_box, draw_caption | ||
15 | +from keras_retinanet.utils.colors import label_color | ||
16 | +from keras_retinanet.utils.gpu import setup_gpu | ||
17 | + | ||
18 | +# import miscellaneous modules | ||
19 | +import matplotlib.pyplot as plt | ||
20 | +import cv2 | ||
21 | +import os | ||
22 | +import numpy as np | ||
23 | +import time | ||
24 | + | ||
25 | +# set tf backend to allow memory to grow, instead of claiming everything | ||
26 | +import tensorflow as tf | ||
27 | + | ||
28 | +# use this to change which GPU to use | ||
29 | +gpu = 0 | ||
30 | + | ||
31 | +# set the modified tf session as backend in keras | ||
32 | +setup_gpu(gpu) | ||
33 | + | ||
34 | + | ||
35 | +# ## Load RetinaNet model | ||
36 | + | ||
37 | +# In[ ]: | ||
38 | + | ||
39 | + | ||
40 | +# adjust this to point to your downloaded/trained model | ||
41 | +# models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases | ||
42 | +model_path = os.path.join("..", "snapshots", "resnet50_coco_best_v2.1.0.h5") | ||
43 | + | ||
44 | +# load retinanet model | ||
45 | +model = models.load_model(model_path, backbone_name="resnet50") | ||
46 | + | ||
47 | +# if the model is not converted to an inference model, use the line below | ||
48 | +# see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model | ||
49 | +# model = models.convert_model(model) | ||
50 | + | ||
51 | +# print(model.summary()) | ||
52 | + | ||
53 | +# load label to names mapping for visualization purposes | ||
54 | +labels_to_names = { | ||
55 | + 0: "person", | ||
56 | + 1: "bicycle", | ||
57 | + 2: "car", | ||
58 | + 3: "motorcycle", | ||
59 | + 4: "airplane", | ||
60 | + 5: "bus", | ||
61 | + 6: "train", | ||
62 | + 7: "truck", | ||
63 | + 8: "boat", | ||
64 | + 9: "traffic light", | ||
65 | + 10: "fire hydrant", | ||
66 | + 11: "stop sign", | ||
67 | + 12: "parking meter", | ||
68 | + 13: "bench", | ||
69 | + 14: "bird", | ||
70 | + 15: "cat", | ||
71 | + 16: "dog", | ||
72 | + 17: "horse", | ||
73 | + 18: "sheep", | ||
74 | + 19: "cow", | ||
75 | + 20: "elephant", | ||
76 | + 21: "bear", | ||
77 | + 22: "zebra", | ||
78 | + 23: "giraffe", | ||
79 | + 24: "backpack", | ||
80 | + 25: "umbrella", | ||
81 | + 26: "handbag", | ||
82 | + 27: "tie", | ||
83 | + 28: "suitcase", | ||
84 | + 29: "frisbee", | ||
85 | + 30: "skis", | ||
86 | + 31: "snowboard", | ||
87 | + 32: "sports ball", | ||
88 | + 33: "kite", | ||
89 | + 34: "baseball bat", | ||
90 | + 35: "baseball glove", | ||
91 | + 36: "skateboard", | ||
92 | + 37: "surfboard", | ||
93 | + 38: "tennis racket", | ||
94 | + 39: "bottle", | ||
95 | + 40: "wine glass", | ||
96 | + 41: "cup", | ||
97 | + 42: "fork", | ||
98 | + 43: "knife", | ||
99 | + 44: "spoon", | ||
100 | + 45: "bowl", | ||
101 | + 46: "banana", | ||
102 | + 47: "apple", | ||
103 | + 48: "sandwich", | ||
104 | + 49: "orange", | ||
105 | + 50: "broccoli", | ||
106 | + 51: "carrot", | ||
107 | + 52: "hot dog", | ||
108 | + 53: "pizza", | ||
109 | + 54: "donut", | ||
110 | + 55: "cake", | ||
111 | + 56: "chair", | ||
112 | + 57: "couch", | ||
113 | + 58: "potted plant", | ||
114 | + 59: "bed", | ||
115 | + 60: "dining table", | ||
116 | + 61: "toilet", | ||
117 | + 62: "tv", | ||
118 | + 63: "laptop", | ||
119 | + 64: "mouse", | ||
120 | + 65: "remote", | ||
121 | + 66: "keyboard", | ||
122 | + 67: "cell phone", | ||
123 | + 68: "microwave", | ||
124 | + 69: "oven", | ||
125 | + 70: "toaster", | ||
126 | + 71: "sink", | ||
127 | + 72: "refrigerator", | ||
128 | + 73: "book", | ||
129 | + 74: "clock", | ||
130 | + 75: "vase", | ||
131 | + 76: "scissors", | ||
132 | + 77: "teddy bear", | ||
133 | + 78: "hair drier", | ||
134 | + 79: "toothbrush", | ||
135 | +} | ||
136 | + | ||
137 | + | ||
138 | +# ## Run detection on example | ||
139 | + | ||
140 | +# In[ ]: | ||
141 | + | ||
142 | + | ||
143 | +# load image | ||
144 | +image = read_image_bgr("000000008021.jpg") | ||
145 | + | ||
146 | +# copy to draw on | ||
147 | +draw = image.copy() | ||
148 | +draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB) | ||
149 | + | ||
150 | +# preprocess image for network | ||
151 | +image = preprocess_image(image) | ||
152 | +image, scale = resize_image(image) | ||
153 | + | ||
154 | +# process image | ||
155 | +start = time.time() | ||
156 | +boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) | ||
157 | +print("processing time: ", time.time() - start) | ||
158 | + | ||
159 | +# correct for image scale | ||
160 | +boxes /= scale | ||
161 | + | ||
162 | +# visualize detections | ||
163 | +for box, score, label in zip(boxes[0], scores[0], labels[0]): | ||
164 | + # scores are sorted so we can break | ||
165 | + if score < 0.5: | ||
166 | + break | ||
167 | + | ||
168 | + color = label_color(label) | ||
169 | + | ||
170 | + b = box.astype(int) | ||
171 | + draw_box(draw, b, color=color) | ||
172 | + | ||
173 | + caption = "{} {:.3f}".format(labels_to_names[label], score) | ||
174 | + draw_caption(draw, b, caption) | ||
175 | + | ||
176 | +plt.figure(figsize=(15, 15)) | ||
177 | +plt.axis("off") | ||
178 | +plt.imshow(draw) | ||
179 | +plt.show() | ||
180 | + | ||
181 | + | ||
182 | +# In[ ]: | ||
183 | + | ||
184 | + | ||
185 | +# In[ ]: |
retinaNet/images/coco1.png
0 → 100644
269 KB
retinaNet/images/coco2.png
0 → 100644
491 KB
retinaNet/images/coco3.png
0 → 100644
468 KB
retinaNet/keras_retinanet/__init__.py
0 → 100644
File mode changed
1 | +from .backend import * # noqa: F401,F403 |
retinaNet/keras_retinanet/backend/backend.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import tensorflow | ||
18 | +from tensorflow import keras | ||
19 | + | ||
20 | + | ||
21 | +def bbox_transform_inv(boxes, deltas, mean=None, std=None): | ||
22 | + """ Applies deltas (usually regression results) to boxes (usually anchors). | ||
23 | + | ||
24 | + Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. | ||
25 | + The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. | ||
26 | + | ||
27 | + Args | ||
28 | + boxes : np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2). | ||
29 | + deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. | ||
30 | + mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]). | ||
31 | + std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). | ||
32 | + | ||
33 | + Returns | ||
34 | + A np.array of the same shape as boxes, but with deltas applied to each box. | ||
35 | + The mean and std are used during training to normalize the regression values (networks love normalization). | ||
36 | + """ | ||
37 | + if mean is None: | ||
38 | + mean = [0, 0, 0, 0] | ||
39 | + if std is None: | ||
40 | + std = [0.2, 0.2, 0.2, 0.2] | ||
41 | + | ||
42 | + width = boxes[:, :, 2] - boxes[:, :, 0] | ||
43 | + height = boxes[:, :, 3] - boxes[:, :, 1] | ||
44 | + | ||
45 | + x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width | ||
46 | + y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height | ||
47 | + x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width | ||
48 | + y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height | ||
49 | + | ||
50 | + pred_boxes = keras.backend.stack([x1, y1, x2, y2], axis=2) | ||
51 | + | ||
52 | + return pred_boxes | ||
53 | + | ||
54 | + | ||
55 | +def shift(shape, stride, anchors): | ||
56 | + """ Produce shifted anchors based on shape of the map and stride size. | ||
57 | + | ||
58 | + Args | ||
59 | + shape : Shape to shift the anchors over. | ||
60 | + stride : Stride to shift the anchors with over the shape. | ||
61 | + anchors: The anchors to apply at each location. | ||
62 | + """ | ||
63 | + shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride | ||
64 | + shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride | ||
65 | + | ||
66 | + shift_x, shift_y = tensorflow.meshgrid(shift_x, shift_y) | ||
67 | + shift_x = keras.backend.reshape(shift_x, [-1]) | ||
68 | + shift_y = keras.backend.reshape(shift_y, [-1]) | ||
69 | + | ||
70 | + shifts = keras.backend.stack([ | ||
71 | + shift_x, | ||
72 | + shift_y, | ||
73 | + shift_x, | ||
74 | + shift_y | ||
75 | + ], axis=0) | ||
76 | + | ||
77 | + shifts = keras.backend.transpose(shifts) | ||
78 | + number_of_anchors = keras.backend.shape(anchors)[0] | ||
79 | + | ||
80 | + k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w | ||
81 | + | ||
82 | + shifted_anchors = keras.backend.reshape(anchors, [1, number_of_anchors, 4]) + keras.backend.cast(keras.backend.reshape(shifts, [k, 1, 4]), keras.backend.floatx()) | ||
83 | + shifted_anchors = keras.backend.reshape(shifted_anchors, [k * number_of_anchors, 4]) | ||
84 | + | ||
85 | + return shifted_anchors | ||
86 | + | ||
87 | + | ||
88 | +def map_fn(*args, **kwargs): | ||
89 | + """ See https://www.tensorflow.org/api_docs/python/tf/map_fn . | ||
90 | + """ | ||
91 | + | ||
92 | + if "shapes" in kwargs: | ||
93 | + shapes = kwargs.pop("shapes") | ||
94 | + dtype = kwargs.pop("dtype") | ||
95 | + sig = [tensorflow.TensorSpec(shapes[i], dtype=t) for i, t in | ||
96 | + enumerate(dtype)] | ||
97 | + | ||
98 | + # Try to use the new feature fn_output_signature in TF 2.3, use fallback if this is not available | ||
99 | + try: | ||
100 | + return tensorflow.map_fn(*args, **kwargs, fn_output_signature=sig) | ||
101 | + except TypeError: | ||
102 | + kwargs["dtype"] = dtype | ||
103 | + | ||
104 | + return tensorflow.map_fn(*args, **kwargs) | ||
105 | + | ||
106 | + | ||
107 | +def resize_images(images, size, method='bilinear', align_corners=False): | ||
108 | + """ See https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/image/resize_images . | ||
109 | + | ||
110 | + Args | ||
111 | + method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area'). | ||
112 | + """ | ||
113 | + methods = { | ||
114 | + 'bilinear': tensorflow.image.ResizeMethod.BILINEAR, | ||
115 | + 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, | ||
116 | + 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC, | ||
117 | + 'area' : tensorflow.image.ResizeMethod.AREA, | ||
118 | + } | ||
119 | + return tensorflow.compat.v1.image.resize_images(images, size, methods[method], align_corners) |
retinaNet/keras_retinanet/bin/__init__.py
0 → 100644
File mode changed
1 | +#!/usr/bin/env python | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
5 | + | ||
6 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | +you may not use this file except in compliance with the License. | ||
8 | +You may obtain a copy of the License at | ||
9 | + | ||
10 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
11 | + | ||
12 | +Unless required by applicable law or agreed to in writing, software | ||
13 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
14 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
15 | +See the License for the specific language governing permissions and | ||
16 | +limitations under the License. | ||
17 | +""" | ||
18 | + | ||
19 | +import argparse | ||
20 | +import os | ||
21 | +import sys | ||
22 | + | ||
23 | +# Allow relative imports when being executed as script. | ||
24 | +if __name__ == "__main__" and __package__ is None: | ||
25 | + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | ||
26 | + import keras_retinanet.bin # noqa: F401 | ||
27 | + __package__ = "keras_retinanet.bin" | ||
28 | + | ||
29 | +# Change these to absolute imports if you copy this script outside the keras_retinanet package. | ||
30 | +from .. import models | ||
31 | +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels | ||
32 | +from ..utils.gpu import setup_gpu | ||
33 | +from ..utils.tf_version import check_tf_version | ||
34 | + | ||
35 | + | ||
36 | +def parse_args(args): | ||
37 | + parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.') | ||
38 | + | ||
39 | + parser.add_argument('model_in', help='The model to convert.') | ||
40 | + parser.add_argument('model_out', help='Path to save the converted model to.') | ||
41 | + parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50') | ||
42 | + parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false') | ||
43 | + parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') | ||
44 | + parser.add_argument('--config', help='Path to a configuration parameters .ini file.') | ||
45 | + parser.add_argument('--nms-threshold', help='Value for non maximum suppression threshold.', type=float, default=0.5) | ||
46 | + parser.add_argument('--score-threshold', help='Threshold for prefiltering boxes.', type=float, default=0.05) | ||
47 | + parser.add_argument('--max-detections', help='Maximum number of detections to keep.', type=int, default=300) | ||
48 | + parser.add_argument('--parallel-iterations', help='Number of batch items to process in parallel.', type=int, default=32) | ||
49 | + | ||
50 | + return parser.parse_args(args) | ||
51 | + | ||
52 | + | ||
53 | +def main(args=None): | ||
54 | + # parse arguments | ||
55 | + if args is None: | ||
56 | + args = sys.argv[1:] | ||
57 | + args = parse_args(args) | ||
58 | + | ||
59 | + # make sure tensorflow is the minimum required version | ||
60 | + check_tf_version() | ||
61 | + | ||
62 | + # set modified tf session to avoid using the GPUs | ||
63 | + setup_gpu('cpu') | ||
64 | + | ||
65 | + # optionally load config parameters | ||
66 | + anchor_parameters = None | ||
67 | + pyramid_levels = None | ||
68 | + if args.config: | ||
69 | + args.config = read_config_file(args.config) | ||
70 | + if 'anchor_parameters' in args.config: | ||
71 | + anchor_parameters = parse_anchor_parameters(args.config) | ||
72 | + | ||
73 | + if 'pyramid_levels' in args.config: | ||
74 | + pyramid_levels = parse_pyramid_levels(args.config) | ||
75 | + | ||
76 | + # load the model | ||
77 | + model = models.load_model(args.model_in, backbone_name=args.backbone) | ||
78 | + | ||
79 | + # check if this is indeed a training model | ||
80 | + models.check_training_model(model) | ||
81 | + | ||
82 | + # convert the model | ||
83 | + model = models.convert_model( | ||
84 | + model, | ||
85 | + nms=args.nms, | ||
86 | + class_specific_filter=args.class_specific_filter, | ||
87 | + anchor_params=anchor_parameters, | ||
88 | + pyramid_levels=pyramid_levels, | ||
89 | + nms_threshold=args.nms_threshold, | ||
90 | + score_threshold=args.score_threshold, | ||
91 | + max_detections=args.max_detections, | ||
92 | + parallel_iterations=args.parallel_iterations | ||
93 | + ) | ||
94 | + | ||
95 | + # save model | ||
96 | + model.save(args.model_out) | ||
97 | + | ||
98 | + | ||
99 | +if __name__ == '__main__': | ||
100 | + main() |
retinaNet/keras_retinanet/bin/debug.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
5 | + | ||
6 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | +you may not use this file except in compliance with the License. | ||
8 | +You may obtain a copy of the License at | ||
9 | + | ||
10 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
11 | + | ||
12 | +Unless required by applicable law or agreed to in writing, software | ||
13 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
14 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
15 | +See the License for the specific language governing permissions and | ||
16 | +limitations under the License. | ||
17 | +""" | ||
18 | + | ||
19 | +import argparse | ||
20 | +import os | ||
21 | +import sys | ||
22 | +import cv2 | ||
23 | + | ||
24 | +# Set keycodes for changing images | ||
25 | +# 81, 83 are left and right arrows on linux in Ascii code (probably not needed) | ||
26 | +# 65361, 65363 are left and right arrows in linux | ||
27 | +# 2424832, 2555904 are left and right arrows on Windows | ||
28 | +# 110, 109 are 'n' and 'm' on mac, windows, linux | ||
29 | +# (unfortunately arrow keys not picked up on mac) | ||
30 | +leftkeys = (81, 110, 65361, 2424832) | ||
31 | +rightkeys = (83, 109, 65363, 2555904) | ||
32 | + | ||
33 | +# Allow relative imports when being executed as script. | ||
34 | +if __name__ == "__main__" and __package__ is None: | ||
35 | + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | ||
36 | + import keras_retinanet.bin # noqa: F401 | ||
37 | + __package__ = "keras_retinanet.bin" | ||
38 | + | ||
39 | +# Change these to absolute imports if you copy this script outside the keras_retinanet package. | ||
40 | +from ..preprocessing.pascal_voc import PascalVocGenerator | ||
41 | +from ..preprocessing.csv_generator import CSVGenerator | ||
42 | +from ..preprocessing.kitti import KittiGenerator | ||
43 | +from ..preprocessing.open_images import OpenImagesGenerator | ||
44 | +from ..utils.anchors import anchors_for_shape, compute_gt_annotations | ||
45 | +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels | ||
46 | +from ..utils.image import random_visual_effect_generator | ||
47 | +from ..utils.tf_version import check_tf_version | ||
48 | +from ..utils.transform import random_transform_generator | ||
49 | +from ..utils.visualization import draw_annotations, draw_boxes, draw_caption | ||
50 | + | ||
51 | + | ||
52 | +def create_generator(args): | ||
53 | + """ Create the data generators. | ||
54 | + | ||
55 | + Args: | ||
56 | + args: parseargs arguments object. | ||
57 | + """ | ||
58 | + common_args = { | ||
59 | + 'config' : args.config, | ||
60 | + 'image_min_side' : args.image_min_side, | ||
61 | + 'image_max_side' : args.image_max_side, | ||
62 | + 'group_method' : args.group_method | ||
63 | + } | ||
64 | + | ||
65 | + # create random transform generator for augmenting training data | ||
66 | + transform_generator = random_transform_generator( | ||
67 | + min_rotation=-0.1, | ||
68 | + max_rotation=0.1, | ||
69 | + min_translation=(-0.1, -0.1), | ||
70 | + max_translation=(0.1, 0.1), | ||
71 | + min_shear=-0.1, | ||
72 | + max_shear=0.1, | ||
73 | + min_scaling=(0.9, 0.9), | ||
74 | + max_scaling=(1.1, 1.1), | ||
75 | + flip_x_chance=0.5, | ||
76 | + flip_y_chance=0.5, | ||
77 | + ) | ||
78 | + | ||
79 | + visual_effect_generator = random_visual_effect_generator( | ||
80 | + contrast_range=(0.9, 1.1), | ||
81 | + brightness_range=(-.1, .1), | ||
82 | + hue_range=(-0.05, 0.05), | ||
83 | + saturation_range=(0.95, 1.05) | ||
84 | + ) | ||
85 | + | ||
86 | + if args.dataset_type == 'coco': | ||
87 | + # import here to prevent unnecessary dependency on cocoapi | ||
88 | + from ..preprocessing.coco import CocoGenerator | ||
89 | + | ||
90 | + generator = CocoGenerator( | ||
91 | + args.coco_path, | ||
92 | + args.coco_set, | ||
93 | + transform_generator=transform_generator, | ||
94 | + visual_effect_generator=visual_effect_generator, | ||
95 | + **common_args | ||
96 | + ) | ||
97 | + elif args.dataset_type == 'pascal': | ||
98 | + generator = PascalVocGenerator( | ||
99 | + args.pascal_path, | ||
100 | + args.pascal_set, | ||
101 | + image_extension=args.image_extension, | ||
102 | + transform_generator=transform_generator, | ||
103 | + visual_effect_generator=visual_effect_generator, | ||
104 | + **common_args | ||
105 | + ) | ||
106 | + elif args.dataset_type == 'csv': | ||
107 | + generator = CSVGenerator( | ||
108 | + args.annotations, | ||
109 | + args.classes, | ||
110 | + transform_generator=transform_generator, | ||
111 | + visual_effect_generator=visual_effect_generator, | ||
112 | + **common_args | ||
113 | + ) | ||
114 | + elif args.dataset_type == 'oid': | ||
115 | + generator = OpenImagesGenerator( | ||
116 | + args.main_dir, | ||
117 | + subset=args.subset, | ||
118 | + version=args.version, | ||
119 | + labels_filter=args.labels_filter, | ||
120 | + parent_label=args.parent_label, | ||
121 | + annotation_cache_dir=args.annotation_cache_dir, | ||
122 | + transform_generator=transform_generator, | ||
123 | + visual_effect_generator=visual_effect_generator, | ||
124 | + **common_args | ||
125 | + ) | ||
126 | + elif args.dataset_type == 'kitti': | ||
127 | + generator = KittiGenerator( | ||
128 | + args.kitti_path, | ||
129 | + subset=args.subset, | ||
130 | + transform_generator=transform_generator, | ||
131 | + visual_effect_generator=visual_effect_generator, | ||
132 | + **common_args | ||
133 | + ) | ||
134 | + else: | ||
135 | + raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) | ||
136 | + | ||
137 | + return generator | ||
138 | + | ||
139 | + | ||
140 | +def parse_args(args): | ||
141 | + """ Parse the arguments. | ||
142 | + """ | ||
143 | + parser = argparse.ArgumentParser(description='Debug script for a RetinaNet network.') | ||
144 | + subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') | ||
145 | + subparsers.required = True | ||
146 | + | ||
147 | + coco_parser = subparsers.add_parser('coco') | ||
148 | + coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') | ||
149 | + coco_parser.add_argument('--coco-set', help='Name of the set to show (defaults to val2017).', default='val2017') | ||
150 | + | ||
151 | + pascal_parser = subparsers.add_parser('pascal') | ||
152 | + pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') | ||
153 | + pascal_parser.add_argument('--pascal-set', help='Name of the set to show (defaults to test).', default='test') | ||
154 | + pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg') | ||
155 | + | ||
156 | + kitti_parser = subparsers.add_parser('kitti') | ||
157 | + kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).') | ||
158 | + kitti_parser.add_argument('subset', help='Argument for loading a subset from train/val.') | ||
159 | + | ||
160 | + def csv_list(string): | ||
161 | + return string.split(',') | ||
162 | + | ||
163 | + oid_parser = subparsers.add_parser('oid') | ||
164 | + oid_parser.add_argument('main_dir', help='Path to dataset directory.') | ||
165 | + oid_parser.add_argument('subset', help='Argument for loading a subset from train/validation/test.') | ||
166 | + oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4') | ||
167 | + oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None) | ||
168 | + oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.') | ||
169 | + oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None) | ||
170 | + | ||
171 | + csv_parser = subparsers.add_parser('csv') | ||
172 | + csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.') | ||
173 | + csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') | ||
174 | + | ||
175 | + parser.add_argument('--no-resize', help='Disable image resizing.', dest='resize', action='store_false') | ||
176 | + parser.add_argument('--anchors', help='Show positive anchors on the image.', action='store_true') | ||
177 | + parser.add_argument('--display-name', help='Display image name on the bottom left corner.', action='store_true') | ||
178 | + parser.add_argument('--show-annotations', help='Show annotations on the image. Green annotations have anchors, red annotations don\'t and therefore don\'t contribute to training.', action='store_true') | ||
179 | + parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true') | ||
180 | + parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800) | ||
181 | + parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333) | ||
182 | + parser.add_argument('--config', help='Path to a configuration parameters .ini file.') | ||
183 | + parser.add_argument('--no-gui', help='Do not open a GUI window. Save images to an output directory instead.', action='store_true') | ||
184 | + parser.add_argument('--output-dir', help='The output directory to save images to if --no-gui is specified.', default='.') | ||
185 | + parser.add_argument('--flatten-output', help='Flatten the folder structure of saved output images into a single folder.', action='store_true') | ||
186 | + parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio']) | ||
187 | + | ||
188 | + return parser.parse_args(args) | ||
189 | + | ||
190 | + | ||
191 | +def run(generator, args, anchor_params, pyramid_levels): | ||
192 | + """ Main loop. | ||
193 | + | ||
194 | + Args | ||
195 | + generator: The generator to debug. | ||
196 | + args: parseargs args object. | ||
197 | + """ | ||
198 | + # display images, one at a time | ||
199 | + i = 0 | ||
200 | + while True: | ||
201 | + # load the data | ||
202 | + image = generator.load_image(i) | ||
203 | + annotations = generator.load_annotations(i) | ||
204 | + if len(annotations['labels']) > 0 : | ||
205 | + # apply random transformations | ||
206 | + if args.random_transform: | ||
207 | + image, annotations = generator.random_transform_group_entry(image, annotations) | ||
208 | + image, annotations = generator.random_visual_effect_group_entry(image, annotations) | ||
209 | + | ||
210 | + # resize the image and annotations | ||
211 | + if args.resize: | ||
212 | + image, image_scale = generator.resize_image(image) | ||
213 | + annotations['bboxes'] *= image_scale | ||
214 | + | ||
215 | + anchors = anchors_for_shape(image.shape, anchor_params=anchor_params, pyramid_levels=pyramid_levels) | ||
216 | + positive_indices, _, max_indices = compute_gt_annotations(anchors, annotations['bboxes']) | ||
217 | + | ||
218 | + # draw anchors on the image | ||
219 | + if args.anchors: | ||
220 | + draw_boxes(image, anchors[positive_indices], (255, 255, 0), thickness=1) | ||
221 | + | ||
222 | + # draw annotations on the image | ||
223 | + if args.show_annotations: | ||
224 | + # draw annotations in red | ||
225 | + draw_annotations(image, annotations, color=(0, 0, 255), label_to_name=generator.label_to_name) | ||
226 | + | ||
227 | + # draw regressed anchors in green to override most red annotations | ||
228 | + # result is that annotations without anchors are red, with anchors are green | ||
229 | + draw_boxes(image, annotations['bboxes'][max_indices[positive_indices], :], (0, 255, 0)) | ||
230 | + | ||
231 | + # display name on the image | ||
232 | + if args.display_name: | ||
233 | + draw_caption(image, [0, image.shape[0]], os.path.basename(generator.image_path(i))) | ||
234 | + | ||
235 | + # write to file and advance if no-gui selected | ||
236 | + if args.no_gui: | ||
237 | + output_path = make_output_path(args.output_dir, generator.image_path(i), flatten=args.flatten_output) | ||
238 | + os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
239 | + cv2.imwrite(output_path, image) | ||
240 | + i += 1 | ||
241 | + if i == generator.size(): # have written all images | ||
242 | + break | ||
243 | + else: | ||
244 | + continue | ||
245 | + | ||
246 | + # if we are using the GUI, then show an image | ||
247 | + cv2.imshow('Image', image) | ||
248 | + key = cv2.waitKeyEx() | ||
249 | + | ||
250 | + # press right for next image and left for previous (linux or windows, doesn't work for macOS) | ||
251 | + # if you run macOS, press "n" or "m" (will also work on linux and windows) | ||
252 | + | ||
253 | + if key in rightkeys: | ||
254 | + i = (i + 1) % generator.size() | ||
255 | + if key in leftkeys: | ||
256 | + i -= 1 | ||
257 | + if i < 0: | ||
258 | + i = generator.size() - 1 | ||
259 | + | ||
260 | + # press q or Esc to quit | ||
261 | + if (key == ord('q')) or (key == 27): | ||
262 | + return False | ||
263 | + | ||
264 | + return True | ||
265 | + | ||
266 | + | ||
267 | +def make_output_path(output_dir, image_path, flatten = False): | ||
268 | + """ Compute the output path for a debug image. """ | ||
269 | + | ||
270 | + # If the output hierarchy is flattened to a single folder, throw away all leading folders. | ||
271 | + if flatten: | ||
272 | + path = os.path.basename(image_path) | ||
273 | + | ||
274 | + # Otherwise, make sure absolute paths are taken relative to the filesystem root. | ||
275 | + else: | ||
276 | + # Make sure to drop drive letters on Windows, otherwise relpath wil fail. | ||
277 | + _, path = os.path.splitdrive(image_path) | ||
278 | + if os.path.isabs(path): | ||
279 | + path = os.path.relpath(path, '/') | ||
280 | + | ||
281 | + # In all cases, append "_debug" to the filename, before the extension. | ||
282 | + base, extension = os.path.splitext(path) | ||
283 | + path = base + "_debug" + extension | ||
284 | + | ||
285 | + # Finally, join the whole thing to the output directory. | ||
286 | + return os.path.join(output_dir, path) | ||
287 | + | ||
288 | + | ||
289 | +def main(args=None): | ||
290 | + # parse arguments | ||
291 | + if args is None: | ||
292 | + args = sys.argv[1:] | ||
293 | + args = parse_args(args) | ||
294 | + | ||
295 | + # make sure tensorflow is the minimum required version | ||
296 | + check_tf_version() | ||
297 | + | ||
298 | + # create the generator | ||
299 | + generator = create_generator(args) | ||
300 | + | ||
301 | + # optionally load config parameters | ||
302 | + if args.config: | ||
303 | + args.config = read_config_file(args.config) | ||
304 | + | ||
305 | + # optionally load anchor parameters | ||
306 | + anchor_params = None | ||
307 | + if args.config and 'anchor_parameters' in args.config: | ||
308 | + anchor_params = parse_anchor_parameters(args.config) | ||
309 | + | ||
310 | + pyramid_levels = None | ||
311 | + if args.config and 'pyramid_levels' in args.config: | ||
312 | + pyramid_levels = parse_pyramid_levels(args.config) | ||
313 | + # create the display window if necessary | ||
314 | + if not args.no_gui: | ||
315 | + cv2.namedWindow('Image', cv2.WINDOW_NORMAL) | ||
316 | + | ||
317 | + run(generator, args, anchor_params=anchor_params, pyramid_levels=pyramid_levels) | ||
318 | + | ||
319 | + | ||
320 | +if __name__ == '__main__': | ||
321 | + main() |
retinaNet/keras_retinanet/bin/evaluate.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
5 | + | ||
6 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | +you may not use this file except in compliance with the License. | ||
8 | +You may obtain a copy of the License at | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import argparse | ||
18 | +import os | ||
19 | +import sys | ||
20 | + | ||
21 | +# Allow relative imports when being executed as script. | ||
22 | +if __name__ == "__main__" and __package__ is None: | ||
23 | + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | ||
24 | + import keras_retinanet.bin # noqa: F401 | ||
25 | + __package__ = "keras_retinanet.bin" | ||
26 | + | ||
27 | +# Change these to absolute imports if you copy this script outside the keras_retinanet package. | ||
28 | +from .. import models | ||
29 | +from ..preprocessing.csv_generator import CSVGenerator | ||
30 | +from ..preprocessing.pascal_voc import PascalVocGenerator | ||
31 | +from ..utils.anchors import make_shapes_callback | ||
32 | +from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels | ||
33 | +from ..utils.eval import evaluate | ||
34 | +from ..utils.gpu import setup_gpu | ||
35 | +from ..utils.tf_version import check_tf_version | ||
36 | + | ||
37 | + | ||
38 | +def create_generator(args, preprocess_image): | ||
39 | + """ Create generators for evaluation. | ||
40 | + """ | ||
41 | + common_args = { | ||
42 | + 'config' : args.config, | ||
43 | + 'image_min_side' : args.image_min_side, | ||
44 | + 'image_max_side' : args.image_max_side, | ||
45 | + 'no_resize' : args.no_resize, | ||
46 | + 'preprocess_image' : preprocess_image, | ||
47 | + 'group_method' : args.group_method | ||
48 | + } | ||
49 | + | ||
50 | + if args.dataset_type == 'coco': | ||
51 | + # import here to prevent unnecessary dependency on cocoapi | ||
52 | + from ..preprocessing.coco import CocoGenerator | ||
53 | + | ||
54 | + validation_generator = CocoGenerator( | ||
55 | + args.coco_path, | ||
56 | + 'val2017', | ||
57 | + shuffle_groups=False, | ||
58 | + **common_args | ||
59 | + ) | ||
60 | + elif args.dataset_type == 'pascal': | ||
61 | + validation_generator = PascalVocGenerator( | ||
62 | + args.pascal_path, | ||
63 | + 'test', | ||
64 | + image_extension=args.image_extension, | ||
65 | + shuffle_groups=False, | ||
66 | + **common_args | ||
67 | + ) | ||
68 | + elif args.dataset_type == 'csv': | ||
69 | + validation_generator = CSVGenerator( | ||
70 | + args.annotations, | ||
71 | + args.classes, | ||
72 | + shuffle_groups=False, | ||
73 | + **common_args | ||
74 | + ) | ||
75 | + else: | ||
76 | + raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) | ||
77 | + | ||
78 | + return validation_generator | ||
79 | + | ||
80 | + | ||
81 | +def parse_args(args): | ||
82 | + """ Parse the arguments. | ||
83 | + """ | ||
84 | + parser = argparse.ArgumentParser(description='Evaluation script for a RetinaNet network.') | ||
85 | + subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') | ||
86 | + subparsers.required = True | ||
87 | + | ||
88 | + coco_parser = subparsers.add_parser('coco') | ||
89 | + coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') | ||
90 | + | ||
91 | + pascal_parser = subparsers.add_parser('pascal') | ||
92 | + pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') | ||
93 | + pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg') | ||
94 | + | ||
95 | + csv_parser = subparsers.add_parser('csv') | ||
96 | + csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.') | ||
97 | + csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') | ||
98 | + | ||
99 | + parser.add_argument('model', help='Path to RetinaNet model.') | ||
100 | + parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true') | ||
101 | + parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50') | ||
102 | + parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).') | ||
103 | + parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float) | ||
104 | + parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float) | ||
105 | + parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int) | ||
106 | + parser.add_argument('--save-path', help='Path for saving images with detections (doesn\'t work for COCO).') | ||
107 | + parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800) | ||
108 | + parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333) | ||
109 | + parser.add_argument('--no-resize', help='Don''t rescale the image.', action='store_true') | ||
110 | + parser.add_argument('--config', help='Path to a configuration parameters .ini file (only used with --convert-model).') | ||
111 | + parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio']) | ||
112 | + | ||
113 | + return parser.parse_args(args) | ||
114 | + | ||
115 | + | ||
116 | +def main(args=None): | ||
117 | + # parse arguments | ||
118 | + if args is None: | ||
119 | + args = sys.argv[1:] | ||
120 | + args = parse_args(args) | ||
121 | + | ||
122 | + # make sure tensorflow is the minimum required version | ||
123 | + check_tf_version() | ||
124 | + | ||
125 | + # optionally choose specific GPU | ||
126 | + if args.gpu: | ||
127 | + setup_gpu(args.gpu) | ||
128 | + | ||
129 | + # make save path if it doesn't exist | ||
130 | + if args.save_path is not None and not os.path.exists(args.save_path): | ||
131 | + os.makedirs(args.save_path) | ||
132 | + | ||
133 | + # optionally load config parameters | ||
134 | + if args.config: | ||
135 | + args.config = read_config_file(args.config) | ||
136 | + | ||
137 | + # create the generator | ||
138 | + backbone = models.backbone(args.backbone) | ||
139 | + generator = create_generator(args, backbone.preprocess_image) | ||
140 | + | ||
141 | + # optionally load anchor parameters | ||
142 | + anchor_params = None | ||
143 | + pyramid_levels = None | ||
144 | + if args.config and 'anchor_parameters' in args.config: | ||
145 | + anchor_params = parse_anchor_parameters(args.config) | ||
146 | + if args.config and 'pyramid_levels' in args.config: | ||
147 | + pyramid_levels = parse_pyramid_levels(args.config) | ||
148 | + | ||
149 | + # load the model | ||
150 | + print('Loading model, this may take a second...') | ||
151 | + model = models.load_model(args.model, backbone_name=args.backbone) | ||
152 | + generator.compute_shapes = make_shapes_callback(model) | ||
153 | + | ||
154 | + # optionally convert the model | ||
155 | + if args.convert_model: | ||
156 | + model = models.convert_model(model, anchor_params=anchor_params, pyramid_levels=pyramid_levels) | ||
157 | + | ||
158 | + # print model summary | ||
159 | + # print(model.summary()) | ||
160 | + | ||
161 | + # start evaluation | ||
162 | + if args.dataset_type == 'coco': | ||
163 | + from ..utils.coco_eval import evaluate_coco | ||
164 | + evaluate_coco(generator, model, args.score_threshold) | ||
165 | + else: | ||
166 | + average_precisions, inference_time = evaluate( | ||
167 | + generator, | ||
168 | + model, | ||
169 | + iou_threshold=args.iou_threshold, | ||
170 | + score_threshold=args.score_threshold, | ||
171 | + max_detections=args.max_detections, | ||
172 | + save_path=args.save_path | ||
173 | + ) | ||
174 | + | ||
175 | + # print evaluation | ||
176 | + total_instances = [] | ||
177 | + precisions = [] | ||
178 | + for label, (average_precision, num_annotations) in average_precisions.items(): | ||
179 | + print('{:.0f} instances of class'.format(num_annotations), | ||
180 | + generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) | ||
181 | + total_instances.append(num_annotations) | ||
182 | + precisions.append(average_precision) | ||
183 | + | ||
184 | + if sum(total_instances) == 0: | ||
185 | + print('No test instances found.') | ||
186 | + return | ||
187 | + | ||
188 | + print('Inference time for {:.0f} images: {:.4f}'.format(generator.size(), inference_time)) | ||
189 | + | ||
190 | + print('mAP using the weighted average of precisions among classes: {:.4f}'.format(sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances))) | ||
191 | + print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances))) | ||
192 | + | ||
193 | + | ||
194 | +if __name__ == '__main__': | ||
195 | + main() |
retinaNet/keras_retinanet/bin/train.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
5 | + | ||
6 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | +you may not use this file except in compliance with the License. | ||
8 | +You may obtain a copy of the License at | ||
9 | + | ||
10 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
11 | + | ||
12 | +Unless required by applicable law or agreed to in writing, software | ||
13 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
14 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
15 | +See the License for the specific language governing permissions and | ||
16 | +limitations under the License. | ||
17 | +""" | ||
18 | + | ||
19 | +import argparse | ||
20 | +import os | ||
21 | +import sys | ||
22 | +import warnings | ||
23 | + | ||
24 | +from tensorflow import keras | ||
25 | +import tensorflow as tf | ||
26 | + | ||
27 | +# Allow relative imports when being executed as script. | ||
28 | +if __name__ == "__main__" and __package__ is None: | ||
29 | + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) | ||
30 | + import keras_retinanet.bin # noqa: F401 | ||
31 | + | ||
32 | + __package__ = "keras_retinanet.bin" | ||
33 | + | ||
34 | +# Change these to absolute imports if you copy this script outside the keras_retinanet package. | ||
35 | +from .. import layers # noqa: F401 | ||
36 | +from .. import losses | ||
37 | +from .. import models | ||
38 | +from ..callbacks import RedirectModel | ||
39 | +from ..callbacks.eval import Evaluate | ||
40 | +from ..models.retinanet import retinanet_bbox | ||
41 | +from ..preprocessing.csv_generator import CSVGenerator | ||
42 | +from ..preprocessing.kitti import KittiGenerator | ||
43 | +from ..preprocessing.open_images import OpenImagesGenerator | ||
44 | +from ..preprocessing.pascal_voc import PascalVocGenerator | ||
45 | +from ..utils.anchors import make_shapes_callback | ||
46 | +from ..utils.config import ( | ||
47 | + read_config_file, | ||
48 | + parse_anchor_parameters, | ||
49 | + parse_pyramid_levels, | ||
50 | +) | ||
51 | +from ..utils.gpu import setup_gpu | ||
52 | +from ..utils.image import random_visual_effect_generator | ||
53 | +from ..utils.model import freeze as freeze_model | ||
54 | +from ..utils.tf_version import check_tf_version | ||
55 | +from ..utils.transform import random_transform_generator | ||
56 | + | ||
57 | +####################### | ||
58 | + | ||
59 | +from ..models import submodel | ||
60 | + | ||
61 | + | ||
62 | +def makedirs(path): | ||
63 | + # Intended behavior: try to create the directory, | ||
64 | + # pass if the directory exists already, fails otherwise. | ||
65 | + # Meant for Python 2.7/3.n compatibility. | ||
66 | + try: | ||
67 | + os.makedirs(path) | ||
68 | + except OSError: | ||
69 | + if not os.path.isdir(path): | ||
70 | + raise | ||
71 | + | ||
72 | + | ||
73 | +def model_with_weights(model, weights, skip_mismatch): | ||
74 | + """Load weights for model. | ||
75 | + | ||
76 | + Args | ||
77 | + model : The model to load weights for. | ||
78 | + weights : The weights to load. | ||
79 | + skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model. | ||
80 | + """ | ||
81 | + if weights is not None: | ||
82 | + model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) | ||
83 | + return model | ||
84 | + | ||
85 | + | ||
86 | +def create_models( | ||
87 | + backbone_retinanet, | ||
88 | + num_classes, | ||
89 | + weights, | ||
90 | + multi_gpu=0, | ||
91 | + freeze_backbone=False, | ||
92 | + lr=1e-5, | ||
93 | + optimizer_clipnorm=0.001, | ||
94 | + config=None, | ||
95 | + submodels=None, | ||
96 | +): | ||
97 | + """Creates three models (model, training_model, prediction_model). | ||
98 | + | ||
99 | + Args | ||
100 | + backbone_retinanet : A function to call to create a retinanet model with a given backbone. | ||
101 | + num_classes : The number of classes to train. | ||
102 | + weights : The weights to load into the model. | ||
103 | + multi_gpu : The number of GPUs to use for training. | ||
104 | + freeze_backbone : If True, disables learning for the backbone. | ||
105 | + config : Config parameters, None indicates the default configuration. | ||
106 | + | ||
107 | + Returns | ||
108 | + model : The base model. This is also the model that is saved in snapshots. | ||
109 | + training_model : The training model. If multi_gpu=0, this is identical to model. | ||
110 | + prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). | ||
111 | + """ | ||
112 | + | ||
113 | + modifier = freeze_model if freeze_backbone else None | ||
114 | + | ||
115 | + # load anchor parameters, or pass None (so that defaults will be used) | ||
116 | + anchor_params = None | ||
117 | + num_anchors = None | ||
118 | + pyramid_levels = None | ||
119 | + if config and "anchor_parameters" in config: | ||
120 | + anchor_params = parse_anchor_parameters(config) | ||
121 | + num_anchors = anchor_params.num_anchors() | ||
122 | + if config and "pyramid_levels" in config: | ||
123 | + pyramid_levels = parse_pyramid_levels(config) | ||
124 | + | ||
125 | + # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors. | ||
126 | + # optionally wrap in a parallel model | ||
127 | + if multi_gpu > 1: | ||
128 | + from keras.utils import multi_gpu_model | ||
129 | + | ||
130 | + with tf.device("/cpu:0"): | ||
131 | + model = model_with_weights( | ||
132 | + backbone_retinanet( | ||
133 | + num_classes, | ||
134 | + num_anchors=num_anchors, | ||
135 | + modifier=modifier, | ||
136 | + pyramid_levels=pyramid_levels, | ||
137 | + ), | ||
138 | + weights=weights, | ||
139 | + skip_mismatch=True, | ||
140 | + ) | ||
141 | + training_model = multi_gpu_model(model, gpus=multi_gpu) | ||
142 | + else: | ||
143 | + model = model_with_weights( | ||
144 | + backbone_retinanet( | ||
145 | + num_classes, | ||
146 | + num_anchors=num_anchors, | ||
147 | + modifier=modifier, | ||
148 | + pyramid_levels=pyramid_levels, | ||
149 | + submodels=submodels, | ||
150 | + ), | ||
151 | + weights=weights, | ||
152 | + skip_mismatch=True, | ||
153 | + ) | ||
154 | + training_model = model | ||
155 | + | ||
156 | + # make prediction model | ||
157 | + prediction_model = retinanet_bbox( | ||
158 | + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels | ||
159 | + ) | ||
160 | + | ||
161 | + # compile model | ||
162 | + training_model.compile( | ||
163 | + loss={"regression": losses.smooth_l1(), "classification": losses.focal()}, | ||
164 | + optimizer=keras.optimizers.Adam(lr=lr, clipnorm=optimizer_clipnorm), | ||
165 | + ) | ||
166 | + | ||
167 | + return model, training_model, prediction_model | ||
168 | + | ||
169 | + | ||
170 | +def create_callbacks( | ||
171 | + model, training_model, prediction_model, validation_generator, args | ||
172 | +): | ||
173 | + """Creates the callbacks to use during training. | ||
174 | + | ||
175 | + Args | ||
176 | + model: The base model. | ||
177 | + training_model: The model that is used for training. | ||
178 | + prediction_model: The model that should be used for validation. | ||
179 | + validation_generator: The generator for creating validation data. | ||
180 | + args: parseargs args object. | ||
181 | + | ||
182 | + Returns: | ||
183 | + A list of callbacks used for training. | ||
184 | + """ | ||
185 | + callbacks = [] | ||
186 | + | ||
187 | + tensorboard_callback = None | ||
188 | + | ||
189 | + if args.tensorboard_dir: | ||
190 | + makedirs(args.tensorboard_dir) | ||
191 | + update_freq = args.tensorboard_freq | ||
192 | + if update_freq not in ["epoch", "batch"]: | ||
193 | + update_freq = int(update_freq) | ||
194 | + tensorboard_callback = keras.callbacks.TensorBoard( | ||
195 | + log_dir=args.tensorboard_dir, | ||
196 | + histogram_freq=0, | ||
197 | + batch_size=args.batch_size, | ||
198 | + write_graph=True, | ||
199 | + write_grads=False, | ||
200 | + write_images=False, | ||
201 | + update_freq=update_freq, | ||
202 | + embeddings_freq=0, | ||
203 | + embeddings_layer_names=None, | ||
204 | + embeddings_metadata=None, | ||
205 | + ) | ||
206 | + | ||
207 | + if args.evaluation and validation_generator: | ||
208 | + if args.dataset_type == "coco": | ||
209 | + from ..callbacks.coco import CocoEval | ||
210 | + | ||
211 | + # use prediction model for evaluation | ||
212 | + evaluation = CocoEval( | ||
213 | + validation_generator, tensorboard=tensorboard_callback | ||
214 | + ) | ||
215 | + else: | ||
216 | + evaluation = Evaluate( | ||
217 | + validation_generator, | ||
218 | + tensorboard=tensorboard_callback, | ||
219 | + weighted_average=args.weighted_average, | ||
220 | + ) | ||
221 | + evaluation = RedirectModel(evaluation, prediction_model) | ||
222 | + callbacks.append(evaluation) | ||
223 | + | ||
224 | + # save the model | ||
225 | + if args.snapshots: | ||
226 | + # ensure directory created first; otherwise h5py will error after epoch. | ||
227 | + makedirs(args.snapshot_path) | ||
228 | + checkpoint = keras.callbacks.ModelCheckpoint( | ||
229 | + os.path.join( | ||
230 | + args.snapshot_path, | ||
231 | + "{backbone}_{dataset_type}_{{epoch:02d}}.h5".format( | ||
232 | + backbone=args.backbone, dataset_type=args.dataset_type | ||
233 | + ), | ||
234 | + ), | ||
235 | + verbose=1, | ||
236 | + # save_best_only=True, | ||
237 | + # monitor="mAP", | ||
238 | + # mode='max' | ||
239 | + ) | ||
240 | + checkpoint = RedirectModel(checkpoint, model) | ||
241 | + callbacks.append(checkpoint) | ||
242 | + | ||
243 | + callbacks.append( | ||
244 | + keras.callbacks.ReduceLROnPlateau( | ||
245 | + monitor="loss", | ||
246 | + factor=args.reduce_lr_factor, | ||
247 | + patience=args.reduce_lr_patience, | ||
248 | + verbose=1, | ||
249 | + mode="auto", | ||
250 | + min_delta=0.0001, | ||
251 | + cooldown=0, | ||
252 | + min_lr=0, | ||
253 | + ) | ||
254 | + ) | ||
255 | + | ||
256 | + if args.evaluation and validation_generator: | ||
257 | + callbacks.append( | ||
258 | + keras.callbacks.EarlyStopping( | ||
259 | + monitor="mAP", patience=5, mode="max", min_delta=0.01 | ||
260 | + ) | ||
261 | + ) | ||
262 | + | ||
263 | + if args.tensorboard_dir: | ||
264 | + callbacks.append(tensorboard_callback) | ||
265 | + | ||
266 | + return callbacks | ||
267 | + | ||
268 | + | ||
269 | +def create_generators(args, preprocess_image): | ||
270 | + """Create generators for training and validation. | ||
271 | + | ||
272 | + Args | ||
273 | + args : parseargs object containing configuration for generators. | ||
274 | + preprocess_image : Function that preprocesses an image for the network. | ||
275 | + """ | ||
276 | + common_args = { | ||
277 | + "batch_size": args.batch_size, | ||
278 | + "config": args.config, | ||
279 | + "image_min_side": args.image_min_side, | ||
280 | + "image_max_side": args.image_max_side, | ||
281 | + "no_resize": args.no_resize, | ||
282 | + "preprocess_image": preprocess_image, | ||
283 | + "group_method": args.group_method, | ||
284 | + } | ||
285 | + | ||
286 | + # create random transform generator for augmenting training data | ||
287 | + if args.random_transform: | ||
288 | + transform_generator = random_transform_generator( | ||
289 | + min_rotation=-0.1, | ||
290 | + max_rotation=0.1, | ||
291 | + min_translation=(-0.1, -0.1), | ||
292 | + max_translation=(0.1, 0.1), | ||
293 | + min_shear=-0.1, | ||
294 | + max_shear=0.1, | ||
295 | + min_scaling=(0.9, 0.9), | ||
296 | + max_scaling=(1.1, 1.1), | ||
297 | + flip_x_chance=0.5, | ||
298 | + flip_y_chance=0.5, | ||
299 | + ) | ||
300 | + visual_effect_generator = random_visual_effect_generator( | ||
301 | + contrast_range=(0.9, 1.1), | ||
302 | + brightness_range=(-0.1, 0.1), | ||
303 | + hue_range=(-0.05, 0.05), | ||
304 | + saturation_range=(0.95, 1.05), | ||
305 | + ) | ||
306 | + else: | ||
307 | + transform_generator = random_transform_generator(flip_x_chance=0.5) | ||
308 | + visual_effect_generator = None | ||
309 | + | ||
310 | + if args.dataset_type == "coco": | ||
311 | + # import here to prevent unnecessary dependency on cocoapi | ||
312 | + from ..preprocessing.coco import CocoGenerator | ||
313 | + | ||
314 | + train_generator = CocoGenerator( | ||
315 | + args.coco_path, | ||
316 | + "train2017", | ||
317 | + transform_generator=transform_generator, | ||
318 | + visual_effect_generator=visual_effect_generator, | ||
319 | + **common_args | ||
320 | + ) | ||
321 | + | ||
322 | + validation_generator = CocoGenerator( | ||
323 | + args.coco_path, "val2017", shuffle_groups=False, **common_args | ||
324 | + ) | ||
325 | + elif args.dataset_type == "pascal": | ||
326 | + train_generator = PascalVocGenerator( | ||
327 | + args.pascal_path, | ||
328 | + "train", | ||
329 | + image_extension=args.image_extension, | ||
330 | + transform_generator=transform_generator, | ||
331 | + visual_effect_generator=visual_effect_generator, | ||
332 | + **common_args | ||
333 | + ) | ||
334 | + | ||
335 | + validation_generator = PascalVocGenerator( | ||
336 | + args.pascal_path, | ||
337 | + "val", | ||
338 | + image_extension=args.image_extension, | ||
339 | + shuffle_groups=False, | ||
340 | + **common_args | ||
341 | + ) | ||
342 | + elif args.dataset_type == "csv": | ||
343 | + train_generator = CSVGenerator( | ||
344 | + args.annotations, | ||
345 | + args.classes, | ||
346 | + transform_generator=transform_generator, | ||
347 | + visual_effect_generator=visual_effect_generator, | ||
348 | + **common_args | ||
349 | + ) | ||
350 | + | ||
351 | + if args.val_annotations: | ||
352 | + validation_generator = CSVGenerator( | ||
353 | + args.val_annotations, args.classes, shuffle_groups=False, **common_args | ||
354 | + ) | ||
355 | + else: | ||
356 | + validation_generator = None | ||
357 | + elif args.dataset_type == "oid": | ||
358 | + train_generator = OpenImagesGenerator( | ||
359 | + args.main_dir, | ||
360 | + subset="train", | ||
361 | + version=args.version, | ||
362 | + labels_filter=args.labels_filter, | ||
363 | + annotation_cache_dir=args.annotation_cache_dir, | ||
364 | + parent_label=args.parent_label, | ||
365 | + transform_generator=transform_generator, | ||
366 | + visual_effect_generator=visual_effect_generator, | ||
367 | + **common_args | ||
368 | + ) | ||
369 | + | ||
370 | + validation_generator = OpenImagesGenerator( | ||
371 | + args.main_dir, | ||
372 | + subset="validation", | ||
373 | + version=args.version, | ||
374 | + labels_filter=args.labels_filter, | ||
375 | + annotation_cache_dir=args.annotation_cache_dir, | ||
376 | + parent_label=args.parent_label, | ||
377 | + shuffle_groups=False, | ||
378 | + **common_args | ||
379 | + ) | ||
380 | + elif args.dataset_type == "kitti": | ||
381 | + train_generator = KittiGenerator( | ||
382 | + args.kitti_path, | ||
383 | + subset="train", | ||
384 | + transform_generator=transform_generator, | ||
385 | + visual_effect_generator=visual_effect_generator, | ||
386 | + **common_args | ||
387 | + ) | ||
388 | + | ||
389 | + validation_generator = KittiGenerator( | ||
390 | + args.kitti_path, subset="val", shuffle_groups=False, **common_args | ||
391 | + ) | ||
392 | + else: | ||
393 | + raise ValueError("Invalid data type received: {}".format(args.dataset_type)) | ||
394 | + | ||
395 | + return train_generator, validation_generator | ||
396 | + | ||
397 | + | ||
398 | +def check_args(parsed_args): | ||
399 | + """Function to check for inherent contradictions within parsed arguments. | ||
400 | + For example, batch_size < num_gpus | ||
401 | + Intended to raise errors prior to backend initialisation. | ||
402 | + | ||
403 | + Args | ||
404 | + parsed_args: parser.parse_args() | ||
405 | + | ||
406 | + Returns | ||
407 | + parsed_args | ||
408 | + """ | ||
409 | + | ||
410 | + if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu: | ||
411 | + raise ValueError( | ||
412 | + "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format( | ||
413 | + parsed_args.batch_size, parsed_args.multi_gpu | ||
414 | + ) | ||
415 | + ) | ||
416 | + | ||
417 | + if parsed_args.multi_gpu > 1 and parsed_args.snapshot: | ||
418 | + raise ValueError( | ||
419 | + "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format( | ||
420 | + parsed_args.multi_gpu, parsed_args.snapshot | ||
421 | + ) | ||
422 | + ) | ||
423 | + | ||
424 | + if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force: | ||
425 | + raise ValueError( | ||
426 | + "Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue." | ||
427 | + ) | ||
428 | + | ||
429 | + if "resnet" not in parsed_args.backbone: | ||
430 | + warnings.warn( | ||
431 | + "Using experimental backbone {}. Only resnet50 has been properly tested.".format( | ||
432 | + parsed_args.backbone | ||
433 | + ) | ||
434 | + ) | ||
435 | + | ||
436 | + return parsed_args | ||
437 | + | ||
438 | + | ||
439 | +def parse_args(args): | ||
440 | + """Parse the arguments.""" | ||
441 | + parser = argparse.ArgumentParser( | ||
442 | + description="Simple training script for training a RetinaNet network." | ||
443 | + ) | ||
444 | + subparsers = parser.add_subparsers( | ||
445 | + help="Arguments for specific dataset types.", dest="dataset_type" | ||
446 | + ) | ||
447 | + subparsers.required = True | ||
448 | + | ||
449 | + coco_parser = subparsers.add_parser("coco") | ||
450 | + coco_parser.add_argument( | ||
451 | + "coco_path", help="Path to dataset directory (ie. /tmp/COCO)." | ||
452 | + ) | ||
453 | + | ||
454 | + pascal_parser = subparsers.add_parser("pascal") | ||
455 | + pascal_parser.add_argument( | ||
456 | + "pascal_path", help="Path to dataset directory (ie. /tmp/VOCdevkit)." | ||
457 | + ) | ||
458 | + pascal_parser.add_argument( | ||
459 | + "--image-extension", | ||
460 | + help="Declares the dataset images' extension.", | ||
461 | + default=".jpg", | ||
462 | + ) | ||
463 | + | ||
464 | + kitti_parser = subparsers.add_parser("kitti") | ||
465 | + kitti_parser.add_argument( | ||
466 | + "kitti_path", help="Path to dataset directory (ie. /tmp/kitti)." | ||
467 | + ) | ||
468 | + | ||
469 | + def csv_list(string): | ||
470 | + return string.split(",") | ||
471 | + | ||
472 | + oid_parser = subparsers.add_parser("oid") | ||
473 | + oid_parser.add_argument("main_dir", help="Path to dataset directory.") | ||
474 | + oid_parser.add_argument( | ||
475 | + "--version", help="The current dataset version is v4.", default="v4" | ||
476 | + ) | ||
477 | + oid_parser.add_argument( | ||
478 | + "--labels-filter", | ||
479 | + help="A list of labels to filter.", | ||
480 | + type=csv_list, | ||
481 | + default=None, | ||
482 | + ) | ||
483 | + oid_parser.add_argument( | ||
484 | + "--annotation-cache-dir", help="Path to store annotation cache.", default="." | ||
485 | + ) | ||
486 | + oid_parser.add_argument( | ||
487 | + "--parent-label", help="Use the hierarchy children of this label.", default=None | ||
488 | + ) | ||
489 | + | ||
490 | + csv_parser = subparsers.add_parser("csv") | ||
491 | + csv_parser.add_argument( | ||
492 | + "annotations", help="Path to CSV file containing annotations for training." | ||
493 | + ) | ||
494 | + csv_parser.add_argument( | ||
495 | + "classes", help="Path to a CSV file containing class label mapping." | ||
496 | + ) | ||
497 | + csv_parser.add_argument( | ||
498 | + "--val-annotations", | ||
499 | + help="Path to CSV file containing annotations for validation (optional).", | ||
500 | + ) | ||
501 | + | ||
502 | + group = parser.add_mutually_exclusive_group() | ||
503 | + group.add_argument("--snapshot", help="Resume training from a snapshot.") | ||
504 | + group.add_argument( | ||
505 | + "--imagenet-weights", | ||
506 | + help="Initialize the model with pretrained imagenet weights. This is the default behaviour.", | ||
507 | + action="store_const", | ||
508 | + const=True, | ||
509 | + default=True, | ||
510 | + ) | ||
511 | + group.add_argument( | ||
512 | + "--weights", help="Initialize the model with weights from a file." | ||
513 | + ) | ||
514 | + group.add_argument( | ||
515 | + "--no-weights", | ||
516 | + help="Don't initialize the model with any weights.", | ||
517 | + dest="imagenet_weights", | ||
518 | + action="store_const", | ||
519 | + const=False, | ||
520 | + ) | ||
521 | + parser.add_argument( | ||
522 | + "--backbone", | ||
523 | + help="Backbone model used by retinanet.", | ||
524 | + default="resnet50", | ||
525 | + type=str, | ||
526 | + ) | ||
527 | + parser.add_argument( | ||
528 | + "--batch-size", help="Size of the batches.", default=1, type=int | ||
529 | + ) | ||
530 | + parser.add_argument( | ||
531 | + "--gpu", help="Id of the GPU to use (as reported by nvidia-smi)." | ||
532 | + ) | ||
533 | + parser.add_argument( | ||
534 | + "--multi-gpu", | ||
535 | + help="Number of GPUs to use for parallel processing.", | ||
536 | + type=int, | ||
537 | + default=0, | ||
538 | + ) | ||
539 | + parser.add_argument( | ||
540 | + "--multi-gpu-force", | ||
541 | + help="Extra flag needed to enable (experimental) multi-gpu support.", | ||
542 | + action="store_true", | ||
543 | + ) | ||
544 | + parser.add_argument( | ||
545 | + "--initial-epoch", | ||
546 | + help="Epoch from which to begin the train, useful if resuming from snapshot.", | ||
547 | + type=int, | ||
548 | + default=0, | ||
549 | + ) | ||
550 | + parser.add_argument( | ||
551 | + "--epochs", help="Number of epochs to train.", type=int, default=50 | ||
552 | + ) | ||
553 | + parser.add_argument( | ||
554 | + "--steps", help="Number of steps per epoch.", type=int, default=10000 | ||
555 | + ) | ||
556 | + parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-5) | ||
557 | + parser.add_argument( | ||
558 | + "--optimizer-clipnorm", | ||
559 | + help="Clipnorm parameter for optimizer.", | ||
560 | + type=float, | ||
561 | + default=0.001, | ||
562 | + ) | ||
563 | + parser.add_argument( | ||
564 | + "--snapshot-path", | ||
565 | + help="Path to store snapshots of models during training (defaults to './snapshots')", | ||
566 | + default="./snapshots", | ||
567 | + ) | ||
568 | + parser.add_argument( | ||
569 | + "--tensorboard-dir", help="Log directory for Tensorboard output", default="" | ||
570 | + ) # default='./logs') => https://github.com/tensorflow/tensorflow/pull/34870 | ||
571 | + parser.add_argument( | ||
572 | + "--tensorboard-freq", | ||
573 | + help="Update frequency for Tensorboard output. Values 'epoch', 'batch' or int", | ||
574 | + default="epoch", | ||
575 | + ) | ||
576 | + parser.add_argument( | ||
577 | + "--no-snapshots", | ||
578 | + help="Disable saving snapshots.", | ||
579 | + dest="snapshots", | ||
580 | + action="store_false", | ||
581 | + ) | ||
582 | + parser.add_argument( | ||
583 | + "--no-evaluation", | ||
584 | + help="Disable per epoch evaluation.", | ||
585 | + dest="evaluation", | ||
586 | + action="store_false", | ||
587 | + ) | ||
588 | + parser.add_argument( | ||
589 | + "--freeze-backbone", | ||
590 | + help="Freeze training of backbone layers.", | ||
591 | + action="store_true", | ||
592 | + ) | ||
593 | + parser.add_argument( | ||
594 | + "--random-transform", | ||
595 | + help="Randomly transform image and annotations.", | ||
596 | + action="store_true", | ||
597 | + ) | ||
598 | + parser.add_argument( | ||
599 | + "--image-min-side", | ||
600 | + help="Rescale the image so the smallest side is min_side.", | ||
601 | + type=int, | ||
602 | + default=800, | ||
603 | + ) | ||
604 | + parser.add_argument( | ||
605 | + "--image-max-side", | ||
606 | + help="Rescale the image if the largest side is larger than max_side.", | ||
607 | + type=int, | ||
608 | + default=1333, | ||
609 | + ) | ||
610 | + parser.add_argument( | ||
611 | + "--no-resize", help="Don" "t rescale the image.", action="store_true" | ||
612 | + ) | ||
613 | + parser.add_argument( | ||
614 | + "--config", help="Path to a configuration parameters .ini file." | ||
615 | + ) | ||
616 | + parser.add_argument( | ||
617 | + "--weighted-average", | ||
618 | + help="Compute the mAP using the weighted average of precisions among classes.", | ||
619 | + action="store_true", | ||
620 | + ) | ||
621 | + parser.add_argument( | ||
622 | + "--compute-val-loss", | ||
623 | + help="Compute validation loss during training", | ||
624 | + dest="compute_val_loss", | ||
625 | + action="store_true", | ||
626 | + ) | ||
627 | + parser.add_argument( | ||
628 | + "--reduce-lr-patience", | ||
629 | + help="Reduce learning rate after validation loss decreases over reduce_lr_patience epochs", | ||
630 | + type=int, | ||
631 | + default=2, | ||
632 | + ) | ||
633 | + parser.add_argument( | ||
634 | + "--reduce-lr-factor", | ||
635 | + help="When learning rate is reduced due to reduce_lr_patience, multiply by reduce_lr_factor", | ||
636 | + type=float, | ||
637 | + default=0.1, | ||
638 | + ) | ||
639 | + parser.add_argument( | ||
640 | + "--group-method", | ||
641 | + help="Determines how images are grouped together", | ||
642 | + type=str, | ||
643 | + default="ratio", | ||
644 | + choices=["none", "random", "ratio"], | ||
645 | + ) | ||
646 | + | ||
647 | + # Fit generator arguments | ||
648 | + parser.add_argument( | ||
649 | + "--multiprocessing", | ||
650 | + help="Use multiprocessing in fit_generator.", | ||
651 | + action="store_true", | ||
652 | + ) | ||
653 | + parser.add_argument( | ||
654 | + "--workers", help="Number of generator workers.", type=int, default=1 | ||
655 | + ) | ||
656 | + parser.add_argument( | ||
657 | + "--max-queue-size", | ||
658 | + help="Queue length for multiprocessing workers in fit_generator.", | ||
659 | + type=int, | ||
660 | + default=10, | ||
661 | + ) | ||
662 | + | ||
663 | + return check_args(parser.parse_args(args)) | ||
664 | + | ||
665 | + | ||
666 | +def main(args=None): | ||
667 | + # parse arguments | ||
668 | + if args is None: | ||
669 | + args = sys.argv[1:] | ||
670 | + args = parse_args(args) | ||
671 | + | ||
672 | + # create object that stores backbone information | ||
673 | + backbone = models.backbone(args.backbone) | ||
674 | + | ||
675 | + # make sure tensorflow is the minimum required version | ||
676 | + check_tf_version() | ||
677 | + | ||
678 | + # optionally choose specific GPU | ||
679 | + if args.gpu is not None: | ||
680 | + setup_gpu(args.gpu) | ||
681 | + | ||
682 | + # optionally load config parameters | ||
683 | + if args.config: | ||
684 | + args.config = read_config_file(args.config) | ||
685 | + | ||
686 | + # create the generators | ||
687 | + train_generator, validation_generator = create_generators( | ||
688 | + args, backbone.preprocess_image | ||
689 | + ) | ||
690 | + | ||
691 | + # create the model | ||
692 | + if args.snapshot is not None: | ||
693 | + print("Loading model, this may take a second...") | ||
694 | + model = models.load_model(args.snapshot, backbone_name=args.backbone) | ||
695 | + training_model = model | ||
696 | + anchor_params = None | ||
697 | + pyramid_levels = None | ||
698 | + if args.config and "anchor_parameters" in args.config: | ||
699 | + anchor_params = parse_anchor_parameters(args.config) | ||
700 | + if args.config and "pyramid_levels" in args.config: | ||
701 | + pyramid_levels = parse_pyramid_levels(args.config) | ||
702 | + | ||
703 | + prediction_model = retinanet_bbox( | ||
704 | + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels | ||
705 | + ) | ||
706 | + else: | ||
707 | + weights = args.weights | ||
708 | + # default to imagenet if nothing else is specified | ||
709 | + if weights is None and args.imagenet_weights: | ||
710 | + weights = backbone.download_imagenet() | ||
711 | + | ||
712 | + ################ | ||
713 | + subclass1 = submodel.custom_classification_model(num_classes=51, num_anchors=None, name="classification_submodel1") | ||
714 | + subregress1 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel1") | ||
715 | + | ||
716 | + subclass2 = submodel.custom_classification_model(num_classes=10, num_anchors=None, name="classification_submodel2") | ||
717 | + subregress2 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel2") | ||
718 | + | ||
719 | + subclass3 = submodel.custom_classification_model(num_classes=16, num_anchors=None, name="classification_submodel3") | ||
720 | + subregress3 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel3") | ||
721 | + | ||
722 | + submodels = [ | ||
723 | + ("regression", subregress1), ("classification", subclass1), | ||
724 | + ("regression", subregress2), ("classification", subclass2), | ||
725 | + ("regression", subregress3), ("classification", subclass3), | ||
726 | + ] | ||
727 | + | ||
728 | + s1 = submodel.custom_default_submodels(51, None) | ||
729 | + s2 = submodel.custom_default_submodels(10, None) | ||
730 | + s3 = submodel.custom_default_submodels(16, None) | ||
731 | + | ||
732 | + submodels = s1 + s2 + s3 | ||
733 | + | ||
734 | + ################# | ||
735 | + print("Creating model, this may take a second...") | ||
736 | + model, training_model, prediction_model = create_models( | ||
737 | + backbone_retinanet=backbone.retinanet, | ||
738 | + num_classes=train_generator.num_classes(), | ||
739 | + weights=weights, | ||
740 | + multi_gpu=args.multi_gpu, | ||
741 | + freeze_backbone=args.freeze_backbone, | ||
742 | + lr=args.lr, | ||
743 | + optimizer_clipnorm=args.optimizer_clipnorm, | ||
744 | + config=args.config, | ||
745 | + submodels=submodels, | ||
746 | + ) | ||
747 | + | ||
748 | + # print model summary | ||
749 | + print(model.summary()) | ||
750 | + | ||
751 | + # this lets the generator compute backbone layer shapes using the actual backbone model | ||
752 | + if "vgg" in args.backbone or "densenet" in args.backbone: | ||
753 | + train_generator.compute_shapes = make_shapes_callback(model) | ||
754 | + if validation_generator: | ||
755 | + validation_generator.compute_shapes = train_generator.compute_shapes | ||
756 | + | ||
757 | + # create the callbacks | ||
758 | + callbacks = create_callbacks( | ||
759 | + model, | ||
760 | + training_model, | ||
761 | + prediction_model, | ||
762 | + validation_generator, | ||
763 | + args, | ||
764 | + ) | ||
765 | + | ||
766 | + if not args.compute_val_loss: | ||
767 | + validation_generator = None | ||
768 | + | ||
769 | + # start training | ||
770 | + return training_model.fit_generator( | ||
771 | + generator=train_generator, | ||
772 | + steps_per_epoch=args.steps, | ||
773 | + epochs=args.epochs, | ||
774 | + verbose=1, | ||
775 | + callbacks=callbacks, | ||
776 | + workers=args.workers, | ||
777 | + use_multiprocessing=args.multiprocessing, | ||
778 | + max_queue_size=args.max_queue_size, | ||
779 | + validation_data=validation_generator, | ||
780 | + initial_epoch=args.initial_epoch, | ||
781 | + ) | ||
782 | + | ||
783 | + | ||
784 | +if __name__ == "__main__": | ||
785 | + main() |
retinaNet/keras_retinanet/bin/train2.py
0 → 100644
1 | +#!/usr/bin/env python | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
5 | + | ||
6 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | +you may not use this file except in compliance with the License. | ||
8 | +You may obtain a copy of the License at | ||
9 | + | ||
10 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
11 | + | ||
12 | +Unless required by applicable law or agreed to in writing, software | ||
13 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
14 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
15 | +See the License for the specific language governing permissions and | ||
16 | +limitations under the License. | ||
17 | +""" | ||
18 | + | ||
19 | +import argparse | ||
20 | +import os | ||
21 | +import sys | ||
22 | +import warnings | ||
23 | + | ||
24 | +from tensorflow import keras | ||
25 | +import tensorflow as tf | ||
26 | + | ||
27 | +from ../models import submodel | ||
28 | + | ||
29 | +# Allow relative imports when being executed as script. | ||
30 | +if __name__ == "__main__" and __package__ is None: | ||
31 | + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) | ||
32 | + import keras_retinanet.bin # noqa: F401 | ||
33 | + | ||
34 | + __package__ = "keras_retinanet.bin" | ||
35 | + | ||
36 | +# Change these to absolute imports if you copy this script outside the keras_retinanet package. | ||
37 | +from .. import layers # noqa: F401 | ||
38 | +from .. import losses | ||
39 | +from .. import models | ||
40 | +from ..callbacks import RedirectModel | ||
41 | +from ..callbacks.eval import Evaluate | ||
42 | +from ..models.retinanet import retinanet_bbox | ||
43 | +from ..preprocessing.csv_generator import CSVGenerator | ||
44 | +from ..preprocessing.kitti import KittiGenerator | ||
45 | +from ..preprocessing.open_images import OpenImagesGenerator | ||
46 | +from ..preprocessing.pascal_voc import PascalVocGenerator | ||
47 | +from ..utils.anchors import make_shapes_callback | ||
48 | +from ..utils.config import ( | ||
49 | + read_config_file, | ||
50 | + parse_anchor_parameters, | ||
51 | + parse_pyramid_levels, | ||
52 | +) | ||
53 | +from ..utils.gpu import setup_gpu | ||
54 | +from ..utils.image import random_visual_effect_generator | ||
55 | +from ..utils.model import freeze as freeze_model | ||
56 | +from ..utils.tf_version import check_tf_version | ||
57 | +from ..utils.transform import random_transform_generator | ||
58 | + | ||
59 | +####################### | ||
60 | + | ||
61 | +from ..models import submodel | ||
62 | + | ||
63 | + | ||
64 | +def makedirs(path): | ||
65 | + # Intended behavior: try to create the directory, | ||
66 | + # pass if the directory exists already, fails otherwise. | ||
67 | + # Meant for Python 2.7/3.n compatibility. | ||
68 | + try: | ||
69 | + os.makedirs(path) | ||
70 | + except OSError: | ||
71 | + if not os.path.isdir(path): | ||
72 | + raise | ||
73 | + | ||
74 | + | ||
75 | +def model_with_weights(model, weights, skip_mismatch): | ||
76 | + """Load weights for model. | ||
77 | + | ||
78 | + Args | ||
79 | + model : The model to load weights for. | ||
80 | + weights : The weights to load. | ||
81 | + skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model. | ||
82 | + """ | ||
83 | + if weights is not None: | ||
84 | + model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) | ||
85 | + return model | ||
86 | + | ||
87 | + | ||
88 | +def create_models( | ||
89 | + backbone_retinanet, | ||
90 | + num_classes, | ||
91 | + weights, | ||
92 | + multi_gpu=0, | ||
93 | + freeze_backbone=False, | ||
94 | + lr=1e-5, | ||
95 | + optimizer_clipnorm=0.001, | ||
96 | + config=None, | ||
97 | + submodels=None, | ||
98 | +): | ||
99 | + """Creates three models (model, training_model, prediction_model). | ||
100 | + | ||
101 | + Args | ||
102 | + backbone_retinanet : A function to call to create a retinanet model with a given backbone. | ||
103 | + num_classes : The number of classes to train. | ||
104 | + weights : The weights to load into the model. | ||
105 | + multi_gpu : The number of GPUs to use for training. | ||
106 | + freeze_backbone : If True, disables learning for the backbone. | ||
107 | + config : Config parameters, None indicates the default configuration. | ||
108 | + | ||
109 | + Returns | ||
110 | + model : The base model. This is also the model that is saved in snapshots. | ||
111 | + training_model : The training model. If multi_gpu=0, this is identical to model. | ||
112 | + prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). | ||
113 | + """ | ||
114 | + | ||
115 | + modifier = freeze_model if freeze_backbone else None | ||
116 | + | ||
117 | + # load anchor parameters, or pass None (so that defaults will be used) | ||
118 | + anchor_params = None | ||
119 | + num_anchors = None | ||
120 | + pyramid_levels = None | ||
121 | + if config and "anchor_parameters" in config: | ||
122 | + anchor_params = parse_anchor_parameters(config) | ||
123 | + num_anchors = anchor_params.num_anchors() | ||
124 | + if config and "pyramid_levels" in config: | ||
125 | + pyramid_levels = parse_pyramid_levels(config) | ||
126 | + | ||
127 | + # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors. | ||
128 | + # optionally wrap in a parallel model | ||
129 | + if multi_gpu > 1: | ||
130 | + from keras.utils import multi_gpu_model | ||
131 | + | ||
132 | + with tf.device("/cpu:0"): | ||
133 | + model = model_with_weights( | ||
134 | + backbone_retinanet( | ||
135 | + num_classes, | ||
136 | + num_anchors=num_anchors, | ||
137 | + modifier=modifier, | ||
138 | + pyramid_levels=pyramid_levels, | ||
139 | + ), | ||
140 | + weights=weights, | ||
141 | + skip_mismatch=True, | ||
142 | + ) | ||
143 | + training_model = multi_gpu_model(model, gpus=multi_gpu) | ||
144 | + else: | ||
145 | + model = model_with_weights( | ||
146 | + backbone_retinanet( | ||
147 | + num_classes, | ||
148 | + num_anchors=num_anchors, | ||
149 | + modifier=modifier, | ||
150 | + pyramid_levels=pyramid_levels, | ||
151 | + submodels=submodels, | ||
152 | + ), | ||
153 | + weights=weights, | ||
154 | + skip_mismatch=True, | ||
155 | + ) | ||
156 | + training_model = model | ||
157 | + | ||
158 | + # make prediction model | ||
159 | + prediction_model = retinanet_bbox( | ||
160 | + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels | ||
161 | + ) | ||
162 | + | ||
163 | + # compile model | ||
164 | + training_model.compile( | ||
165 | + loss={"regression": losses.smooth_l1(), "classification": losses.focal()}, | ||
166 | + optimizer=keras.optimizers.Adam(lr=lr, clipnorm=optimizer_clipnorm), | ||
167 | + ) | ||
168 | + | ||
169 | + return model, training_model, prediction_model | ||
170 | + | ||
171 | + | ||
172 | +def create_callbacks( | ||
173 | + model, training_model, prediction_model, validation_generator, args | ||
174 | +): | ||
175 | + """Creates the callbacks to use during training. | ||
176 | + | ||
177 | + Args | ||
178 | + model: The base model. | ||
179 | + training_model: The model that is used for training. | ||
180 | + prediction_model: The model that should be used for validation. | ||
181 | + validation_generator: The generator for creating validation data. | ||
182 | + args: parseargs args object. | ||
183 | + | ||
184 | + Returns: | ||
185 | + A list of callbacks used for training. | ||
186 | + """ | ||
187 | + callbacks = [] | ||
188 | + | ||
189 | + tensorboard_callback = None | ||
190 | + | ||
191 | + if args.tensorboard_dir: | ||
192 | + makedirs(args.tensorboard_dir) | ||
193 | + update_freq = args.tensorboard_freq | ||
194 | + if update_freq not in ["epoch", "batch"]: | ||
195 | + update_freq = int(update_freq) | ||
196 | + tensorboard_callback = keras.callbacks.TensorBoard( | ||
197 | + log_dir=args.tensorboard_dir, | ||
198 | + histogram_freq=0, | ||
199 | + batch_size=args.batch_size, | ||
200 | + write_graph=True, | ||
201 | + write_grads=False, | ||
202 | + write_images=False, | ||
203 | + update_freq=update_freq, | ||
204 | + embeddings_freq=0, | ||
205 | + embeddings_layer_names=None, | ||
206 | + embeddings_metadata=None, | ||
207 | + ) | ||
208 | + | ||
209 | + if args.evaluation and validation_generator: | ||
210 | + if args.dataset_type == "coco": | ||
211 | + from ..callbacks.coco import CocoEval | ||
212 | + | ||
213 | + # use prediction model for evaluation | ||
214 | + evaluation = CocoEval( | ||
215 | + validation_generator, tensorboard=tensorboard_callback | ||
216 | + ) | ||
217 | + else: | ||
218 | + evaluation = Evaluate( | ||
219 | + validation_generator, | ||
220 | + tensorboard=tensorboard_callback, | ||
221 | + weighted_average=args.weighted_average, | ||
222 | + ) | ||
223 | + evaluation = RedirectModel(evaluation, prediction_model) | ||
224 | + callbacks.append(evaluation) | ||
225 | + | ||
226 | + # save the model | ||
227 | + if args.snapshots: | ||
228 | + # ensure directory created first; otherwise h5py will error after epoch. | ||
229 | + makedirs(args.snapshot_path) | ||
230 | + checkpoint = keras.callbacks.ModelCheckpoint( | ||
231 | + os.path.join( | ||
232 | + args.snapshot_path, | ||
233 | + "{backbone}_{dataset_type}_{{epoch:02d}}.h5".format( | ||
234 | + backbone=args.backbone, dataset_type=args.dataset_type | ||
235 | + ), | ||
236 | + ), | ||
237 | + verbose=1, | ||
238 | + # save_best_only=True, | ||
239 | + # monitor="mAP", | ||
240 | + # mode='max' | ||
241 | + ) | ||
242 | + checkpoint = RedirectModel(checkpoint, model) | ||
243 | + callbacks.append(checkpoint) | ||
244 | + | ||
245 | + callbacks.append( | ||
246 | + keras.callbacks.ReduceLROnPlateau( | ||
247 | + monitor="loss", | ||
248 | + factor=args.reduce_lr_factor, | ||
249 | + patience=args.reduce_lr_patience, | ||
250 | + verbose=1, | ||
251 | + mode="auto", | ||
252 | + min_delta=0.0001, | ||
253 | + cooldown=0, | ||
254 | + min_lr=0, | ||
255 | + ) | ||
256 | + ) | ||
257 | + | ||
258 | + if args.evaluation and validation_generator: | ||
259 | + callbacks.append( | ||
260 | + keras.callbacks.EarlyStopping( | ||
261 | + monitor="mAP", patience=5, mode="max", min_delta=0.01 | ||
262 | + ) | ||
263 | + ) | ||
264 | + | ||
265 | + if args.tensorboard_dir: | ||
266 | + callbacks.append(tensorboard_callback) | ||
267 | + | ||
268 | + return callbacks | ||
269 | + | ||
270 | + | ||
271 | +def create_generators(args, preprocess_image): | ||
272 | + """Create generators for training and validation. | ||
273 | + | ||
274 | + Args | ||
275 | + args : parseargs object containing configuration for generators. | ||
276 | + preprocess_image : Function that preprocesses an image for the network. | ||
277 | + """ | ||
278 | + common_args = { | ||
279 | + "batch_size": args.batch_size, | ||
280 | + "config": args.config, | ||
281 | + "image_min_side": args.image_min_side, | ||
282 | + "image_max_side": args.image_max_side, | ||
283 | + "no_resize": args.no_resize, | ||
284 | + "preprocess_image": preprocess_image, | ||
285 | + "group_method": args.group_method, | ||
286 | + } | ||
287 | + | ||
288 | + # create random transform generator for augmenting training data | ||
289 | + if args.random_transform: | ||
290 | + transform_generator = random_transform_generator( | ||
291 | + min_rotation=-0.1, | ||
292 | + max_rotation=0.1, | ||
293 | + min_translation=(-0.1, -0.1), | ||
294 | + max_translation=(0.1, 0.1), | ||
295 | + min_shear=-0.1, | ||
296 | + max_shear=0.1, | ||
297 | + min_scaling=(0.9, 0.9), | ||
298 | + max_scaling=(1.1, 1.1), | ||
299 | + flip_x_chance=0.5, | ||
300 | + flip_y_chance=0.5, | ||
301 | + ) | ||
302 | + visual_effect_generator = random_visual_effect_generator( | ||
303 | + contrast_range=(0.9, 1.1), | ||
304 | + brightness_range=(-0.1, 0.1), | ||
305 | + hue_range=(-0.05, 0.05), | ||
306 | + saturation_range=(0.95, 1.05), | ||
307 | + ) | ||
308 | + else: | ||
309 | + transform_generator = random_transform_generator(flip_x_chance=0.5) | ||
310 | + visual_effect_generator = None | ||
311 | + | ||
312 | + if args.dataset_type == "coco": | ||
313 | + # import here to prevent unnecessary dependency on cocoapi | ||
314 | + from ..preprocessing.coco import CocoGenerator | ||
315 | + | ||
316 | + train_generator = CocoGenerator( | ||
317 | + args.coco_path, | ||
318 | + "train2017", | ||
319 | + transform_generator=transform_generator, | ||
320 | + visual_effect_generator=visual_effect_generator, | ||
321 | + **common_args | ||
322 | + ) | ||
323 | + | ||
324 | + validation_generator = CocoGenerator( | ||
325 | + args.coco_path, "val2017", shuffle_groups=False, **common_args | ||
326 | + ) | ||
327 | + elif args.dataset_type == "pascal": | ||
328 | + train_generator = PascalVocGenerator( | ||
329 | + args.pascal_path, | ||
330 | + "train", | ||
331 | + image_extension=args.image_extension, | ||
332 | + transform_generator=transform_generator, | ||
333 | + visual_effect_generator=visual_effect_generator, | ||
334 | + **common_args | ||
335 | + ) | ||
336 | + | ||
337 | + validation_generator = PascalVocGenerator( | ||
338 | + args.pascal_path, | ||
339 | + "val", | ||
340 | + image_extension=args.image_extension, | ||
341 | + shuffle_groups=False, | ||
342 | + **common_args | ||
343 | + ) | ||
344 | + elif args.dataset_type == "csv": | ||
345 | + train_generator = CSVGenerator( | ||
346 | + args.annotations, | ||
347 | + args.classes, | ||
348 | + transform_generator=transform_generator, | ||
349 | + visual_effect_generator=visual_effect_generator, | ||
350 | + **common_args | ||
351 | + ) | ||
352 | + | ||
353 | + if args.val_annotations: | ||
354 | + validation_generator = CSVGenerator( | ||
355 | + args.val_annotations, args.classes, shuffle_groups=False, **common_args | ||
356 | + ) | ||
357 | + else: | ||
358 | + validation_generator = None | ||
359 | + elif args.dataset_type == "oid": | ||
360 | + train_generator = OpenImagesGenerator( | ||
361 | + args.main_dir, | ||
362 | + subset="train", | ||
363 | + version=args.version, | ||
364 | + labels_filter=args.labels_filter, | ||
365 | + annotation_cache_dir=args.annotation_cache_dir, | ||
366 | + parent_label=args.parent_label, | ||
367 | + transform_generator=transform_generator, | ||
368 | + visual_effect_generator=visual_effect_generator, | ||
369 | + **common_args | ||
370 | + ) | ||
371 | + | ||
372 | + validation_generator = OpenImagesGenerator( | ||
373 | + args.main_dir, | ||
374 | + subset="validation", | ||
375 | + version=args.version, | ||
376 | + labels_filter=args.labels_filter, | ||
377 | + annotation_cache_dir=args.annotation_cache_dir, | ||
378 | + parent_label=args.parent_label, | ||
379 | + shuffle_groups=False, | ||
380 | + **common_args | ||
381 | + ) | ||
382 | + elif args.dataset_type == "kitti": | ||
383 | + train_generator = KittiGenerator( | ||
384 | + args.kitti_path, | ||
385 | + subset="train", | ||
386 | + transform_generator=transform_generator, | ||
387 | + visual_effect_generator=visual_effect_generator, | ||
388 | + **common_args | ||
389 | + ) | ||
390 | + | ||
391 | + validation_generator = KittiGenerator( | ||
392 | + args.kitti_path, subset="val", shuffle_groups=False, **common_args | ||
393 | + ) | ||
394 | + else: | ||
395 | + raise ValueError("Invalid data type received: {}".format(args.dataset_type)) | ||
396 | + | ||
397 | + return train_generator, validation_generator | ||
398 | + | ||
399 | + | ||
400 | +def check_args(parsed_args): | ||
401 | + """Function to check for inherent contradictions within parsed arguments. | ||
402 | + For example, batch_size < num_gpus | ||
403 | + Intended to raise errors prior to backend initialisation. | ||
404 | + | ||
405 | + Args | ||
406 | + parsed_args: parser.parse_args() | ||
407 | + | ||
408 | + Returns | ||
409 | + parsed_args | ||
410 | + """ | ||
411 | + | ||
412 | + if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu: | ||
413 | + raise ValueError( | ||
414 | + "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format( | ||
415 | + parsed_args.batch_size, parsed_args.multi_gpu | ||
416 | + ) | ||
417 | + ) | ||
418 | + | ||
419 | + if parsed_args.multi_gpu > 1 and parsed_args.snapshot: | ||
420 | + raise ValueError( | ||
421 | + "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format( | ||
422 | + parsed_args.multi_gpu, parsed_args.snapshot | ||
423 | + ) | ||
424 | + ) | ||
425 | + | ||
426 | + if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force: | ||
427 | + raise ValueError( | ||
428 | + "Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue." | ||
429 | + ) | ||
430 | + | ||
431 | + if "resnet" not in parsed_args.backbone: | ||
432 | + warnings.warn( | ||
433 | + "Using experimental backbone {}. Only resnet50 has been properly tested.".format( | ||
434 | + parsed_args.backbone | ||
435 | + ) | ||
436 | + ) | ||
437 | + | ||
438 | + return parsed_args | ||
439 | + | ||
440 | + | ||
441 | +def parse_args(args): | ||
442 | + """Parse the arguments.""" | ||
443 | + parser = argparse.ArgumentParser( | ||
444 | + description="Simple training script for training a RetinaNet network." | ||
445 | + ) | ||
446 | + subparsers = parser.add_subparsers( | ||
447 | + help="Arguments for specific dataset types.", dest="dataset_type" | ||
448 | + ) | ||
449 | + subparsers.required = True | ||
450 | + | ||
451 | + coco_parser = subparsers.add_parser("coco") | ||
452 | + coco_parser.add_argument( | ||
453 | + "coco_path", help="Path to dataset directory (ie. /tmp/COCO)." | ||
454 | + ) | ||
455 | + | ||
456 | + pascal_parser = subparsers.add_parser("pascal") | ||
457 | + pascal_parser.add_argument( | ||
458 | + "pascal_path", help="Path to dataset directory (ie. /tmp/VOCdevkit)." | ||
459 | + ) | ||
460 | + pascal_parser.add_argument( | ||
461 | + "--image-extension", | ||
462 | + help="Declares the dataset images' extension.", | ||
463 | + default=".jpg", | ||
464 | + ) | ||
465 | + | ||
466 | + kitti_parser = subparsers.add_parser("kitti") | ||
467 | + kitti_parser.add_argument( | ||
468 | + "kitti_path", help="Path to dataset directory (ie. /tmp/kitti)." | ||
469 | + ) | ||
470 | + | ||
471 | + def csv_list(string): | ||
472 | + return string.split(",") | ||
473 | + | ||
474 | + oid_parser = subparsers.add_parser("oid") | ||
475 | + oid_parser.add_argument("main_dir", help="Path to dataset directory.") | ||
476 | + oid_parser.add_argument( | ||
477 | + "--version", help="The current dataset version is v4.", default="v4" | ||
478 | + ) | ||
479 | + oid_parser.add_argument( | ||
480 | + "--labels-filter", | ||
481 | + help="A list of labels to filter.", | ||
482 | + type=csv_list, | ||
483 | + default=None, | ||
484 | + ) | ||
485 | + oid_parser.add_argument( | ||
486 | + "--annotation-cache-dir", help="Path to store annotation cache.", default="." | ||
487 | + ) | ||
488 | + oid_parser.add_argument( | ||
489 | + "--parent-label", help="Use the hierarchy children of this label.", default=None | ||
490 | + ) | ||
491 | + | ||
492 | + csv_parser = subparsers.add_parser("csv") | ||
493 | + csv_parser.add_argument( | ||
494 | + "annotations", help="Path to CSV file containing annotations for training." | ||
495 | + ) | ||
496 | + csv_parser.add_argument( | ||
497 | + "classes", help="Path to a CSV file containing class label mapping." | ||
498 | + ) | ||
499 | + csv_parser.add_argument( | ||
500 | + "--val-annotations", | ||
501 | + help="Path to CSV file containing annotations for validation (optional).", | ||
502 | + ) | ||
503 | + | ||
504 | + group = parser.add_mutually_exclusive_group() | ||
505 | + group.add_argument("--snapshot", help="Resume training from a snapshot.") | ||
506 | + group.add_argument( | ||
507 | + "--imagenet-weights", | ||
508 | + help="Initialize the model with pretrained imagenet weights. This is the default behaviour.", | ||
509 | + action="store_const", | ||
510 | + const=True, | ||
511 | + default=True, | ||
512 | + ) | ||
513 | + group.add_argument( | ||
514 | + "--weights", help="Initialize the model with weights from a file." | ||
515 | + ) | ||
516 | + group.add_argument( | ||
517 | + "--no-weights", | ||
518 | + help="Don't initialize the model with any weights.", | ||
519 | + dest="imagenet_weights", | ||
520 | + action="store_const", | ||
521 | + const=False, | ||
522 | + ) | ||
523 | + parser.add_argument( | ||
524 | + "--backbone", | ||
525 | + help="Backbone model used by retinanet.", | ||
526 | + default="resnet50", | ||
527 | + type=str, | ||
528 | + ) | ||
529 | + parser.add_argument( | ||
530 | + "--batch-size", help="Size of the batches.", default=1, type=int | ||
531 | + ) | ||
532 | + parser.add_argument( | ||
533 | + "--gpu", help="Id of the GPU to use (as reported by nvidia-smi)." | ||
534 | + ) | ||
535 | + parser.add_argument( | ||
536 | + "--multi-gpu", | ||
537 | + help="Number of GPUs to use for parallel processing.", | ||
538 | + type=int, | ||
539 | + default=0, | ||
540 | + ) | ||
541 | + parser.add_argument( | ||
542 | + "--multi-gpu-force", | ||
543 | + help="Extra flag needed to enable (experimental) multi-gpu support.", | ||
544 | + action="store_true", | ||
545 | + ) | ||
546 | + parser.add_argument( | ||
547 | + "--initial-epoch", | ||
548 | + help="Epoch from which to begin the train, useful if resuming from snapshot.", | ||
549 | + type=int, | ||
550 | + default=0, | ||
551 | + ) | ||
552 | + parser.add_argument( | ||
553 | + "--epochs", help="Number of epochs to train.", type=int, default=50 | ||
554 | + ) | ||
555 | + parser.add_argument( | ||
556 | + "--steps", help="Number of steps per epoch.", type=int, default=10000 | ||
557 | + ) | ||
558 | + parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-5) | ||
559 | + parser.add_argument( | ||
560 | + "--optimizer-clipnorm", | ||
561 | + help="Clipnorm parameter for optimizer.", | ||
562 | + type=float, | ||
563 | + default=0.001, | ||
564 | + ) | ||
565 | + parser.add_argument( | ||
566 | + "--snapshot-path", | ||
567 | + help="Path to store snapshots of models during training (defaults to './snapshots')", | ||
568 | + default="./snapshots", | ||
569 | + ) | ||
570 | + parser.add_argument( | ||
571 | + "--tensorboard-dir", help="Log directory for Tensorboard output", default="" | ||
572 | + ) # default='./logs') => https://github.com/tensorflow/tensorflow/pull/34870 | ||
573 | + parser.add_argument( | ||
574 | + "--tensorboard-freq", | ||
575 | + help="Update frequency for Tensorboard output. Values 'epoch', 'batch' or int", | ||
576 | + default="epoch", | ||
577 | + ) | ||
578 | + parser.add_argument( | ||
579 | + "--no-snapshots", | ||
580 | + help="Disable saving snapshots.", | ||
581 | + dest="snapshots", | ||
582 | + action="store_false", | ||
583 | + ) | ||
584 | + parser.add_argument( | ||
585 | + "--no-evaluation", | ||
586 | + help="Disable per epoch evaluation.", | ||
587 | + dest="evaluation", | ||
588 | + action="store_false", | ||
589 | + ) | ||
590 | + parser.add_argument( | ||
591 | + "--freeze-backbone", | ||
592 | + help="Freeze training of backbone layers.", | ||
593 | + action="store_true", | ||
594 | + ) | ||
595 | + parser.add_argument( | ||
596 | + "--random-transform", | ||
597 | + help="Randomly transform image and annotations.", | ||
598 | + action="store_true", | ||
599 | + ) | ||
600 | + parser.add_argument( | ||
601 | + "--image-min-side", | ||
602 | + help="Rescale the image so the smallest side is min_side.", | ||
603 | + type=int, | ||
604 | + default=800, | ||
605 | + ) | ||
606 | + parser.add_argument( | ||
607 | + "--image-max-side", | ||
608 | + help="Rescale the image if the largest side is larger than max_side.", | ||
609 | + type=int, | ||
610 | + default=1333, | ||
611 | + ) | ||
612 | + parser.add_argument( | ||
613 | + "--no-resize", help="Don" "t rescale the image.", action="store_true" | ||
614 | + ) | ||
615 | + parser.add_argument( | ||
616 | + "--config", help="Path to a configuration parameters .ini file." | ||
617 | + ) | ||
618 | + parser.add_argument( | ||
619 | + "--weighted-average", | ||
620 | + help="Compute the mAP using the weighted average of precisions among classes.", | ||
621 | + action="store_true", | ||
622 | + ) | ||
623 | + parser.add_argument( | ||
624 | + "--compute-val-loss", | ||
625 | + help="Compute validation loss during training", | ||
626 | + dest="compute_val_loss", | ||
627 | + action="store_true", | ||
628 | + ) | ||
629 | + parser.add_argument( | ||
630 | + "--reduce-lr-patience", | ||
631 | + help="Reduce learning rate after validation loss decreases over reduce_lr_patience epochs", | ||
632 | + type=int, | ||
633 | + default=2, | ||
634 | + ) | ||
635 | + parser.add_argument( | ||
636 | + "--reduce-lr-factor", | ||
637 | + help="When learning rate is reduced due to reduce_lr_patience, multiply by reduce_lr_factor", | ||
638 | + type=float, | ||
639 | + default=0.1, | ||
640 | + ) | ||
641 | + parser.add_argument( | ||
642 | + "--group-method", | ||
643 | + help="Determines how images are grouped together", | ||
644 | + type=str, | ||
645 | + default="ratio", | ||
646 | + choices=["none", "random", "ratio"], | ||
647 | + ) | ||
648 | + | ||
649 | + # Fit generator arguments | ||
650 | + parser.add_argument( | ||
651 | + "--multiprocessing", | ||
652 | + help="Use multiprocessing in fit_generator.", | ||
653 | + action="store_true", | ||
654 | + ) | ||
655 | + parser.add_argument( | ||
656 | + "--workers", help="Number of generator workers.", type=int, default=1 | ||
657 | + ) | ||
658 | + parser.add_argument( | ||
659 | + "--max-queue-size", | ||
660 | + help="Queue length for multiprocessing workers in fit_generator.", | ||
661 | + type=int, | ||
662 | + default=10, | ||
663 | + ) | ||
664 | + | ||
665 | + return check_args(parser.parse_args(args)) | ||
666 | + | ||
667 | + | ||
668 | +def main(args=None): | ||
669 | + # parse arguments | ||
670 | + if args is None: | ||
671 | + args = sys.argv[1:] | ||
672 | + args = parse_args(args) | ||
673 | + | ||
674 | + # create object that stores backbone information | ||
675 | + backbone = models.backbone(args.backbone) | ||
676 | + | ||
677 | + # make sure tensorflow is the minimum required version | ||
678 | + check_tf_version() | ||
679 | + | ||
680 | + # optionally choose specific GPU | ||
681 | + if args.gpu is not None: | ||
682 | + setup_gpu(args.gpu) | ||
683 | + | ||
684 | + # optionally load config parameters | ||
685 | + if args.config: | ||
686 | + args.config = read_config_file(args.config) | ||
687 | + | ||
688 | + # create the generators | ||
689 | + train_generator, validation_generator = create_generators( | ||
690 | + args, backbone.preprocess_image | ||
691 | + ) | ||
692 | + | ||
693 | + # create the model | ||
694 | + if args.snapshot is not None: | ||
695 | + print("Loading model, this may take a second...") | ||
696 | + model = models.load_model(args.snapshot, backbone_name=args.backbone) | ||
697 | + training_model = model | ||
698 | + anchor_params = None | ||
699 | + pyramid_levels = None | ||
700 | + if args.config and "anchor_parameters" in args.config: | ||
701 | + anchor_params = parse_anchor_parameters(args.config) | ||
702 | + if args.config and "pyramid_levels" in args.config: | ||
703 | + pyramid_levels = parse_pyramid_levels(args.config) | ||
704 | + | ||
705 | + prediction_model = retinanet_bbox( | ||
706 | + model=model, anchor_params=anchor_params, pyramid_levels=pyramid_levels | ||
707 | + ) | ||
708 | + else: | ||
709 | + weights = args.weights | ||
710 | + # default to imagenet if nothing else is specified | ||
711 | + if weights is None and args.imagenet_weights: | ||
712 | + weights = backbone.download_imagenet() | ||
713 | + | ||
714 | + ################# | ||
715 | + # subclass1 = submodel.custom_classification_model(num_classes=51, num_anchors=None, name="classification_submodel1") | ||
716 | + # subregress1 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel1") | ||
717 | + | ||
718 | + # subclass2 = submodel.custom_classification_model(num_classes=10, num_anchors=None, name="classification_submodel2") | ||
719 | + # subregress2 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel2") | ||
720 | + | ||
721 | + # subclass3 = submodel.custom_classification_model(num_classes=16, num_anchors=None, name="classification_submodel3") | ||
722 | + # subregress3 = submodel.custom_regression_model(num_values=4, num_anchors=None, name="regression_submodel3") | ||
723 | + | ||
724 | + # submodels = [ | ||
725 | + # ("regression", subregress1), ("classification", subclass1), | ||
726 | + # ("regression", subregress2), ("classification", subclass2), | ||
727 | + # ("regression", subregress3), ("classification", subclass3), | ||
728 | + # ] | ||
729 | + | ||
730 | + # s1 = submodel.custom_default_submodels(51, None) | ||
731 | + # s2 = submodel.custom_default_submodels(10, None) | ||
732 | + # s3 = submodel.custom_default_submodels(16, None) | ||
733 | + | ||
734 | + # submodels = s1 + s2 + s3 | ||
735 | + | ||
736 | + ################# | ||
737 | + print("Creating model, this may take a second...") | ||
738 | + model, training_model, prediction_model = create_models( | ||
739 | + backbone_retinanet=backbone.retinanet, | ||
740 | + num_classes=train_generator.num_classes(), | ||
741 | + weights=weights, | ||
742 | + multi_gpu=args.multi_gpu, | ||
743 | + freeze_backbone=args.freeze_backbone, | ||
744 | + lr=args.lr, | ||
745 | + optimizer_clipnorm=args.optimizer_clipnorm, | ||
746 | + config=args.config, | ||
747 | + submodels=submodel.custom_classification_model(76,), | ||
748 | + ) | ||
749 | + | ||
750 | + # print model summary | ||
751 | + print(model.summary()) | ||
752 | + | ||
753 | + # this lets the generator compute backbone layer shapes using the actual backbone model | ||
754 | + if "vgg" in args.backbone or "densenet" in args.backbone: | ||
755 | + train_generator.compute_shapes = make_shapes_callback(model) | ||
756 | + if validation_generator: | ||
757 | + validation_generator.compute_shapes = train_generator.compute_shapes | ||
758 | + | ||
759 | + # create the callbacks | ||
760 | + callbacks = create_callbacks( | ||
761 | + model, | ||
762 | + training_model, | ||
763 | + prediction_model, | ||
764 | + validation_generator, | ||
765 | + args, | ||
766 | + ) | ||
767 | + | ||
768 | + if not args.compute_val_loss: | ||
769 | + validation_generator = None | ||
770 | + | ||
771 | + # start training | ||
772 | + return training_model.fit_generator( | ||
773 | + generator=train_generator, | ||
774 | + steps_per_epoch=args.steps, | ||
775 | + epochs=args.epochs, | ||
776 | + verbose=1, | ||
777 | + callbacks=callbacks, | ||
778 | + workers=args.workers, | ||
779 | + use_multiprocessing=args.multiprocessing, | ||
780 | + max_queue_size=args.max_queue_size, | ||
781 | + validation_data=validation_generator, | ||
782 | + initial_epoch=args.initial_epoch, | ||
783 | + ) | ||
784 | + | ||
785 | + | ||
786 | +if __name__ == "__main__": | ||
787 | + main() |
1 | +from .common import * # noqa: F401,F403 |
retinaNet/keras_retinanet/callbacks/coco.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +from ..utils.coco_eval import evaluate_coco | ||
19 | + | ||
20 | + | ||
21 | +class CocoEval(keras.callbacks.Callback): | ||
22 | + """ Performs COCO evaluation on each epoch. | ||
23 | + """ | ||
24 | + def __init__(self, generator, tensorboard=None, threshold=0.05): | ||
25 | + """ CocoEval callback intializer. | ||
26 | + | ||
27 | + Args | ||
28 | + generator : The generator used for creating validation data. | ||
29 | + tensorboard : If given, the results will be written to tensorboard. | ||
30 | + threshold : The score threshold to use. | ||
31 | + """ | ||
32 | + self.generator = generator | ||
33 | + self.threshold = threshold | ||
34 | + self.tensorboard = tensorboard | ||
35 | + | ||
36 | + super(CocoEval, self).__init__() | ||
37 | + | ||
38 | + def on_epoch_end(self, epoch, logs=None): | ||
39 | + logs = logs or {} | ||
40 | + | ||
41 | + coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', | ||
42 | + 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', | ||
43 | + 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', | ||
44 | + 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', | ||
45 | + 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', | ||
46 | + 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', | ||
47 | + 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', | ||
48 | + 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', | ||
49 | + 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', | ||
50 | + 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', | ||
51 | + 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', | ||
52 | + 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] | ||
53 | + coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold) | ||
54 | + | ||
55 | + if coco_eval_stats is not None: | ||
56 | + for index, result in enumerate(coco_eval_stats): | ||
57 | + logs[coco_tag[index]] = result | ||
58 | + | ||
59 | + if self.tensorboard: | ||
60 | + import tensorflow as tf | ||
61 | + writer = tf.summary.create_file_writer(self.tensorboard.log_dir) | ||
62 | + with writer.as_default(): | ||
63 | + for index, result in enumerate(coco_eval_stats): | ||
64 | + tf.summary.scalar('{}. {}'.format(index + 1, coco_tag[index]), result, step=epoch) | ||
65 | + writer.flush() |
1 | +from tensorflow import keras | ||
2 | + | ||
3 | + | ||
4 | +class RedirectModel(keras.callbacks.Callback): | ||
5 | + """Callback which wraps another callback, but executed on a different model. | ||
6 | + | ||
7 | + ```python | ||
8 | + model = keras.models.load_model('model.h5') | ||
9 | + model_checkpoint = ModelCheckpoint(filepath='snapshot.h5') | ||
10 | + parallel_model = multi_gpu_model(model, gpus=2) | ||
11 | + parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)]) | ||
12 | + ``` | ||
13 | + | ||
14 | + Args | ||
15 | + callback : callback to wrap. | ||
16 | + model : model to use when executing callbacks. | ||
17 | + """ | ||
18 | + | ||
19 | + def __init__(self, | ||
20 | + callback, | ||
21 | + model): | ||
22 | + super(RedirectModel, self).__init__() | ||
23 | + | ||
24 | + self.callback = callback | ||
25 | + self.redirect_model = model | ||
26 | + | ||
27 | + def on_epoch_begin(self, epoch, logs=None): | ||
28 | + self.callback.on_epoch_begin(epoch, logs=logs) | ||
29 | + | ||
30 | + def on_epoch_end(self, epoch, logs=None): | ||
31 | + self.callback.on_epoch_end(epoch, logs=logs) | ||
32 | + | ||
33 | + def on_batch_begin(self, batch, logs=None): | ||
34 | + self.callback.on_batch_begin(batch, logs=logs) | ||
35 | + | ||
36 | + def on_batch_end(self, batch, logs=None): | ||
37 | + self.callback.on_batch_end(batch, logs=logs) | ||
38 | + | ||
39 | + def on_train_begin(self, logs=None): | ||
40 | + # overwrite the model with our custom model | ||
41 | + self.callback.set_model(self.redirect_model) | ||
42 | + | ||
43 | + self.callback.on_train_begin(logs=logs) | ||
44 | + | ||
45 | + def on_train_end(self, logs=None): | ||
46 | + self.callback.on_train_end(logs=logs) |
retinaNet/keras_retinanet/callbacks/eval.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +from ..utils.eval import evaluate | ||
19 | + | ||
20 | + | ||
21 | +class Evaluate(keras.callbacks.Callback): | ||
22 | + """ Evaluation callback for arbitrary datasets. | ||
23 | + """ | ||
24 | + | ||
25 | + def __init__( | ||
26 | + self, | ||
27 | + generator, | ||
28 | + iou_threshold=0.5, | ||
29 | + score_threshold=0.05, | ||
30 | + max_detections=100, | ||
31 | + save_path=None, | ||
32 | + tensorboard=None, | ||
33 | + weighted_average=False, | ||
34 | + verbose=1 | ||
35 | + ): | ||
36 | + """ Evaluate a given dataset using a given model at the end of every epoch during training. | ||
37 | + | ||
38 | + # Arguments | ||
39 | + generator : The generator that represents the dataset to evaluate. | ||
40 | + iou_threshold : The threshold used to consider when a detection is positive or negative. | ||
41 | + score_threshold : The score confidence threshold to use for detections. | ||
42 | + max_detections : The maximum number of detections to use per image. | ||
43 | + save_path : The path to save images with visualized detections to. | ||
44 | + tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value. | ||
45 | + weighted_average : Compute the mAP using the weighted average of precisions among classes. | ||
46 | + verbose : Set the verbosity level, by default this is set to 1. | ||
47 | + """ | ||
48 | + self.generator = generator | ||
49 | + self.iou_threshold = iou_threshold | ||
50 | + self.score_threshold = score_threshold | ||
51 | + self.max_detections = max_detections | ||
52 | + self.save_path = save_path | ||
53 | + self.tensorboard = tensorboard | ||
54 | + self.weighted_average = weighted_average | ||
55 | + self.verbose = verbose | ||
56 | + | ||
57 | + super(Evaluate, self).__init__() | ||
58 | + | ||
59 | + def on_epoch_end(self, epoch, logs=None): | ||
60 | + logs = logs or {} | ||
61 | + | ||
62 | + # run evaluation | ||
63 | + average_precisions, _ = evaluate( | ||
64 | + self.generator, | ||
65 | + self.model, | ||
66 | + iou_threshold=self.iou_threshold, | ||
67 | + score_threshold=self.score_threshold, | ||
68 | + max_detections=self.max_detections, | ||
69 | + save_path=self.save_path | ||
70 | + ) | ||
71 | + | ||
72 | + # compute per class average precision | ||
73 | + total_instances = [] | ||
74 | + precisions = [] | ||
75 | + for label, (average_precision, num_annotations) in average_precisions.items(): | ||
76 | + if self.verbose == 1: | ||
77 | + print('{:.0f} instances of class'.format(num_annotations), | ||
78 | + self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) | ||
79 | + total_instances.append(num_annotations) | ||
80 | + precisions.append(average_precision) | ||
81 | + if self.weighted_average: | ||
82 | + self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) | ||
83 | + else: | ||
84 | + self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) | ||
85 | + | ||
86 | + if self.tensorboard: | ||
87 | + import tensorflow as tf | ||
88 | + writer = tf.summary.create_file_writer(self.tensorboard.log_dir) | ||
89 | + with writer.as_default(): | ||
90 | + tf.summary.scalar("mAP", self.mean_ap, step=epoch) | ||
91 | + if self.verbose == 1: | ||
92 | + for label, (average_precision, num_annotations) in average_precisions.items(): | ||
93 | + tf.summary.scalar("AP_" + self.generator.label_to_name(label), average_precision, step=epoch) | ||
94 | + writer.flush() | ||
95 | + | ||
96 | + logs['mAP'] = self.mean_ap | ||
97 | + | ||
98 | + if self.verbose == 1: | ||
99 | + print('mAP: {:.4f}'.format(self.mean_ap)) |
retinaNet/keras_retinanet/initializers.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | + | ||
19 | +import math | ||
20 | + | ||
21 | + | ||
22 | +class PriorProbability(keras.initializers.Initializer): | ||
23 | + """ Apply a prior probability to the weights. | ||
24 | + """ | ||
25 | + | ||
26 | + def __init__(self, probability=0.01): | ||
27 | + self.probability = probability | ||
28 | + | ||
29 | + def get_config(self): | ||
30 | + return { | ||
31 | + 'probability': self.probability | ||
32 | + } | ||
33 | + | ||
34 | + def __call__(self, shape, dtype=None): | ||
35 | + # set bias to -log((1 - p)/p) for foreground | ||
36 | + result = keras.backend.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability) | ||
37 | + | ||
38 | + return result |
retinaNet/keras_retinanet/layers/__init__.py
0 → 100644
retinaNet/keras_retinanet/layers/_misc.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import tensorflow | ||
18 | +from tensorflow import keras | ||
19 | +from .. import backend | ||
20 | +from ..utils import anchors as utils_anchors | ||
21 | + | ||
22 | +import numpy as np | ||
23 | + | ||
24 | + | ||
25 | +class Anchors(keras.layers.Layer): | ||
26 | + """ Keras layer for generating achors for a given shape. | ||
27 | + """ | ||
28 | + | ||
29 | + def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs): | ||
30 | + """ Initializer for an Anchors layer. | ||
31 | + | ||
32 | + Args | ||
33 | + size: The base size of the anchors to generate. | ||
34 | + stride: The stride of the anchors to generate. | ||
35 | + ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios). | ||
36 | + scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales). | ||
37 | + """ | ||
38 | + self.size = size | ||
39 | + self.stride = stride | ||
40 | + self.ratios = ratios | ||
41 | + self.scales = scales | ||
42 | + | ||
43 | + if ratios is None: | ||
44 | + self.ratios = utils_anchors.AnchorParameters.default.ratios | ||
45 | + elif isinstance(ratios, list): | ||
46 | + self.ratios = np.array(ratios) | ||
47 | + if scales is None: | ||
48 | + self.scales = utils_anchors.AnchorParameters.default.scales | ||
49 | + elif isinstance(scales, list): | ||
50 | + self.scales = np.array(scales) | ||
51 | + | ||
52 | + self.num_anchors = len(self.ratios) * len(self.scales) | ||
53 | + self.anchors = utils_anchors.generate_anchors( | ||
54 | + base_size=self.size, | ||
55 | + ratios=self.ratios, | ||
56 | + scales=self.scales, | ||
57 | + ).astype(np.float32) | ||
58 | + | ||
59 | + super(Anchors, self).__init__(*args, **kwargs) | ||
60 | + | ||
61 | + def call(self, inputs, **kwargs): | ||
62 | + features = inputs | ||
63 | + features_shape = keras.backend.shape(features) | ||
64 | + | ||
65 | + # generate proposals from bbox deltas and shifted anchors | ||
66 | + if keras.backend.image_data_format() == 'channels_first': | ||
67 | + anchors = backend.shift(features_shape[2:4], self.stride, self.anchors) | ||
68 | + else: | ||
69 | + anchors = backend.shift(features_shape[1:3], self.stride, self.anchors) | ||
70 | + anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1)) | ||
71 | + | ||
72 | + return anchors | ||
73 | + | ||
74 | + def compute_output_shape(self, input_shape): | ||
75 | + if None not in input_shape[1:]: | ||
76 | + if keras.backend.image_data_format() == 'channels_first': | ||
77 | + total = np.prod(input_shape[2:4]) * self.num_anchors | ||
78 | + else: | ||
79 | + total = np.prod(input_shape[1:3]) * self.num_anchors | ||
80 | + | ||
81 | + return (input_shape[0], total, 4) | ||
82 | + else: | ||
83 | + return (input_shape[0], None, 4) | ||
84 | + | ||
85 | + def get_config(self): | ||
86 | + config = super(Anchors, self).get_config() | ||
87 | + config.update({ | ||
88 | + 'size' : self.size, | ||
89 | + 'stride' : self.stride, | ||
90 | + 'ratios' : self.ratios.tolist(), | ||
91 | + 'scales' : self.scales.tolist(), | ||
92 | + }) | ||
93 | + | ||
94 | + return config | ||
95 | + | ||
96 | + | ||
97 | +class UpsampleLike(keras.layers.Layer): | ||
98 | + """ Keras layer for upsampling a Tensor to be the same shape as another Tensor. | ||
99 | + """ | ||
100 | + | ||
101 | + def call(self, inputs, **kwargs): | ||
102 | + source, target = inputs | ||
103 | + target_shape = keras.backend.shape(target) | ||
104 | + if keras.backend.image_data_format() == 'channels_first': | ||
105 | + source = tensorflow.transpose(source, (0, 2, 3, 1)) | ||
106 | + output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest') | ||
107 | + output = tensorflow.transpose(output, (0, 3, 1, 2)) | ||
108 | + return output | ||
109 | + else: | ||
110 | + return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest') | ||
111 | + | ||
112 | + def compute_output_shape(self, input_shape): | ||
113 | + if keras.backend.image_data_format() == 'channels_first': | ||
114 | + return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4] | ||
115 | + else: | ||
116 | + return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],) | ||
117 | + | ||
118 | + | ||
119 | +class RegressBoxes(keras.layers.Layer): | ||
120 | + """ Keras layer for applying regression values to boxes. | ||
121 | + """ | ||
122 | + | ||
123 | + def __init__(self, mean=None, std=None, *args, **kwargs): | ||
124 | + """ Initializer for the RegressBoxes layer. | ||
125 | + | ||
126 | + Args | ||
127 | + mean: The mean value of the regression values which was used for normalization. | ||
128 | + std: The standard value of the regression values which was used for normalization. | ||
129 | + """ | ||
130 | + if mean is None: | ||
131 | + mean = np.array([0, 0, 0, 0]) | ||
132 | + if std is None: | ||
133 | + std = np.array([0.2, 0.2, 0.2, 0.2]) | ||
134 | + | ||
135 | + if isinstance(mean, (list, tuple)): | ||
136 | + mean = np.array(mean) | ||
137 | + elif not isinstance(mean, np.ndarray): | ||
138 | + raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) | ||
139 | + | ||
140 | + if isinstance(std, (list, tuple)): | ||
141 | + std = np.array(std) | ||
142 | + elif not isinstance(std, np.ndarray): | ||
143 | + raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) | ||
144 | + | ||
145 | + self.mean = mean | ||
146 | + self.std = std | ||
147 | + super(RegressBoxes, self).__init__(*args, **kwargs) | ||
148 | + | ||
149 | + def call(self, inputs, **kwargs): | ||
150 | + anchors, regression = inputs | ||
151 | + return backend.bbox_transform_inv(anchors, regression, mean=self.mean, std=self.std) | ||
152 | + | ||
153 | + def compute_output_shape(self, input_shape): | ||
154 | + return input_shape[0] | ||
155 | + | ||
156 | + def get_config(self): | ||
157 | + config = super(RegressBoxes, self).get_config() | ||
158 | + config.update({ | ||
159 | + 'mean': self.mean.tolist(), | ||
160 | + 'std' : self.std.tolist(), | ||
161 | + }) | ||
162 | + | ||
163 | + return config | ||
164 | + | ||
165 | + | ||
166 | +class ClipBoxes(keras.layers.Layer): | ||
167 | + """ Keras layer to clip box values to lie inside a given shape. | ||
168 | + """ | ||
169 | + def call(self, inputs, **kwargs): | ||
170 | + image, boxes = inputs | ||
171 | + shape = keras.backend.cast(keras.backend.shape(image), keras.backend.floatx()) | ||
172 | + if keras.backend.image_data_format() == 'channels_first': | ||
173 | + _, _, height, width = tensorflow.unstack(shape, axis=0) | ||
174 | + else: | ||
175 | + _, height, width, _ = tensorflow.unstack(shape, axis=0) | ||
176 | + | ||
177 | + x1, y1, x2, y2 = tensorflow.unstack(boxes, axis=-1) | ||
178 | + x1 = tensorflow.clip_by_value(x1, 0, width - 1) | ||
179 | + y1 = tensorflow.clip_by_value(y1, 0, height - 1) | ||
180 | + x2 = tensorflow.clip_by_value(x2, 0, width - 1) | ||
181 | + y2 = tensorflow.clip_by_value(y2, 0, height - 1) | ||
182 | + | ||
183 | + return keras.backend.stack([x1, y1, x2, y2], axis=2) | ||
184 | + | ||
185 | + def compute_output_shape(self, input_shape): | ||
186 | + return input_shape[1] |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import tensorflow | ||
18 | +from tensorflow import keras | ||
19 | +from .. import backend | ||
20 | + | ||
21 | + | ||
22 | +def filter_detections( | ||
23 | + boxes, | ||
24 | + classification, | ||
25 | + other = [], | ||
26 | + class_specific_filter = True, | ||
27 | + nms = True, | ||
28 | + score_threshold = 0.05, | ||
29 | + max_detections = 300, | ||
30 | + nms_threshold = 0.5 | ||
31 | +): | ||
32 | + """ Filter detections using the boxes and classification values. | ||
33 | + | ||
34 | + Args | ||
35 | + boxes : Tensor of shape (num_boxes, 4) containing the boxes in (x1, y1, x2, y2) format. | ||
36 | + classification : Tensor of shape (num_boxes, num_classes) containing the classification scores. | ||
37 | + other : List of tensors of shape (num_boxes, ...) to filter along with the boxes and classification scores. | ||
38 | + class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those. | ||
39 | + nms : Flag to enable/disable non maximum suppression. | ||
40 | + score_threshold : Threshold used to prefilter the boxes with. | ||
41 | + max_detections : Maximum number of detections to keep. | ||
42 | + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. | ||
43 | + | ||
44 | + Returns | ||
45 | + A list of [boxes, scores, labels, other[0], other[1], ...]. | ||
46 | + boxes is shaped (max_detections, 4) and contains the (x1, y1, x2, y2) of the non-suppressed boxes. | ||
47 | + scores is shaped (max_detections,) and contains the scores of the predicted class. | ||
48 | + labels is shaped (max_detections,) and contains the predicted label. | ||
49 | + other[i] is shaped (max_detections, ...) and contains the filtered other[i] data. | ||
50 | + In case there are less than max_detections detections, the tensors are padded with -1's. | ||
51 | + """ | ||
52 | + def _filter_detections(scores, labels): | ||
53 | + # threshold based on score | ||
54 | + indices = tensorflow.where(keras.backend.greater(scores, score_threshold)) | ||
55 | + | ||
56 | + if nms: | ||
57 | + filtered_boxes = tensorflow.gather_nd(boxes, indices) | ||
58 | + filtered_scores = keras.backend.gather(scores, indices)[:, 0] | ||
59 | + | ||
60 | + # perform NMS | ||
61 | + nms_indices = tensorflow.image.non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold) | ||
62 | + | ||
63 | + # filter indices based on NMS | ||
64 | + indices = keras.backend.gather(indices, nms_indices) | ||
65 | + | ||
66 | + # add indices to list of all indices | ||
67 | + labels = tensorflow.gather_nd(labels, indices) | ||
68 | + indices = keras.backend.stack([indices[:, 0], labels], axis=1) | ||
69 | + | ||
70 | + return indices | ||
71 | + | ||
72 | + if class_specific_filter: | ||
73 | + all_indices = [] | ||
74 | + # perform per class filtering | ||
75 | + for c in range(int(classification.shape[1])): | ||
76 | + scores = classification[:, c] | ||
77 | + labels = c * tensorflow.ones((keras.backend.shape(scores)[0],), dtype='int64') | ||
78 | + all_indices.append(_filter_detections(scores, labels)) | ||
79 | + | ||
80 | + # concatenate indices to single tensor | ||
81 | + indices = keras.backend.concatenate(all_indices, axis=0) | ||
82 | + else: | ||
83 | + scores = keras.backend.max(classification, axis = 1) | ||
84 | + labels = keras.backend.argmax(classification, axis = 1) | ||
85 | + indices = _filter_detections(scores, labels) | ||
86 | + | ||
87 | + # select top k | ||
88 | + scores = tensorflow.gather_nd(classification, indices) | ||
89 | + labels = indices[:, 1] | ||
90 | + scores, top_indices = tensorflow.nn.top_k(scores, k=keras.backend.minimum(max_detections, keras.backend.shape(scores)[0])) | ||
91 | + | ||
92 | + # filter input using the final set of indices | ||
93 | + indices = keras.backend.gather(indices[:, 0], top_indices) | ||
94 | + boxes = keras.backend.gather(boxes, indices) | ||
95 | + labels = keras.backend.gather(labels, top_indices) | ||
96 | + other_ = [keras.backend.gather(o, indices) for o in other] | ||
97 | + | ||
98 | + # zero pad the outputs | ||
99 | + pad_size = keras.backend.maximum(0, max_detections - keras.backend.shape(scores)[0]) | ||
100 | + boxes = tensorflow.pad(boxes, [[0, pad_size], [0, 0]], constant_values=-1) | ||
101 | + scores = tensorflow.pad(scores, [[0, pad_size]], constant_values=-1) | ||
102 | + labels = tensorflow.pad(labels, [[0, pad_size]], constant_values=-1) | ||
103 | + labels = keras.backend.cast(labels, 'int32') | ||
104 | + other_ = [tensorflow.pad(o, [[0, pad_size]] + [[0, 0] for _ in range(1, len(o.shape))], constant_values=-1) for o in other_] | ||
105 | + | ||
106 | + # set shapes, since we know what they are | ||
107 | + boxes.set_shape([max_detections, 4]) | ||
108 | + scores.set_shape([max_detections]) | ||
109 | + labels.set_shape([max_detections]) | ||
110 | + for o, s in zip(other_, [list(keras.backend.int_shape(o)) for o in other]): | ||
111 | + o.set_shape([max_detections] + s[1:]) | ||
112 | + | ||
113 | + return [boxes, scores, labels] + other_ | ||
114 | + | ||
115 | + | ||
116 | +class FilterDetections(keras.layers.Layer): | ||
117 | + """ Keras layer for filtering detections using score threshold and NMS. | ||
118 | + """ | ||
119 | + | ||
120 | + def __init__( | ||
121 | + self, | ||
122 | + nms = True, | ||
123 | + class_specific_filter = True, | ||
124 | + nms_threshold = 0.5, | ||
125 | + score_threshold = 0.05, | ||
126 | + max_detections = 300, | ||
127 | + parallel_iterations = 32, | ||
128 | + **kwargs | ||
129 | + ): | ||
130 | + """ Filters detections using score threshold, NMS and selecting the top-k detections. | ||
131 | + | ||
132 | + Args | ||
133 | + nms : Flag to enable/disable NMS. | ||
134 | + class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those. | ||
135 | + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. | ||
136 | + score_threshold : Threshold used to prefilter the boxes with. | ||
137 | + max_detections : Maximum number of detections to keep. | ||
138 | + parallel_iterations : Number of batch items to process in parallel. | ||
139 | + """ | ||
140 | + self.nms = nms | ||
141 | + self.class_specific_filter = class_specific_filter | ||
142 | + self.nms_threshold = nms_threshold | ||
143 | + self.score_threshold = score_threshold | ||
144 | + self.max_detections = max_detections | ||
145 | + self.parallel_iterations = parallel_iterations | ||
146 | + super(FilterDetections, self).__init__(**kwargs) | ||
147 | + | ||
148 | + def call(self, inputs, **kwargs): | ||
149 | + """ Constructs the NMS graph. | ||
150 | + | ||
151 | + Args | ||
152 | + inputs : List of [boxes, classification, other[0], other[1], ...] tensors. | ||
153 | + """ | ||
154 | + boxes = inputs[0] | ||
155 | + classification = inputs[1] | ||
156 | + other = inputs[2:] | ||
157 | + | ||
158 | + # wrap nms with our parameters | ||
159 | + def _filter_detections(args): | ||
160 | + boxes = args[0] | ||
161 | + classification = args[1] | ||
162 | + other = args[2] | ||
163 | + | ||
164 | + return filter_detections( | ||
165 | + boxes, | ||
166 | + classification, | ||
167 | + other, | ||
168 | + nms = self.nms, | ||
169 | + class_specific_filter = self.class_specific_filter, | ||
170 | + score_threshold = self.score_threshold, | ||
171 | + max_detections = self.max_detections, | ||
172 | + nms_threshold = self.nms_threshold, | ||
173 | + ) | ||
174 | + | ||
175 | + # call filter_detections on each batch | ||
176 | + dtypes = [keras.backend.floatx(), keras.backend.floatx(), 'int32'] + [o.dtype for o in other] | ||
177 | + shapes = [(self.max_detections, 4), (self.max_detections,), (self.max_detections,)] | ||
178 | + shapes.extend([(self.max_detections,) + o.shape[2:] for o in other]) | ||
179 | + outputs = backend.map_fn( | ||
180 | + _filter_detections, | ||
181 | + elems=[boxes, classification, other], | ||
182 | + dtype=dtypes, | ||
183 | + shapes=shapes, | ||
184 | + parallel_iterations=self.parallel_iterations, | ||
185 | + ) | ||
186 | + | ||
187 | + return outputs | ||
188 | + | ||
189 | + def compute_output_shape(self, input_shape): | ||
190 | + """ Computes the output shapes given the input shapes. | ||
191 | + | ||
192 | + Args | ||
193 | + input_shape : List of input shapes [boxes, classification, other[0], other[1], ...]. | ||
194 | + | ||
195 | + Returns | ||
196 | + List of tuples representing the output shapes: | ||
197 | + [filtered_boxes.shape, filtered_scores.shape, filtered_labels.shape, filtered_other[0].shape, filtered_other[1].shape, ...] | ||
198 | + """ | ||
199 | + return [ | ||
200 | + (input_shape[0][0], self.max_detections, 4), | ||
201 | + (input_shape[1][0], self.max_detections), | ||
202 | + (input_shape[1][0], self.max_detections), | ||
203 | + ] + [ | ||
204 | + tuple([input_shape[i][0], self.max_detections] + list(input_shape[i][2:])) for i in range(2, len(input_shape)) | ||
205 | + ] | ||
206 | + | ||
207 | + def compute_mask(self, inputs, mask=None): | ||
208 | + """ This is required in Keras when there is more than 1 output. | ||
209 | + """ | ||
210 | + return (len(inputs) + 1) * [None] | ||
211 | + | ||
212 | + def get_config(self): | ||
213 | + """ Gets the configuration of this layer. | ||
214 | + | ||
215 | + Returns | ||
216 | + Dictionary containing the parameters of this layer. | ||
217 | + """ | ||
218 | + config = super(FilterDetections, self).get_config() | ||
219 | + config.update({ | ||
220 | + 'nms' : self.nms, | ||
221 | + 'class_specific_filter' : self.class_specific_filter, | ||
222 | + 'nms_threshold' : self.nms_threshold, | ||
223 | + 'score_threshold' : self.score_threshold, | ||
224 | + 'max_detections' : self.max_detections, | ||
225 | + 'parallel_iterations' : self.parallel_iterations, | ||
226 | + }) | ||
227 | + | ||
228 | + return config |
retinaNet/keras_retinanet/losses.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import tensorflow | ||
18 | +from tensorflow import keras | ||
19 | + | ||
20 | + | ||
21 | +def focal(alpha=0.25, gamma=2.0, cutoff=0.5): | ||
22 | + """ Create a functor for computing the focal loss. | ||
23 | + | ||
24 | + Args | ||
25 | + alpha: Scale the focal weight with alpha. | ||
26 | + gamma: Take the power of the focal weight with gamma. | ||
27 | + cutoff: Positive prediction cutoff for soft targets | ||
28 | + | ||
29 | + Returns | ||
30 | + A functor that computes the focal loss using the alpha and gamma. | ||
31 | + """ | ||
32 | + def _focal(y_true, y_pred): | ||
33 | + """ Compute the focal loss given the target tensor and the predicted tensor. | ||
34 | + | ||
35 | + As defined in https://arxiv.org/abs/1708.02002 | ||
36 | + | ||
37 | + Args | ||
38 | + y_true: Tensor of target data from the generator with shape (B, N, num_classes). | ||
39 | + y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). | ||
40 | + | ||
41 | + Returns | ||
42 | + The focal loss of y_pred w.r.t. y_true. | ||
43 | + """ | ||
44 | + labels = y_true[:, :, :-1] | ||
45 | + anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object | ||
46 | + classification = y_pred | ||
47 | + | ||
48 | + # filter out "ignore" anchors | ||
49 | + indices = tensorflow.where(keras.backend.not_equal(anchor_state, -1)) | ||
50 | + labels = tensorflow.gather_nd(labels, indices) | ||
51 | + classification = tensorflow.gather_nd(classification, indices) | ||
52 | + | ||
53 | + # compute the focal loss | ||
54 | + alpha_factor = keras.backend.ones_like(labels) * alpha | ||
55 | + alpha_factor = tensorflow.where(keras.backend.greater(labels, cutoff), alpha_factor, 1 - alpha_factor) | ||
56 | + focal_weight = tensorflow.where(keras.backend.greater(labels, cutoff), 1 - classification, classification) | ||
57 | + focal_weight = alpha_factor * focal_weight ** gamma | ||
58 | + | ||
59 | + cls_loss = focal_weight * keras.backend.binary_crossentropy(labels, classification) | ||
60 | + | ||
61 | + # compute the normalizer: the number of positive anchors | ||
62 | + normalizer = tensorflow.where(keras.backend.equal(anchor_state, 1)) | ||
63 | + normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx()) | ||
64 | + normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer) | ||
65 | + | ||
66 | + return keras.backend.sum(cls_loss) / normalizer | ||
67 | + | ||
68 | + return _focal | ||
69 | + | ||
70 | + | ||
71 | +def smooth_l1(sigma=3.0): | ||
72 | + """ Create a smooth L1 loss functor. | ||
73 | + | ||
74 | + Args | ||
75 | + sigma: This argument defines the point where the loss changes from L2 to L1. | ||
76 | + | ||
77 | + Returns | ||
78 | + A functor for computing the smooth L1 loss given target data and predicted data. | ||
79 | + """ | ||
80 | + sigma_squared = sigma ** 2 | ||
81 | + | ||
82 | + def _smooth_l1(y_true, y_pred): | ||
83 | + """ Compute the smooth L1 loss of y_pred w.r.t. y_true. | ||
84 | + | ||
85 | + Args | ||
86 | + y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). | ||
87 | + y_pred: Tensor from the network of shape (B, N, 4). | ||
88 | + | ||
89 | + Returns | ||
90 | + The smooth L1 loss of y_pred w.r.t. y_true. | ||
91 | + """ | ||
92 | + # separate target and state | ||
93 | + regression = y_pred | ||
94 | + regression_target = y_true[:, :, :-1] | ||
95 | + anchor_state = y_true[:, :, -1] | ||
96 | + | ||
97 | + # filter out "ignore" anchors | ||
98 | + indices = tensorflow.where(keras.backend.equal(anchor_state, 1)) | ||
99 | + regression = tensorflow.gather_nd(regression, indices) | ||
100 | + regression_target = tensorflow.gather_nd(regression_target, indices) | ||
101 | + | ||
102 | + # compute smooth L1 loss | ||
103 | + # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma | ||
104 | + # |x| - 0.5 / sigma / sigma otherwise | ||
105 | + regression_diff = regression - regression_target | ||
106 | + regression_diff = keras.backend.abs(regression_diff) | ||
107 | + regression_loss = tensorflow.where( | ||
108 | + keras.backend.less(regression_diff, 1.0 / sigma_squared), | ||
109 | + 0.5 * sigma_squared * keras.backend.pow(regression_diff, 2), | ||
110 | + regression_diff - 0.5 / sigma_squared | ||
111 | + ) | ||
112 | + | ||
113 | + # compute the normalizer: the number of positive anchors | ||
114 | + normalizer = keras.backend.maximum(1, keras.backend.shape(indices)[0]) | ||
115 | + normalizer = keras.backend.cast(normalizer, dtype=keras.backend.floatx()) | ||
116 | + return keras.backend.sum(regression_loss) / normalizer | ||
117 | + | ||
118 | + return _smooth_l1 |
retinaNet/keras_retinanet/models/__init__.py
0 → 100644
1 | +from __future__ import print_function | ||
2 | +import sys | ||
3 | + | ||
4 | + | ||
5 | +class Backbone(object): | ||
6 | + """ This class stores additional information on backbones. | ||
7 | + """ | ||
8 | + def __init__(self, backbone): | ||
9 | + # a dictionary mapping custom layer names to the correct classes | ||
10 | + from .. import layers | ||
11 | + from .. import losses | ||
12 | + from .. import initializers | ||
13 | + self.custom_objects = { | ||
14 | + 'UpsampleLike' : layers.UpsampleLike, | ||
15 | + 'PriorProbability' : initializers.PriorProbability, | ||
16 | + 'RegressBoxes' : layers.RegressBoxes, | ||
17 | + 'FilterDetections' : layers.FilterDetections, | ||
18 | + 'Anchors' : layers.Anchors, | ||
19 | + 'ClipBoxes' : layers.ClipBoxes, | ||
20 | + '_smooth_l1' : losses.smooth_l1(), | ||
21 | + '_focal' : losses.focal(), | ||
22 | + } | ||
23 | + | ||
24 | + self.backbone = backbone | ||
25 | + self.validate() | ||
26 | + | ||
27 | + def retinanet(self, *args, **kwargs): | ||
28 | + """ Returns a retinanet model using the correct backbone. | ||
29 | + """ | ||
30 | + raise NotImplementedError('retinanet method not implemented.') | ||
31 | + | ||
32 | + def download_imagenet(self): | ||
33 | + """ Downloads ImageNet weights and returns path to weights file. | ||
34 | + """ | ||
35 | + raise NotImplementedError('download_imagenet method not implemented.') | ||
36 | + | ||
37 | + def validate(self): | ||
38 | + """ Checks whether the backbone string is correct. | ||
39 | + """ | ||
40 | + raise NotImplementedError('validate method not implemented.') | ||
41 | + | ||
42 | + def preprocess_image(self, inputs): | ||
43 | + """ Takes as input an image and prepares it for being passed through the network. | ||
44 | + Having this function in Backbone allows other backbones to define a specific preprocessing step. | ||
45 | + """ | ||
46 | + raise NotImplementedError('preprocess_image method not implemented.') | ||
47 | + | ||
48 | + | ||
49 | +def backbone(backbone_name): | ||
50 | + """ Returns a backbone object for the given backbone. | ||
51 | + """ | ||
52 | + if 'densenet' in backbone_name: | ||
53 | + from .densenet import DenseNetBackbone as b | ||
54 | + elif 'seresnext' in backbone_name or 'seresnet' in backbone_name or 'senet' in backbone_name: | ||
55 | + from .senet import SeBackbone as b | ||
56 | + elif 'resnet' in backbone_name: | ||
57 | + from .resnet import ResNetBackbone as b | ||
58 | + elif 'mobilenet' in backbone_name: | ||
59 | + from .mobilenet import MobileNetBackbone as b | ||
60 | + elif 'vgg' in backbone_name: | ||
61 | + from .vgg import VGGBackbone as b | ||
62 | + elif 'EfficientNet' in backbone_name: | ||
63 | + from .effnet import EfficientNetBackbone as b | ||
64 | + else: | ||
65 | + raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone)) | ||
66 | + | ||
67 | + return b(backbone_name) | ||
68 | + | ||
69 | + | ||
70 | +def load_model(filepath, backbone_name='resnet50'): | ||
71 | + """ Loads a retinanet model using the correct custom objects. | ||
72 | + | ||
73 | + Args | ||
74 | + filepath: one of the following: | ||
75 | + - string, path to the saved model, or | ||
76 | + - h5py.File object from which to load the model | ||
77 | + backbone_name : Backbone with which the model was trained. | ||
78 | + | ||
79 | + Returns | ||
80 | + A keras.models.Model object. | ||
81 | + | ||
82 | + Raises | ||
83 | + ImportError: if h5py is not available. | ||
84 | + ValueError: In case of an invalid savefile. | ||
85 | + """ | ||
86 | + from tensorflow import keras | ||
87 | + return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects) | ||
88 | + | ||
89 | + | ||
90 | +def convert_model(model, nms=True, class_specific_filter=True, anchor_params=None, **kwargs): | ||
91 | + """ Converts a training model to an inference model. | ||
92 | + | ||
93 | + Args | ||
94 | + model : A retinanet training model. | ||
95 | + nms : Boolean, whether to add NMS filtering to the converted model. | ||
96 | + class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. | ||
97 | + anchor_params : Anchor parameters object. If omitted, default values are used. | ||
98 | + **kwargs : Inference and minimal retinanet model settings. | ||
99 | + | ||
100 | + Returns | ||
101 | + A keras.models.Model object. | ||
102 | + | ||
103 | + Raises | ||
104 | + ImportError: if h5py is not available. | ||
105 | + ValueError: In case of an invalid savefile. | ||
106 | + """ | ||
107 | + from .retinanet import retinanet_bbox | ||
108 | + return retinanet_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter, anchor_params=anchor_params, **kwargs) | ||
109 | + | ||
110 | + | ||
111 | +def assert_training_model(model): | ||
112 | + """ Assert that the model is a training model. | ||
113 | + """ | ||
114 | + assert(all(output in model.output_names for output in ['regression', 'classification'])), \ | ||
115 | + "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(model.output_names) | ||
116 | + | ||
117 | + | ||
118 | +def check_training_model(model): | ||
119 | + """ Check that model is a training model and exit otherwise. | ||
120 | + """ | ||
121 | + try: | ||
122 | + assert_training_model(model) | ||
123 | + except AssertionError as e: | ||
124 | + print(e, file=sys.stderr) | ||
125 | + sys.exit(1) |
retinaNet/keras_retinanet/models/densenet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2018 vidosits (https://github.com/vidosits/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | + | ||
19 | +from . import retinanet | ||
20 | +from . import Backbone | ||
21 | +from ..utils.image import preprocess_image | ||
22 | + | ||
23 | + | ||
24 | +allowed_backbones = { | ||
25 | + 'densenet121': ([6, 12, 24, 16], keras.applications.densenet.DenseNet121), | ||
26 | + 'densenet169': ([6, 12, 32, 32], keras.applications.densenet.DenseNet169), | ||
27 | + 'densenet201': ([6, 12, 48, 32], keras.applications.densenet.DenseNet201), | ||
28 | +} | ||
29 | + | ||
30 | + | ||
31 | +class DenseNetBackbone(Backbone): | ||
32 | + """ Describes backbone information and provides utility functions. | ||
33 | + """ | ||
34 | + | ||
35 | + def retinanet(self, *args, **kwargs): | ||
36 | + """ Returns a retinanet model using the correct backbone. | ||
37 | + """ | ||
38 | + return densenet_retinanet(*args, backbone=self.backbone, **kwargs) | ||
39 | + | ||
40 | + def download_imagenet(self): | ||
41 | + """ Download pre-trained weights for the specified backbone name. | ||
42 | + This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop | ||
43 | + where backbone is the densenet + number of layers (e.g. densenet121). | ||
44 | + For more info check the explanation from the keras densenet script itself: | ||
45 | + https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py | ||
46 | + """ | ||
47 | + origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' | ||
48 | + file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' | ||
49 | + | ||
50 | + # load weights | ||
51 | + if keras.backend.image_data_format() == 'channels_first': | ||
52 | + raise ValueError('Weights for "channels_first" format are not available.') | ||
53 | + | ||
54 | + weights_url = origin + file_name.format(self.backbone) | ||
55 | + return keras.utils.get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') | ||
56 | + | ||
57 | + def validate(self): | ||
58 | + """ Checks whether the backbone string is correct. | ||
59 | + """ | ||
60 | + backbone = self.backbone.split('_')[0] | ||
61 | + | ||
62 | + if backbone not in allowed_backbones: | ||
63 | + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys())) | ||
64 | + | ||
65 | + def preprocess_image(self, inputs): | ||
66 | + """ Takes as input an image and prepares it for being passed through the network. | ||
67 | + """ | ||
68 | + return preprocess_image(inputs, mode='tf') | ||
69 | + | ||
70 | + | ||
71 | +def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs): | ||
72 | + """ Constructs a retinanet model using a densenet backbone. | ||
73 | + | ||
74 | + Args | ||
75 | + num_classes: Number of classes to predict. | ||
76 | + backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). | ||
77 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
78 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
79 | + | ||
80 | + Returns | ||
81 | + RetinaNet model with a DenseNet backbone. | ||
82 | + """ | ||
83 | + # choose default input | ||
84 | + if inputs is None: | ||
85 | + inputs = keras.layers.Input((None, None, 3)) | ||
86 | + | ||
87 | + blocks, creator = allowed_backbones[backbone] | ||
88 | + model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None) | ||
89 | + | ||
90 | + # get last conv layer from the end of each dense block | ||
91 | + layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)] | ||
92 | + | ||
93 | + # create the densenet backbone | ||
94 | + # layer_outputs contains 4 layers | ||
95 | + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) | ||
96 | + | ||
97 | + # invoke modifier if given | ||
98 | + if modifier: | ||
99 | + model = modifier(model) | ||
100 | + | ||
101 | + # create the full model | ||
102 | + backbone_layers = { | ||
103 | + 'C2': model.outputs[0], | ||
104 | + 'C3': model.outputs[1], | ||
105 | + 'C4': model.outputs[2], | ||
106 | + 'C5': model.outputs[3] | ||
107 | + } | ||
108 | + | ||
109 | + model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) | ||
110 | + | ||
111 | + return model |
retinaNet/keras_retinanet/models/effnet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | + | ||
19 | +from . import retinanet | ||
20 | +from . import Backbone | ||
21 | +import efficientnet.keras as efn | ||
22 | + | ||
23 | + | ||
24 | +class EfficientNetBackbone(Backbone): | ||
25 | + """ Describes backbone information and provides utility functions. | ||
26 | + """ | ||
27 | + | ||
28 | + def __init__(self, backbone): | ||
29 | + super(EfficientNetBackbone, self).__init__(backbone) | ||
30 | + self.preprocess_image_func = None | ||
31 | + | ||
32 | + def retinanet(self, *args, **kwargs): | ||
33 | + """ Returns a retinanet model using the correct backbone. | ||
34 | + """ | ||
35 | + return effnet_retinanet(*args, backbone=self.backbone, **kwargs) | ||
36 | + | ||
37 | + def download_imagenet(self): | ||
38 | + """ Downloads ImageNet weights and returns path to weights file. | ||
39 | + """ | ||
40 | + from efficientnet.weights import IMAGENET_WEIGHTS_PATH | ||
41 | + from efficientnet.weights import IMAGENET_WEIGHTS_HASHES | ||
42 | + | ||
43 | + model_name = 'efficientnet-b' + self.backbone[-1] | ||
44 | + file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5' | ||
45 | + file_hash = IMAGENET_WEIGHTS_HASHES[model_name][1] | ||
46 | + weights_path = keras.utils.get_file(file_name, IMAGENET_WEIGHTS_PATH + file_name, cache_subdir='models', file_hash=file_hash) | ||
47 | + return weights_path | ||
48 | + | ||
49 | + def validate(self): | ||
50 | + """ Checks whether the backbone string is correct. | ||
51 | + """ | ||
52 | + allowed_backbones = ['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4', | ||
53 | + 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7'] | ||
54 | + backbone = self.backbone.split('_')[0] | ||
55 | + | ||
56 | + if backbone not in allowed_backbones: | ||
57 | + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) | ||
58 | + | ||
59 | + def preprocess_image(self, inputs): | ||
60 | + """ Takes as input an image and prepares it for being passed through the network. | ||
61 | + """ | ||
62 | + return efn.preprocess_input(inputs) | ||
63 | + | ||
64 | + | ||
65 | +def effnet_retinanet(num_classes, backbone='EfficientNetB0', inputs=None, modifier=None, **kwargs): | ||
66 | + """ Constructs a retinanet model using a resnet backbone. | ||
67 | + | ||
68 | + Args | ||
69 | + num_classes: Number of classes to predict. | ||
70 | + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). | ||
71 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
72 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
73 | + | ||
74 | + Returns | ||
75 | + RetinaNet model with a ResNet backbone. | ||
76 | + """ | ||
77 | + # choose default input | ||
78 | + if inputs is None: | ||
79 | + if keras.backend.image_data_format() == 'channels_first': | ||
80 | + inputs = keras.layers.Input(shape=(3, None, None)) | ||
81 | + else: | ||
82 | + # inputs = keras.layers.Input(shape=(224, 224, 3)) | ||
83 | + inputs = keras.layers.Input(shape=(None, None, 3)) | ||
84 | + | ||
85 | + # get last conv layer from the end of each block [28x28, 14x14, 7x7] | ||
86 | + if backbone == 'EfficientNetB0': | ||
87 | + model = efn.EfficientNetB0(input_tensor=inputs, include_top=False, weights=None) | ||
88 | + elif backbone == 'EfficientNetB1': | ||
89 | + model = efn.EfficientNetB1(input_tensor=inputs, include_top=False, weights=None) | ||
90 | + elif backbone == 'EfficientNetB2': | ||
91 | + model = efn.EfficientNetB2(input_tensor=inputs, include_top=False, weights=None) | ||
92 | + elif backbone == 'EfficientNetB3': | ||
93 | + model = efn.EfficientNetB3(input_tensor=inputs, include_top=False, weights=None) | ||
94 | + elif backbone == 'EfficientNetB4': | ||
95 | + model = efn.EfficientNetB4(input_tensor=inputs, include_top=False, weights=None) | ||
96 | + elif backbone == 'EfficientNetB5': | ||
97 | + model = efn.EfficientNetB5(input_tensor=inputs, include_top=False, weights=None) | ||
98 | + elif backbone == 'EfficientNetB6': | ||
99 | + model = efn.EfficientNetB6(input_tensor=inputs, include_top=False, weights=None) | ||
100 | + elif backbone == 'EfficientNetB7': | ||
101 | + model = efn.EfficientNetB7(input_tensor=inputs, include_top=False, weights=None) | ||
102 | + else: | ||
103 | + raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) | ||
104 | + | ||
105 | + layer_outputs = ['block4a_expand_activation', 'block6a_expand_activation', 'top_activation'] | ||
106 | + | ||
107 | + layer_outputs = [ | ||
108 | + model.get_layer(name=layer_outputs[0]).output, # 28x28 | ||
109 | + model.get_layer(name=layer_outputs[1]).output, # 14x14 | ||
110 | + model.get_layer(name=layer_outputs[2]).output, # 7x7 | ||
111 | + ] | ||
112 | + # create the densenet backbone | ||
113 | + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) | ||
114 | + | ||
115 | + # invoke modifier if given | ||
116 | + if modifier: | ||
117 | + model = modifier(model) | ||
118 | + | ||
119 | + # C2 not provided | ||
120 | + backbone_layers = { | ||
121 | + 'C3': model.outputs[0], | ||
122 | + 'C4': model.outputs[1], | ||
123 | + 'C5': model.outputs[2] | ||
124 | + } | ||
125 | + | ||
126 | + # create the full model | ||
127 | + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) | ||
128 | + | ||
129 | + | ||
130 | +def EfficientNetB0_retinanet(num_classes, inputs=None, **kwargs): | ||
131 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB0', inputs=inputs, **kwargs) | ||
132 | + | ||
133 | + | ||
134 | +def EfficientNetB1_retinanet(num_classes, inputs=None, **kwargs): | ||
135 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB1', inputs=inputs, **kwargs) | ||
136 | + | ||
137 | + | ||
138 | +def EfficientNetB2_retinanet(num_classes, inputs=None, **kwargs): | ||
139 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB2', inputs=inputs, **kwargs) | ||
140 | + | ||
141 | + | ||
142 | +def EfficientNetB3_retinanet(num_classes, inputs=None, **kwargs): | ||
143 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB3', inputs=inputs, **kwargs) | ||
144 | + | ||
145 | + | ||
146 | +def EfficientNetB4_retinanet(num_classes, inputs=None, **kwargs): | ||
147 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB4', inputs=inputs, **kwargs) | ||
148 | + | ||
149 | + | ||
150 | +def EfficientNetB5_retinanet(num_classes, inputs=None, **kwargs): | ||
151 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB5', inputs=inputs, **kwargs) | ||
152 | + | ||
153 | + | ||
154 | +def EfficientNetB6_retinanet(num_classes, inputs=None, **kwargs): | ||
155 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB6', inputs=inputs, **kwargs) | ||
156 | + | ||
157 | + | ||
158 | +def EfficientNetB7_retinanet(num_classes, inputs=None, **kwargs): | ||
159 | + return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB7', inputs=inputs, **kwargs) |
1 | +""" | ||
2 | +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +from ..utils.image import preprocess_image | ||
19 | + | ||
20 | +from . import retinanet | ||
21 | +from . import Backbone | ||
22 | + | ||
23 | + | ||
24 | +class MobileNetBackbone(Backbone): | ||
25 | + """ Describes backbone information and provides utility functions. | ||
26 | + """ | ||
27 | + | ||
28 | + allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224'] | ||
29 | + | ||
30 | + def retinanet(self, *args, **kwargs): | ||
31 | + """ Returns a retinanet model using the correct backbone. | ||
32 | + """ | ||
33 | + return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs) | ||
34 | + | ||
35 | + def download_imagenet(self): | ||
36 | + """ Download pre-trained weights for the specified backbone name. | ||
37 | + This name is in the format mobilenet{rows}_{alpha} where rows is the | ||
38 | + imagenet shape dimension and 'alpha' controls the width of the network. | ||
39 | + For more info check the explanation from the keras mobilenet script itself. | ||
40 | + """ | ||
41 | + | ||
42 | + alpha = float(self.backbone.split('_')[1]) | ||
43 | + rows = int(self.backbone.split('_')[0].replace('mobilenet', '')) | ||
44 | + | ||
45 | + # load weights | ||
46 | + if keras.backend.image_data_format() == 'channels_first': | ||
47 | + raise ValueError('Weights for "channels_last" format ' | ||
48 | + 'are not available.') | ||
49 | + if alpha == 1.0: | ||
50 | + alpha_text = '1_0' | ||
51 | + elif alpha == 0.75: | ||
52 | + alpha_text = '7_5' | ||
53 | + elif alpha == 0.50: | ||
54 | + alpha_text = '5_0' | ||
55 | + else: | ||
56 | + alpha_text = '2_5' | ||
57 | + | ||
58 | + model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows) | ||
59 | + weights_url = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' + model_name | ||
60 | + weights_path = keras.utils.get_file(model_name, weights_url, cache_subdir='models') | ||
61 | + | ||
62 | + return weights_path | ||
63 | + | ||
64 | + def validate(self): | ||
65 | + """ Checks whether the backbone string is correct. | ||
66 | + """ | ||
67 | + backbone = self.backbone.split('_')[0] | ||
68 | + | ||
69 | + if backbone not in MobileNetBackbone.allowed_backbones: | ||
70 | + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones)) | ||
71 | + | ||
72 | + def preprocess_image(self, inputs): | ||
73 | + """ Takes as input an image and prepares it for being passed through the network. | ||
74 | + """ | ||
75 | + return preprocess_image(inputs, mode='tf') | ||
76 | + | ||
77 | + | ||
78 | +def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs): | ||
79 | + """ Constructs a retinanet model using a mobilenet backbone. | ||
80 | + | ||
81 | + Args | ||
82 | + num_classes: Number of classes to predict. | ||
83 | + backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')). | ||
84 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
85 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
86 | + | ||
87 | + Returns | ||
88 | + RetinaNet model with a MobileNet backbone. | ||
89 | + """ | ||
90 | + alpha = float(backbone.split('_')[1]) | ||
91 | + | ||
92 | + # choose default input | ||
93 | + if inputs is None: | ||
94 | + inputs = keras.layers.Input((None, None, 3)) | ||
95 | + | ||
96 | + backbone = keras.applications.mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None) | ||
97 | + | ||
98 | + # create the full model | ||
99 | + layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu'] | ||
100 | + layer_outputs = [backbone.get_layer(name).output for name in layer_names] | ||
101 | + backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name) | ||
102 | + | ||
103 | + # invoke modifier if given | ||
104 | + if modifier: | ||
105 | + backbone = modifier(backbone) | ||
106 | + | ||
107 | + # C2 not provided | ||
108 | + backbone_layers = { | ||
109 | + 'C3': backbone.outputs[0], | ||
110 | + 'C4': backbone.outputs[1], | ||
111 | + 'C5': backbone.outputs[2] | ||
112 | + } | ||
113 | + | ||
114 | + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) |
retinaNet/keras_retinanet/models/resnet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +import keras_resnet | ||
19 | +import keras_resnet.models | ||
20 | + | ||
21 | +from . import retinanet | ||
22 | +from . import Backbone | ||
23 | +from ..utils.image import preprocess_image | ||
24 | + | ||
25 | + | ||
26 | +class ResNetBackbone(Backbone): | ||
27 | + """Describes backbone information and provides utility functions.""" | ||
28 | + | ||
29 | + def __init__(self, backbone): | ||
30 | + super(ResNetBackbone, self).__init__(backbone) | ||
31 | + self.custom_objects.update(keras_resnet.custom_objects) | ||
32 | + | ||
33 | + def retinanet(self, *args, **kwargs): | ||
34 | + """Returns a retinanet model using the correct backbone.""" | ||
35 | + return resnet_retinanet(*args, backbone=self.backbone, **kwargs) | ||
36 | + | ||
37 | + def download_imagenet(self): | ||
38 | + """Downloads ImageNet weights and returns path to weights file.""" | ||
39 | + resnet_filename = "ResNet-{}-model.keras.h5" | ||
40 | + resnet_resource = ( | ||
41 | + "https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}".format( | ||
42 | + resnet_filename | ||
43 | + ) | ||
44 | + ) | ||
45 | + depth = int(self.backbone.replace("resnet", "")) | ||
46 | + | ||
47 | + filename = resnet_filename.format(depth) | ||
48 | + resource = resnet_resource.format(depth) | ||
49 | + if depth == 50: | ||
50 | + checksum = "3e9f4e4f77bbe2c9bec13b53ee1c2319" | ||
51 | + elif depth == 101: | ||
52 | + checksum = "05dc86924389e5b401a9ea0348a3213c" | ||
53 | + elif depth == 152: | ||
54 | + checksum = "6ee11ef2b135592f8031058820bb9e71" | ||
55 | + | ||
56 | + return keras.utils.get_file( | ||
57 | + filename, resource, cache_subdir="models", md5_hash=checksum | ||
58 | + ) | ||
59 | + | ||
60 | + def validate(self): | ||
61 | + """Checks whether the backbone string is correct.""" | ||
62 | + allowed_backbones = ["resnet50", "resnet101", "resnet152"] | ||
63 | + backbone = self.backbone.split("_")[0] | ||
64 | + | ||
65 | + if backbone not in allowed_backbones: | ||
66 | + raise ValueError( | ||
67 | + "Backbone ('{}') not in allowed backbones ({}).".format( | ||
68 | + backbone, allowed_backbones | ||
69 | + ) | ||
70 | + ) | ||
71 | + | ||
72 | + def preprocess_image(self, inputs): | ||
73 | + """Takes as input an image and prepares it for being passed through the network.""" | ||
74 | + return preprocess_image(inputs, mode="caffe") | ||
75 | + | ||
76 | + | ||
77 | +def resnet_retinanet( | ||
78 | + num_classes, backbone="resnet50", inputs=None, modifier=None, **kwargs | ||
79 | +): | ||
80 | + """Constructs a retinanet model using a resnet backbone. | ||
81 | + | ||
82 | + Args | ||
83 | + num_classes: Number of classes to predict. | ||
84 | + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). | ||
85 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
86 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
87 | + | ||
88 | + Returns | ||
89 | + RetinaNet model with a ResNet backbone. | ||
90 | + """ | ||
91 | + # choose default input | ||
92 | + if inputs is None: | ||
93 | + if keras.backend.image_data_format() == "channels_first": | ||
94 | + inputs = keras.layers.Input(shape=(3, None, None)) | ||
95 | + else: | ||
96 | + inputs = keras.layers.Input(shape=(None, None, 3)) | ||
97 | + | ||
98 | + # create the resnet backbone | ||
99 | + if backbone == "resnet50": | ||
100 | + resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True) | ||
101 | + elif backbone == "resnet101": | ||
102 | + resnet = keras_resnet.models.ResNet101( | ||
103 | + inputs, include_top=False, freeze_bn=True | ||
104 | + ) | ||
105 | + elif backbone == "resnet152": | ||
106 | + resnet = keras_resnet.models.ResNet152( | ||
107 | + inputs, include_top=False, freeze_bn=True | ||
108 | + ) | ||
109 | + else: | ||
110 | + raise ValueError("Backbone ('{}') is invalid.".format(backbone)) | ||
111 | + | ||
112 | + # invoke modifier if given | ||
113 | + if modifier: | ||
114 | + resnet = modifier(resnet) | ||
115 | + | ||
116 | + # create the full model | ||
117 | + # resnet.outputs contains 4 layers | ||
118 | + backbone_layers = { | ||
119 | + "C2": resnet.outputs[0], | ||
120 | + "C3": resnet.outputs[1], | ||
121 | + "C4": resnet.outputs[2], | ||
122 | + "C5": resnet.outputs[3], | ||
123 | + } | ||
124 | + | ||
125 | + return retinanet.retinanet( | ||
126 | + inputs=inputs, | ||
127 | + num_classes=num_classes, | ||
128 | + backbone_layers=backbone_layers, | ||
129 | + **kwargs | ||
130 | + ) | ||
131 | + | ||
132 | + | ||
133 | +def resnet50_retinanet(num_classes, inputs=None, **kwargs): | ||
134 | + return resnet_retinanet( | ||
135 | + num_classes=num_classes, backbone="resnet50", inputs=inputs, **kwargs | ||
136 | + ) | ||
137 | + | ||
138 | + | ||
139 | +def resnet101_retinanet(num_classes, inputs=None, **kwargs): | ||
140 | + return resnet_retinanet( | ||
141 | + num_classes=num_classes, backbone="resnet101", inputs=inputs, **kwargs | ||
142 | + ) | ||
143 | + | ||
144 | + | ||
145 | +def resnet152_retinanet(num_classes, inputs=None, **kwargs): | ||
146 | + return resnet_retinanet( | ||
147 | + num_classes=num_classes, backbone="resnet152", inputs=inputs, **kwargs | ||
148 | + ) |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +from .. import initializers | ||
19 | +from .. import layers | ||
20 | +from ..utils.anchors import AnchorParameters | ||
21 | +from . import assert_training_model | ||
22 | + | ||
23 | + | ||
24 | +def default_classification_model( | ||
25 | + num_classes, | ||
26 | + num_anchors, | ||
27 | + pyramid_feature_size=256, | ||
28 | + prior_probability=0.01, | ||
29 | + classification_feature_size=256, | ||
30 | + name="classification_submodel", | ||
31 | +): | ||
32 | + """Creates the default classification submodel. | ||
33 | + | ||
34 | + Args | ||
35 | + num_classes : Number of classes to predict a score for at each feature level. | ||
36 | + num_anchors : Number of anchors to predict classification scores for at each feature level. | ||
37 | + pyramid_feature_size : The number of filters to expect from the feature pyramid levels. | ||
38 | + classification_feature_size : The number of filters to use in the layers in the classification submodel. | ||
39 | + name : The name of the submodel. | ||
40 | + | ||
41 | + Returns | ||
42 | + A keras.models.Model that predicts classes for each anchor. | ||
43 | + """ | ||
44 | + options = { | ||
45 | + "kernel_size": 3, | ||
46 | + "strides": 1, | ||
47 | + "padding": "same", | ||
48 | + } | ||
49 | + | ||
50 | + # set input | ||
51 | + if keras.backend.image_data_format() == "channels_first": | ||
52 | + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None)) | ||
53 | + else: | ||
54 | + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size)) | ||
55 | + | ||
56 | + outputs = inputs | ||
57 | + | ||
58 | + # 4 layer | ||
59 | + for i in range(4): | ||
60 | + | ||
61 | + # 각 층의 output | ||
62 | + outputs = keras.layers.Conv2D( | ||
63 | + filters=classification_feature_size, | ||
64 | + activation="relu", | ||
65 | + name="pyramid_classification_{}".format(i), | ||
66 | + kernel_initializer=keras.initializers.RandomNormal( | ||
67 | + mean=0.0, stddev=0.01, seed=None | ||
68 | + ), # 정규분포에 따라 텐서를 생성하는 초기값 설정 | ||
69 | + bias_initializer="zeros", | ||
70 | + **options | ||
71 | + )(outputs) | ||
72 | + | ||
73 | + # 마지막 layer는 다른 필터로 다른 conv layer를 생성 | ||
74 | + outputs = keras.layers.Conv2D( | ||
75 | + filters=num_classes * num_anchors, | ||
76 | + kernel_initializer=keras.initializers.RandomNormal( | ||
77 | + mean=0.0, stddev=0.01, seed=None | ||
78 | + ), | ||
79 | + bias_initializer=initializers.PriorProbability(probability=prior_probability), | ||
80 | + name="pyramid_classification", | ||
81 | + **options | ||
82 | + )(outputs) | ||
83 | + | ||
84 | + # reshape output and apply sigmoid | ||
85 | + if keras.backend.image_data_format() == "channels_first": | ||
86 | + outputs = keras.layers.Permute( | ||
87 | + (2, 3, 1), name="pyramid_classification_permute" | ||
88 | + )(outputs) | ||
89 | + | ||
90 | + # reshape : 2차원 > 1차원 | ||
91 | + outputs = keras.layers.Reshape( | ||
92 | + (-1, num_classes), name="pyramid_classification_reshape" | ||
93 | + )(outputs) | ||
94 | + | ||
95 | + # output layer activation : sigmoid | ||
96 | + outputs = keras.layers.Activation("sigmoid", name="pyramid_classification_sigmoid")( | ||
97 | + outputs | ||
98 | + ) | ||
99 | + | ||
100 | + return keras.models.Model(inputs=inputs, outputs=outputs, name=name) | ||
101 | + | ||
102 | + | ||
103 | +def default_regression_model( | ||
104 | + num_values, | ||
105 | + num_anchors, | ||
106 | + pyramid_feature_size=256, | ||
107 | + regression_feature_size=256, | ||
108 | + name="regression_submodel", | ||
109 | +): | ||
110 | + """Creates the default regression submodel. | ||
111 | + | ||
112 | + Args | ||
113 | + num_values : Number of values to regress. | ||
114 | + num_anchors : Number of anchors to regress for each feature level. | ||
115 | + pyramid_feature_size : The number of filters to expect from the feature pyramid levels. | ||
116 | + regression_feature_size : The number of filters to use in the layers in the regression submodel. | ||
117 | + name : The name of the submodel. | ||
118 | + | ||
119 | + Returns | ||
120 | + A keras.models.Model that predicts regression values for each anchor. | ||
121 | + """ | ||
122 | + # All new conv layers except the final one in the | ||
123 | + # RetinaNet (classification) subnets are initialized | ||
124 | + # with bias b = 0 and a Gaussian weight fill with stddev = 0.01. | ||
125 | + options = { | ||
126 | + "kernel_size": 3, | ||
127 | + "strides": 1, | ||
128 | + "padding": "same", | ||
129 | + "kernel_initializer": keras.initializers.RandomNormal( | ||
130 | + mean=0.0, stddev=0.01, seed=None | ||
131 | + ), | ||
132 | + "bias_initializer": "zeros", | ||
133 | + } | ||
134 | + | ||
135 | + if keras.backend.image_data_format() == "channels_first": | ||
136 | + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None)) | ||
137 | + else: | ||
138 | + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size)) | ||
139 | + outputs = inputs | ||
140 | + for i in range(4): | ||
141 | + outputs = keras.layers.Conv2D( | ||
142 | + filters=regression_feature_size, | ||
143 | + activation="relu", | ||
144 | + name="pyramid_regression_{}".format(i), | ||
145 | + **options | ||
146 | + )(outputs) | ||
147 | + | ||
148 | + outputs = keras.layers.Conv2D( | ||
149 | + num_anchors * num_values, name="pyramid_regression", **options | ||
150 | + )(outputs) | ||
151 | + if keras.backend.image_data_format() == "channels_first": | ||
152 | + outputs = keras.layers.Permute((2, 3, 1), name="pyramid_regression_permute")( | ||
153 | + outputs | ||
154 | + ) | ||
155 | + outputs = keras.layers.Reshape((-1, num_values), name="pyramid_regression_reshape")( | ||
156 | + outputs | ||
157 | + ) | ||
158 | + | ||
159 | + return keras.models.Model(inputs=inputs, outputs=outputs, name=name) | ||
160 | + | ||
161 | + | ||
162 | +def __create_pyramid_features(backbone_layers, pyramid_levels, feature_size=256): | ||
163 | + """Creates the FPN layers on top of the backbone features. | ||
164 | + | ||
165 | + Args | ||
166 | + backbone_layers: a dictionary containing feature stages C3, C4, C5 from the backbone. Also contains C2 if provided. | ||
167 | + pyramid_levels: Pyramid levels in use. | ||
168 | + feature_size : The feature size to use for the resulting feature levels. | ||
169 | + | ||
170 | + Returns | ||
171 | + output_layers : A dict of feature levels. P3, P4, P5, P6 are always included. P2, P6, P7 included if in use. | ||
172 | + """ | ||
173 | + | ||
174 | + output_layers = {} | ||
175 | + | ||
176 | + # upsample C5 to get P5 from the FPN paper | ||
177 | + P5 = keras.layers.Conv2D( | ||
178 | + feature_size, kernel_size=1, strides=1, padding="same", name="C5_reduced" | ||
179 | + )(backbone_layers["C5"]) | ||
180 | + P5_upsampled = layers.UpsampleLike(name="P5_upsampled")([P5, backbone_layers["C4"]]) | ||
181 | + P5 = keras.layers.Conv2D( | ||
182 | + feature_size, kernel_size=3, strides=1, padding="same", name="P5" | ||
183 | + )(P5) | ||
184 | + output_layers["P5"] = P5 | ||
185 | + | ||
186 | + # add P5 elementwise to C4 | ||
187 | + P4 = keras.layers.Conv2D( | ||
188 | + feature_size, kernel_size=1, strides=1, padding="same", name="C4_reduced" | ||
189 | + )(backbone_layers["C4"]) | ||
190 | + P4 = keras.layers.Add(name="P4_merged")([P5_upsampled, P4]) | ||
191 | + P4_upsampled = layers.UpsampleLike(name="P4_upsampled")([P4, backbone_layers["C3"]]) | ||
192 | + P4 = keras.layers.Conv2D( | ||
193 | + feature_size, kernel_size=3, strides=1, padding="same", name="P4" | ||
194 | + )(P4) | ||
195 | + output_layers["P4"] = P4 | ||
196 | + | ||
197 | + # add P4 elementwise to C3 | ||
198 | + P3 = keras.layers.Conv2D( | ||
199 | + feature_size, kernel_size=1, strides=1, padding="same", name="C3_reduced" | ||
200 | + )(backbone_layers["C3"]) | ||
201 | + P3 = keras.layers.Add(name="P3_merged")([P4_upsampled, P3]) | ||
202 | + if "C2" in backbone_layers and 2 in pyramid_levels: | ||
203 | + P3_upsampled = layers.UpsampleLike(name="P3_upsampled")( | ||
204 | + [P3, backbone_layers["C2"]] | ||
205 | + ) | ||
206 | + P3 = keras.layers.Conv2D( | ||
207 | + feature_size, kernel_size=3, strides=1, padding="same", name="P3" | ||
208 | + )(P3) | ||
209 | + output_layers["P3"] = P3 | ||
210 | + | ||
211 | + if "C2" in backbone_layers and 2 in pyramid_levels: | ||
212 | + P2 = keras.layers.Conv2D( | ||
213 | + feature_size, kernel_size=1, strides=1, padding="same", name="C2_reduced" | ||
214 | + )(backbone_layers["C2"]) | ||
215 | + P2 = keras.layers.Add(name="P2_merged")([P3_upsampled, P2]) | ||
216 | + P2 = keras.layers.Conv2D( | ||
217 | + feature_size, kernel_size=3, strides=1, padding="same", name="P2" | ||
218 | + )(P2) | ||
219 | + output_layers["P2"] = P2 | ||
220 | + | ||
221 | + # "P6 is obtained via a 3x3 stride-2 conv on C5" | ||
222 | + if 6 in pyramid_levels: | ||
223 | + P6 = keras.layers.Conv2D( | ||
224 | + feature_size, kernel_size=3, strides=2, padding="same", name="P6" | ||
225 | + )(backbone_layers["C5"]) | ||
226 | + output_layers["P6"] = P6 | ||
227 | + | ||
228 | + # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6" | ||
229 | + if 7 in pyramid_levels: | ||
230 | + if 6 not in pyramid_levels: | ||
231 | + raise ValueError("P6 is required to use P7") | ||
232 | + P7 = keras.layers.Activation("relu", name="C6_relu")(P6) | ||
233 | + P7 = keras.layers.Conv2D( | ||
234 | + feature_size, kernel_size=3, strides=2, padding="same", name="P7" | ||
235 | + )(P7) | ||
236 | + output_layers["P7"] = P7 | ||
237 | + | ||
238 | + return output_layers | ||
239 | + | ||
240 | + | ||
241 | +def default_submodels(num_classes, num_anchors): | ||
242 | + """Create a list of default submodels used for object detection. | ||
243 | + | ||
244 | + The default submodels contains a regression submodel and a classification submodel. | ||
245 | + | ||
246 | + Args | ||
247 | + num_classes : Number of classes to use. | ||
248 | + num_anchors : Number of base anchors. | ||
249 | + | ||
250 | + Returns | ||
251 | + A list of tuple, where the first element is the name of the submodel and the second element is the submodel itself. | ||
252 | + """ | ||
253 | + return [ | ||
254 | + ("regression", default_regression_model(4, num_anchors)), | ||
255 | + ("classification", default_classification_model(num_classes, num_anchors)), | ||
256 | + ] | ||
257 | + | ||
258 | + | ||
259 | +def __build_model_pyramid(name, model, features): | ||
260 | + """Applies a single submodel to each FPN level. | ||
261 | + | ||
262 | + Args | ||
263 | + name : Name of the submodel. | ||
264 | + model : The submodel to evaluate. | ||
265 | + features : The FPN features. | ||
266 | + | ||
267 | + Returns | ||
268 | + A tensor containing the response from the submodel on the FPN features. | ||
269 | + """ | ||
270 | + return keras.layers.Concatenate(axis=1, name=name)([model(f) for f in features]) | ||
271 | + | ||
272 | + | ||
273 | +def __build_pyramid(models, features): | ||
274 | + """Applies all submodels to each FPN level. | ||
275 | + | ||
276 | + Args | ||
277 | + models : List of submodels to run on each pyramid level (by default only regression, classifcation). | ||
278 | + features : The FPN features. | ||
279 | + | ||
280 | + Returns | ||
281 | + A list of tensors, one for each submodel. | ||
282 | + """ | ||
283 | + return [__build_model_pyramid(n, m, features) for n, m in models] | ||
284 | + | ||
285 | + | ||
286 | +def __build_anchors(anchor_parameters, features): | ||
287 | + """Builds anchors for the shape of the features from FPN. | ||
288 | + | ||
289 | + Args | ||
290 | + anchor_parameters : Parameteres that determine how anchors are generated. | ||
291 | + features : The FPN features. | ||
292 | + | ||
293 | + Returns | ||
294 | + A tensor containing the anchors for the FPN features. | ||
295 | + | ||
296 | + The shape is: | ||
297 | + ``` | ||
298 | + (batch_size, num_anchors, 4) | ||
299 | + ``` | ||
300 | + """ | ||
301 | + anchors = [ | ||
302 | + layers.Anchors( | ||
303 | + size=anchor_parameters.sizes[i], | ||
304 | + stride=anchor_parameters.strides[i], | ||
305 | + ratios=anchor_parameters.ratios, | ||
306 | + scales=anchor_parameters.scales, | ||
307 | + name="anchors_{}".format(i), | ||
308 | + )(f) | ||
309 | + for i, f in enumerate(features) | ||
310 | + ] | ||
311 | + | ||
312 | + return keras.layers.Concatenate(axis=1, name="anchors")(anchors) | ||
313 | + | ||
314 | + | ||
315 | +def retinanet( | ||
316 | + inputs, | ||
317 | + backbone_layers, | ||
318 | + num_classes, | ||
319 | + num_anchors=None, | ||
320 | + create_pyramid_features=__create_pyramid_features, | ||
321 | + pyramid_levels=None, | ||
322 | + submodels=None, | ||
323 | + name="retinanet", | ||
324 | +): | ||
325 | + """Construct a RetinaNet model on top of a backbone. | ||
326 | + | ||
327 | + This model is the minimum model necessary for training (with the unfortunate exception of anchors as output). | ||
328 | + | ||
329 | + Args | ||
330 | + inputs : keras.layers.Input (or list of) for the input to the model. | ||
331 | + num_classes : Number of classes to classify. | ||
332 | + num_anchors : Number of base anchors. | ||
333 | + create_pyramid_features : Functor for creating pyramid features given the features C3, C4, C5, and possibly C2 from the backbone. | ||
334 | + pyramid_levels : pyramid levels to use. | ||
335 | + submodels : Submodels to run on each feature map (default is regression and classification submodels). | ||
336 | + name : Name of the model. | ||
337 | + | ||
338 | + Returns | ||
339 | + A keras.models.Model which takes an image as input and outputs generated anchors and the result from each submodel on every pyramid level. | ||
340 | + | ||
341 | + The order of the outputs is as defined in submodels: | ||
342 | + ``` | ||
343 | + [ | ||
344 | + regression, classification, other[0], other[1], ... | ||
345 | + ] | ||
346 | + ``` | ||
347 | + """ | ||
348 | + | ||
349 | + if num_anchors is None: | ||
350 | + num_anchors = AnchorParameters.default.num_anchors() | ||
351 | + | ||
352 | + if submodels is None: | ||
353 | + submodels = default_submodels(num_classes, num_anchors) | ||
354 | + | ||
355 | + if pyramid_levels is None: | ||
356 | + pyramid_levels = [3, 4, 5, 6, 7] | ||
357 | + | ||
358 | + if 2 in pyramid_levels and "C2" not in backbone_layers: | ||
359 | + raise ValueError("C2 not provided by backbone model. Cannot create P2 layers.") | ||
360 | + | ||
361 | + if 3 not in pyramid_levels or 4 not in pyramid_levels or 5 not in pyramid_levels: | ||
362 | + raise ValueError("pyramid levels 3, 4, and 5 required for functionality") | ||
363 | + | ||
364 | + # compute pyramid features as per https://arxiv.org/abs/1708.02002 | ||
365 | + features = create_pyramid_features(backbone_layers, pyramid_levels) | ||
366 | + feature_list = [features["P{}".format(p)] for p in pyramid_levels] | ||
367 | + | ||
368 | + # for all pyramid levels, run available submodels | ||
369 | + pyramids = __build_pyramid(submodels, feature_list) | ||
370 | + | ||
371 | + return keras.models.Model(inputs=inputs, outputs=pyramids, name=name) | ||
372 | + | ||
373 | + | ||
374 | +def retinanet_bbox( | ||
375 | + model=None, | ||
376 | + nms=True, | ||
377 | + class_specific_filter=True, | ||
378 | + name="retinanet-bbox", | ||
379 | + anchor_params=None, | ||
380 | + pyramid_levels=None, | ||
381 | + nms_threshold=0.5, | ||
382 | + score_threshold=0.05, | ||
383 | + max_detections=300, | ||
384 | + parallel_iterations=32, | ||
385 | + **kwargs | ||
386 | +): | ||
387 | + """Construct a RetinaNet model on top of a backbone and adds convenience functions to output boxes directly. | ||
388 | + | ||
389 | + This model uses the minimum retinanet model and appends a few layers to compute boxes within the graph. | ||
390 | + These layers include applying the regression values to the anchors and performing NMS. | ||
391 | + | ||
392 | + Args | ||
393 | + model : RetinaNet model to append bbox layers to. If None, it will create a RetinaNet model using **kwargs. | ||
394 | + nms : Whether to use non-maximum suppression for the filtering step. | ||
395 | + class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. | ||
396 | + name : Name of the model. | ||
397 | + anchor_params : Struct containing anchor parameters. If None, default values are used. | ||
398 | + pyramid_levels : pyramid levels to use. | ||
399 | + nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. | ||
400 | + score_threshold : Threshold used to prefilter the boxes with. | ||
401 | + max_detections : Maximum number of detections to keep. | ||
402 | + parallel_iterations : Number of batch items to process in parallel. | ||
403 | + **kwargs : Additional kwargs to pass to the minimal retinanet model. | ||
404 | + | ||
405 | + Returns | ||
406 | + A keras.models.Model which takes an image as input and outputs the detections on the image. | ||
407 | + | ||
408 | + The order is defined as follows: | ||
409 | + ``` | ||
410 | + [ | ||
411 | + boxes, scores, labels, other[0], other[1], ... | ||
412 | + ] | ||
413 | + ``` | ||
414 | + """ | ||
415 | + | ||
416 | + # if no anchor parameters are passed, use default values | ||
417 | + if anchor_params is None: | ||
418 | + anchor_params = AnchorParameters.default | ||
419 | + | ||
420 | + # create RetinaNet model | ||
421 | + if model is None: | ||
422 | + model = retinanet(num_anchors=anchor_params.num_anchors(), **kwargs) | ||
423 | + else: | ||
424 | + assert_training_model(model) | ||
425 | + | ||
426 | + if pyramid_levels is None: | ||
427 | + pyramid_levels = [3, 4, 5, 6, 7] | ||
428 | + | ||
429 | + assert len(pyramid_levels) == len( | ||
430 | + anchor_params.sizes | ||
431 | + ), "number of pyramid levels {} should match number of anchor parameter sizes {}".format( | ||
432 | + len(pyramid_levels), len(anchor_params.sizes) | ||
433 | + ) | ||
434 | + | ||
435 | + pyramid_layer_names = ["P{}".format(p) for p in pyramid_levels] | ||
436 | + # compute the anchors | ||
437 | + features = [model.get_layer(p_name).output for p_name in pyramid_layer_names] | ||
438 | + anchors = __build_anchors(anchor_params, features) | ||
439 | + | ||
440 | + # we expect the anchors, regression and classification values as first output | ||
441 | + regression = model.outputs[0] | ||
442 | + classification = model.outputs[1] | ||
443 | + | ||
444 | + # "other" can be any additional output from custom submodels, by default this will be [] | ||
445 | + other = model.outputs[2:] | ||
446 | + | ||
447 | + # apply predicted regression to anchors | ||
448 | + boxes = layers.RegressBoxes(name="boxes")([anchors, regression]) | ||
449 | + boxes = layers.ClipBoxes(name="clipped_boxes")([model.inputs[0], boxes]) | ||
450 | + | ||
451 | + # filter detections (apply NMS / score threshold / select top-k) | ||
452 | + detections = layers.FilterDetections( | ||
453 | + nms=nms, | ||
454 | + class_specific_filter=class_specific_filter, | ||
455 | + name="filtered_detections", | ||
456 | + nms_threshold=nms_threshold, | ||
457 | + score_threshold=score_threshold, | ||
458 | + max_detections=max_detections, | ||
459 | + parallel_iterations=parallel_iterations, | ||
460 | + )([boxes, classification] + other) | ||
461 | + | ||
462 | + # construct the model | ||
463 | + return keras.models.Model(inputs=model.inputs, outputs=detections, name=name) |
retinaNet/keras_retinanet/models/senet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | + | ||
19 | +from . import retinanet | ||
20 | +from . import Backbone | ||
21 | +from classification_models.keras import Classifiers | ||
22 | + | ||
23 | + | ||
24 | +class SeBackbone(Backbone): | ||
25 | + """ Describes backbone information and provides utility functions. | ||
26 | + """ | ||
27 | + | ||
28 | + def __init__(self, backbone): | ||
29 | + super(SeBackbone, self).__init__(backbone) | ||
30 | + _, self.preprocess_image_func = Classifiers.get(self.backbone) | ||
31 | + | ||
32 | + def retinanet(self, *args, **kwargs): | ||
33 | + """ Returns a retinanet model using the correct backbone. | ||
34 | + """ | ||
35 | + return senet_retinanet(*args, backbone=self.backbone, **kwargs) | ||
36 | + | ||
37 | + def download_imagenet(self): | ||
38 | + """ Downloads ImageNet weights and returns path to weights file. | ||
39 | + """ | ||
40 | + from classification_models.weights import WEIGHTS_COLLECTION | ||
41 | + | ||
42 | + weights_path = None | ||
43 | + for el in WEIGHTS_COLLECTION: | ||
44 | + if el['model'] == self.backbone and not el['include_top']: | ||
45 | + weights_path = keras.utils.get_file(el['name'], el['url'], cache_subdir='models', file_hash=el['md5']) | ||
46 | + | ||
47 | + if weights_path is None: | ||
48 | + raise ValueError('Unable to find imagenet weights for backbone {}!'.format(self.backbone)) | ||
49 | + | ||
50 | + return weights_path | ||
51 | + | ||
52 | + def validate(self): | ||
53 | + """ Checks whether the backbone string is correct. | ||
54 | + """ | ||
55 | + allowed_backbones = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', | ||
56 | + 'seresnext50', 'seresnext101', 'senet154'] | ||
57 | + backbone = self.backbone.split('_')[0] | ||
58 | + | ||
59 | + if backbone not in allowed_backbones: | ||
60 | + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) | ||
61 | + | ||
62 | + def preprocess_image(self, inputs): | ||
63 | + """ Takes as input an image and prepares it for being passed through the network. | ||
64 | + """ | ||
65 | + return self.preprocess_image_func(inputs) | ||
66 | + | ||
67 | + | ||
68 | +def senet_retinanet(num_classes, backbone='seresnext50', inputs=None, modifier=None, **kwargs): | ||
69 | + """ Constructs a retinanet model using a resnet backbone. | ||
70 | + | ||
71 | + Args | ||
72 | + num_classes: Number of classes to predict. | ||
73 | + backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). | ||
74 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
75 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
76 | + | ||
77 | + Returns | ||
78 | + RetinaNet model with a ResNet backbone. | ||
79 | + """ | ||
80 | + # choose default input | ||
81 | + if inputs is None: | ||
82 | + if keras.backend.image_data_format() == 'channels_first': | ||
83 | + inputs = keras.layers.Input(shape=(3, None, None)) | ||
84 | + else: | ||
85 | + # inputs = keras.layers.Input(shape=(224, 224, 3)) | ||
86 | + inputs = keras.layers.Input(shape=(None, None, 3)) | ||
87 | + | ||
88 | + classifier, _ = Classifiers.get(backbone) | ||
89 | + model = classifier(input_tensor=inputs, include_top=False, weights=None) | ||
90 | + | ||
91 | + # get last conv layer from the end of each block [28x28, 14x14, 7x7] | ||
92 | + if backbone == 'seresnet18' or backbone == 'seresnet34': | ||
93 | + layer_outputs = ['stage3_unit1_relu1', 'stage4_unit1_relu1', 'relu1'] | ||
94 | + elif backbone == 'seresnet50': | ||
95 | + layer_outputs = ['activation_36', 'activation_66', 'activation_81'] | ||
96 | + elif backbone == 'seresnet101': | ||
97 | + layer_outputs = ['activation_36', 'activation_151', 'activation_166'] | ||
98 | + elif backbone == 'seresnet152': | ||
99 | + layer_outputs = ['activation_56', 'activation_236', 'activation_251'] | ||
100 | + elif backbone == 'seresnext50': | ||
101 | + layer_outputs = ['activation_37', 'activation_67', 'activation_81'] | ||
102 | + elif backbone == 'seresnext101': | ||
103 | + layer_outputs = ['activation_37', 'activation_152', 'activation_166'] | ||
104 | + elif backbone == 'senet154': | ||
105 | + layer_outputs = ['activation_59', 'activation_239', 'activation_253'] | ||
106 | + else: | ||
107 | + raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) | ||
108 | + | ||
109 | + layer_outputs = [ | ||
110 | + model.get_layer(name=layer_outputs[0]).output, # 28x28 | ||
111 | + model.get_layer(name=layer_outputs[1]).output, # 14x14 | ||
112 | + model.get_layer(name=layer_outputs[2]).output, # 7x7 | ||
113 | + ] | ||
114 | + # create the densenet backbone | ||
115 | + model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) | ||
116 | + | ||
117 | + # invoke modifier if given | ||
118 | + if modifier: | ||
119 | + model = modifier(model) | ||
120 | + | ||
121 | + # C2 not provided | ||
122 | + backbone_layers = { | ||
123 | + 'C3': model.outputs[0], | ||
124 | + 'C4': model.outputs[1], | ||
125 | + 'C5': model.outputs[2] | ||
126 | + } | ||
127 | + | ||
128 | + # create the full model | ||
129 | + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) | ||
130 | + | ||
131 | + | ||
132 | +def seresnet18_retinanet(num_classes, inputs=None, **kwargs): | ||
133 | + return senet_retinanet(num_classes=num_classes, backbone='seresnet18', inputs=inputs, **kwargs) | ||
134 | + | ||
135 | + | ||
136 | +def seresnet34_retinanet(num_classes, inputs=None, **kwargs): | ||
137 | + return senet_retinanet(num_classes=num_classes, backbone='seresnet34', inputs=inputs, **kwargs) | ||
138 | + | ||
139 | + | ||
140 | +def seresnet50_retinanet(num_classes, inputs=None, **kwargs): | ||
141 | + return senet_retinanet(num_classes=num_classes, backbone='seresnet50', inputs=inputs, **kwargs) | ||
142 | + | ||
143 | + | ||
144 | +def seresnet101_retinanet(num_classes, inputs=None, **kwargs): | ||
145 | + return senet_retinanet(num_classes=num_classes, backbone='seresnet101', inputs=inputs, **kwargs) | ||
146 | + | ||
147 | + | ||
148 | +def seresnet152_retinanet(num_classes, inputs=None, **kwargs): | ||
149 | + return senet_retinanet(num_classes=num_classes, backbone='seresnet152', inputs=inputs, **kwargs) | ||
150 | + | ||
151 | + | ||
152 | +def seresnext50_retinanet(num_classes, inputs=None, **kwargs): | ||
153 | + return senet_retinanet(num_classes=num_classes, backbone='seresnext50', inputs=inputs, **kwargs) | ||
154 | + | ||
155 | + | ||
156 | +def seresnext101_retinanet(num_classes, inputs=None, **kwargs): | ||
157 | + return senet_retinanet(num_classes=num_classes, backbone='seresnext101', inputs=inputs, **kwargs) | ||
158 | + | ||
159 | + | ||
160 | +def senet154_retinanet(num_classes, inputs=None, **kwargs): | ||
161 | + return senet_retinanet(num_classes=num_classes, backbone='senet154', inputs=inputs, **kwargs) |
retinaNet/keras_retinanet/models/submodel.py
0 → 100644
1 | +from tensorflow import keras | ||
2 | +from .. import initializers | ||
3 | +from .. import layers | ||
4 | +from ..utils.anchors import AnchorParameters | ||
5 | +from . import assert_training_model | ||
6 | +from . import retinanet | ||
7 | + | ||
8 | +def custom_classification_model( | ||
9 | + num_classes, | ||
10 | + num_anchors, | ||
11 | + pyramid_feature_size=256, | ||
12 | + prior_probability=0.01, | ||
13 | + classification_feature_size=256, | ||
14 | + name='classification_submodel' | ||
15 | +): | ||
16 | + # set input | ||
17 | + if keras.backend.image_data_format() == "channels_first": | ||
18 | + inputs = keras.layers.Input(shape=(pyramid_feature_size, None, None)) | ||
19 | + else: | ||
20 | + inputs = keras.layers.Input(shape=(None, None, pyramid_feature_size)) | ||
21 | + | ||
22 | + outputs = inputs | ||
23 | + | ||
24 | + # 3 layer | ||
25 | + for i in range(3): | ||
26 | + | ||
27 | + # 각 층의 output | ||
28 | + outputs = keras.layers.Conv2D( | ||
29 | + filters=classification_feature_size, | ||
30 | + activation="relu", | ||
31 | + name="pyramid_classification_{}".format(i), | ||
32 | + kernel_initializer=keras.initializers.RandomNormal( | ||
33 | + mean=0.0, stddev=0.01, seed=None | ||
34 | + ), # 정규분포에 따라 텐서를 생성하는 초기값 설정 | ||
35 | + bias_initializer="zeros", | ||
36 | + **options | ||
37 | + )(outputs) | ||
38 | + | ||
39 | + # 마지막 layer는 다른 필터로 다른 conv layer를 생성 | ||
40 | + outputs = keras.layers.Conv2D( | ||
41 | + filters=num_classes * num_anchors, | ||
42 | + kernel_initializer=keras.initializers.RandomNormal( | ||
43 | + mean=0.0, stddev=0.01, seed=None | ||
44 | + ), | ||
45 | + bias_initializer=initializers.PriorProbability(probability=prior_probability), | ||
46 | + name="pyramid_classification", | ||
47 | + **options | ||
48 | + )(outputs) | ||
49 | + | ||
50 | + # reshape output and apply sigmoid | ||
51 | + if keras.backend.image_data_format() == "channels_first": | ||
52 | + outputs = keras.layers.Permute( | ||
53 | + (2, 3, 1), name="pyramid_classification_permute" | ||
54 | + )(outputs) | ||
55 | + | ||
56 | + # reshape : 2차원 > 1차원 | ||
57 | + outputs = keras.layers.Reshape( | ||
58 | + (-1, num_classes), name="pyramid_classification_reshape" | ||
59 | + )(outputs) | ||
60 | + | ||
61 | + # output layer activation : sigmoid | ||
62 | + outputs = keras.layers.Activation("sigmoid", name="pyramid_classification_sigmoid")( | ||
63 | + outputs | ||
64 | + ) | ||
65 | + | ||
66 | + return keras.models.Model(inputs=inputs, outputs=outputs, name=name) | ||
67 | + | ||
68 | +def custom_regression_model(num_values, num_anchors, pyramid_feature_size=256, regression_feature_size=256, name='regression_submodel'): | ||
69 | + if num_anchors is None: | ||
70 | + num_anchors = AnchorParameters.default.num_anchors() | ||
71 | + model = retinanet.default_regression_model(num_values, num_anchors, pyramid_feature_size, regression_feature_size, name) | ||
72 | + return model | ||
73 | + | ||
74 | + | ||
75 | +def custom_submodels(num_classes, num_anchors): | ||
76 | + if num_anchors is None: | ||
77 | + num_anchors = AnchorParameters.default.num_anchors() | ||
78 | + return [ | ||
79 | + ("regression", custom_regression_model(4, num_anchors)), | ||
80 | + ("classification", custom_classification_model(num_classes, num_anchors)), | ||
81 | + ] | ||
82 | + | ||
83 | + | ||
84 | + |
retinaNet/keras_retinanet/models/vgg.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 cgratie (https://github.com/cgratie/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | + | ||
18 | +from tensorflow import keras | ||
19 | + | ||
20 | +from . import retinanet | ||
21 | +from . import Backbone | ||
22 | +from ..utils.image import preprocess_image | ||
23 | + | ||
24 | + | ||
25 | +class VGGBackbone(Backbone): | ||
26 | + """ Describes backbone information and provides utility functions. | ||
27 | + """ | ||
28 | + | ||
29 | + def retinanet(self, *args, **kwargs): | ||
30 | + """ Returns a retinanet model using the correct backbone. | ||
31 | + """ | ||
32 | + return vgg_retinanet(*args, backbone=self.backbone, **kwargs) | ||
33 | + | ||
34 | + def download_imagenet(self): | ||
35 | + """ Downloads ImageNet weights and returns path to weights file. | ||
36 | + Weights can be downloaded at https://github.com/fizyr/keras-models/releases . | ||
37 | + """ | ||
38 | + if self.backbone == 'vgg16': | ||
39 | + resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP | ||
40 | + checksum = '6d6bbae143d832006294945121d1f1fc' | ||
41 | + elif self.backbone == 'vgg19': | ||
42 | + resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP | ||
43 | + checksum = '253f8cb515780f3b799900260a226db6' | ||
44 | + else: | ||
45 | + raise ValueError("Backbone '{}' not recognized.".format(self.backbone)) | ||
46 | + | ||
47 | + return keras.utils.get_file( | ||
48 | + '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone), | ||
49 | + resource, | ||
50 | + cache_subdir='models', | ||
51 | + file_hash=checksum | ||
52 | + ) | ||
53 | + | ||
54 | + def validate(self): | ||
55 | + """ Checks whether the backbone string is correct. | ||
56 | + """ | ||
57 | + allowed_backbones = ['vgg16', 'vgg19'] | ||
58 | + | ||
59 | + if self.backbone not in allowed_backbones: | ||
60 | + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) | ||
61 | + | ||
62 | + def preprocess_image(self, inputs): | ||
63 | + """ Takes as input an image and prepares it for being passed through the network. | ||
64 | + """ | ||
65 | + return preprocess_image(inputs, mode='caffe') | ||
66 | + | ||
67 | + | ||
68 | +def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs): | ||
69 | + """ Constructs a retinanet model using a vgg backbone. | ||
70 | + | ||
71 | + Args | ||
72 | + num_classes: Number of classes to predict. | ||
73 | + backbone: Which backbone to use (one of ('vgg16', 'vgg19')). | ||
74 | + inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | ||
75 | + modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | ||
76 | + | ||
77 | + Returns | ||
78 | + RetinaNet model with a VGG backbone. | ||
79 | + """ | ||
80 | + # choose default input | ||
81 | + if inputs is None: | ||
82 | + inputs = keras.layers.Input(shape=(None, None, 3)) | ||
83 | + | ||
84 | + # create the vgg backbone | ||
85 | + if backbone == 'vgg16': | ||
86 | + vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None) | ||
87 | + elif backbone == 'vgg19': | ||
88 | + vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None) | ||
89 | + else: | ||
90 | + raise ValueError("Backbone '{}' not recognized.".format(backbone)) | ||
91 | + | ||
92 | + if modifier: | ||
93 | + vgg = modifier(vgg) | ||
94 | + | ||
95 | + # create the full model | ||
96 | + layer_names = ["block3_pool", "block4_pool", "block5_pool"] | ||
97 | + layer_outputs = [vgg.get_layer(name).output for name in layer_names] | ||
98 | + | ||
99 | + # C2 not provided | ||
100 | + backbone_layers = { | ||
101 | + 'C3': layer_outputs[0], | ||
102 | + 'C4': layer_outputs[1], | ||
103 | + 'C5': layer_outputs[2] | ||
104 | + } | ||
105 | + | ||
106 | + return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) |
File mode changed
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from ..preprocessing.generator import Generator | ||
18 | +from ..utils.image import read_image_bgr | ||
19 | + | ||
20 | +import os | ||
21 | +import numpy as np | ||
22 | + | ||
23 | +from pycocotools.coco import COCO | ||
24 | + | ||
25 | + | ||
26 | +class CocoGenerator(Generator): | ||
27 | + """ Generate data from the COCO dataset. | ||
28 | + | ||
29 | + See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. | ||
30 | + """ | ||
31 | + | ||
32 | + def __init__(self, data_dir, set_name, **kwargs): | ||
33 | + """ Initialize a COCO data generator. | ||
34 | + | ||
35 | + Args | ||
36 | + data_dir: Path to where the COCO dataset is stored. | ||
37 | + set_name: Name of the set to parse. | ||
38 | + """ | ||
39 | + self.data_dir = data_dir | ||
40 | + self.set_name = set_name | ||
41 | + self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) | ||
42 | + self.image_ids = self.coco.getImgIds() | ||
43 | + | ||
44 | + self.load_classes() | ||
45 | + | ||
46 | + super(CocoGenerator, self).__init__(**kwargs) | ||
47 | + | ||
48 | + def load_classes(self): | ||
49 | + """ Loads the class to label mapping (and inverse) for COCO. | ||
50 | + """ | ||
51 | + # load class names (name -> label) | ||
52 | + categories = self.coco.loadCats(self.coco.getCatIds()) | ||
53 | + categories.sort(key=lambda x: x['id']) | ||
54 | + | ||
55 | + self.classes = {} | ||
56 | + self.coco_labels = {} | ||
57 | + self.coco_labels_inverse = {} | ||
58 | + for c in categories: | ||
59 | + self.coco_labels[len(self.classes)] = c['id'] | ||
60 | + self.coco_labels_inverse[c['id']] = len(self.classes) | ||
61 | + self.classes[c['name']] = len(self.classes) | ||
62 | + | ||
63 | + # also load the reverse (label -> name) | ||
64 | + self.labels = {} | ||
65 | + for key, value in self.classes.items(): | ||
66 | + self.labels[value] = key | ||
67 | + | ||
68 | + def size(self): | ||
69 | + """ Size of the COCO dataset. | ||
70 | + """ | ||
71 | + return len(self.image_ids) | ||
72 | + | ||
73 | + def num_classes(self): | ||
74 | + """ Number of classes in the dataset. For COCO this is 80. | ||
75 | + """ | ||
76 | + return len(self.classes) | ||
77 | + | ||
78 | + def has_label(self, label): | ||
79 | + """ Return True if label is a known label. | ||
80 | + """ | ||
81 | + return label in self.labels | ||
82 | + | ||
83 | + def has_name(self, name): | ||
84 | + """ Returns True if name is a known class. | ||
85 | + """ | ||
86 | + return name in self.classes | ||
87 | + | ||
88 | + def name_to_label(self, name): | ||
89 | + """ Map name to label. | ||
90 | + """ | ||
91 | + return self.classes[name] | ||
92 | + | ||
93 | + def label_to_name(self, label): | ||
94 | + """ Map label to name. | ||
95 | + """ | ||
96 | + return self.labels[label] | ||
97 | + | ||
98 | + def coco_label_to_label(self, coco_label): | ||
99 | + """ Map COCO label to the label as used in the network. | ||
100 | + COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. | ||
101 | + """ | ||
102 | + return self.coco_labels_inverse[coco_label] | ||
103 | + | ||
104 | + def coco_label_to_name(self, coco_label): | ||
105 | + """ Map COCO label to name. | ||
106 | + """ | ||
107 | + return self.label_to_name(self.coco_label_to_label(coco_label)) | ||
108 | + | ||
109 | + def label_to_coco_label(self, label): | ||
110 | + """ Map label as used by the network to labels as used by COCO. | ||
111 | + """ | ||
112 | + return self.coco_labels[label] | ||
113 | + | ||
114 | + def image_path(self, image_index): | ||
115 | + """ Returns the image path for image_index. | ||
116 | + """ | ||
117 | + image_info = self.coco.loadImgs(self.image_ids[image_index])[0] | ||
118 | + path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) | ||
119 | + return path | ||
120 | + | ||
121 | + def image_aspect_ratio(self, image_index): | ||
122 | + """ Compute the aspect ratio for an image with image_index. | ||
123 | + """ | ||
124 | + image = self.coco.loadImgs(self.image_ids[image_index])[0] | ||
125 | + return float(image['width']) / float(image['height']) | ||
126 | + | ||
127 | + def load_image(self, image_index): | ||
128 | + """ Load an image at the image_index. | ||
129 | + """ | ||
130 | + path = self.image_path(image_index) | ||
131 | + return read_image_bgr(path) | ||
132 | + | ||
133 | + def load_annotations(self, image_index): | ||
134 | + """ Load annotations for an image_index. | ||
135 | + """ | ||
136 | + # get ground truth annotations | ||
137 | + annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) | ||
138 | + annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} | ||
139 | + | ||
140 | + # some images appear to miss annotations (like image with id 257034) | ||
141 | + if len(annotations_ids) == 0: | ||
142 | + return annotations | ||
143 | + | ||
144 | + # parse annotations | ||
145 | + coco_annotations = self.coco.loadAnns(annotations_ids) | ||
146 | + for idx, a in enumerate(coco_annotations): | ||
147 | + # some annotations have basically no width / height, skip them | ||
148 | + if a['bbox'][2] < 1 or a['bbox'][3] < 1: | ||
149 | + continue | ||
150 | + | ||
151 | + annotations['labels'] = np.concatenate([annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) | ||
152 | + annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ | ||
153 | + a['bbox'][0], | ||
154 | + a['bbox'][1], | ||
155 | + a['bbox'][0] + a['bbox'][2], | ||
156 | + a['bbox'][1] + a['bbox'][3], | ||
157 | + ]]], axis=0) | ||
158 | + | ||
159 | + return annotations |
1 | +""" | ||
2 | +Copyright 2017-2018 yhenon (https://github.com/yhenon/) | ||
3 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
4 | + | ||
5 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | +you may not use this file except in compliance with the License. | ||
7 | +You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | +Unless required by applicable law or agreed to in writing, software | ||
12 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | +See the License for the specific language governing permissions and | ||
15 | +limitations under the License. | ||
16 | +""" | ||
17 | + | ||
18 | +from .generator import Generator | ||
19 | +from ..utils.image import read_image_bgr | ||
20 | + | ||
21 | +import numpy as np | ||
22 | +from PIL import Image | ||
23 | +from six import raise_from | ||
24 | + | ||
25 | +import csv | ||
26 | +import sys | ||
27 | +import os.path | ||
28 | +from collections import OrderedDict | ||
29 | + | ||
30 | + | ||
31 | +def _parse(value, function, fmt): | ||
32 | + """ | ||
33 | + Parse a string into a value, and format a nice ValueError if it fails. | ||
34 | + | ||
35 | + Returns `function(value)`. | ||
36 | + Any `ValueError` raised is catched and a new `ValueError` is raised | ||
37 | + with message `fmt.format(e)`, where `e` is the caught `ValueError`. | ||
38 | + """ | ||
39 | + try: | ||
40 | + return function(value) | ||
41 | + except ValueError as e: | ||
42 | + raise_from(ValueError(fmt.format(e)), None) | ||
43 | + | ||
44 | + | ||
45 | +def _read_classes(csv_reader): | ||
46 | + """ Parse the classes file given by csv_reader. | ||
47 | + """ | ||
48 | + result = OrderedDict() | ||
49 | + for line, row in enumerate(csv_reader): | ||
50 | + line += 1 | ||
51 | + | ||
52 | + try: | ||
53 | + class_name, class_id = row | ||
54 | + except ValueError: | ||
55 | + raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) | ||
56 | + class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) | ||
57 | + | ||
58 | + if class_name in result: | ||
59 | + raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) | ||
60 | + result[class_name] = class_id | ||
61 | + return result | ||
62 | + | ||
63 | + | ||
64 | +def _read_annotations(csv_reader, classes): | ||
65 | + """ Read annotations from the csv_reader. | ||
66 | + """ | ||
67 | + result = OrderedDict() | ||
68 | + for line, row in enumerate(csv_reader): | ||
69 | + line += 1 | ||
70 | + | ||
71 | + try: | ||
72 | + img_file, x1, y1, x2, y2, class_name = row[:6] | ||
73 | + except ValueError: | ||
74 | + raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None) | ||
75 | + | ||
76 | + if img_file not in result: | ||
77 | + result[img_file] = [] | ||
78 | + | ||
79 | + # If a row contains only an image path, it's an image without annotations. | ||
80 | + if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): | ||
81 | + continue | ||
82 | + | ||
83 | + x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) | ||
84 | + y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) | ||
85 | + x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) | ||
86 | + y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) | ||
87 | + | ||
88 | + # Check that the bounding box is valid. | ||
89 | + if x2 <= x1: | ||
90 | + raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) | ||
91 | + if y2 <= y1: | ||
92 | + raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) | ||
93 | + | ||
94 | + # check if the current class name is correctly present | ||
95 | + if class_name not in classes: | ||
96 | + raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) | ||
97 | + | ||
98 | + result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) | ||
99 | + return result | ||
100 | + | ||
101 | + | ||
102 | +def _open_for_csv(path): | ||
103 | + """ Open a file with flags suitable for csv.reader. | ||
104 | + | ||
105 | + This is different for python2 it means with mode 'rb', | ||
106 | + for python3 this means 'r' with "universal newlines". | ||
107 | + """ | ||
108 | + if sys.version_info[0] < 3: | ||
109 | + return open(path, 'rb') | ||
110 | + else: | ||
111 | + return open(path, 'r', newline='') | ||
112 | + | ||
113 | + | ||
114 | +class CSVGenerator(Generator): | ||
115 | + """ Generate data for a custom CSV dataset. | ||
116 | + | ||
117 | + See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. | ||
118 | + """ | ||
119 | + | ||
120 | + def __init__( | ||
121 | + self, | ||
122 | + csv_data_file, | ||
123 | + csv_class_file, | ||
124 | + base_dir=None, | ||
125 | + **kwargs | ||
126 | + ): | ||
127 | + """ Initialize a CSV data generator. | ||
128 | + | ||
129 | + Args | ||
130 | + csv_data_file: Path to the CSV annotations file. | ||
131 | + csv_class_file: Path to the CSV classes file. | ||
132 | + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). | ||
133 | + """ | ||
134 | + self.image_names = [] | ||
135 | + self.image_data = {} | ||
136 | + self.base_dir = base_dir | ||
137 | + | ||
138 | + # Take base_dir from annotations file if not explicitly specified. | ||
139 | + if self.base_dir is None: | ||
140 | + self.base_dir = os.path.dirname(csv_data_file) | ||
141 | + | ||
142 | + # parse the provided class file | ||
143 | + try: | ||
144 | + with _open_for_csv(csv_class_file) as file: | ||
145 | + self.classes = _read_classes(csv.reader(file, delimiter=',')) | ||
146 | + except ValueError as e: | ||
147 | + raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) | ||
148 | + | ||
149 | + self.labels = {} | ||
150 | + for key, value in self.classes.items(): | ||
151 | + self.labels[value] = key | ||
152 | + | ||
153 | + # csv with img_path, x1, y1, x2, y2, class_name | ||
154 | + try: | ||
155 | + with _open_for_csv(csv_data_file) as file: | ||
156 | + self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) | ||
157 | + except ValueError as e: | ||
158 | + raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) | ||
159 | + self.image_names = list(self.image_data.keys()) | ||
160 | + | ||
161 | + super(CSVGenerator, self).__init__(**kwargs) | ||
162 | + | ||
163 | + def size(self): | ||
164 | + """ Size of the dataset. | ||
165 | + """ | ||
166 | + return len(self.image_names) | ||
167 | + | ||
168 | + def num_classes(self): | ||
169 | + """ Number of classes in the dataset. | ||
170 | + """ | ||
171 | + return max(self.classes.values()) + 1 | ||
172 | + | ||
173 | + def has_label(self, label): | ||
174 | + """ Return True if label is a known label. | ||
175 | + """ | ||
176 | + return label in self.labels | ||
177 | + | ||
178 | + def has_name(self, name): | ||
179 | + """ Returns True if name is a known class. | ||
180 | + """ | ||
181 | + return name in self.classes | ||
182 | + | ||
183 | + def name_to_label(self, name): | ||
184 | + """ Map name to label. | ||
185 | + """ | ||
186 | + return self.classes[name] | ||
187 | + | ||
188 | + def label_to_name(self, label): | ||
189 | + """ Map label to name. | ||
190 | + """ | ||
191 | + return self.labels[label] | ||
192 | + | ||
193 | + def image_path(self, image_index): | ||
194 | + """ Returns the image path for image_index. | ||
195 | + """ | ||
196 | + return os.path.join(self.base_dir, self.image_names[image_index]) | ||
197 | + | ||
198 | + def image_aspect_ratio(self, image_index): | ||
199 | + """ Compute the aspect ratio for an image with image_index. | ||
200 | + """ | ||
201 | + # PIL is fast for metadata | ||
202 | + image = Image.open(self.image_path(image_index)) | ||
203 | + return float(image.width) / float(image.height) | ||
204 | + | ||
205 | + def load_image(self, image_index): | ||
206 | + """ Load an image at the image_index. | ||
207 | + """ | ||
208 | + return read_image_bgr(self.image_path(image_index)) | ||
209 | + | ||
210 | + def load_annotations(self, image_index): | ||
211 | + """ Load annotations for an image_index. | ||
212 | + """ | ||
213 | + path = self.image_names[image_index] | ||
214 | + annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} | ||
215 | + | ||
216 | + for idx, annot in enumerate(self.image_data[path]): | ||
217 | + annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) | ||
218 | + annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ | ||
219 | + float(annot['x1']), | ||
220 | + float(annot['y1']), | ||
221 | + float(annot['x2']), | ||
222 | + float(annot['y2']), | ||
223 | + ]])) | ||
224 | + | ||
225 | + return annotations |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import numpy as np | ||
18 | +import random | ||
19 | +import warnings | ||
20 | + | ||
21 | +from tensorflow import keras | ||
22 | + | ||
23 | +from ..utils.anchors import ( | ||
24 | + anchor_targets_bbox, | ||
25 | + anchors_for_shape, | ||
26 | + guess_shapes | ||
27 | +) | ||
28 | +from ..utils.config import parse_anchor_parameters, parse_pyramid_levels | ||
29 | +from ..utils.image import ( | ||
30 | + TransformParameters, | ||
31 | + adjust_transform_for_image, | ||
32 | + apply_transform, | ||
33 | + preprocess_image, | ||
34 | + resize_image, | ||
35 | +) | ||
36 | +from ..utils.transform import transform_aabb | ||
37 | + | ||
38 | + | ||
39 | +class Generator(keras.utils.Sequence): | ||
40 | + """ Abstract generator class. | ||
41 | + """ | ||
42 | + | ||
43 | + def __init__( | ||
44 | + self, | ||
45 | + transform_generator = None, | ||
46 | + visual_effect_generator=None, | ||
47 | + batch_size=1, | ||
48 | + group_method='ratio', # one of 'none', 'random', 'ratio' | ||
49 | + shuffle_groups=True, | ||
50 | + image_min_side=800, | ||
51 | + image_max_side=1333, | ||
52 | + no_resize=False, | ||
53 | + transform_parameters=None, | ||
54 | + compute_anchor_targets=anchor_targets_bbox, | ||
55 | + compute_shapes=guess_shapes, | ||
56 | + preprocess_image=preprocess_image, | ||
57 | + config=None | ||
58 | + ): | ||
59 | + """ Initialize Generator object. | ||
60 | + | ||
61 | + Args | ||
62 | + transform_generator : A generator used to randomly transform images and annotations. | ||
63 | + batch_size : The size of the batches to generate. | ||
64 | + group_method : Determines how images are grouped together (defaults to 'ratio', one of ('none', 'random', 'ratio')). | ||
65 | + shuffle_groups : If True, shuffles the groups each epoch. | ||
66 | + image_min_side : After resizing the minimum side of an image is equal to image_min_side. | ||
67 | + image_max_side : If after resizing the maximum side is larger than image_max_side, scales down further so that the max side is equal to image_max_side. | ||
68 | + no_resize : If True, no image/annotation resizing is performed. | ||
69 | + transform_parameters : The transform parameters used for data augmentation. | ||
70 | + compute_anchor_targets : Function handler for computing the targets of anchors for an image and its annotations. | ||
71 | + compute_shapes : Function handler for computing the shapes of the pyramid for a given input. | ||
72 | + preprocess_image : Function handler for preprocessing an image (scaling / normalizing) for passing through a network. | ||
73 | + """ | ||
74 | + self.transform_generator = transform_generator | ||
75 | + self.visual_effect_generator = visual_effect_generator | ||
76 | + self.batch_size = int(batch_size) | ||
77 | + self.group_method = group_method | ||
78 | + self.shuffle_groups = shuffle_groups | ||
79 | + self.image_min_side = image_min_side | ||
80 | + self.image_max_side = image_max_side | ||
81 | + self.no_resize = no_resize | ||
82 | + self.transform_parameters = transform_parameters or TransformParameters() | ||
83 | + self.compute_anchor_targets = compute_anchor_targets | ||
84 | + self.compute_shapes = compute_shapes | ||
85 | + self.preprocess_image = preprocess_image | ||
86 | + self.config = config | ||
87 | + | ||
88 | + # Define groups | ||
89 | + self.group_images() | ||
90 | + | ||
91 | + # Shuffle when initializing | ||
92 | + if self.shuffle_groups: | ||
93 | + self.on_epoch_end() | ||
94 | + | ||
95 | + def on_epoch_end(self): | ||
96 | + if self.shuffle_groups: | ||
97 | + random.shuffle(self.groups) | ||
98 | + | ||
99 | + def size(self): | ||
100 | + """ Size of the dataset. | ||
101 | + """ | ||
102 | + raise NotImplementedError('size method not implemented') | ||
103 | + | ||
104 | + def num_classes(self): | ||
105 | + """ Number of classes in the dataset. | ||
106 | + """ | ||
107 | + raise NotImplementedError('num_classes method not implemented') | ||
108 | + | ||
109 | + def has_label(self, label): | ||
110 | + """ Returns True if label is a known label. | ||
111 | + """ | ||
112 | + raise NotImplementedError('has_label method not implemented') | ||
113 | + | ||
114 | + def has_name(self, name): | ||
115 | + """ Returns True if name is a known class. | ||
116 | + """ | ||
117 | + raise NotImplementedError('has_name method not implemented') | ||
118 | + | ||
119 | + def name_to_label(self, name): | ||
120 | + """ Map name to label. | ||
121 | + """ | ||
122 | + raise NotImplementedError('name_to_label method not implemented') | ||
123 | + | ||
124 | + def label_to_name(self, label): | ||
125 | + """ Map label to name. | ||
126 | + """ | ||
127 | + raise NotImplementedError('label_to_name method not implemented') | ||
128 | + | ||
129 | + def image_aspect_ratio(self, image_index): | ||
130 | + """ Compute the aspect ratio for an image with image_index. | ||
131 | + """ | ||
132 | + raise NotImplementedError('image_aspect_ratio method not implemented') | ||
133 | + | ||
134 | + def image_path(self, image_index): | ||
135 | + """ Get the path to an image. | ||
136 | + """ | ||
137 | + raise NotImplementedError('image_path method not implemented') | ||
138 | + | ||
139 | + def load_image(self, image_index): | ||
140 | + """ Load an image at the image_index. | ||
141 | + """ | ||
142 | + raise NotImplementedError('load_image method not implemented') | ||
143 | + | ||
144 | + def load_annotations(self, image_index): | ||
145 | + """ Load annotations for an image_index. | ||
146 | + """ | ||
147 | + raise NotImplementedError('load_annotations method not implemented') | ||
148 | + | ||
149 | + def load_annotations_group(self, group): | ||
150 | + """ Load annotations for all images in group. | ||
151 | + """ | ||
152 | + annotations_group = [self.load_annotations(image_index) for image_index in group] | ||
153 | + for annotations in annotations_group: | ||
154 | + assert(isinstance(annotations, dict)), '\'load_annotations\' should return a list of dictionaries, received: {}'.format(type(annotations)) | ||
155 | + assert('labels' in annotations), '\'load_annotations\' should return a list of dictionaries that contain \'labels\' and \'bboxes\'.' | ||
156 | + assert('bboxes' in annotations), '\'load_annotations\' should return a list of dictionaries that contain \'labels\' and \'bboxes\'.' | ||
157 | + | ||
158 | + return annotations_group | ||
159 | + | ||
160 | + def filter_annotations(self, image_group, annotations_group, group): | ||
161 | + """ Filter annotations by removing those that are outside of the image bounds or whose width/height < 0. | ||
162 | + """ | ||
163 | + # test all annotations | ||
164 | + for index, (image, annotations) in enumerate(zip(image_group, annotations_group)): | ||
165 | + # test x2 < x1 | y2 < y1 | x1 < 0 | y1 < 0 | x2 <= 0 | y2 <= 0 | x2 >= image.shape[1] | y2 >= image.shape[0] | ||
166 | + invalid_indices = np.where( | ||
167 | + (annotations['bboxes'][:, 2] <= annotations['bboxes'][:, 0]) | | ||
168 | + (annotations['bboxes'][:, 3] <= annotations['bboxes'][:, 1]) | | ||
169 | + (annotations['bboxes'][:, 0] < 0) | | ||
170 | + (annotations['bboxes'][:, 1] < 0) | | ||
171 | + (annotations['bboxes'][:, 2] > image.shape[1]) | | ||
172 | + (annotations['bboxes'][:, 3] > image.shape[0]) | ||
173 | + )[0] | ||
174 | + | ||
175 | + # delete invalid indices | ||
176 | + if len(invalid_indices): | ||
177 | + warnings.warn('Image {} with id {} (shape {}) contains the following invalid boxes: {}.'.format( | ||
178 | + self.image_path(group[index]), | ||
179 | + group[index], | ||
180 | + image.shape, | ||
181 | + annotations['bboxes'][invalid_indices, :] | ||
182 | + )) | ||
183 | + for k in annotations_group[index].keys(): | ||
184 | + annotations_group[index][k] = np.delete(annotations[k], invalid_indices, axis=0) | ||
185 | + return image_group, annotations_group | ||
186 | + | ||
187 | + def load_image_group(self, group): | ||
188 | + """ Load images for all images in a group. | ||
189 | + """ | ||
190 | + return [self.load_image(image_index) for image_index in group] | ||
191 | + | ||
192 | + def random_visual_effect_group_entry(self, image, annotations): | ||
193 | + """ Randomly transforms image and annotation. | ||
194 | + """ | ||
195 | + visual_effect = next(self.visual_effect_generator) | ||
196 | + # apply visual effect | ||
197 | + image = visual_effect(image) | ||
198 | + return image, annotations | ||
199 | + | ||
200 | + def random_visual_effect_group(self, image_group, annotations_group): | ||
201 | + """ Randomly apply visual effect on each image. | ||
202 | + """ | ||
203 | + assert(len(image_group) == len(annotations_group)) | ||
204 | + | ||
205 | + if self.visual_effect_generator is None: | ||
206 | + # do nothing | ||
207 | + return image_group, annotations_group | ||
208 | + | ||
209 | + for index in range(len(image_group)): | ||
210 | + # apply effect on a single group entry | ||
211 | + image_group[index], annotations_group[index] = self.random_visual_effect_group_entry( | ||
212 | + image_group[index], annotations_group[index] | ||
213 | + ) | ||
214 | + | ||
215 | + return image_group, annotations_group | ||
216 | + | ||
217 | + def random_transform_group_entry(self, image, annotations, transform=None): | ||
218 | + """ Randomly transforms image and annotation. | ||
219 | + """ | ||
220 | + # randomly transform both image and annotations | ||
221 | + if transform is not None or self.transform_generator: | ||
222 | + if transform is None: | ||
223 | + transform = adjust_transform_for_image(next(self.transform_generator), image, self.transform_parameters.relative_translation) | ||
224 | + | ||
225 | + # apply transformation to image | ||
226 | + image = apply_transform(transform, image, self.transform_parameters) | ||
227 | + | ||
228 | + # Transform the bounding boxes in the annotations. | ||
229 | + annotations['bboxes'] = annotations['bboxes'].copy() | ||
230 | + for index in range(annotations['bboxes'].shape[0]): | ||
231 | + annotations['bboxes'][index, :] = transform_aabb(transform, annotations['bboxes'][index, :]) | ||
232 | + | ||
233 | + return image, annotations | ||
234 | + | ||
235 | + def random_transform_group(self, image_group, annotations_group): | ||
236 | + """ Randomly transforms each image and its annotations. | ||
237 | + """ | ||
238 | + | ||
239 | + assert(len(image_group) == len(annotations_group)) | ||
240 | + | ||
241 | + for index in range(len(image_group)): | ||
242 | + # transform a single group entry | ||
243 | + image_group[index], annotations_group[index] = self.random_transform_group_entry(image_group[index], annotations_group[index]) | ||
244 | + | ||
245 | + return image_group, annotations_group | ||
246 | + | ||
247 | + def resize_image(self, image): | ||
248 | + """ Resize an image using image_min_side and image_max_side. | ||
249 | + """ | ||
250 | + if self.no_resize: | ||
251 | + return image, 1 | ||
252 | + else: | ||
253 | + return resize_image(image, min_side=self.image_min_side, max_side=self.image_max_side) | ||
254 | + | ||
255 | + def preprocess_group_entry(self, image, annotations): | ||
256 | + """ Preprocess image and its annotations. | ||
257 | + """ | ||
258 | + # resize image | ||
259 | + image, image_scale = self.resize_image(image) | ||
260 | + | ||
261 | + # preprocess the image | ||
262 | + image = self.preprocess_image(image) | ||
263 | + | ||
264 | + # apply resizing to annotations too | ||
265 | + annotations['bboxes'] *= image_scale | ||
266 | + | ||
267 | + # convert to the wanted keras floatx | ||
268 | + image = keras.backend.cast_to_floatx(image) | ||
269 | + | ||
270 | + return image, annotations | ||
271 | + | ||
272 | + def preprocess_group(self, image_group, annotations_group): | ||
273 | + """ Preprocess each image and its annotations in its group. | ||
274 | + """ | ||
275 | + assert(len(image_group) == len(annotations_group)) | ||
276 | + | ||
277 | + for index in range(len(image_group)): | ||
278 | + # preprocess a single group entry | ||
279 | + image_group[index], annotations_group[index] = self.preprocess_group_entry(image_group[index], annotations_group[index]) | ||
280 | + | ||
281 | + return image_group, annotations_group | ||
282 | + | ||
283 | + def group_images(self): | ||
284 | + """ Order the images according to self.order and makes groups of self.batch_size. | ||
285 | + """ | ||
286 | + # determine the order of the images | ||
287 | + order = list(range(self.size())) | ||
288 | + if self.group_method == 'random': | ||
289 | + random.shuffle(order) | ||
290 | + elif self.group_method == 'ratio': | ||
291 | + order.sort(key=lambda x: self.image_aspect_ratio(x)) | ||
292 | + | ||
293 | + # divide into groups, one group = one batch | ||
294 | + self.groups = [[order[x % len(order)] for x in range(i, i + self.batch_size)] for i in range(0, len(order), self.batch_size)] | ||
295 | + | ||
296 | + def compute_inputs(self, image_group): | ||
297 | + """ Compute inputs for the network using an image_group. | ||
298 | + """ | ||
299 | + # get the max image shape | ||
300 | + max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3)) | ||
301 | + | ||
302 | + # construct an image batch object | ||
303 | + image_batch = np.zeros((self.batch_size,) + max_shape, dtype=keras.backend.floatx()) | ||
304 | + | ||
305 | + # copy all images to the upper left part of the image batch object | ||
306 | + for image_index, image in enumerate(image_group): | ||
307 | + image_batch[image_index, :image.shape[0], :image.shape[1], :image.shape[2]] = image | ||
308 | + | ||
309 | + if keras.backend.image_data_format() == 'channels_first': | ||
310 | + image_batch = image_batch.transpose((0, 3, 1, 2)) | ||
311 | + | ||
312 | + return image_batch | ||
313 | + | ||
314 | + def generate_anchors(self, image_shape): | ||
315 | + anchor_params = None | ||
316 | + pyramid_levels = None | ||
317 | + if self.config and 'anchor_parameters' in self.config: | ||
318 | + anchor_params = parse_anchor_parameters(self.config) | ||
319 | + if self.config and 'pyramid_levels' in self.config: | ||
320 | + pyramid_levels = parse_pyramid_levels(self.config) | ||
321 | + | ||
322 | + return anchors_for_shape(image_shape, anchor_params=anchor_params, pyramid_levels=pyramid_levels, shapes_callback=self.compute_shapes) | ||
323 | + | ||
324 | + def compute_targets(self, image_group, annotations_group): | ||
325 | + """ Compute target outputs for the network using images and their annotations. | ||
326 | + """ | ||
327 | + # get the max image shape | ||
328 | + max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3)) | ||
329 | + anchors = self.generate_anchors(max_shape) | ||
330 | + | ||
331 | + batches = self.compute_anchor_targets( | ||
332 | + anchors, | ||
333 | + image_group, | ||
334 | + annotations_group, | ||
335 | + self.num_classes() | ||
336 | + ) | ||
337 | + | ||
338 | + return list(batches) | ||
339 | + | ||
340 | + def compute_input_output(self, group): | ||
341 | + """ Compute inputs and target outputs for the network. | ||
342 | + """ | ||
343 | + # load images and annotations | ||
344 | + image_group = self.load_image_group(group) | ||
345 | + annotations_group = self.load_annotations_group(group) | ||
346 | + | ||
347 | + # check validity of annotations | ||
348 | + image_group, annotations_group = self.filter_annotations(image_group, annotations_group, group) | ||
349 | + | ||
350 | + # randomly apply visual effect | ||
351 | + image_group, annotations_group = self.random_visual_effect_group(image_group, annotations_group) | ||
352 | + | ||
353 | + # randomly transform data | ||
354 | + image_group, annotations_group = self.random_transform_group(image_group, annotations_group) | ||
355 | + | ||
356 | + # perform preprocessing steps | ||
357 | + image_group, annotations_group = self.preprocess_group(image_group, annotations_group) | ||
358 | + | ||
359 | + # compute network inputs | ||
360 | + inputs = self.compute_inputs(image_group) | ||
361 | + | ||
362 | + # compute network targets | ||
363 | + targets = self.compute_targets(image_group, annotations_group) | ||
364 | + | ||
365 | + return inputs, targets | ||
366 | + | ||
367 | + def __len__(self): | ||
368 | + """ | ||
369 | + Number of batches for generator. | ||
370 | + """ | ||
371 | + | ||
372 | + return len(self.groups) | ||
373 | + | ||
374 | + def __getitem__(self, index): | ||
375 | + """ | ||
376 | + Keras sequence method for generating batches. | ||
377 | + """ | ||
378 | + group = self.groups[index] | ||
379 | + inputs, targets = self.compute_input_output(group) | ||
380 | + | ||
381 | + return inputs, targets |
1 | +""" | ||
2 | +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import csv | ||
18 | +import os.path | ||
19 | + | ||
20 | +import numpy as np | ||
21 | +from PIL import Image | ||
22 | + | ||
23 | +from .generator import Generator | ||
24 | +from ..utils.image import read_image_bgr | ||
25 | + | ||
26 | +kitti_classes = { | ||
27 | + 'Car': 0, | ||
28 | + 'Van': 1, | ||
29 | + 'Truck': 2, | ||
30 | + 'Pedestrian': 3, | ||
31 | + 'Person_sitting': 4, | ||
32 | + 'Cyclist': 5, | ||
33 | + 'Tram': 6, | ||
34 | + 'Misc': 7, | ||
35 | + 'DontCare': 7 | ||
36 | +} | ||
37 | + | ||
38 | + | ||
39 | +class KittiGenerator(Generator): | ||
40 | + """ Generate data for a KITTI dataset. | ||
41 | + | ||
42 | + See http://www.cvlibs.net/datasets/kitti/ for more information. | ||
43 | + """ | ||
44 | + | ||
45 | + def __init__( | ||
46 | + self, | ||
47 | + base_dir, | ||
48 | + subset='train', | ||
49 | + **kwargs | ||
50 | + ): | ||
51 | + """ Initialize a KITTI data generator. | ||
52 | + | ||
53 | + Args | ||
54 | + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). | ||
55 | + subset: The subset to generate data for (defaults to 'train'). | ||
56 | + """ | ||
57 | + self.base_dir = base_dir | ||
58 | + | ||
59 | + label_dir = os.path.join(self.base_dir, subset, 'labels') | ||
60 | + image_dir = os.path.join(self.base_dir, subset, 'images') | ||
61 | + | ||
62 | + """ | ||
63 | + 1 type Describes the type of object: 'Car', 'Van', 'Truck', | ||
64 | + 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', | ||
65 | + 'Misc' or 'DontCare' | ||
66 | + 1 truncated Float from 0 (non-truncated) to 1 (truncated), where | ||
67 | + truncated refers to the object leaving image boundaries | ||
68 | + 1 occluded Integer (0,1,2,3) indicating occlusion state: | ||
69 | + 0 = fully visible, 1 = partly occluded | ||
70 | + 2 = largely occluded, 3 = unknown | ||
71 | + 1 alpha Observation angle of object, ranging [-pi..pi] | ||
72 | + 4 bbox 2D bounding box of object in the image (0-based index): | ||
73 | + contains left, top, right, bottom pixel coordinates | ||
74 | + 3 dimensions 3D object dimensions: height, width, length (in meters) | ||
75 | + 3 location 3D object location x,y,z in camera coordinates (in meters) | ||
76 | + 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] | ||
77 | + """ | ||
78 | + | ||
79 | + self.labels = {} | ||
80 | + self.classes = kitti_classes | ||
81 | + for name, label in self.classes.items(): | ||
82 | + self.labels[label] = name | ||
83 | + | ||
84 | + self.image_data = dict() | ||
85 | + self.images = [] | ||
86 | + for i, fn in enumerate(os.listdir(label_dir)): | ||
87 | + label_fp = os.path.join(label_dir, fn) | ||
88 | + image_fp = os.path.join(image_dir, fn.replace('.txt', '.png')) | ||
89 | + | ||
90 | + self.images.append(image_fp) | ||
91 | + | ||
92 | + fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl', | ||
93 | + 'lx', 'ly', 'lz', 'ry'] | ||
94 | + with open(label_fp, 'r') as csv_file: | ||
95 | + reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames) | ||
96 | + boxes = [] | ||
97 | + for line, row in enumerate(reader): | ||
98 | + label = row['type'] | ||
99 | + cls_id = kitti_classes[label] | ||
100 | + | ||
101 | + annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']} | ||
102 | + boxes.append(annotation) | ||
103 | + | ||
104 | + self.image_data[i] = boxes | ||
105 | + | ||
106 | + super(KittiGenerator, self).__init__(**kwargs) | ||
107 | + | ||
108 | + def size(self): | ||
109 | + """ Size of the dataset. | ||
110 | + """ | ||
111 | + return len(self.images) | ||
112 | + | ||
113 | + def num_classes(self): | ||
114 | + """ Number of classes in the dataset. | ||
115 | + """ | ||
116 | + return max(self.classes.values()) + 1 | ||
117 | + | ||
118 | + def has_label(self, label): | ||
119 | + """ Return True if label is a known label. | ||
120 | + """ | ||
121 | + return label in self.labels | ||
122 | + | ||
123 | + def has_name(self, name): | ||
124 | + """ Returns True if name is a known class. | ||
125 | + """ | ||
126 | + return name in self.classes | ||
127 | + | ||
128 | + def name_to_label(self, name): | ||
129 | + """ Map name to label. | ||
130 | + """ | ||
131 | + raise NotImplementedError() | ||
132 | + | ||
133 | + def label_to_name(self, label): | ||
134 | + """ Map label to name. | ||
135 | + """ | ||
136 | + return self.labels[label] | ||
137 | + | ||
138 | + def image_aspect_ratio(self, image_index): | ||
139 | + """ Compute the aspect ratio for an image with image_index. | ||
140 | + """ | ||
141 | + # PIL is fast for metadata | ||
142 | + image = Image.open(self.images[image_index]) | ||
143 | + return float(image.width) / float(image.height) | ||
144 | + | ||
145 | + def image_path(self, image_index): | ||
146 | + """ Get the path to an image. | ||
147 | + """ | ||
148 | + return self.images[image_index] | ||
149 | + | ||
150 | + def load_image(self, image_index): | ||
151 | + """ Load an image at the image_index. | ||
152 | + """ | ||
153 | + return read_image_bgr(self.image_path(image_index)) | ||
154 | + | ||
155 | + def load_annotations(self, image_index): | ||
156 | + """ Load annotations for an image_index. | ||
157 | + """ | ||
158 | + image_data = self.image_data[image_index] | ||
159 | + annotations = {'labels': np.empty((len(image_data),)), 'bboxes': np.empty((len(image_data), 4))} | ||
160 | + | ||
161 | + for idx, ann in enumerate(image_data): | ||
162 | + annotations['bboxes'][idx, 0] = float(ann['x1']) | ||
163 | + annotations['bboxes'][idx, 1] = float(ann['y1']) | ||
164 | + annotations['bboxes'][idx, 2] = float(ann['x2']) | ||
165 | + annotations['bboxes'][idx, 3] = float(ann['y2']) | ||
166 | + annotations['labels'][idx] = int(ann['cls_id']) | ||
167 | + | ||
168 | + return annotations |
1 | +""" | ||
2 | +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import csv | ||
18 | +import json | ||
19 | +import os | ||
20 | +import warnings | ||
21 | + | ||
22 | +import numpy as np | ||
23 | +from PIL import Image | ||
24 | + | ||
25 | +from .generator import Generator | ||
26 | +from ..utils.image import read_image_bgr | ||
27 | + | ||
28 | + | ||
29 | +def load_hierarchy(metadata_dir, version='v4'): | ||
30 | + hierarchy = None | ||
31 | + if version == 'challenge2018': | ||
32 | + hierarchy = 'bbox_labels_500_hierarchy.json' | ||
33 | + elif version == 'v4': | ||
34 | + hierarchy = 'bbox_labels_600_hierarchy.json' | ||
35 | + elif version == 'v3': | ||
36 | + hierarchy = 'bbox_labels_600_hierarchy.json' | ||
37 | + | ||
38 | + hierarchy_json = os.path.join(metadata_dir, hierarchy) | ||
39 | + with open(hierarchy_json) as f: | ||
40 | + hierarchy_data = json.loads(f.read()) | ||
41 | + | ||
42 | + return hierarchy_data | ||
43 | + | ||
44 | + | ||
45 | +def load_hierarchy_children(hierarchy): | ||
46 | + res = [hierarchy['LabelName']] | ||
47 | + | ||
48 | + if 'Subcategory' in hierarchy: | ||
49 | + for subcategory in hierarchy['Subcategory']: | ||
50 | + children = load_hierarchy_children(subcategory) | ||
51 | + | ||
52 | + for c in children: | ||
53 | + res.append(c) | ||
54 | + | ||
55 | + return res | ||
56 | + | ||
57 | + | ||
58 | +def find_hierarchy_parent(hierarchy, parent_cls): | ||
59 | + if hierarchy['LabelName'] == parent_cls: | ||
60 | + return hierarchy | ||
61 | + elif 'Subcategory' in hierarchy: | ||
62 | + for child in hierarchy['Subcategory']: | ||
63 | + res = find_hierarchy_parent(child, parent_cls) | ||
64 | + if res is not None: | ||
65 | + return res | ||
66 | + | ||
67 | + return None | ||
68 | + | ||
69 | + | ||
70 | +def get_labels(metadata_dir, version='v4'): | ||
71 | + if version == 'v4' or version == 'challenge2018': | ||
72 | + csv_file = 'class-descriptions-boxable.csv' if version == 'v4' else 'challenge-2018-class-descriptions-500.csv' | ||
73 | + | ||
74 | + boxable_classes_descriptions = os.path.join(metadata_dir, csv_file) | ||
75 | + id_to_labels = {} | ||
76 | + cls_index = {} | ||
77 | + | ||
78 | + i = 0 | ||
79 | + with open(boxable_classes_descriptions) as f: | ||
80 | + for row in csv.reader(f): | ||
81 | + # make sure the csv row is not empty (usually the last one) | ||
82 | + if len(row): | ||
83 | + label = row[0] | ||
84 | + description = row[1].replace("\"", "").replace("'", "").replace('`', '') | ||
85 | + | ||
86 | + id_to_labels[i] = description | ||
87 | + cls_index[label] = i | ||
88 | + | ||
89 | + i += 1 | ||
90 | + else: | ||
91 | + trainable_classes_path = os.path.join(metadata_dir, 'classes-bbox-trainable.txt') | ||
92 | + description_path = os.path.join(metadata_dir, 'class-descriptions.csv') | ||
93 | + | ||
94 | + description_table = {} | ||
95 | + with open(description_path) as f: | ||
96 | + for row in csv.reader(f): | ||
97 | + # make sure the csv row is not empty (usually the last one) | ||
98 | + if len(row): | ||
99 | + description_table[row[0]] = row[1].replace("\"", "").replace("'", "").replace('`', '') | ||
100 | + | ||
101 | + with open(trainable_classes_path, 'rb') as f: | ||
102 | + trainable_classes = f.read().split('\n') | ||
103 | + | ||
104 | + id_to_labels = dict([(i, description_table[c]) for i, c in enumerate(trainable_classes)]) | ||
105 | + cls_index = dict([(c, i) for i, c in enumerate(trainable_classes)]) | ||
106 | + | ||
107 | + return id_to_labels, cls_index | ||
108 | + | ||
109 | + | ||
110 | +def generate_images_annotations_json(main_dir, metadata_dir, subset, cls_index, version='v4'): | ||
111 | + validation_image_ids = {} | ||
112 | + | ||
113 | + if version == 'v4': | ||
114 | + annotations_path = os.path.join(metadata_dir, subset, '{}-annotations-bbox.csv'.format(subset)) | ||
115 | + elif version == 'challenge2018': | ||
116 | + validation_image_ids_path = os.path.join(metadata_dir, 'challenge-2018-image-ids-valset-od.csv') | ||
117 | + | ||
118 | + with open(validation_image_ids_path, 'r') as csv_file: | ||
119 | + reader = csv.DictReader(csv_file, fieldnames=['ImageID']) | ||
120 | + next(reader) | ||
121 | + for line, row in enumerate(reader): | ||
122 | + image_id = row['ImageID'] | ||
123 | + validation_image_ids[image_id] = True | ||
124 | + | ||
125 | + annotations_path = os.path.join(metadata_dir, 'challenge-2018-train-annotations-bbox.csv') | ||
126 | + else: | ||
127 | + annotations_path = os.path.join(metadata_dir, subset, 'annotations-human-bbox.csv') | ||
128 | + | ||
129 | + fieldnames = ['ImageID', 'Source', 'LabelName', 'Confidence', | ||
130 | + 'XMin', 'XMax', 'YMin', 'YMax', | ||
131 | + 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction', 'IsInside'] | ||
132 | + | ||
133 | + id_annotations = dict() | ||
134 | + with open(annotations_path, 'r') as csv_file: | ||
135 | + reader = csv.DictReader(csv_file, fieldnames=fieldnames) | ||
136 | + next(reader) | ||
137 | + | ||
138 | + images_sizes = {} | ||
139 | + for line, row in enumerate(reader): | ||
140 | + frame = row['ImageID'] | ||
141 | + | ||
142 | + if version == 'challenge2018': | ||
143 | + if subset == 'train': | ||
144 | + if frame in validation_image_ids: | ||
145 | + continue | ||
146 | + elif subset == 'validation': | ||
147 | + if frame not in validation_image_ids: | ||
148 | + continue | ||
149 | + else: | ||
150 | + raise NotImplementedError('This generator handles only the train and validation subsets') | ||
151 | + | ||
152 | + class_name = row['LabelName'] | ||
153 | + | ||
154 | + if class_name not in cls_index: | ||
155 | + continue | ||
156 | + | ||
157 | + cls_id = cls_index[class_name] | ||
158 | + | ||
159 | + if version == 'challenge2018': | ||
160 | + # We recommend participants to use the provided subset of the training set as a validation set. | ||
161 | + # This is preferable over using the V4 val/test sets, as the training set is more densely annotated. | ||
162 | + img_path = os.path.join(main_dir, 'images', 'train', frame + '.jpg') | ||
163 | + else: | ||
164 | + img_path = os.path.join(main_dir, 'images', subset, frame + '.jpg') | ||
165 | + | ||
166 | + if frame in images_sizes: | ||
167 | + width, height = images_sizes[frame] | ||
168 | + else: | ||
169 | + try: | ||
170 | + with Image.open(img_path) as img: | ||
171 | + width, height = img.width, img.height | ||
172 | + images_sizes[frame] = (width, height) | ||
173 | + except Exception as ex: | ||
174 | + if version == 'challenge2018': | ||
175 | + raise ex | ||
176 | + continue | ||
177 | + | ||
178 | + x1 = float(row['XMin']) | ||
179 | + x2 = float(row['XMax']) | ||
180 | + y1 = float(row['YMin']) | ||
181 | + y2 = float(row['YMax']) | ||
182 | + | ||
183 | + x1_int = int(round(x1 * width)) | ||
184 | + x2_int = int(round(x2 * width)) | ||
185 | + y1_int = int(round(y1 * height)) | ||
186 | + y2_int = int(round(y2 * height)) | ||
187 | + | ||
188 | + # Check that the bounding box is valid. | ||
189 | + if x2 <= x1: | ||
190 | + raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) | ||
191 | + if y2 <= y1: | ||
192 | + raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) | ||
193 | + | ||
194 | + if y2_int == y1_int: | ||
195 | + warnings.warn('filtering line {}: rounding y2 ({}) and y1 ({}) makes them equal'.format(line, y2, y1)) | ||
196 | + continue | ||
197 | + | ||
198 | + if x2_int == x1_int: | ||
199 | + warnings.warn('filtering line {}: rounding x2 ({}) and x1 ({}) makes them equal'.format(line, x2, x1)) | ||
200 | + continue | ||
201 | + | ||
202 | + img_id = row['ImageID'] | ||
203 | + annotation = {'cls_id': cls_id, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2} | ||
204 | + | ||
205 | + if img_id in id_annotations: | ||
206 | + annotations = id_annotations[img_id] | ||
207 | + annotations['boxes'].append(annotation) | ||
208 | + else: | ||
209 | + id_annotations[img_id] = {'w': width, 'h': height, 'boxes': [annotation]} | ||
210 | + return id_annotations | ||
211 | + | ||
212 | + | ||
213 | +class OpenImagesGenerator(Generator): | ||
214 | + def __init__( | ||
215 | + self, main_dir, subset, version='v4', | ||
216 | + labels_filter=None, annotation_cache_dir='.', | ||
217 | + parent_label=None, | ||
218 | + **kwargs | ||
219 | + ): | ||
220 | + if version == 'challenge2018': | ||
221 | + metadata = 'challenge2018' | ||
222 | + elif version == 'v4': | ||
223 | + metadata = '2018_04' | ||
224 | + elif version == 'v3': | ||
225 | + metadata = '2017_11' | ||
226 | + else: | ||
227 | + raise NotImplementedError('There is currently no implementation for versions older than v3') | ||
228 | + | ||
229 | + if version == 'challenge2018': | ||
230 | + self.base_dir = os.path.join(main_dir, 'images', 'train') | ||
231 | + else: | ||
232 | + self.base_dir = os.path.join(main_dir, 'images', subset) | ||
233 | + | ||
234 | + metadata_dir = os.path.join(main_dir, metadata) | ||
235 | + annotation_cache_json = os.path.join(annotation_cache_dir, subset + '.json') | ||
236 | + | ||
237 | + self.hierarchy = load_hierarchy(metadata_dir, version=version) | ||
238 | + id_to_labels, cls_index = get_labels(metadata_dir, version=version) | ||
239 | + | ||
240 | + if os.path.exists(annotation_cache_json): | ||
241 | + with open(annotation_cache_json, 'r') as f: | ||
242 | + self.annotations = json.loads(f.read()) | ||
243 | + else: | ||
244 | + self.annotations = generate_images_annotations_json(main_dir, metadata_dir, subset, cls_index, version=version) | ||
245 | + json.dump(self.annotations, open(annotation_cache_json, "w")) | ||
246 | + | ||
247 | + if labels_filter is not None or parent_label is not None: | ||
248 | + self.id_to_labels, self.annotations = self.__filter_data(id_to_labels, cls_index, labels_filter, parent_label) | ||
249 | + else: | ||
250 | + self.id_to_labels = id_to_labels | ||
251 | + | ||
252 | + self.id_to_image_id = dict([(i, k) for i, k in enumerate(self.annotations)]) | ||
253 | + | ||
254 | + super(OpenImagesGenerator, self).__init__(**kwargs) | ||
255 | + | ||
256 | + def __filter_data(self, id_to_labels, cls_index, labels_filter=None, parent_label=None): | ||
257 | + """ | ||
258 | + If you want to work with a subset of the labels just set a list with trainable labels | ||
259 | + :param labels_filter: Ex: labels_filter = ['Helmet', 'Hat', 'Analog television'] | ||
260 | + :param parent_label: If parent_label is set this will bring you the parent label | ||
261 | + but also its children in the semantic hierarchy as defined in OID, ex: Animal | ||
262 | + hierarchical tree | ||
263 | + :return: | ||
264 | + """ | ||
265 | + | ||
266 | + children_id_to_labels = {} | ||
267 | + | ||
268 | + if parent_label is None: | ||
269 | + # there is/are no other sublabel(s) other than the labels itself | ||
270 | + | ||
271 | + for label in labels_filter: | ||
272 | + for i, lb in id_to_labels.items(): | ||
273 | + if lb == label: | ||
274 | + children_id_to_labels[i] = label | ||
275 | + break | ||
276 | + else: | ||
277 | + parent_cls = None | ||
278 | + for i, lb in id_to_labels.items(): | ||
279 | + if lb == parent_label: | ||
280 | + parent_id = i | ||
281 | + for c, index in cls_index.items(): | ||
282 | + if index == parent_id: | ||
283 | + parent_cls = c | ||
284 | + break | ||
285 | + | ||
286 | + if parent_cls is None: | ||
287 | + raise Exception('Couldnt find label {}'.format(parent_label)) | ||
288 | + | ||
289 | + parent_tree = find_hierarchy_parent(self.hierarchy, parent_cls) | ||
290 | + | ||
291 | + if parent_tree is None: | ||
292 | + raise Exception('Couldnt find parent {} in the semantic hierarchical tree'.format(parent_label)) | ||
293 | + | ||
294 | + children = load_hierarchy_children(parent_tree) | ||
295 | + | ||
296 | + for cls in children: | ||
297 | + index = cls_index[cls] | ||
298 | + label = id_to_labels[index] | ||
299 | + children_id_to_labels[index] = label | ||
300 | + | ||
301 | + id_map = dict([(ind, i) for i, ind in enumerate(children_id_to_labels.keys())]) | ||
302 | + | ||
303 | + filtered_annotations = {} | ||
304 | + for k in self.annotations: | ||
305 | + img_ann = self.annotations[k] | ||
306 | + | ||
307 | + filtered_boxes = [] | ||
308 | + for ann in img_ann['boxes']: | ||
309 | + cls_id = ann['cls_id'] | ||
310 | + if cls_id in children_id_to_labels: | ||
311 | + ann['cls_id'] = id_map[cls_id] | ||
312 | + filtered_boxes.append(ann) | ||
313 | + | ||
314 | + if len(filtered_boxes) > 0: | ||
315 | + filtered_annotations[k] = {'w': img_ann['w'], 'h': img_ann['h'], 'boxes': filtered_boxes} | ||
316 | + | ||
317 | + children_id_to_labels = dict([(id_map[i], l) for (i, l) in children_id_to_labels.items()]) | ||
318 | + | ||
319 | + return children_id_to_labels, filtered_annotations | ||
320 | + | ||
321 | + def size(self): | ||
322 | + return len(self.annotations) | ||
323 | + | ||
324 | + def num_classes(self): | ||
325 | + return len(self.id_to_labels) | ||
326 | + | ||
327 | + def has_label(self, label): | ||
328 | + """ Return True if label is a known label. | ||
329 | + """ | ||
330 | + return label in self.id_to_labels | ||
331 | + | ||
332 | + def has_name(self, name): | ||
333 | + """ Returns True if name is a known class. | ||
334 | + """ | ||
335 | + raise NotImplementedError() | ||
336 | + | ||
337 | + def name_to_label(self, name): | ||
338 | + raise NotImplementedError() | ||
339 | + | ||
340 | + def label_to_name(self, label): | ||
341 | + return self.id_to_labels[label] | ||
342 | + | ||
343 | + def image_aspect_ratio(self, image_index): | ||
344 | + img_annotations = self.annotations[self.id_to_image_id[image_index]] | ||
345 | + height, width = img_annotations['h'], img_annotations['w'] | ||
346 | + return float(width) / float(height) | ||
347 | + | ||
348 | + def image_path(self, image_index): | ||
349 | + path = os.path.join(self.base_dir, self.id_to_image_id[image_index] + '.jpg') | ||
350 | + return path | ||
351 | + | ||
352 | + def load_image(self, image_index): | ||
353 | + return read_image_bgr(self.image_path(image_index)) | ||
354 | + | ||
355 | + def load_annotations(self, image_index): | ||
356 | + image_annotations = self.annotations[self.id_to_image_id[image_index]] | ||
357 | + | ||
358 | + labels = image_annotations['boxes'] | ||
359 | + height, width = image_annotations['h'], image_annotations['w'] | ||
360 | + | ||
361 | + annotations = {'labels': np.empty((len(labels),)), 'bboxes': np.empty((len(labels), 4))} | ||
362 | + for idx, ann in enumerate(labels): | ||
363 | + cls_id = ann['cls_id'] | ||
364 | + x1 = ann['x1'] * width | ||
365 | + x2 = ann['x2'] * width | ||
366 | + y1 = ann['y1'] * height | ||
367 | + y2 = ann['y2'] * height | ||
368 | + | ||
369 | + annotations['bboxes'][idx, 0] = x1 | ||
370 | + annotations['bboxes'][idx, 1] = y1 | ||
371 | + annotations['bboxes'][idx, 2] = x2 | ||
372 | + annotations['bboxes'][idx, 3] = y2 | ||
373 | + annotations['labels'][idx] = cls_id | ||
374 | + | ||
375 | + return annotations |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from ..preprocessing.generator import Generator | ||
18 | +from ..utils.image import read_image_bgr | ||
19 | + | ||
20 | +import os | ||
21 | +import numpy as np | ||
22 | +from six import raise_from | ||
23 | +from PIL import Image | ||
24 | + | ||
25 | +try: | ||
26 | + import xml.etree.cElementTree as ET | ||
27 | +except ImportError: | ||
28 | + import xml.etree.ElementTree as ET | ||
29 | + | ||
30 | +voc_classes = { | ||
31 | + 'aeroplane' : 0, | ||
32 | + 'bicycle' : 1, | ||
33 | + 'bird' : 2, | ||
34 | + 'boat' : 3, | ||
35 | + 'bottle' : 4, | ||
36 | + 'bus' : 5, | ||
37 | + 'car' : 6, | ||
38 | + 'cat' : 7, | ||
39 | + 'chair' : 8, | ||
40 | + 'cow' : 9, | ||
41 | + 'diningtable' : 10, | ||
42 | + 'dog' : 11, | ||
43 | + 'horse' : 12, | ||
44 | + 'motorbike' : 13, | ||
45 | + 'person' : 14, | ||
46 | + 'pottedplant' : 15, | ||
47 | + 'sheep' : 16, | ||
48 | + 'sofa' : 17, | ||
49 | + 'train' : 18, | ||
50 | + 'tvmonitor' : 19 | ||
51 | +} | ||
52 | + | ||
53 | + | ||
54 | +def _findNode(parent, name, debug_name=None, parse=None): | ||
55 | + if debug_name is None: | ||
56 | + debug_name = name | ||
57 | + | ||
58 | + result = parent.find(name) | ||
59 | + if result is None: | ||
60 | + raise ValueError('missing element \'{}\''.format(debug_name)) | ||
61 | + if parse is not None: | ||
62 | + try: | ||
63 | + return parse(result.text) | ||
64 | + except ValueError as e: | ||
65 | + raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) | ||
66 | + return result | ||
67 | + | ||
68 | + | ||
69 | +class PascalVocGenerator(Generator): | ||
70 | + """ Generate data for a Pascal VOC dataset. | ||
71 | + | ||
72 | + See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. | ||
73 | + """ | ||
74 | + | ||
75 | + def __init__( | ||
76 | + self, | ||
77 | + data_dir, | ||
78 | + set_name, | ||
79 | + classes=voc_classes, | ||
80 | + image_extension='.jpg', | ||
81 | + skip_truncated=False, | ||
82 | + skip_difficult=False, | ||
83 | + **kwargs | ||
84 | + ): | ||
85 | + """ Initialize a Pascal VOC data generator. | ||
86 | + | ||
87 | + Args | ||
88 | + base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). | ||
89 | + csv_class_file: Path to the CSV classes file. | ||
90 | + """ | ||
91 | + self.data_dir = data_dir | ||
92 | + self.set_name = set_name | ||
93 | + self.classes = classes | ||
94 | + self.image_names = [line.strip().split(None, 1)[0] for line in open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] | ||
95 | + self.image_extension = image_extension | ||
96 | + self.skip_truncated = skip_truncated | ||
97 | + self.skip_difficult = skip_difficult | ||
98 | + | ||
99 | + self.labels = {} | ||
100 | + for key, value in self.classes.items(): | ||
101 | + self.labels[value] = key | ||
102 | + | ||
103 | + super(PascalVocGenerator, self).__init__(**kwargs) | ||
104 | + | ||
105 | + def size(self): | ||
106 | + """ Size of the dataset. | ||
107 | + """ | ||
108 | + return len(self.image_names) | ||
109 | + | ||
110 | + def num_classes(self): | ||
111 | + """ Number of classes in the dataset. | ||
112 | + """ | ||
113 | + return len(self.classes) | ||
114 | + | ||
115 | + def has_label(self, label): | ||
116 | + """ Return True if label is a known label. | ||
117 | + """ | ||
118 | + return label in self.labels | ||
119 | + | ||
120 | + def has_name(self, name): | ||
121 | + """ Returns True if name is a known class. | ||
122 | + """ | ||
123 | + return name in self.classes | ||
124 | + | ||
125 | + def name_to_label(self, name): | ||
126 | + """ Map name to label. | ||
127 | + """ | ||
128 | + return self.classes[name] | ||
129 | + | ||
130 | + def label_to_name(self, label): | ||
131 | + """ Map label to name. | ||
132 | + """ | ||
133 | + return self.labels[label] | ||
134 | + | ||
135 | + def image_aspect_ratio(self, image_index): | ||
136 | + """ Compute the aspect ratio for an image with image_index. | ||
137 | + """ | ||
138 | + path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) | ||
139 | + image = Image.open(path) | ||
140 | + return float(image.width) / float(image.height) | ||
141 | + | ||
142 | + def image_path(self, image_index): | ||
143 | + """ Get the path to an image. | ||
144 | + """ | ||
145 | + return os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) | ||
146 | + | ||
147 | + def load_image(self, image_index): | ||
148 | + """ Load an image at the image_index. | ||
149 | + """ | ||
150 | + return read_image_bgr(self.image_path(image_index)) | ||
151 | + | ||
152 | + def __parse_annotation(self, element): | ||
153 | + """ Parse an annotation given an XML element. | ||
154 | + """ | ||
155 | + truncated = _findNode(element, 'truncated', parse=int) | ||
156 | + difficult = _findNode(element, 'difficult', parse=int) | ||
157 | + | ||
158 | + class_name = _findNode(element, 'name').text | ||
159 | + if class_name not in self.classes: | ||
160 | + raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) | ||
161 | + | ||
162 | + box = np.zeros((4,)) | ||
163 | + label = self.name_to_label(class_name) | ||
164 | + | ||
165 | + bndbox = _findNode(element, 'bndbox') | ||
166 | + box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 | ||
167 | + box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 | ||
168 | + box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 | ||
169 | + box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 | ||
170 | + | ||
171 | + return truncated, difficult, box, label | ||
172 | + | ||
173 | + def __parse_annotations(self, xml_root): | ||
174 | + """ Parse all annotations under the xml_root. | ||
175 | + """ | ||
176 | + annotations = {'labels': np.empty((len(xml_root.findall('object')),)), 'bboxes': np.empty((len(xml_root.findall('object')), 4))} | ||
177 | + for i, element in enumerate(xml_root.iter('object')): | ||
178 | + try: | ||
179 | + truncated, difficult, box, label = self.__parse_annotation(element) | ||
180 | + except ValueError as e: | ||
181 | + raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) | ||
182 | + | ||
183 | + if truncated and self.skip_truncated: | ||
184 | + continue | ||
185 | + if difficult and self.skip_difficult: | ||
186 | + continue | ||
187 | + | ||
188 | + annotations['bboxes'][i, :] = box | ||
189 | + annotations['labels'][i] = label | ||
190 | + | ||
191 | + return annotations | ||
192 | + | ||
193 | + def load_annotations(self, image_index): | ||
194 | + """ Load annotations for an image_index. | ||
195 | + """ | ||
196 | + filename = self.image_names[image_index] + '.xml' | ||
197 | + try: | ||
198 | + tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) | ||
199 | + return self.__parse_annotations(tree.getroot()) | ||
200 | + except ET.ParseError as e: | ||
201 | + raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) | ||
202 | + except ValueError as e: | ||
203 | + raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) |
retinaNet/keras_retinanet/utils/__init__.py
0 → 100644
File mode changed
retinaNet/keras_retinanet/utils/anchors.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import numpy as np | ||
18 | +from tensorflow import keras | ||
19 | + | ||
20 | +from ..utils.compute_overlap import compute_overlap | ||
21 | + | ||
22 | + | ||
23 | +class AnchorParameters: | ||
24 | + """ The parameteres that define how anchors are generated. | ||
25 | + | ||
26 | + Args | ||
27 | + sizes : List of sizes to use. Each size corresponds to one feature level. | ||
28 | + strides : List of strides to use. Each stride correspond to one feature level. | ||
29 | + ratios : List of ratios to use per location in a feature map. | ||
30 | + scales : List of scales to use per location in a feature map. | ||
31 | + """ | ||
32 | + def __init__(self, sizes, strides, ratios, scales): | ||
33 | + self.sizes = sizes | ||
34 | + self.strides = strides | ||
35 | + self.ratios = ratios | ||
36 | + self.scales = scales | ||
37 | + | ||
38 | + def num_anchors(self): | ||
39 | + return len(self.ratios) * len(self.scales) | ||
40 | + | ||
41 | + | ||
42 | +""" | ||
43 | +The default anchor parameters. | ||
44 | +""" | ||
45 | +AnchorParameters.default = AnchorParameters( | ||
46 | + sizes = [32, 64, 128, 256, 512], | ||
47 | + strides = [8, 16, 32, 64, 128], | ||
48 | + ratios = np.array([0.5, 1, 2], keras.backend.floatx()), | ||
49 | + scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)], keras.backend.floatx()), | ||
50 | +) | ||
51 | + | ||
52 | + | ||
53 | +def anchor_targets_bbox( | ||
54 | + anchors, | ||
55 | + image_group, | ||
56 | + annotations_group, | ||
57 | + num_classes, | ||
58 | + negative_overlap=0.4, | ||
59 | + positive_overlap=0.5 | ||
60 | +): | ||
61 | + """ Generate anchor targets for bbox detection. | ||
62 | + | ||
63 | + Args | ||
64 | + anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2). | ||
65 | + image_group: List of BGR images. | ||
66 | + annotations_group: List of annotation dictionaries with each annotation containing 'labels' and 'bboxes' of an image. | ||
67 | + num_classes: Number of classes to predict. | ||
68 | + mask_shape: If the image is padded with zeros, mask_shape can be used to mark the relevant part of the image. | ||
69 | + negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative). | ||
70 | + positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive). | ||
71 | + | ||
72 | + Returns | ||
73 | + labels_batch: batch that contains labels & anchor states (np.array of shape (batch_size, N, num_classes + 1), | ||
74 | + where N is the number of anchors for an image and the last column defines the anchor state (-1 for ignore, 0 for bg, 1 for fg). | ||
75 | + regression_batch: batch that contains bounding-box regression targets for an image & anchor states (np.array of shape (batch_size, N, 4 + 1), | ||
76 | + where N is the number of anchors for an image, the first 4 columns define regression targets for (x1, y1, x2, y2) and the | ||
77 | + last column defines anchor states (-1 for ignore, 0 for bg, 1 for fg). | ||
78 | + """ | ||
79 | + | ||
80 | + assert(len(image_group) == len(annotations_group)), "The length of the images and annotations need to be equal." | ||
81 | + assert(len(annotations_group) > 0), "No data received to compute anchor targets for." | ||
82 | + for annotations in annotations_group: | ||
83 | + assert('bboxes' in annotations), "Annotations should contain bboxes." | ||
84 | + assert('labels' in annotations), "Annotations should contain labels." | ||
85 | + | ||
86 | + batch_size = len(image_group) | ||
87 | + | ||
88 | + regression_batch = np.zeros((batch_size, anchors.shape[0], 4 + 1), dtype=keras.backend.floatx()) | ||
89 | + labels_batch = np.zeros((batch_size, anchors.shape[0], num_classes + 1), dtype=keras.backend.floatx()) | ||
90 | + | ||
91 | + # compute labels and regression targets | ||
92 | + for index, (image, annotations) in enumerate(zip(image_group, annotations_group)): | ||
93 | + if annotations['bboxes'].shape[0]: | ||
94 | + # obtain indices of gt annotations with the greatest overlap | ||
95 | + positive_indices, ignore_indices, argmax_overlaps_inds = compute_gt_annotations(anchors, annotations['bboxes'], negative_overlap, positive_overlap) | ||
96 | + | ||
97 | + labels_batch[index, ignore_indices, -1] = -1 | ||
98 | + labels_batch[index, positive_indices, -1] = 1 | ||
99 | + | ||
100 | + regression_batch[index, ignore_indices, -1] = -1 | ||
101 | + regression_batch[index, positive_indices, -1] = 1 | ||
102 | + | ||
103 | + # compute target class labels | ||
104 | + labels_batch[index, positive_indices, annotations['labels'][argmax_overlaps_inds[positive_indices]].astype(int)] = 1 | ||
105 | + | ||
106 | + regression_batch[index, :, :-1] = bbox_transform(anchors, annotations['bboxes'][argmax_overlaps_inds, :]) | ||
107 | + | ||
108 | + # ignore annotations outside of image | ||
109 | + if image.shape: | ||
110 | + anchors_centers = np.vstack([(anchors[:, 0] + anchors[:, 2]) / 2, (anchors[:, 1] + anchors[:, 3]) / 2]).T | ||
111 | + indices = np.logical_or(anchors_centers[:, 0] >= image.shape[1], anchors_centers[:, 1] >= image.shape[0]) | ||
112 | + | ||
113 | + labels_batch[index, indices, -1] = -1 | ||
114 | + regression_batch[index, indices, -1] = -1 | ||
115 | + | ||
116 | + return regression_batch, labels_batch | ||
117 | + | ||
118 | + | ||
119 | +def compute_gt_annotations( | ||
120 | + anchors, | ||
121 | + annotations, | ||
122 | + negative_overlap=0.4, | ||
123 | + positive_overlap=0.5 | ||
124 | +): | ||
125 | + """ Obtain indices of gt annotations with the greatest overlap. | ||
126 | + | ||
127 | + Args | ||
128 | + anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2). | ||
129 | + annotations: np.array of shape (N, 5) for (x1, y1, x2, y2, label). | ||
130 | + negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative). | ||
131 | + positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive). | ||
132 | + | ||
133 | + Returns | ||
134 | + positive_indices: indices of positive anchors | ||
135 | + ignore_indices: indices of ignored anchors | ||
136 | + argmax_overlaps_inds: ordered overlaps indices | ||
137 | + """ | ||
138 | + | ||
139 | + overlaps = compute_overlap(anchors.astype(np.float64), annotations.astype(np.float64)) | ||
140 | + argmax_overlaps_inds = np.argmax(overlaps, axis=1) | ||
141 | + max_overlaps = overlaps[np.arange(overlaps.shape[0]), argmax_overlaps_inds] | ||
142 | + | ||
143 | + # assign "dont care" labels | ||
144 | + positive_indices = max_overlaps >= positive_overlap | ||
145 | + ignore_indices = (max_overlaps > negative_overlap) & ~positive_indices | ||
146 | + | ||
147 | + return positive_indices, ignore_indices, argmax_overlaps_inds | ||
148 | + | ||
149 | + | ||
150 | +def layer_shapes(image_shape, model): | ||
151 | + """Compute layer shapes given input image shape and the model. | ||
152 | + | ||
153 | + Args | ||
154 | + image_shape: The shape of the image. | ||
155 | + model: The model to use for computing how the image shape is transformed in the pyramid. | ||
156 | + | ||
157 | + Returns | ||
158 | + A dictionary mapping layer names to image shapes. | ||
159 | + """ | ||
160 | + shape = { | ||
161 | + model.layers[0].name: (None,) + image_shape, | ||
162 | + } | ||
163 | + | ||
164 | + for layer in model.layers[1:]: | ||
165 | + nodes = layer._inbound_nodes | ||
166 | + for node in nodes: | ||
167 | + if isinstance(node.inbound_layers, keras.layers.Layer): | ||
168 | + inputs = [shape[node.inbound_layers.name]] | ||
169 | + else: | ||
170 | + inputs = [shape[lr.name] for lr in node.inbound_layers] | ||
171 | + if not inputs: | ||
172 | + continue | ||
173 | + shape[layer.name] = layer.compute_output_shape(inputs[0] if len(inputs) == 1 else inputs) | ||
174 | + | ||
175 | + return shape | ||
176 | + | ||
177 | + | ||
178 | +def make_shapes_callback(model): | ||
179 | + """ Make a function for getting the shape of the pyramid levels. | ||
180 | + """ | ||
181 | + def get_shapes(image_shape, pyramid_levels): | ||
182 | + shape = layer_shapes(image_shape, model) | ||
183 | + image_shapes = [shape["P{}".format(level)][1:3] for level in pyramid_levels] | ||
184 | + return image_shapes | ||
185 | + | ||
186 | + return get_shapes | ||
187 | + | ||
188 | + | ||
189 | +def guess_shapes(image_shape, pyramid_levels): | ||
190 | + """Guess shapes based on pyramid levels. | ||
191 | + | ||
192 | + Args | ||
193 | + image_shape: The shape of the image. | ||
194 | + pyramid_levels: A list of what pyramid levels are used. | ||
195 | + | ||
196 | + Returns | ||
197 | + A list of image shapes at each pyramid level. | ||
198 | + """ | ||
199 | + image_shape = np.array(image_shape[:2]) | ||
200 | + image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels] | ||
201 | + return image_shapes | ||
202 | + | ||
203 | + | ||
204 | +def anchors_for_shape( | ||
205 | + image_shape, | ||
206 | + pyramid_levels=None, | ||
207 | + anchor_params=None, | ||
208 | + shapes_callback=None, | ||
209 | +): | ||
210 | + """ Generators anchors for a given shape. | ||
211 | + | ||
212 | + Args | ||
213 | + image_shape: The shape of the image. | ||
214 | + pyramid_levels: List of ints representing which pyramids to use (defaults to [3, 4, 5, 6, 7]). | ||
215 | + anchor_params: Struct containing anchor parameters. If None, default values are used. | ||
216 | + shapes_callback: Function to call for getting the shape of the image at different pyramid levels. | ||
217 | + | ||
218 | + Returns | ||
219 | + np.array of shape (N, 4) containing the (x1, y1, x2, y2) coordinates for the anchors. | ||
220 | + """ | ||
221 | + | ||
222 | + if pyramid_levels is None: | ||
223 | + pyramid_levels = [3, 4, 5, 6, 7] | ||
224 | + | ||
225 | + if anchor_params is None: | ||
226 | + anchor_params = AnchorParameters.default | ||
227 | + | ||
228 | + if shapes_callback is None: | ||
229 | + shapes_callback = guess_shapes | ||
230 | + image_shapes = shapes_callback(image_shape, pyramid_levels) | ||
231 | + | ||
232 | + # compute anchors over all pyramid levels | ||
233 | + all_anchors = np.zeros((0, 4)) | ||
234 | + for idx, p in enumerate(pyramid_levels): | ||
235 | + anchors = generate_anchors( | ||
236 | + base_size=anchor_params.sizes[idx], | ||
237 | + ratios=anchor_params.ratios, | ||
238 | + scales=anchor_params.scales | ||
239 | + ) | ||
240 | + shifted_anchors = shift(image_shapes[idx], anchor_params.strides[idx], anchors) | ||
241 | + all_anchors = np.append(all_anchors, shifted_anchors, axis=0) | ||
242 | + | ||
243 | + return all_anchors | ||
244 | + | ||
245 | + | ||
246 | +def shift(shape, stride, anchors): | ||
247 | + """ Produce shifted anchors based on shape of the map and stride size. | ||
248 | + | ||
249 | + Args | ||
250 | + shape : Shape to shift the anchors over. | ||
251 | + stride : Stride to shift the anchors with over the shape. | ||
252 | + anchors: The anchors to apply at each location. | ||
253 | + """ | ||
254 | + | ||
255 | + # create a grid starting from half stride from the top left corner | ||
256 | + shift_x = (np.arange(0, shape[1]) + 0.5) * stride | ||
257 | + shift_y = (np.arange(0, shape[0]) + 0.5) * stride | ||
258 | + | ||
259 | + shift_x, shift_y = np.meshgrid(shift_x, shift_y) | ||
260 | + | ||
261 | + shifts = np.vstack(( | ||
262 | + shift_x.ravel(), shift_y.ravel(), | ||
263 | + shift_x.ravel(), shift_y.ravel() | ||
264 | + )).transpose() | ||
265 | + | ||
266 | + # add A anchors (1, A, 4) to | ||
267 | + # cell K shifts (K, 1, 4) to get | ||
268 | + # shift anchors (K, A, 4) | ||
269 | + # reshape to (K*A, 4) shifted anchors | ||
270 | + A = anchors.shape[0] | ||
271 | + K = shifts.shape[0] | ||
272 | + all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))) | ||
273 | + all_anchors = all_anchors.reshape((K * A, 4)) | ||
274 | + | ||
275 | + return all_anchors | ||
276 | + | ||
277 | + | ||
278 | +def generate_anchors(base_size=16, ratios=None, scales=None): | ||
279 | + """ | ||
280 | + Generate anchor (reference) windows by enumerating aspect ratios X | ||
281 | + scales w.r.t. a reference window. | ||
282 | + """ | ||
283 | + | ||
284 | + if ratios is None: | ||
285 | + ratios = AnchorParameters.default.ratios | ||
286 | + | ||
287 | + if scales is None: | ||
288 | + scales = AnchorParameters.default.scales | ||
289 | + | ||
290 | + num_anchors = len(ratios) * len(scales) | ||
291 | + | ||
292 | + # initialize output anchors | ||
293 | + anchors = np.zeros((num_anchors, 4)) | ||
294 | + | ||
295 | + # scale base_size | ||
296 | + anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T | ||
297 | + | ||
298 | + # compute areas of anchors | ||
299 | + areas = anchors[:, 2] * anchors[:, 3] | ||
300 | + | ||
301 | + # correct for ratios | ||
302 | + anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales))) | ||
303 | + anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales)) | ||
304 | + | ||
305 | + # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2) | ||
306 | + anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T | ||
307 | + anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T | ||
308 | + | ||
309 | + return anchors | ||
310 | + | ||
311 | + | ||
312 | +def bbox_transform(anchors, gt_boxes, mean=None, std=None): | ||
313 | + """Compute bounding-box regression targets for an image.""" | ||
314 | + | ||
315 | + # The Mean and std are calculated from COCO dataset. | ||
316 | + # Bounding box normalization was firstly introduced in the Fast R-CNN paper. | ||
317 | + # See https://github.com/fizyr/keras-retinanet/issues/1273#issuecomment-585828825 for more details | ||
318 | + if mean is None: | ||
319 | + mean = np.array([0, 0, 0, 0]) | ||
320 | + if std is None: | ||
321 | + std = np.array([0.2, 0.2, 0.2, 0.2]) | ||
322 | + | ||
323 | + if isinstance(mean, (list, tuple)): | ||
324 | + mean = np.array(mean) | ||
325 | + elif not isinstance(mean, np.ndarray): | ||
326 | + raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) | ||
327 | + | ||
328 | + if isinstance(std, (list, tuple)): | ||
329 | + std = np.array(std) | ||
330 | + elif not isinstance(std, np.ndarray): | ||
331 | + raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) | ||
332 | + | ||
333 | + anchor_widths = anchors[:, 2] - anchors[:, 0] | ||
334 | + anchor_heights = anchors[:, 3] - anchors[:, 1] | ||
335 | + | ||
336 | + # According to the information provided by a keras-retinanet author, they got marginally better results using | ||
337 | + # the following way of bounding box parametrization. | ||
338 | + # See https://github.com/fizyr/keras-retinanet/issues/1273#issuecomment-585828825 for more details | ||
339 | + targets_dx1 = (gt_boxes[:, 0] - anchors[:, 0]) / anchor_widths | ||
340 | + targets_dy1 = (gt_boxes[:, 1] - anchors[:, 1]) / anchor_heights | ||
341 | + targets_dx2 = (gt_boxes[:, 2] - anchors[:, 2]) / anchor_widths | ||
342 | + targets_dy2 = (gt_boxes[:, 3] - anchors[:, 3]) / anchor_heights | ||
343 | + | ||
344 | + targets = np.stack((targets_dx1, targets_dy1, targets_dx2, targets_dy2)) | ||
345 | + targets = targets.T | ||
346 | + | ||
347 | + targets = (targets - mean) / std | ||
348 | + | ||
349 | + return targets |
retinaNet/keras_retinanet/utils/coco_eval.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from pycocotools.cocoeval import COCOeval | ||
18 | + | ||
19 | +from tensorflow import keras | ||
20 | +import numpy as np | ||
21 | +import json | ||
22 | + | ||
23 | +import progressbar | ||
24 | +assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." | ||
25 | + | ||
26 | + | ||
27 | +def evaluate_coco(generator, model, threshold=0.05): | ||
28 | + """ Use the pycocotools to evaluate a COCO model on a dataset. | ||
29 | + | ||
30 | + Args | ||
31 | + generator : The generator for generating the evaluation data. | ||
32 | + model : The model to evaluate. | ||
33 | + threshold : The score threshold to use. | ||
34 | + """ | ||
35 | + # start collecting results | ||
36 | + results = [] | ||
37 | + image_ids = [] | ||
38 | + for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '): | ||
39 | + image = generator.load_image(index) | ||
40 | + image = generator.preprocess_image(image) | ||
41 | + image, scale = generator.resize_image(image) | ||
42 | + | ||
43 | + if keras.backend.image_data_format() == 'channels_first': | ||
44 | + image = image.transpose((2, 0, 1)) | ||
45 | + | ||
46 | + # run network | ||
47 | + boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) | ||
48 | + | ||
49 | + # correct boxes for image scale | ||
50 | + boxes /= scale | ||
51 | + | ||
52 | + # change to (x, y, w, h) (MS COCO standard) | ||
53 | + boxes[:, :, 2] -= boxes[:, :, 0] | ||
54 | + boxes[:, :, 3] -= boxes[:, :, 1] | ||
55 | + | ||
56 | + # compute predicted labels and scores | ||
57 | + for box, score, label in zip(boxes[0], scores[0], labels[0]): | ||
58 | + # scores are sorted, so we can break | ||
59 | + if score < threshold: | ||
60 | + break | ||
61 | + | ||
62 | + # append detection for each positively labeled class | ||
63 | + image_result = { | ||
64 | + 'image_id' : generator.image_ids[index], | ||
65 | + 'category_id' : generator.label_to_coco_label(label), | ||
66 | + 'score' : float(score), | ||
67 | + 'bbox' : box.tolist(), | ||
68 | + } | ||
69 | + | ||
70 | + # append detection to results | ||
71 | + results.append(image_result) | ||
72 | + | ||
73 | + # append image to list of processed images | ||
74 | + image_ids.append(generator.image_ids[index]) | ||
75 | + | ||
76 | + if not len(results): | ||
77 | + return | ||
78 | + | ||
79 | + # write output | ||
80 | + json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) | ||
81 | + json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) | ||
82 | + | ||
83 | + # load results in COCO evaluation tool | ||
84 | + coco_true = generator.coco | ||
85 | + coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) | ||
86 | + | ||
87 | + # run COCO evaluation | ||
88 | + coco_eval = COCOeval(coco_true, coco_pred, 'bbox') | ||
89 | + coco_eval.params.imgIds = image_ids | ||
90 | + coco_eval.evaluate() | ||
91 | + coco_eval.accumulate() | ||
92 | + coco_eval.summarize() | ||
93 | + return coco_eval.stats |
retinaNet/keras_retinanet/utils/colors.py
0 → 100644
1 | +import warnings | ||
2 | + | ||
3 | + | ||
4 | +def label_color(label): | ||
5 | + """ Return a color from a set of predefined colors. Contains 80 colors in total. | ||
6 | + | ||
7 | + Args | ||
8 | + label: The label to get the color for. | ||
9 | + | ||
10 | + Returns | ||
11 | + A list of three values representing a RGB color. | ||
12 | + | ||
13 | + If no color is defined for a certain label, the color green is returned and a warning is printed. | ||
14 | + """ | ||
15 | + if label < len(colors): | ||
16 | + return colors[label] | ||
17 | + else: | ||
18 | + warnings.warn('Label {} has no color, returning default.'.format(label)) | ||
19 | + return (0, 255, 0) | ||
20 | + | ||
21 | + | ||
22 | +""" | ||
23 | +Generated using: | ||
24 | + | ||
25 | +``` | ||
26 | +colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)] | ||
27 | +shuffle(colors) | ||
28 | +pprint(colors) | ||
29 | +``` | ||
30 | +""" | ||
31 | +colors = [ | ||
32 | + [31 , 0 , 255] , | ||
33 | + [0 , 159 , 255] , | ||
34 | + [255 , 95 , 0] , | ||
35 | + [255 , 19 , 0] , | ||
36 | + [255 , 0 , 0] , | ||
37 | + [255 , 38 , 0] , | ||
38 | + [0 , 255 , 25] , | ||
39 | + [255 , 0 , 133] , | ||
40 | + [255 , 172 , 0] , | ||
41 | + [108 , 0 , 255] , | ||
42 | + [0 , 82 , 255] , | ||
43 | + [0 , 255 , 6] , | ||
44 | + [255 , 0 , 152] , | ||
45 | + [223 , 0 , 255] , | ||
46 | + [12 , 0 , 255] , | ||
47 | + [0 , 255 , 178] , | ||
48 | + [108 , 255 , 0] , | ||
49 | + [184 , 0 , 255] , | ||
50 | + [255 , 0 , 76] , | ||
51 | + [146 , 255 , 0] , | ||
52 | + [51 , 0 , 255] , | ||
53 | + [0 , 197 , 255] , | ||
54 | + [255 , 248 , 0] , | ||
55 | + [255 , 0 , 19] , | ||
56 | + [255 , 0 , 38] , | ||
57 | + [89 , 255 , 0] , | ||
58 | + [127 , 255 , 0] , | ||
59 | + [255 , 153 , 0] , | ||
60 | + [0 , 255 , 255] , | ||
61 | + [0 , 255 , 216] , | ||
62 | + [0 , 255 , 121] , | ||
63 | + [255 , 0 , 248] , | ||
64 | + [70 , 0 , 255] , | ||
65 | + [0 , 255 , 159] , | ||
66 | + [0 , 216 , 255] , | ||
67 | + [0 , 6 , 255] , | ||
68 | + [0 , 63 , 255] , | ||
69 | + [31 , 255 , 0] , | ||
70 | + [255 , 57 , 0] , | ||
71 | + [255 , 0 , 210] , | ||
72 | + [0 , 255 , 102] , | ||
73 | + [242 , 255 , 0] , | ||
74 | + [255 , 191 , 0] , | ||
75 | + [0 , 255 , 63] , | ||
76 | + [255 , 0 , 95] , | ||
77 | + [146 , 0 , 255] , | ||
78 | + [184 , 255 , 0] , | ||
79 | + [255 , 114 , 0] , | ||
80 | + [0 , 255 , 235] , | ||
81 | + [255 , 229 , 0] , | ||
82 | + [0 , 178 , 255] , | ||
83 | + [255 , 0 , 114] , | ||
84 | + [255 , 0 , 57] , | ||
85 | + [0 , 140 , 255] , | ||
86 | + [0 , 121 , 255] , | ||
87 | + [12 , 255 , 0] , | ||
88 | + [255 , 210 , 0] , | ||
89 | + [0 , 255 , 44] , | ||
90 | + [165 , 255 , 0] , | ||
91 | + [0 , 25 , 255] , | ||
92 | + [0 , 255 , 140] , | ||
93 | + [0 , 101 , 255] , | ||
94 | + [0 , 255 , 82] , | ||
95 | + [223 , 255 , 0] , | ||
96 | + [242 , 0 , 255] , | ||
97 | + [89 , 0 , 255] , | ||
98 | + [165 , 0 , 255] , | ||
99 | + [70 , 255 , 0] , | ||
100 | + [255 , 0 , 172] , | ||
101 | + [255 , 76 , 0] , | ||
102 | + [203 , 255 , 0] , | ||
103 | + [204 , 0 , 255] , | ||
104 | + [255 , 0 , 229] , | ||
105 | + [255 , 133 , 0] , | ||
106 | + [127 , 0 , 255] , | ||
107 | + [0 , 235 , 255] , | ||
108 | + [0 , 255 , 197] , | ||
109 | + [255 , 0 , 191] , | ||
110 | + [0 , 44 , 255] , | ||
111 | + [50 , 255 , 0] | ||
112 | +] |
This diff could not be displayed because it is too large.
No preview for this file type
1 | +# -------------------------------------------------------- | ||
2 | +# Fast R-CNN | ||
3 | +# Copyright (c) 2015 Microsoft | ||
4 | +# Licensed under The MIT License [see LICENSE for details] | ||
5 | +# Written by Sergey Karayev | ||
6 | +# -------------------------------------------------------- | ||
7 | + | ||
8 | +cimport cython | ||
9 | +import numpy as np | ||
10 | +cimport numpy as np | ||
11 | + | ||
12 | + | ||
13 | +def compute_overlap( | ||
14 | + np.ndarray[double, ndim=2] boxes, | ||
15 | + np.ndarray[double, ndim=2] query_boxes | ||
16 | +): | ||
17 | + """ | ||
18 | + Args | ||
19 | + a: (N, 4) ndarray of float | ||
20 | + b: (K, 4) ndarray of float | ||
21 | + | ||
22 | + Returns | ||
23 | + overlaps: (N, K) ndarray of overlap between boxes and query_boxes | ||
24 | + """ | ||
25 | + cdef unsigned int N = boxes.shape[0] | ||
26 | + cdef unsigned int K = query_boxes.shape[0] | ||
27 | + cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) | ||
28 | + cdef double iw, ih, box_area | ||
29 | + cdef double ua | ||
30 | + cdef unsigned int k, n | ||
31 | + for k in range(K): | ||
32 | + box_area = ( | ||
33 | + (query_boxes[k, 2] - query_boxes[k, 0]) * | ||
34 | + (query_boxes[k, 3] - query_boxes[k, 1]) | ||
35 | + ) | ||
36 | + for n in range(N): | ||
37 | + iw = ( | ||
38 | + min(boxes[n, 2], query_boxes[k, 2]) - | ||
39 | + max(boxes[n, 0], query_boxes[k, 0]) | ||
40 | + ) | ||
41 | + if iw > 0: | ||
42 | + ih = ( | ||
43 | + min(boxes[n, 3], query_boxes[k, 3]) - | ||
44 | + max(boxes[n, 1], query_boxes[k, 1]) | ||
45 | + ) | ||
46 | + if ih > 0: | ||
47 | + ua = np.float64( | ||
48 | + (boxes[n, 2] - boxes[n, 0]) * | ||
49 | + (boxes[n, 3] - boxes[n, 1]) + | ||
50 | + box_area - iw * ih | ||
51 | + ) | ||
52 | + overlaps[n, k] = iw * ih / ua | ||
53 | + return overlaps |
retinaNet/keras_retinanet/utils/config.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import configparser | ||
18 | +import numpy as np | ||
19 | +from tensorflow import keras | ||
20 | +from ..utils.anchors import AnchorParameters | ||
21 | + | ||
22 | + | ||
23 | +def read_config_file(config_path): | ||
24 | + config = configparser.ConfigParser() | ||
25 | + | ||
26 | + with open(config_path, 'r') as file: | ||
27 | + config.read_file(file) | ||
28 | + | ||
29 | + assert 'anchor_parameters' in config, \ | ||
30 | + "Malformed config file. Verify that it contains the anchor_parameters section." | ||
31 | + | ||
32 | + config_keys = set(config['anchor_parameters']) | ||
33 | + default_keys = set(AnchorParameters.default.__dict__.keys()) | ||
34 | + | ||
35 | + assert config_keys <= default_keys, \ | ||
36 | + "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys) | ||
37 | + | ||
38 | + if 'pyramid_levels' in config: | ||
39 | + assert('levels' in config['pyramid_levels']), "pyramid levels specified by levels key" | ||
40 | + | ||
41 | + return config | ||
42 | + | ||
43 | + | ||
44 | +def parse_anchor_parameters(config): | ||
45 | + ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx()) | ||
46 | + scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx()) | ||
47 | + sizes = list(map(int, config['anchor_parameters']['sizes'].split(' '))) | ||
48 | + strides = list(map(int, config['anchor_parameters']['strides'].split(' '))) | ||
49 | + assert (len(sizes) == len(strides)), "sizes and strides should have an equal number of values" | ||
50 | + | ||
51 | + return AnchorParameters(sizes, strides, ratios, scales) | ||
52 | + | ||
53 | + | ||
54 | +def parse_pyramid_levels(config): | ||
55 | + levels = list(map(int, config['pyramid_levels']['levels'].split(' '))) | ||
56 | + | ||
57 | + return levels |
retinaNet/keras_retinanet/utils/eval.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from .anchors import compute_overlap | ||
18 | +from .visualization import draw_detections, draw_annotations | ||
19 | + | ||
20 | +from tensorflow import keras | ||
21 | +import numpy as np | ||
22 | +import os | ||
23 | +import time | ||
24 | + | ||
25 | +import cv2 | ||
26 | +import progressbar | ||
27 | +assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." | ||
28 | + | ||
29 | + | ||
30 | +def _compute_ap(recall, precision): | ||
31 | + """ Compute the average precision, given the recall and precision curves. | ||
32 | + | ||
33 | + Code originally from https://github.com/rbgirshick/py-faster-rcnn. | ||
34 | + | ||
35 | + # Arguments | ||
36 | + recall: The recall curve (list). | ||
37 | + precision: The precision curve (list). | ||
38 | + # Returns | ||
39 | + The average precision as computed in py-faster-rcnn. | ||
40 | + """ | ||
41 | + # correct AP calculation | ||
42 | + # first append sentinel values at the end | ||
43 | + mrec = np.concatenate(([0.], recall, [1.])) | ||
44 | + mpre = np.concatenate(([0.], precision, [0.])) | ||
45 | + | ||
46 | + # compute the precision envelope | ||
47 | + for i in range(mpre.size - 1, 0, -1): | ||
48 | + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) | ||
49 | + | ||
50 | + # to calculate area under PR curve, look for points | ||
51 | + # where X axis (recall) changes value | ||
52 | + i = np.where(mrec[1:] != mrec[:-1])[0] | ||
53 | + | ||
54 | + # and sum (\Delta recall) * prec | ||
55 | + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) | ||
56 | + return ap | ||
57 | + | ||
58 | + | ||
59 | +def _get_detections(generator, model, score_threshold=0.05, max_detections=100, save_path=None): | ||
60 | + """ Get the detections from the model using the generator. | ||
61 | + | ||
62 | + The result is a list of lists such that the size is: | ||
63 | + all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes] | ||
64 | + | ||
65 | + # Arguments | ||
66 | + generator : The generator used to run images through the model. | ||
67 | + model : The model to run on the images. | ||
68 | + score_threshold : The score confidence threshold to use. | ||
69 | + max_detections : The maximum number of detections to use per image. | ||
70 | + save_path : The path to save the images with visualized detections to. | ||
71 | + # Returns | ||
72 | + A list of lists containing the detections for each image in the generator. | ||
73 | + """ | ||
74 | + all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())] | ||
75 | + all_inferences = [None for i in range(generator.size())] | ||
76 | + | ||
77 | + for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): | ||
78 | + raw_image = generator.load_image(i) | ||
79 | + image, scale = generator.resize_image(raw_image.copy()) | ||
80 | + image = generator.preprocess_image(image) | ||
81 | + | ||
82 | + if keras.backend.image_data_format() == 'channels_first': | ||
83 | + image = image.transpose((2, 0, 1)) | ||
84 | + | ||
85 | + # run network | ||
86 | + start = time.time() | ||
87 | + boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3] | ||
88 | + inference_time = time.time() - start | ||
89 | + | ||
90 | + # correct boxes for image scale | ||
91 | + boxes /= scale | ||
92 | + | ||
93 | + # select indices which have a score above the threshold | ||
94 | + indices = np.where(scores[0, :] > score_threshold)[0] | ||
95 | + | ||
96 | + # select those scores | ||
97 | + scores = scores[0][indices] | ||
98 | + | ||
99 | + # find the order with which to sort the scores | ||
100 | + scores_sort = np.argsort(-scores)[:max_detections] | ||
101 | + | ||
102 | + # select detections | ||
103 | + image_boxes = boxes[0, indices[scores_sort], :] | ||
104 | + image_scores = scores[scores_sort] | ||
105 | + image_labels = labels[0, indices[scores_sort]] | ||
106 | + image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1) | ||
107 | + | ||
108 | + if save_path is not None: | ||
109 | + draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name) | ||
110 | + draw_detections(raw_image, image_boxes, image_scores, image_labels, label_to_name=generator.label_to_name, score_threshold=score_threshold) | ||
111 | + | ||
112 | + cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) | ||
113 | + | ||
114 | + # copy detections to all_detections | ||
115 | + for label in range(generator.num_classes()): | ||
116 | + if not generator.has_label(label): | ||
117 | + continue | ||
118 | + | ||
119 | + all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1] | ||
120 | + | ||
121 | + all_inferences[i] = inference_time | ||
122 | + | ||
123 | + return all_detections, all_inferences | ||
124 | + | ||
125 | + | ||
126 | +def _get_annotations(generator): | ||
127 | + """ Get the ground truth annotations from the generator. | ||
128 | + | ||
129 | + The result is a list of lists such that the size is: | ||
130 | + all_detections[num_images][num_classes] = annotations[num_detections, 5] | ||
131 | + | ||
132 | + # Arguments | ||
133 | + generator : The generator used to retrieve ground truth annotations. | ||
134 | + # Returns | ||
135 | + A list of lists containing the annotations for each image in the generator. | ||
136 | + """ | ||
137 | + all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] | ||
138 | + | ||
139 | + for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '): | ||
140 | + # load the annotations | ||
141 | + annotations = generator.load_annotations(i) | ||
142 | + | ||
143 | + # copy detections to all_annotations | ||
144 | + for label in range(generator.num_classes()): | ||
145 | + if not generator.has_label(label): | ||
146 | + continue | ||
147 | + | ||
148 | + all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() | ||
149 | + | ||
150 | + return all_annotations | ||
151 | + | ||
152 | + | ||
153 | +def evaluate( | ||
154 | + generator, | ||
155 | + model, | ||
156 | + iou_threshold=0.5, | ||
157 | + score_threshold=0.05, | ||
158 | + max_detections=100, | ||
159 | + save_path=None | ||
160 | +): | ||
161 | + """ Evaluate a given dataset using a given model. | ||
162 | + | ||
163 | + # Arguments | ||
164 | + generator : The generator that represents the dataset to evaluate. | ||
165 | + model : The model to evaluate. | ||
166 | + iou_threshold : The threshold used to consider when a detection is positive or negative. | ||
167 | + score_threshold : The score confidence threshold to use for detections. | ||
168 | + max_detections : The maximum number of detections to use per image. | ||
169 | + save_path : The path to save images with visualized detections to. | ||
170 | + # Returns | ||
171 | + A dict mapping class names to mAP scores. | ||
172 | + """ | ||
173 | + # gather all detections and annotations | ||
174 | + all_detections, all_inferences = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path) | ||
175 | + all_annotations = _get_annotations(generator) | ||
176 | + average_precisions = {} | ||
177 | + | ||
178 | + # all_detections = pickle.load(open('all_detections.pkl', 'rb')) | ||
179 | + # all_annotations = pickle.load(open('all_annotations.pkl', 'rb')) | ||
180 | + # pickle.dump(all_detections, open('all_detections.pkl', 'wb')) | ||
181 | + # pickle.dump(all_annotations, open('all_annotations.pkl', 'wb')) | ||
182 | + | ||
183 | + # process detections and annotations | ||
184 | + for label in range(generator.num_classes()): | ||
185 | + if not generator.has_label(label): | ||
186 | + continue | ||
187 | + | ||
188 | + false_positives = np.zeros((0,)) | ||
189 | + true_positives = np.zeros((0,)) | ||
190 | + scores = np.zeros((0,)) | ||
191 | + num_annotations = 0.0 | ||
192 | + | ||
193 | + for i in range(generator.size()): | ||
194 | + detections = all_detections[i][label] | ||
195 | + annotations = all_annotations[i][label] | ||
196 | + num_annotations += annotations.shape[0] | ||
197 | + detected_annotations = [] | ||
198 | + | ||
199 | + for d in detections: | ||
200 | + scores = np.append(scores, d[4]) | ||
201 | + | ||
202 | + if annotations.shape[0] == 0: | ||
203 | + false_positives = np.append(false_positives, 1) | ||
204 | + true_positives = np.append(true_positives, 0) | ||
205 | + continue | ||
206 | + | ||
207 | + overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations) | ||
208 | + assigned_annotation = np.argmax(overlaps, axis=1) | ||
209 | + max_overlap = overlaps[0, assigned_annotation] | ||
210 | + | ||
211 | + if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: | ||
212 | + false_positives = np.append(false_positives, 0) | ||
213 | + true_positives = np.append(true_positives, 1) | ||
214 | + detected_annotations.append(assigned_annotation) | ||
215 | + else: | ||
216 | + false_positives = np.append(false_positives, 1) | ||
217 | + true_positives = np.append(true_positives, 0) | ||
218 | + | ||
219 | + # no annotations -> AP for this class is 0 (is this correct?) | ||
220 | + if num_annotations == 0: | ||
221 | + average_precisions[label] = 0, 0 | ||
222 | + continue | ||
223 | + | ||
224 | + # sort by score | ||
225 | + indices = np.argsort(-scores) | ||
226 | + false_positives = false_positives[indices] | ||
227 | + true_positives = true_positives[indices] | ||
228 | + | ||
229 | + # compute false positives and true positives | ||
230 | + false_positives = np.cumsum(false_positives) | ||
231 | + true_positives = np.cumsum(true_positives) | ||
232 | + | ||
233 | + # compute recall and precision | ||
234 | + recall = true_positives / num_annotations | ||
235 | + precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) | ||
236 | + | ||
237 | + # compute average precision | ||
238 | + average_precision = _compute_ap(recall, precision) | ||
239 | + average_precisions[label] = average_precision, num_annotations | ||
240 | + | ||
241 | + # inference time | ||
242 | + inference_time = np.sum(all_inferences) / generator.size() | ||
243 | + | ||
244 | + return average_precisions, inference_time |
retinaNet/keras_retinanet/utils/gpu.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2019 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import tensorflow as tf | ||
18 | + | ||
19 | + | ||
20 | +def setup_gpu(gpu_id): | ||
21 | + try: | ||
22 | + visible_gpu_indices = [int(id) for id in gpu_id.split(',')] | ||
23 | + available_gpus = tf.config.list_physical_devices('GPU') | ||
24 | + visible_gpus = [gpu for idx, gpu in enumerate(available_gpus) if idx in visible_gpu_indices] | ||
25 | + | ||
26 | + if visible_gpus: | ||
27 | + try: | ||
28 | + # Currently, memory growth needs to be the same across GPUs. | ||
29 | + for gpu in available_gpus: | ||
30 | + tf.config.experimental.set_memory_growth(gpu, True) | ||
31 | + | ||
32 | + # Use only the selcted gpu. | ||
33 | + tf.config.set_visible_devices(visible_gpus, 'GPU') | ||
34 | + except RuntimeError as e: | ||
35 | + # Visible devices must be set before GPUs have been initialized. | ||
36 | + print(e) | ||
37 | + | ||
38 | + logical_gpus = tf.config.list_logical_devices('GPU') | ||
39 | + print(len(available_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") | ||
40 | + else: | ||
41 | + tf.config.set_visible_devices([], 'GPU') | ||
42 | + except ValueError: | ||
43 | + tf.config.set_visible_devices([], 'GPU') |
retinaNet/keras_retinanet/utils/image.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from __future__ import division | ||
18 | +import numpy as np | ||
19 | +import cv2 | ||
20 | +from PIL import Image | ||
21 | + | ||
22 | +from .transform import change_transform_origin | ||
23 | + | ||
24 | + | ||
25 | +def read_image_bgr(path): | ||
26 | + """ Read an image in BGR format. | ||
27 | + | ||
28 | + Args | ||
29 | + path: Path to the image. | ||
30 | + """ | ||
31 | + # We deliberately don't use cv2.imread here, since it gives no feedback on errors while reading the image. | ||
32 | + image = np.ascontiguousarray(Image.open(path).convert('RGB')) | ||
33 | + return image[:, :, ::-1] | ||
34 | + | ||
35 | + | ||
36 | +def preprocess_image(x, mode='caffe'): | ||
37 | + """ Preprocess an image by subtracting the ImageNet mean. | ||
38 | + | ||
39 | + Args | ||
40 | + x: np.array of shape (None, None, 3) or (3, None, None). | ||
41 | + mode: One of "caffe" or "tf". | ||
42 | + - caffe: will zero-center each color channel with | ||
43 | + respect to the ImageNet dataset, without scaling. | ||
44 | + - tf: will scale pixels between -1 and 1, sample-wise. | ||
45 | + | ||
46 | + Returns | ||
47 | + The input with the ImageNet mean subtracted. | ||
48 | + """ | ||
49 | + # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py" | ||
50 | + # except for converting RGB -> BGR since we assume BGR already | ||
51 | + | ||
52 | + # covert always to float32 to keep compatibility with opencv | ||
53 | + x = x.astype(np.float32) | ||
54 | + | ||
55 | + if mode == 'tf': | ||
56 | + x /= 127.5 | ||
57 | + x -= 1. | ||
58 | + elif mode == 'caffe': | ||
59 | + x -= [103.939, 116.779, 123.68] | ||
60 | + | ||
61 | + return x | ||
62 | + | ||
63 | + | ||
64 | +def adjust_transform_for_image(transform, image, relative_translation): | ||
65 | + """ Adjust a transformation for a specific image. | ||
66 | + | ||
67 | + The translation of the matrix will be scaled with the size of the image. | ||
68 | + The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image. | ||
69 | + """ | ||
70 | + height, width, channels = image.shape | ||
71 | + | ||
72 | + result = transform | ||
73 | + | ||
74 | + # Scale the translation with the image size if specified. | ||
75 | + if relative_translation: | ||
76 | + result[0:2, 2] *= [width, height] | ||
77 | + | ||
78 | + # Move the origin of transformation. | ||
79 | + result = change_transform_origin(transform, (0.5 * width, 0.5 * height)) | ||
80 | + | ||
81 | + return result | ||
82 | + | ||
83 | + | ||
84 | +class TransformParameters: | ||
85 | + """ Struct holding parameters determining how to apply a transformation to an image. | ||
86 | + | ||
87 | + Args | ||
88 | + fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap' | ||
89 | + interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4' | ||
90 | + cval: Fill value to use with fill_mode='constant' | ||
91 | + relative_translation: If true (the default), interpret translation as a factor of the image size. | ||
92 | + If false, interpret it as absolute pixels. | ||
93 | + """ | ||
94 | + def __init__( | ||
95 | + self, | ||
96 | + fill_mode = 'nearest', | ||
97 | + interpolation = 'linear', | ||
98 | + cval = 0, | ||
99 | + relative_translation = True, | ||
100 | + ): | ||
101 | + self.fill_mode = fill_mode | ||
102 | + self.cval = cval | ||
103 | + self.interpolation = interpolation | ||
104 | + self.relative_translation = relative_translation | ||
105 | + | ||
106 | + def cvBorderMode(self): | ||
107 | + if self.fill_mode == 'constant': | ||
108 | + return cv2.BORDER_CONSTANT | ||
109 | + if self.fill_mode == 'nearest': | ||
110 | + return cv2.BORDER_REPLICATE | ||
111 | + if self.fill_mode == 'reflect': | ||
112 | + return cv2.BORDER_REFLECT_101 | ||
113 | + if self.fill_mode == 'wrap': | ||
114 | + return cv2.BORDER_WRAP | ||
115 | + | ||
116 | + def cvInterpolation(self): | ||
117 | + if self.interpolation == 'nearest': | ||
118 | + return cv2.INTER_NEAREST | ||
119 | + if self.interpolation == 'linear': | ||
120 | + return cv2.INTER_LINEAR | ||
121 | + if self.interpolation == 'cubic': | ||
122 | + return cv2.INTER_CUBIC | ||
123 | + if self.interpolation == 'area': | ||
124 | + return cv2.INTER_AREA | ||
125 | + if self.interpolation == 'lanczos4': | ||
126 | + return cv2.INTER_LANCZOS4 | ||
127 | + | ||
128 | + | ||
129 | +def apply_transform(matrix, image, params): | ||
130 | + """ | ||
131 | + Apply a transformation to an image. | ||
132 | + | ||
133 | + The origin of transformation is at the top left corner of the image. | ||
134 | + | ||
135 | + The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image. | ||
136 | + Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space. | ||
137 | + | ||
138 | + Args | ||
139 | + matrix: A homogeneous 3 by 3 matrix holding representing the transformation to apply. | ||
140 | + image: The image to transform. | ||
141 | + params: The transform parameters (see TransformParameters) | ||
142 | + """ | ||
143 | + output = cv2.warpAffine( | ||
144 | + image, | ||
145 | + matrix[:2, :], | ||
146 | + dsize = (image.shape[1], image.shape[0]), | ||
147 | + flags = params.cvInterpolation(), | ||
148 | + borderMode = params.cvBorderMode(), | ||
149 | + borderValue = params.cval, | ||
150 | + ) | ||
151 | + return output | ||
152 | + | ||
153 | + | ||
154 | +def compute_resize_scale(image_shape, min_side=800, max_side=1333): | ||
155 | + """ Compute an image scale such that the image size is constrained to min_side and max_side. | ||
156 | + | ||
157 | + Args | ||
158 | + min_side: The image's min side will be equal to min_side after resizing. | ||
159 | + max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. | ||
160 | + | ||
161 | + Returns | ||
162 | + A resizing scale. | ||
163 | + """ | ||
164 | + (rows, cols, _) = image_shape | ||
165 | + | ||
166 | + smallest_side = min(rows, cols) | ||
167 | + | ||
168 | + # rescale the image so the smallest side is min_side | ||
169 | + scale = min_side / smallest_side | ||
170 | + | ||
171 | + # check if the largest side is now greater than max_side, which can happen | ||
172 | + # when images have a large aspect ratio | ||
173 | + largest_side = max(rows, cols) | ||
174 | + if largest_side * scale > max_side: | ||
175 | + scale = max_side / largest_side | ||
176 | + | ||
177 | + return scale | ||
178 | + | ||
179 | + | ||
180 | +def resize_image(img, min_side=800, max_side=1333): | ||
181 | + """ Resize an image such that the size is constrained to min_side and max_side. | ||
182 | + | ||
183 | + Args | ||
184 | + min_side: The image's min side will be equal to min_side after resizing. | ||
185 | + max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. | ||
186 | + | ||
187 | + Returns | ||
188 | + A resized image. | ||
189 | + """ | ||
190 | + # compute scale to resize the image | ||
191 | + scale = compute_resize_scale(img.shape, min_side=min_side, max_side=max_side) | ||
192 | + | ||
193 | + # resize the image with the computed scale | ||
194 | + img = cv2.resize(img, None, fx=scale, fy=scale) | ||
195 | + | ||
196 | + return img, scale | ||
197 | + | ||
198 | + | ||
199 | +def _uniform(val_range): | ||
200 | + """ Uniformly sample from the given range. | ||
201 | + | ||
202 | + Args | ||
203 | + val_range: A pair of lower and upper bound. | ||
204 | + """ | ||
205 | + return np.random.uniform(val_range[0], val_range[1]) | ||
206 | + | ||
207 | + | ||
208 | +def _check_range(val_range, min_val=None, max_val=None): | ||
209 | + """ Check whether the range is a valid range. | ||
210 | + | ||
211 | + Args | ||
212 | + val_range: A pair of lower and upper bound. | ||
213 | + min_val: Minimal value for the lower bound. | ||
214 | + max_val: Maximal value for the upper bound. | ||
215 | + """ | ||
216 | + if val_range[0] > val_range[1]: | ||
217 | + raise ValueError('interval lower bound > upper bound') | ||
218 | + if min_val is not None and val_range[0] < min_val: | ||
219 | + raise ValueError('invalid interval lower bound') | ||
220 | + if max_val is not None and val_range[1] > max_val: | ||
221 | + raise ValueError('invalid interval upper bound') | ||
222 | + | ||
223 | + | ||
224 | +def _clip(image): | ||
225 | + """ | ||
226 | + Clip and convert an image to np.uint8. | ||
227 | + | ||
228 | + Args | ||
229 | + image: Image to clip. | ||
230 | + """ | ||
231 | + return np.clip(image, 0, 255).astype(np.uint8) | ||
232 | + | ||
233 | + | ||
234 | +class VisualEffect: | ||
235 | + """ Struct holding parameters and applying image color transformation. | ||
236 | + | ||
237 | + Args | ||
238 | + contrast_factor: A factor for adjusting contrast. Should be between 0 and 3. | ||
239 | + brightness_delta: Brightness offset between -1 and 1 added to the pixel values. | ||
240 | + hue_delta: Hue offset between -1 and 1 added to the hue channel. | ||
241 | + saturation_factor: A factor multiplying the saturation values of each pixel. | ||
242 | + """ | ||
243 | + | ||
244 | + def __init__( | ||
245 | + self, | ||
246 | + contrast_factor, | ||
247 | + brightness_delta, | ||
248 | + hue_delta, | ||
249 | + saturation_factor, | ||
250 | + ): | ||
251 | + self.contrast_factor = contrast_factor | ||
252 | + self.brightness_delta = brightness_delta | ||
253 | + self.hue_delta = hue_delta | ||
254 | + self.saturation_factor = saturation_factor | ||
255 | + | ||
256 | + def __call__(self, image): | ||
257 | + """ Apply a visual effect on the image. | ||
258 | + | ||
259 | + Args | ||
260 | + image: Image to adjust | ||
261 | + """ | ||
262 | + | ||
263 | + if self.contrast_factor: | ||
264 | + image = adjust_contrast(image, self.contrast_factor) | ||
265 | + if self.brightness_delta: | ||
266 | + image = adjust_brightness(image, self.brightness_delta) | ||
267 | + | ||
268 | + if self.hue_delta or self.saturation_factor: | ||
269 | + | ||
270 | + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | ||
271 | + | ||
272 | + if self.hue_delta: | ||
273 | + image = adjust_hue(image, self.hue_delta) | ||
274 | + if self.saturation_factor: | ||
275 | + image = adjust_saturation(image, self.saturation_factor) | ||
276 | + | ||
277 | + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) | ||
278 | + | ||
279 | + return image | ||
280 | + | ||
281 | + | ||
282 | +def random_visual_effect_generator( | ||
283 | + contrast_range=(0.9, 1.1), | ||
284 | + brightness_range=(-.1, .1), | ||
285 | + hue_range=(-0.05, 0.05), | ||
286 | + saturation_range=(0.95, 1.05) | ||
287 | +): | ||
288 | + """ Generate visual effect parameters uniformly sampled from the given intervals. | ||
289 | + | ||
290 | + Args | ||
291 | + contrast_factor: A factor interval for adjusting contrast. Should be between 0 and 3. | ||
292 | + brightness_delta: An interval between -1 and 1 for the amount added to the pixels. | ||
293 | + hue_delta: An interval between -1 and 1 for the amount added to the hue channel. | ||
294 | + The values are rotated if they exceed 180. | ||
295 | + saturation_factor: An interval for the factor multiplying the saturation values of each | ||
296 | + pixel. | ||
297 | + """ | ||
298 | + _check_range(contrast_range, 0) | ||
299 | + _check_range(brightness_range, -1, 1) | ||
300 | + _check_range(hue_range, -1, 1) | ||
301 | + _check_range(saturation_range, 0) | ||
302 | + | ||
303 | + def _generate(): | ||
304 | + while True: | ||
305 | + yield VisualEffect( | ||
306 | + contrast_factor=_uniform(contrast_range), | ||
307 | + brightness_delta=_uniform(brightness_range), | ||
308 | + hue_delta=_uniform(hue_range), | ||
309 | + saturation_factor=_uniform(saturation_range), | ||
310 | + ) | ||
311 | + | ||
312 | + return _generate() | ||
313 | + | ||
314 | + | ||
315 | +def adjust_contrast(image, factor): | ||
316 | + """ Adjust contrast of an image. | ||
317 | + | ||
318 | + Args | ||
319 | + image: Image to adjust. | ||
320 | + factor: A factor for adjusting contrast. | ||
321 | + """ | ||
322 | + mean = image.mean(axis=0).mean(axis=0) | ||
323 | + return _clip((image - mean) * factor + mean) | ||
324 | + | ||
325 | + | ||
326 | +def adjust_brightness(image, delta): | ||
327 | + """ Adjust brightness of an image | ||
328 | + | ||
329 | + Args | ||
330 | + image: Image to adjust. | ||
331 | + delta: Brightness offset between -1 and 1 added to the pixel values. | ||
332 | + """ | ||
333 | + return _clip(image + delta * 255) | ||
334 | + | ||
335 | + | ||
336 | +def adjust_hue(image, delta): | ||
337 | + """ Adjust hue of an image. | ||
338 | + | ||
339 | + Args | ||
340 | + image: Image to adjust. | ||
341 | + delta: An interval between -1 and 1 for the amount added to the hue channel. | ||
342 | + The values are rotated if they exceed 180. | ||
343 | + """ | ||
344 | + image[..., 0] = np.mod(image[..., 0] + delta * 180, 180) | ||
345 | + return image | ||
346 | + | ||
347 | + | ||
348 | +def adjust_saturation(image, factor): | ||
349 | + """ Adjust saturation of an image. | ||
350 | + | ||
351 | + Args | ||
352 | + image: Image to adjust. | ||
353 | + factor: An interval for the factor multiplying the saturation values of each pixel. | ||
354 | + """ | ||
355 | + image[..., 1] = np.clip(image[..., 1] * factor, 0 , 255) | ||
356 | + return image |
retinaNet/keras_retinanet/utils/model.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | + | ||
18 | +def freeze(model): | ||
19 | + """ Set all layers in a model to non-trainable. | ||
20 | + | ||
21 | + The weights for these layers will not be updated during training. | ||
22 | + | ||
23 | + This function modifies the given model in-place, | ||
24 | + but it also returns the modified model to allow easy chaining with other functions. | ||
25 | + """ | ||
26 | + for layer in model.layers: | ||
27 | + layer.trainable = False | ||
28 | + return model |
1 | +""" | ||
2 | +Copyright 2017-2019 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from __future__ import print_function | ||
18 | + | ||
19 | +import tensorflow as tf | ||
20 | +import sys | ||
21 | + | ||
22 | +MINIMUM_TF_VERSION = 2, 3, 0 | ||
23 | +BLACKLISTED_TF_VERSIONS = [] | ||
24 | + | ||
25 | + | ||
26 | +def tf_version(): | ||
27 | + """ Get the Tensorflow version. | ||
28 | + Returns | ||
29 | + tuple of (major, minor, patch). | ||
30 | + """ | ||
31 | + return tuple(map(int, tf.version.VERSION.split('-')[0].split('.'))) | ||
32 | + | ||
33 | + | ||
34 | +def tf_version_ok(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS): | ||
35 | + """ Check if the current Tensorflow version is higher than the minimum version. | ||
36 | + """ | ||
37 | + return tf_version() >= minimum_tf_version and tf_version() not in blacklisted | ||
38 | + | ||
39 | + | ||
40 | +def assert_tf_version(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS): | ||
41 | + """ Assert that the Tensorflow version is up to date. | ||
42 | + """ | ||
43 | + detected = tf.version.VERSION | ||
44 | + required = '.'.join(map(str, minimum_tf_version)) | ||
45 | + assert(tf_version_ok(minimum_tf_version, blacklisted)), 'You are using tensorflow version {}. The minimum required version is {} (blacklisted: {}).'.format(detected, required, blacklisted) | ||
46 | + | ||
47 | + | ||
48 | +def check_tf_version(): | ||
49 | + """ Check that the Tensorflow version is up to date. If it isn't, print an error message and exit the script. | ||
50 | + """ | ||
51 | + try: | ||
52 | + assert_tf_version() | ||
53 | + except AssertionError as e: | ||
54 | + print(e, file=sys.stderr) | ||
55 | + sys.exit(1) |
retinaNet/keras_retinanet/utils/transform.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import numpy as np | ||
18 | + | ||
19 | +DEFAULT_PRNG = np.random | ||
20 | + | ||
21 | + | ||
22 | +def colvec(*args): | ||
23 | + """ Create a numpy array representing a column vector. """ | ||
24 | + return np.array([args]).T | ||
25 | + | ||
26 | + | ||
27 | +def transform_aabb(transform, aabb): | ||
28 | + """ Apply a transformation to an axis aligned bounding box. | ||
29 | + | ||
30 | + The result is a new AABB in the same coordinate system as the original AABB. | ||
31 | + The new AABB contains all corner points of the original AABB after applying the given transformation. | ||
32 | + | ||
33 | + Args | ||
34 | + transform: The transformation to apply. | ||
35 | + x1: The minimum x value of the AABB. | ||
36 | + y1: The minimum y value of the AABB. | ||
37 | + x2: The maximum x value of the AABB. | ||
38 | + y2: The maximum y value of the AABB. | ||
39 | + Returns | ||
40 | + The new AABB as tuple (x1, y1, x2, y2) | ||
41 | + """ | ||
42 | + x1, y1, x2, y2 = aabb | ||
43 | + # Transform all 4 corners of the AABB. | ||
44 | + points = transform.dot([ | ||
45 | + [x1, x2, x1, x2], | ||
46 | + [y1, y2, y2, y1], | ||
47 | + [1, 1, 1, 1 ], | ||
48 | + ]) | ||
49 | + | ||
50 | + # Extract the min and max corners again. | ||
51 | + min_corner = points.min(axis=1) | ||
52 | + max_corner = points.max(axis=1) | ||
53 | + | ||
54 | + return [min_corner[0], min_corner[1], max_corner[0], max_corner[1]] | ||
55 | + | ||
56 | + | ||
57 | +def _random_vector(min, max, prng=DEFAULT_PRNG): | ||
58 | + """ Construct a random vector between min and max. | ||
59 | + Args | ||
60 | + min: the minimum value for each component | ||
61 | + max: the maximum value for each component | ||
62 | + """ | ||
63 | + min = np.array(min) | ||
64 | + max = np.array(max) | ||
65 | + assert min.shape == max.shape | ||
66 | + assert len(min.shape) == 1 | ||
67 | + return prng.uniform(min, max) | ||
68 | + | ||
69 | + | ||
70 | +def rotation(angle): | ||
71 | + """ Construct a homogeneous 2D rotation matrix. | ||
72 | + Args | ||
73 | + angle: the angle in radians | ||
74 | + Returns | ||
75 | + the rotation matrix as 3 by 3 numpy array | ||
76 | + """ | ||
77 | + return np.array([ | ||
78 | + [np.cos(angle), -np.sin(angle), 0], | ||
79 | + [np.sin(angle), np.cos(angle), 0], | ||
80 | + [0, 0, 1] | ||
81 | + ]) | ||
82 | + | ||
83 | + | ||
84 | +def random_rotation(min, max, prng=DEFAULT_PRNG): | ||
85 | + """ Construct a random rotation between -max and max. | ||
86 | + Args | ||
87 | + min: a scalar for the minimum absolute angle in radians | ||
88 | + max: a scalar for the maximum absolute angle in radians | ||
89 | + prng: the pseudo-random number generator to use. | ||
90 | + Returns | ||
91 | + a homogeneous 3 by 3 rotation matrix | ||
92 | + """ | ||
93 | + return rotation(prng.uniform(min, max)) | ||
94 | + | ||
95 | + | ||
96 | +def translation(translation): | ||
97 | + """ Construct a homogeneous 2D translation matrix. | ||
98 | + # Arguments | ||
99 | + translation: the translation 2D vector | ||
100 | + # Returns | ||
101 | + the translation matrix as 3 by 3 numpy array | ||
102 | + """ | ||
103 | + return np.array([ | ||
104 | + [1, 0, translation[0]], | ||
105 | + [0, 1, translation[1]], | ||
106 | + [0, 0, 1] | ||
107 | + ]) | ||
108 | + | ||
109 | + | ||
110 | +def random_translation(min, max, prng=DEFAULT_PRNG): | ||
111 | + """ Construct a random 2D translation between min and max. | ||
112 | + Args | ||
113 | + min: a 2D vector with the minimum translation for each dimension | ||
114 | + max: a 2D vector with the maximum translation for each dimension | ||
115 | + prng: the pseudo-random number generator to use. | ||
116 | + Returns | ||
117 | + a homogeneous 3 by 3 translation matrix | ||
118 | + """ | ||
119 | + return translation(_random_vector(min, max, prng)) | ||
120 | + | ||
121 | + | ||
122 | +def shear(angle): | ||
123 | + """ Construct a homogeneous 2D shear matrix. | ||
124 | + Args | ||
125 | + angle: the shear angle in radians | ||
126 | + Returns | ||
127 | + the shear matrix as 3 by 3 numpy array | ||
128 | + """ | ||
129 | + return np.array([ | ||
130 | + [1, -np.sin(angle), 0], | ||
131 | + [0, np.cos(angle), 0], | ||
132 | + [0, 0, 1] | ||
133 | + ]) | ||
134 | + | ||
135 | + | ||
136 | +def random_shear(min, max, prng=DEFAULT_PRNG): | ||
137 | + """ Construct a random 2D shear matrix with shear angle between -max and max. | ||
138 | + Args | ||
139 | + min: the minimum shear angle in radians. | ||
140 | + max: the maximum shear angle in radians. | ||
141 | + prng: the pseudo-random number generator to use. | ||
142 | + Returns | ||
143 | + a homogeneous 3 by 3 shear matrix | ||
144 | + """ | ||
145 | + return shear(prng.uniform(min, max)) | ||
146 | + | ||
147 | + | ||
148 | +def scaling(factor): | ||
149 | + """ Construct a homogeneous 2D scaling matrix. | ||
150 | + Args | ||
151 | + factor: a 2D vector for X and Y scaling | ||
152 | + Returns | ||
153 | + the zoom matrix as 3 by 3 numpy array | ||
154 | + """ | ||
155 | + return np.array([ | ||
156 | + [factor[0], 0, 0], | ||
157 | + [0, factor[1], 0], | ||
158 | + [0, 0, 1] | ||
159 | + ]) | ||
160 | + | ||
161 | + | ||
162 | +def random_scaling(min, max, prng=DEFAULT_PRNG): | ||
163 | + """ Construct a random 2D scale matrix between -max and max. | ||
164 | + Args | ||
165 | + min: a 2D vector containing the minimum scaling factor for X and Y. | ||
166 | + min: a 2D vector containing The maximum scaling factor for X and Y. | ||
167 | + prng: the pseudo-random number generator to use. | ||
168 | + Returns | ||
169 | + a homogeneous 3 by 3 scaling matrix | ||
170 | + """ | ||
171 | + return scaling(_random_vector(min, max, prng)) | ||
172 | + | ||
173 | + | ||
174 | +def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG): | ||
175 | + """ Construct a transformation randomly containing X/Y flips (or not). | ||
176 | + Args | ||
177 | + flip_x_chance: The chance that the result will contain a flip along the X axis. | ||
178 | + flip_y_chance: The chance that the result will contain a flip along the Y axis. | ||
179 | + prng: The pseudo-random number generator to use. | ||
180 | + Returns | ||
181 | + a homogeneous 3 by 3 transformation matrix | ||
182 | + """ | ||
183 | + flip_x = prng.uniform(0, 1) < flip_x_chance | ||
184 | + flip_y = prng.uniform(0, 1) < flip_y_chance | ||
185 | + # 1 - 2 * bool gives 1 for False and -1 for True. | ||
186 | + return scaling((1 - 2 * flip_x, 1 - 2 * flip_y)) | ||
187 | + | ||
188 | + | ||
189 | +def change_transform_origin(transform, center): | ||
190 | + """ Create a new transform representing the same transformation, | ||
191 | + only with the origin of the linear part changed. | ||
192 | + Args | ||
193 | + transform: the transformation matrix | ||
194 | + center: the new origin of the transformation | ||
195 | + Returns | ||
196 | + translate(center) * transform * translate(-center) | ||
197 | + """ | ||
198 | + center = np.array(center) | ||
199 | + return np.linalg.multi_dot([translation(center), transform, translation(-center)]) | ||
200 | + | ||
201 | + | ||
202 | +def random_transform( | ||
203 | + min_rotation=0, | ||
204 | + max_rotation=0, | ||
205 | + min_translation=(0, 0), | ||
206 | + max_translation=(0, 0), | ||
207 | + min_shear=0, | ||
208 | + max_shear=0, | ||
209 | + min_scaling=(1, 1), | ||
210 | + max_scaling=(1, 1), | ||
211 | + flip_x_chance=0, | ||
212 | + flip_y_chance=0, | ||
213 | + prng=DEFAULT_PRNG | ||
214 | +): | ||
215 | + """ Create a random transformation. | ||
216 | + | ||
217 | + The transformation consists of the following operations in this order (from left to right): | ||
218 | + * rotation | ||
219 | + * translation | ||
220 | + * shear | ||
221 | + * scaling | ||
222 | + * flip x (if applied) | ||
223 | + * flip y (if applied) | ||
224 | + | ||
225 | + Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation | ||
226 | + as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width. | ||
227 | + Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret | ||
228 | + the translation directly as pixel distances instead. | ||
229 | + | ||
230 | + Args | ||
231 | + min_rotation: The minimum rotation in radians for the transform as scalar. | ||
232 | + max_rotation: The maximum rotation in radians for the transform as scalar. | ||
233 | + min_translation: The minimum translation for the transform as 2D column vector. | ||
234 | + max_translation: The maximum translation for the transform as 2D column vector. | ||
235 | + min_shear: The minimum shear angle for the transform in radians. | ||
236 | + max_shear: The maximum shear angle for the transform in radians. | ||
237 | + min_scaling: The minimum scaling for the transform as 2D column vector. | ||
238 | + max_scaling: The maximum scaling for the transform as 2D column vector. | ||
239 | + flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction. | ||
240 | + flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction. | ||
241 | + prng: The pseudo-random number generator to use. | ||
242 | + """ | ||
243 | + return np.linalg.multi_dot([ | ||
244 | + random_rotation(min_rotation, max_rotation, prng), | ||
245 | + random_translation(min_translation, max_translation, prng), | ||
246 | + random_shear(min_shear, max_shear, prng), | ||
247 | + random_scaling(min_scaling, max_scaling, prng), | ||
248 | + random_flip(flip_x_chance, flip_y_chance, prng) | ||
249 | + ]) | ||
250 | + | ||
251 | + | ||
252 | +def random_transform_generator(prng=None, **kwargs): | ||
253 | + """ Create a random transform generator. | ||
254 | + | ||
255 | + Uses a dedicated, newly created, properly seeded PRNG by default instead of the global DEFAULT_PRNG. | ||
256 | + | ||
257 | + The transformation consists of the following operations in this order (from left to right): | ||
258 | + * rotation | ||
259 | + * translation | ||
260 | + * shear | ||
261 | + * scaling | ||
262 | + * flip x (if applied) | ||
263 | + * flip y (if applied) | ||
264 | + | ||
265 | + Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation | ||
266 | + as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width. | ||
267 | + Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret | ||
268 | + the translation directly as pixel distances instead. | ||
269 | + | ||
270 | + Args | ||
271 | + min_rotation: The minimum rotation in radians for the transform as scalar. | ||
272 | + max_rotation: The maximum rotation in radians for the transform as scalar. | ||
273 | + min_translation: The minimum translation for the transform as 2D column vector. | ||
274 | + max_translation: The maximum translation for the transform as 2D column vector. | ||
275 | + min_shear: The minimum shear angle for the transform in radians. | ||
276 | + max_shear: The maximum shear angle for the transform in radians. | ||
277 | + min_scaling: The minimum scaling for the transform as 2D column vector. | ||
278 | + max_scaling: The maximum scaling for the transform as 2D column vector. | ||
279 | + flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction. | ||
280 | + flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction. | ||
281 | + prng: The pseudo-random number generator to use. | ||
282 | + """ | ||
283 | + | ||
284 | + if prng is None: | ||
285 | + # RandomState automatically seeds using the best available method. | ||
286 | + prng = np.random.RandomState() | ||
287 | + | ||
288 | + while True: | ||
289 | + yield random_transform(prng=prng, **kwargs) |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import cv2 | ||
18 | +import numpy as np | ||
19 | + | ||
20 | +from .colors import label_color | ||
21 | + | ||
22 | + | ||
23 | +def draw_box(image, box, color, thickness=2): | ||
24 | + """ Draws a box on an image with a given color. | ||
25 | + | ||
26 | + # Arguments | ||
27 | + image : The image to draw on. | ||
28 | + box : A list of 4 elements (x1, y1, x2, y2). | ||
29 | + color : The color of the box. | ||
30 | + thickness : The thickness of the lines to draw a box with. | ||
31 | + """ | ||
32 | + b = np.array(box).astype(int) | ||
33 | + cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA) | ||
34 | + | ||
35 | + | ||
36 | +def draw_caption(image, box, caption): | ||
37 | + """ Draws a caption above the box in an image. | ||
38 | + | ||
39 | + # Arguments | ||
40 | + image : The image to draw on. | ||
41 | + box : A list of 4 elements (x1, y1, x2, y2). | ||
42 | + caption : String containing the text to draw. | ||
43 | + """ | ||
44 | + b = np.array(box).astype(int) | ||
45 | + cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2) | ||
46 | + cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1) | ||
47 | + | ||
48 | + | ||
49 | +def draw_boxes(image, boxes, color, thickness=2): | ||
50 | + """ Draws boxes on an image with a given color. | ||
51 | + | ||
52 | + # Arguments | ||
53 | + image : The image to draw on. | ||
54 | + boxes : A [N, 4] matrix (x1, y1, x2, y2). | ||
55 | + color : The color of the boxes. | ||
56 | + thickness : The thickness of the lines to draw boxes with. | ||
57 | + """ | ||
58 | + for b in boxes: | ||
59 | + draw_box(image, b, color, thickness=thickness) | ||
60 | + | ||
61 | + | ||
62 | +def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5): | ||
63 | + """ Draws detections in an image. | ||
64 | + | ||
65 | + # Arguments | ||
66 | + image : The image to draw on. | ||
67 | + boxes : A [N, 4] matrix (x1, y1, x2, y2). | ||
68 | + scores : A list of N classification scores. | ||
69 | + labels : A list of N labels. | ||
70 | + color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. | ||
71 | + label_to_name : (optional) Functor for mapping a label to a name. | ||
72 | + score_threshold : Threshold used for determining what detections to draw. | ||
73 | + """ | ||
74 | + selection = np.where(scores > score_threshold)[0] | ||
75 | + | ||
76 | + for i in selection: | ||
77 | + c = color if color is not None else label_color(labels[i]) | ||
78 | + draw_box(image, boxes[i, :], color=c) | ||
79 | + | ||
80 | + # draw labels | ||
81 | + caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i]) | ||
82 | + draw_caption(image, boxes[i, :], caption) | ||
83 | + | ||
84 | + | ||
85 | +def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None): | ||
86 | + """ Draws annotations in an image. | ||
87 | + | ||
88 | + # Arguments | ||
89 | + image : The image to draw on. | ||
90 | + annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]). | ||
91 | + color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. | ||
92 | + label_to_name : (optional) Functor for mapping a label to a name. | ||
93 | + """ | ||
94 | + if isinstance(annotations, np.ndarray): | ||
95 | + annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]} | ||
96 | + | ||
97 | + assert('bboxes' in annotations) | ||
98 | + assert('labels' in annotations) | ||
99 | + assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0]) | ||
100 | + | ||
101 | + for i in range(annotations['bboxes'].shape[0]): | ||
102 | + label = annotations['labels'][i] | ||
103 | + c = color if color is not None else label_color(label) | ||
104 | + caption = '{}'.format(label_to_name(label) if label_to_name else label) | ||
105 | + draw_caption(image, annotations['bboxes'][i], caption) | ||
106 | + draw_box(image, annotations['bboxes'][i], color=c) |
retinaNet/requirements.txt
0 → 100644
retinaNet/setup.cfg
0 → 100644
1 | +# ignore: | ||
2 | +# E201 whitespace after '[' | ||
3 | +# E202 whitespace before ']' | ||
4 | +# E203 whitespace before ':' | ||
5 | +# E221 multiple spaces before operator | ||
6 | +# E241 multiple spaces after ',' | ||
7 | +# E251 unexpected spaces around keyword / parameter equals | ||
8 | +# E501 line too long (85 > 79 characters) | ||
9 | +# W504 line break after binary operator | ||
10 | +[tool:pytest] | ||
11 | +flake8-max-line-length = 100 | ||
12 | +flake8-ignore = E201 E202 E203 E221 E241 E251 E402 E501 W504 |
retinaNet/setup.py
0 → 100644
1 | +import setuptools | ||
2 | +from setuptools.extension import Extension | ||
3 | +from distutils.command.build_ext import build_ext as DistUtilsBuildExt | ||
4 | + | ||
5 | + | ||
6 | +class BuildExtension(setuptools.Command): | ||
7 | + description = DistUtilsBuildExt.description | ||
8 | + user_options = DistUtilsBuildExt.user_options | ||
9 | + boolean_options = DistUtilsBuildExt.boolean_options | ||
10 | + help_options = DistUtilsBuildExt.help_options | ||
11 | + | ||
12 | + def __init__(self, *args, **kwargs): | ||
13 | + from setuptools.command.build_ext import build_ext as SetupToolsBuildExt | ||
14 | + | ||
15 | + # Bypass __setatrr__ to avoid infinite recursion. | ||
16 | + self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs) | ||
17 | + | ||
18 | + def __getattr__(self, name): | ||
19 | + return getattr(self._command, name) | ||
20 | + | ||
21 | + def __setattr__(self, name, value): | ||
22 | + setattr(self._command, name, value) | ||
23 | + | ||
24 | + def initialize_options(self, *args, **kwargs): | ||
25 | + return self._command.initialize_options(*args, **kwargs) | ||
26 | + | ||
27 | + def finalize_options(self, *args, **kwargs): | ||
28 | + ret = self._command.finalize_options(*args, **kwargs) | ||
29 | + import numpy | ||
30 | + self.include_dirs.append(numpy.get_include()) | ||
31 | + return ret | ||
32 | + | ||
33 | + def run(self, *args, **kwargs): | ||
34 | + return self._command.run(*args, **kwargs) | ||
35 | + | ||
36 | + | ||
37 | +extensions = [ | ||
38 | + Extension( | ||
39 | + 'keras_retinanet.utils.compute_overlap', | ||
40 | + ['keras_retinanet/utils/compute_overlap.pyx'] | ||
41 | + ), | ||
42 | +] | ||
43 | + | ||
44 | + | ||
45 | +setuptools.setup( | ||
46 | + name = 'keras-retinanet', | ||
47 | + version = '1.0.0', | ||
48 | + description = 'Keras implementation of RetinaNet object detection.', | ||
49 | + url = 'https://github.com/fizyr/keras-retinanet', | ||
50 | + author = 'Hans Gaiser', | ||
51 | + author_email = 'h.gaiser@fizyr.com', | ||
52 | + maintainer = 'Hans Gaiser', | ||
53 | + maintainer_email = 'h.gaiser@fizyr.com', | ||
54 | + cmdclass = {'build_ext': BuildExtension}, | ||
55 | + packages = setuptools.find_packages(), | ||
56 | + install_requires = ['keras-resnet==0.2.0', 'six', 'numpy', 'cython', 'Pillow', 'opencv-python', 'progressbar2'], | ||
57 | + entry_points = { | ||
58 | + 'console_scripts': [ | ||
59 | + 'retinanet-train=keras_retinanet.bin.train:main', | ||
60 | + 'retinanet-evaluate=keras_retinanet.bin.evaluate:main', | ||
61 | + 'retinanet-debug=keras_retinanet.bin.debug:main', | ||
62 | + 'retinanet-convert-model=keras_retinanet.bin.convert_model:main', | ||
63 | + ], | ||
64 | + }, | ||
65 | + ext_modules = extensions, | ||
66 | + setup_requires = ["cython>=0.28", "numpy>=1.14.0"] | ||
67 | +) |
retinaNet/tests/__init__.py
0 → 100644
File mode changed
retinaNet/tests/backend/__init__.py
0 → 100644
File mode changed
retinaNet/tests/backend/test_common.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import numpy as np | ||
18 | +from tensorflow import keras | ||
19 | +import keras_retinanet.backend | ||
20 | + | ||
21 | + | ||
22 | +def test_bbox_transform_inv(): | ||
23 | + boxes = np.array([[ | ||
24 | + [100, 100, 200, 200], | ||
25 | + [100, 100, 300, 300], | ||
26 | + [100, 100, 200, 300], | ||
27 | + [100, 100, 300, 200], | ||
28 | + [80, 120, 200, 200], | ||
29 | + [80, 120, 300, 300], | ||
30 | + [80, 120, 200, 300], | ||
31 | + [80, 120, 300, 200], | ||
32 | + ]]) | ||
33 | + boxes = keras.backend.variable(boxes) | ||
34 | + | ||
35 | + deltas = np.array([[ | ||
36 | + [0 , 0 , 0 , 0 ], | ||
37 | + [0 , 0.1, 0 , 0 ], | ||
38 | + [-0.3, 0 , 0 , 0 ], | ||
39 | + [0.2 , 0.2, 0 , 0 ], | ||
40 | + [0 , 0 , 0.1 , 0 ], | ||
41 | + [0 , 0 , 0 , -0.3], | ||
42 | + [0 , 0 , 0.2 , 0.2 ], | ||
43 | + [0.1 , 0.2, -0.3, 0.4 ], | ||
44 | + ]]) | ||
45 | + deltas = keras.backend.variable(deltas) | ||
46 | + | ||
47 | + expected = np.array([[ | ||
48 | + [100 , 100 , 200 , 200 ], | ||
49 | + [100 , 104 , 300 , 300 ], | ||
50 | + [ 94 , 100 , 200 , 300 ], | ||
51 | + [108 , 104 , 300 , 200 ], | ||
52 | + [ 80 , 120 , 202.4 , 200 ], | ||
53 | + [ 80 , 120 , 300 , 289.2], | ||
54 | + [ 80 , 120 , 204.8 , 307.2], | ||
55 | + [ 84.4, 123.2, 286.8 , 206.4] | ||
56 | + ]]) | ||
57 | + | ||
58 | + result = keras_retinanet.backend.bbox_transform_inv(boxes, deltas) | ||
59 | + result = keras.backend.eval(result) | ||
60 | + | ||
61 | + np.testing.assert_array_almost_equal(result, expected, decimal=2) | ||
62 | + | ||
63 | + | ||
64 | +def test_shift(): | ||
65 | + shape = (2, 3) | ||
66 | + stride = 8 | ||
67 | + | ||
68 | + anchors = np.array([ | ||
69 | + [-8, -8, 8, 8], | ||
70 | + [-16, -16, 16, 16], | ||
71 | + [-12, -12, 12, 12], | ||
72 | + [-12, -16, 12, 16], | ||
73 | + [-16, -12, 16, 12] | ||
74 | + ], dtype=keras.backend.floatx()) | ||
75 | + | ||
76 | + expected = [ | ||
77 | + # anchors for (0, 0) | ||
78 | + [4 - 8, 4 - 8, 4 + 8, 4 + 8], | ||
79 | + [4 - 16, 4 - 16, 4 + 16, 4 + 16], | ||
80 | + [4 - 12, 4 - 12, 4 + 12, 4 + 12], | ||
81 | + [4 - 12, 4 - 16, 4 + 12, 4 + 16], | ||
82 | + [4 - 16, 4 - 12, 4 + 16, 4 + 12], | ||
83 | + | ||
84 | + # anchors for (0, 1) | ||
85 | + [12 - 8, 4 - 8, 12 + 8, 4 + 8], | ||
86 | + [12 - 16, 4 - 16, 12 + 16, 4 + 16], | ||
87 | + [12 - 12, 4 - 12, 12 + 12, 4 + 12], | ||
88 | + [12 - 12, 4 - 16, 12 + 12, 4 + 16], | ||
89 | + [12 - 16, 4 - 12, 12 + 16, 4 + 12], | ||
90 | + | ||
91 | + # anchors for (0, 2) | ||
92 | + [20 - 8, 4 - 8, 20 + 8, 4 + 8], | ||
93 | + [20 - 16, 4 - 16, 20 + 16, 4 + 16], | ||
94 | + [20 - 12, 4 - 12, 20 + 12, 4 + 12], | ||
95 | + [20 - 12, 4 - 16, 20 + 12, 4 + 16], | ||
96 | + [20 - 16, 4 - 12, 20 + 16, 4 + 12], | ||
97 | + | ||
98 | + # anchors for (1, 0) | ||
99 | + [4 - 8, 12 - 8, 4 + 8, 12 + 8], | ||
100 | + [4 - 16, 12 - 16, 4 + 16, 12 + 16], | ||
101 | + [4 - 12, 12 - 12, 4 + 12, 12 + 12], | ||
102 | + [4 - 12, 12 - 16, 4 + 12, 12 + 16], | ||
103 | + [4 - 16, 12 - 12, 4 + 16, 12 + 12], | ||
104 | + | ||
105 | + # anchors for (1, 1) | ||
106 | + [12 - 8, 12 - 8, 12 + 8, 12 + 8], | ||
107 | + [12 - 16, 12 - 16, 12 + 16, 12 + 16], | ||
108 | + [12 - 12, 12 - 12, 12 + 12, 12 + 12], | ||
109 | + [12 - 12, 12 - 16, 12 + 12, 12 + 16], | ||
110 | + [12 - 16, 12 - 12, 12 + 16, 12 + 12], | ||
111 | + | ||
112 | + # anchors for (1, 2) | ||
113 | + [20 - 8, 12 - 8, 20 + 8, 12 + 8], | ||
114 | + [20 - 16, 12 - 16, 20 + 16, 12 + 16], | ||
115 | + [20 - 12, 12 - 12, 20 + 12, 12 + 12], | ||
116 | + [20 - 12, 12 - 16, 20 + 12, 12 + 16], | ||
117 | + [20 - 16, 12 - 12, 20 + 16, 12 + 12], | ||
118 | + ] | ||
119 | + | ||
120 | + result = keras_retinanet.backend.shift(shape, stride, anchors) | ||
121 | + result = keras.backend.eval(result) | ||
122 | + | ||
123 | + np.testing.assert_array_equal(result, expected) |
retinaNet/tests/bin/test_train.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import keras_retinanet.backend | ||
18 | +import keras_retinanet.bin.train | ||
19 | +from tensorflow import keras | ||
20 | + | ||
21 | +import warnings | ||
22 | + | ||
23 | +import pytest | ||
24 | + | ||
25 | + | ||
26 | +@pytest.fixture(autouse=True) | ||
27 | +def clear_session(): | ||
28 | + # run before test (do nothing) | ||
29 | + yield | ||
30 | + # run after test, clear keras session | ||
31 | + keras.backend.clear_session() | ||
32 | + | ||
33 | + | ||
34 | +def test_coco(): | ||
35 | + # ignore warnings in this test | ||
36 | + warnings.simplefilter('ignore') | ||
37 | + | ||
38 | + # run training / evaluation | ||
39 | + keras_retinanet.bin.train.main([ | ||
40 | + '--epochs=1', | ||
41 | + '--steps=1', | ||
42 | + '--no-weights', | ||
43 | + '--no-snapshots', | ||
44 | + 'coco', | ||
45 | + 'tests/test-data/coco', | ||
46 | + ]) | ||
47 | + | ||
48 | + | ||
49 | +def test_pascal(): | ||
50 | + # ignore warnings in this test | ||
51 | + warnings.simplefilter('ignore') | ||
52 | + | ||
53 | + # run training / evaluation | ||
54 | + keras_retinanet.bin.train.main([ | ||
55 | + '--epochs=1', | ||
56 | + '--steps=1', | ||
57 | + '--no-weights', | ||
58 | + '--no-snapshots', | ||
59 | + 'pascal', | ||
60 | + 'tests/test-data/pascal', | ||
61 | + ]) | ||
62 | + | ||
63 | + | ||
64 | +def test_csv(): | ||
65 | + # ignore warnings in this test | ||
66 | + warnings.simplefilter('ignore') | ||
67 | + | ||
68 | + # run training / evaluation | ||
69 | + keras_retinanet.bin.train.main([ | ||
70 | + '--epochs=1', | ||
71 | + '--steps=1', | ||
72 | + '--no-weights', | ||
73 | + '--no-snapshots', | ||
74 | + 'csv', | ||
75 | + 'tests/test-data/csv/annotations.csv', | ||
76 | + 'tests/test-data/csv/classes.csv', | ||
77 | + ]) | ||
78 | + | ||
79 | + | ||
80 | +def test_vgg(): | ||
81 | + # ignore warnings in this test | ||
82 | + warnings.simplefilter('ignore') | ||
83 | + | ||
84 | + # run training / evaluation | ||
85 | + keras_retinanet.bin.train.main([ | ||
86 | + '--backbone=vgg16', | ||
87 | + '--epochs=1', | ||
88 | + '--steps=1', | ||
89 | + '--no-weights', | ||
90 | + '--no-snapshots', | ||
91 | + '--freeze-backbone', | ||
92 | + 'coco', | ||
93 | + 'tests/test-data/coco', | ||
94 | + ]) |
retinaNet/tests/layers/__init__.py
0 → 100644
File mode changed
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +import keras_retinanet.backend | ||
19 | +import keras_retinanet.layers | ||
20 | + | ||
21 | +import numpy as np | ||
22 | + | ||
23 | + | ||
24 | +class TestFilterDetections(object): | ||
25 | + def test_simple(self): | ||
26 | + # create simple FilterDetections layer | ||
27 | + filter_detections_layer = keras_retinanet.layers.FilterDetections() | ||
28 | + | ||
29 | + # create simple input | ||
30 | + boxes = np.array([[ | ||
31 | + [0, 0, 10, 10], | ||
32 | + [0, 0, 10, 10], # this will be suppressed | ||
33 | + ]], dtype=keras.backend.floatx()) | ||
34 | + boxes = keras.backend.constant(boxes) | ||
35 | + | ||
36 | + classification = np.array([[ | ||
37 | + [0, 0.9], # this will be suppressed | ||
38 | + [0, 1], | ||
39 | + ]], dtype=keras.backend.floatx()) | ||
40 | + classification = keras.backend.constant(classification) | ||
41 | + | ||
42 | + # compute output | ||
43 | + actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) | ||
44 | + actual_boxes = keras.backend.eval(actual_boxes) | ||
45 | + actual_scores = keras.backend.eval(actual_scores) | ||
46 | + actual_labels = keras.backend.eval(actual_labels) | ||
47 | + | ||
48 | + # define expected output | ||
49 | + expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) | ||
50 | + expected_boxes[0, 0, :] = [0, 0, 10, 10] | ||
51 | + | ||
52 | + expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) | ||
53 | + expected_scores[0, 0] = 1 | ||
54 | + | ||
55 | + expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) | ||
56 | + expected_labels[0, 0] = 1 | ||
57 | + | ||
58 | + # assert actual and expected are equal | ||
59 | + np.testing.assert_array_equal(actual_boxes, expected_boxes) | ||
60 | + np.testing.assert_array_equal(actual_scores, expected_scores) | ||
61 | + np.testing.assert_array_equal(actual_labels, expected_labels) | ||
62 | + | ||
63 | + def test_simple_with_other(self): | ||
64 | + # create simple FilterDetections layer | ||
65 | + filter_detections_layer = keras_retinanet.layers.FilterDetections() | ||
66 | + | ||
67 | + # create simple input | ||
68 | + boxes = np.array([[ | ||
69 | + [0, 0, 10, 10], | ||
70 | + [0, 0, 10, 10], # this will be suppressed | ||
71 | + ]], dtype=keras.backend.floatx()) | ||
72 | + boxes = keras.backend.constant(boxes) | ||
73 | + | ||
74 | + classification = np.array([[ | ||
75 | + [0, 0.9], # this will be suppressed | ||
76 | + [0, 1], | ||
77 | + ]], dtype=keras.backend.floatx()) | ||
78 | + classification = keras.backend.constant(classification) | ||
79 | + | ||
80 | + other = [] | ||
81 | + other.append(np.array([[ | ||
82 | + [0, 1234], # this will be suppressed | ||
83 | + [0, 5678], | ||
84 | + ]], dtype=keras.backend.floatx())) | ||
85 | + other.append(np.array([[ | ||
86 | + 5678, # this will be suppressed | ||
87 | + 1234, | ||
88 | + ]], dtype=keras.backend.floatx())) | ||
89 | + other = [keras.backend.constant(o) for o in other] | ||
90 | + | ||
91 | + # compute output | ||
92 | + actual = filter_detections_layer.call([boxes, classification] + other) | ||
93 | + actual_boxes = keras.backend.eval(actual[0]) | ||
94 | + actual_scores = keras.backend.eval(actual[1]) | ||
95 | + actual_labels = keras.backend.eval(actual[2]) | ||
96 | + actual_other = [keras.backend.eval(a) for a in actual[3:]] | ||
97 | + | ||
98 | + # define expected output | ||
99 | + expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) | ||
100 | + expected_boxes[0, 0, :] = [0, 0, 10, 10] | ||
101 | + | ||
102 | + expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) | ||
103 | + expected_scores[0, 0] = 1 | ||
104 | + | ||
105 | + expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) | ||
106 | + expected_labels[0, 0] = 1 | ||
107 | + | ||
108 | + expected_other = [] | ||
109 | + expected_other.append(-1 * np.ones((1, 300, 2), dtype=keras.backend.floatx())) | ||
110 | + expected_other[-1][0, 0, :] = [0, 5678] | ||
111 | + expected_other.append(-1 * np.ones((1, 300), dtype=keras.backend.floatx())) | ||
112 | + expected_other[-1][0, 0] = 1234 | ||
113 | + | ||
114 | + # assert actual and expected are equal | ||
115 | + np.testing.assert_array_equal(actual_boxes, expected_boxes) | ||
116 | + np.testing.assert_array_equal(actual_scores, expected_scores) | ||
117 | + np.testing.assert_array_equal(actual_labels, expected_labels) | ||
118 | + | ||
119 | + for a, e in zip(actual_other, expected_other): | ||
120 | + np.testing.assert_array_equal(a, e) | ||
121 | + | ||
122 | + def test_mini_batch(self): | ||
123 | + # create simple FilterDetections layer | ||
124 | + filter_detections_layer = keras_retinanet.layers.FilterDetections() | ||
125 | + | ||
126 | + # create input with batch_size=2 | ||
127 | + boxes = np.array([ | ||
128 | + [ | ||
129 | + [0, 0, 10, 10], # this will be suppressed | ||
130 | + [0, 0, 10, 10], | ||
131 | + ], | ||
132 | + [ | ||
133 | + [100, 100, 150, 150], | ||
134 | + [100, 100, 150, 150], # this will be suppressed | ||
135 | + ], | ||
136 | + ], dtype=keras.backend.floatx()) | ||
137 | + boxes = keras.backend.constant(boxes) | ||
138 | + | ||
139 | + classification = np.array([ | ||
140 | + [ | ||
141 | + [0, 0.9], # this will be suppressed | ||
142 | + [0, 1], | ||
143 | + ], | ||
144 | + [ | ||
145 | + [1, 0], | ||
146 | + [0.9, 0], # this will be suppressed | ||
147 | + ], | ||
148 | + ], dtype=keras.backend.floatx()) | ||
149 | + classification = keras.backend.constant(classification) | ||
150 | + | ||
151 | + # compute output | ||
152 | + actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) | ||
153 | + actual_boxes = keras.backend.eval(actual_boxes) | ||
154 | + actual_scores = keras.backend.eval(actual_scores) | ||
155 | + actual_labels = keras.backend.eval(actual_labels) | ||
156 | + | ||
157 | + # define expected output | ||
158 | + expected_boxes = -1 * np.ones((2, 300, 4), dtype=keras.backend.floatx()) | ||
159 | + expected_boxes[0, 0, :] = [0, 0, 10, 10] | ||
160 | + expected_boxes[1, 0, :] = [100, 100, 150, 150] | ||
161 | + | ||
162 | + expected_scores = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) | ||
163 | + expected_scores[0, 0] = 1 | ||
164 | + expected_scores[1, 0] = 1 | ||
165 | + | ||
166 | + expected_labels = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) | ||
167 | + expected_labels[0, 0] = 1 | ||
168 | + expected_labels[1, 0] = 0 | ||
169 | + | ||
170 | + # assert actual and expected are equal | ||
171 | + np.testing.assert_array_equal(actual_boxes, expected_boxes) | ||
172 | + np.testing.assert_array_equal(actual_scores, expected_scores) | ||
173 | + np.testing.assert_array_equal(actual_labels, expected_labels) |
retinaNet/tests/layers/test_misc.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from tensorflow import keras | ||
18 | +import keras_retinanet.backend | ||
19 | +import keras_retinanet.layers | ||
20 | + | ||
21 | +import numpy as np | ||
22 | + | ||
23 | + | ||
24 | +class TestAnchors(object): | ||
25 | + def test_simple(self): | ||
26 | + # create simple Anchors layer | ||
27 | + anchors_layer = keras_retinanet.layers.Anchors( | ||
28 | + size=32, | ||
29 | + stride=8, | ||
30 | + ratios=np.array([1], keras.backend.floatx()), | ||
31 | + scales=np.array([1], keras.backend.floatx()), | ||
32 | + ) | ||
33 | + | ||
34 | + # create fake features input (only shape is used anyway) | ||
35 | + features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx()) | ||
36 | + features = keras.backend.variable(features) | ||
37 | + | ||
38 | + # call the Anchors layer | ||
39 | + anchors = anchors_layer.call(features) | ||
40 | + anchors = keras.backend.eval(anchors) | ||
41 | + | ||
42 | + # expected anchor values | ||
43 | + expected = np.array([[ | ||
44 | + [-12, -12, 20, 20], | ||
45 | + [-4 , -12, 28, 20], | ||
46 | + [-12, -4 , 20, 28], | ||
47 | + [-4 , -4 , 28, 28], | ||
48 | + ]], dtype=keras.backend.floatx()) | ||
49 | + | ||
50 | + # test anchor values | ||
51 | + np.testing.assert_array_equal(anchors, expected) | ||
52 | + | ||
53 | + # mark test to fail | ||
54 | + def test_mini_batch(self): | ||
55 | + # create simple Anchors layer | ||
56 | + anchors_layer = keras_retinanet.layers.Anchors( | ||
57 | + size=32, | ||
58 | + stride=8, | ||
59 | + ratios=np.array([1], dtype=keras.backend.floatx()), | ||
60 | + scales=np.array([1], dtype=keras.backend.floatx()), | ||
61 | + ) | ||
62 | + | ||
63 | + # create fake features input with batch_size=2 | ||
64 | + features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx()) | ||
65 | + features = keras.backend.variable(features) | ||
66 | + | ||
67 | + # call the Anchors layer | ||
68 | + anchors = anchors_layer.call(features) | ||
69 | + anchors = keras.backend.eval(anchors) | ||
70 | + | ||
71 | + # expected anchor values | ||
72 | + expected = np.array([[ | ||
73 | + [-12, -12, 20, 20], | ||
74 | + [-4 , -12, 28, 20], | ||
75 | + [-12, -4 , 20, 28], | ||
76 | + [-4 , -4 , 28, 28], | ||
77 | + ]], dtype=keras.backend.floatx()) | ||
78 | + expected = np.tile(expected, (2, 1, 1)) | ||
79 | + | ||
80 | + # test anchor values | ||
81 | + np.testing.assert_array_equal(anchors, expected) | ||
82 | + | ||
83 | + | ||
84 | +class TestUpsampleLike(object): | ||
85 | + def test_simple(self): | ||
86 | + # create simple UpsampleLike layer | ||
87 | + upsample_like_layer = keras_retinanet.layers.UpsampleLike() | ||
88 | + | ||
89 | + # create input source | ||
90 | + source = np.zeros((1, 2, 2, 1), dtype=keras.backend.floatx()) | ||
91 | + source = keras.backend.variable(source) | ||
92 | + target = np.zeros((1, 5, 5, 1), dtype=keras.backend.floatx()) | ||
93 | + expected = target | ||
94 | + target = keras.backend.variable(target) | ||
95 | + | ||
96 | + # compute output | ||
97 | + actual = upsample_like_layer.call([source, target]) | ||
98 | + actual = keras.backend.eval(actual) | ||
99 | + | ||
100 | + np.testing.assert_array_equal(actual, expected) | ||
101 | + | ||
102 | + def test_mini_batch(self): | ||
103 | + # create simple UpsampleLike layer | ||
104 | + upsample_like_layer = keras_retinanet.layers.UpsampleLike() | ||
105 | + | ||
106 | + # create input source | ||
107 | + source = np.zeros((2, 2, 2, 1), dtype=keras.backend.floatx()) | ||
108 | + source = keras.backend.variable(source) | ||
109 | + | ||
110 | + target = np.zeros((2, 5, 5, 1), dtype=keras.backend.floatx()) | ||
111 | + expected = target | ||
112 | + target = keras.backend.variable(target) | ||
113 | + | ||
114 | + # compute output | ||
115 | + actual = upsample_like_layer.call([source, target]) | ||
116 | + actual = keras.backend.eval(actual) | ||
117 | + | ||
118 | + np.testing.assert_array_equal(actual, expected) | ||
119 | + | ||
120 | + | ||
121 | +class TestRegressBoxes(object): | ||
122 | + def test_simple(self): | ||
123 | + mean = [0, 0, 0, 0] | ||
124 | + std = [0.2, 0.2, 0.2, 0.2] | ||
125 | + | ||
126 | + # create simple RegressBoxes layer | ||
127 | + regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) | ||
128 | + | ||
129 | + # create input | ||
130 | + anchors = np.array([[ | ||
131 | + [0 , 0 , 10 , 10 ], | ||
132 | + [50, 50, 100, 100], | ||
133 | + [20, 20, 40 , 40 ], | ||
134 | + ]], dtype=keras.backend.floatx()) | ||
135 | + anchors = keras.backend.variable(anchors) | ||
136 | + | ||
137 | + regression = np.array([[ | ||
138 | + [0 , 0 , 0 , 0 ], | ||
139 | + [0.1, 0.1, 0 , 0 ], | ||
140 | + [0 , 0 , 0.1, 0.1], | ||
141 | + ]], dtype=keras.backend.floatx()) | ||
142 | + regression = keras.backend.variable(regression) | ||
143 | + | ||
144 | + # compute output | ||
145 | + actual = regress_boxes_layer.call([anchors, regression]) | ||
146 | + actual = keras.backend.eval(actual) | ||
147 | + | ||
148 | + # compute expected output | ||
149 | + expected = np.array([[ | ||
150 | + [0 , 0 , 10 , 10 ], | ||
151 | + [51, 51, 100 , 100 ], | ||
152 | + [20, 20, 40.4, 40.4], | ||
153 | + ]], dtype=keras.backend.floatx()) | ||
154 | + | ||
155 | + np.testing.assert_array_almost_equal(actual, expected, decimal=2) | ||
156 | + | ||
157 | + # mark test to fail | ||
158 | + def test_mini_batch(self): | ||
159 | + mean = [0, 0, 0, 0] | ||
160 | + std = [0.2, 0.2, 0.2, 0.2] | ||
161 | + | ||
162 | + # create simple RegressBoxes layer | ||
163 | + regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) | ||
164 | + | ||
165 | + # create input | ||
166 | + anchors = np.array([ | ||
167 | + [ | ||
168 | + [0 , 0 , 10 , 10 ], # 1 | ||
169 | + [50, 50, 100, 100], # 2 | ||
170 | + [20, 20, 40 , 40 ], # 3 | ||
171 | + ], | ||
172 | + [ | ||
173 | + [20, 20, 40 , 40 ], # 3 | ||
174 | + [0 , 0 , 10 , 10 ], # 1 | ||
175 | + [50, 50, 100, 100], # 2 | ||
176 | + ], | ||
177 | + ], dtype=keras.backend.floatx()) | ||
178 | + anchors = keras.backend.variable(anchors) | ||
179 | + | ||
180 | + regression = np.array([ | ||
181 | + [ | ||
182 | + [0 , 0 , 0 , 0 ], # 1 | ||
183 | + [0.1, 0.1, 0 , 0 ], # 2 | ||
184 | + [0 , 0 , 0.1, 0.1], # 3 | ||
185 | + ], | ||
186 | + [ | ||
187 | + [0 , 0 , 0.1, 0.1], # 3 | ||
188 | + [0 , 0 , 0 , 0 ], # 1 | ||
189 | + [0.1, 0.1, 0 , 0 ], # 2 | ||
190 | + ], | ||
191 | + ], dtype=keras.backend.floatx()) | ||
192 | + regression = keras.backend.variable(regression) | ||
193 | + | ||
194 | + # compute output | ||
195 | + actual = regress_boxes_layer.call([anchors, regression]) | ||
196 | + actual = keras.backend.eval(actual) | ||
197 | + | ||
198 | + # compute expected output | ||
199 | + expected = np.array([ | ||
200 | + [ | ||
201 | + [0 , 0 , 10 , 10 ], # 1 | ||
202 | + [51, 51, 100 , 100 ], # 2 | ||
203 | + [20, 20, 40.4, 40.4], # 3 | ||
204 | + ], | ||
205 | + [ | ||
206 | + [20, 20, 40.4, 40.4], # 3 | ||
207 | + [0 , 0 , 10 , 10 ], # 1 | ||
208 | + [51, 51, 100 , 100 ], # 2 | ||
209 | + ], | ||
210 | + ], dtype=keras.backend.floatx()) | ||
211 | + | ||
212 | + np.testing.assert_array_almost_equal(actual, expected, decimal=2) |
retinaNet/tests/models/__init__.py
0 → 100644
File mode changed
retinaNet/tests/models/test_densenet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2018 vidosits (https://github.com/vidosits/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import warnings | ||
18 | +import pytest | ||
19 | +import numpy as np | ||
20 | +from tensorflow import keras | ||
21 | +from keras_retinanet import losses | ||
22 | +from keras_retinanet.models.densenet import DenseNetBackbone | ||
23 | + | ||
24 | +parameters = ['densenet121'] | ||
25 | + | ||
26 | + | ||
27 | +@pytest.mark.parametrize("backbone", parameters) | ||
28 | +def test_backbone(backbone): | ||
29 | + # ignore warnings in this test | ||
30 | + warnings.simplefilter('ignore') | ||
31 | + | ||
32 | + num_classes = 10 | ||
33 | + | ||
34 | + inputs = np.zeros((1, 200, 400, 3), dtype=np.float32) | ||
35 | + targets = [np.zeros((1, 14814, 5), dtype=np.float32), np.zeros((1, 14814, num_classes + 1))] | ||
36 | + | ||
37 | + inp = keras.layers.Input(inputs[0].shape) | ||
38 | + | ||
39 | + densenet_backbone = DenseNetBackbone(backbone) | ||
40 | + model = densenet_backbone.retinanet(num_classes=num_classes, inputs=inp) | ||
41 | + model.summary() | ||
42 | + | ||
43 | + # compile model | ||
44 | + model.compile( | ||
45 | + loss={ | ||
46 | + 'regression': losses.smooth_l1(), | ||
47 | + 'classification': losses.focal() | ||
48 | + }, | ||
49 | + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001)) | ||
50 | + | ||
51 | + model.fit(inputs, targets, batch_size=1) |
retinaNet/tests/models/test_mobilenet.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import warnings | ||
18 | +import pytest | ||
19 | +import numpy as np | ||
20 | +from tensorflow import keras | ||
21 | +from keras_retinanet import losses | ||
22 | +from keras_retinanet.models.mobilenet import MobileNetBackbone | ||
23 | + | ||
24 | + | ||
25 | +alphas = ['1.0'] | ||
26 | +parameters = [] | ||
27 | + | ||
28 | +for backbone in MobileNetBackbone.allowed_backbones: | ||
29 | + for alpha in alphas: | ||
30 | + parameters.append((backbone, alpha)) | ||
31 | + | ||
32 | + | ||
33 | +@pytest.mark.parametrize("backbone, alpha", parameters) | ||
34 | +def test_backbone(backbone, alpha): | ||
35 | + # ignore warnings in this test | ||
36 | + warnings.simplefilter('ignore') | ||
37 | + | ||
38 | + num_classes = 10 | ||
39 | + | ||
40 | + inputs = np.zeros((1, 1024, 363, 3), dtype=np.float32) | ||
41 | + targets = [np.zeros((1, 68760, 5), dtype=np.float32), np.zeros((1, 68760, num_classes + 1))] | ||
42 | + | ||
43 | + inp = keras.layers.Input(inputs[0].shape) | ||
44 | + | ||
45 | + mobilenet_backbone = MobileNetBackbone(backbone='{}_{}'.format(backbone, format(alpha))) | ||
46 | + training_model = mobilenet_backbone.retinanet(num_classes=num_classes, inputs=inp) | ||
47 | + training_model.summary() | ||
48 | + | ||
49 | + # compile model | ||
50 | + training_model.compile( | ||
51 | + loss={ | ||
52 | + 'regression': losses.smooth_l1(), | ||
53 | + 'classification': losses.focal() | ||
54 | + }, | ||
55 | + optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001)) | ||
56 | + | ||
57 | + training_model.fit(inputs, targets, batch_size=1) |
retinaNet/tests/preprocessing/__init__.py
0 → 100644
File mode changed
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +import csv | ||
18 | +import pytest | ||
19 | +try: | ||
20 | + from io import StringIO | ||
21 | +except ImportError: | ||
22 | + from stringio import StringIO | ||
23 | + | ||
24 | +from keras_retinanet.preprocessing import csv_generator | ||
25 | + | ||
26 | + | ||
27 | +def csv_str(string): | ||
28 | + if str == bytes: | ||
29 | + string = string.decode('utf-8') | ||
30 | + return csv.reader(StringIO(string)) | ||
31 | + | ||
32 | + | ||
33 | +def annotation(x1, y1, x2, y2, class_name): | ||
34 | + return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'class': class_name} | ||
35 | + | ||
36 | + | ||
37 | +def test_read_classes(): | ||
38 | + assert csv_generator._read_classes(csv_str('')) == {} | ||
39 | + assert csv_generator._read_classes(csv_str('a,1')) == {'a': 1} | ||
40 | + assert csv_generator._read_classes(csv_str('a,1\nb,2')) == {'a': 1, 'b': 2} | ||
41 | + | ||
42 | + | ||
43 | +def test_read_classes_wrong_format(): | ||
44 | + with pytest.raises(ValueError): | ||
45 | + try: | ||
46 | + csv_generator._read_classes(csv_str('a,b,c')) | ||
47 | + except ValueError as e: | ||
48 | + assert str(e).startswith('line 1: format should be') | ||
49 | + raise | ||
50 | + with pytest.raises(ValueError): | ||
51 | + try: | ||
52 | + csv_generator._read_classes(csv_str('a,1\nb,c,d')) | ||
53 | + except ValueError as e: | ||
54 | + assert str(e).startswith('line 2: format should be') | ||
55 | + raise | ||
56 | + | ||
57 | + | ||
58 | +def test_read_classes_malformed_class_id(): | ||
59 | + with pytest.raises(ValueError): | ||
60 | + try: | ||
61 | + csv_generator._read_classes(csv_str('a,b')) | ||
62 | + except ValueError as e: | ||
63 | + assert str(e).startswith("line 1: malformed class ID:") | ||
64 | + raise | ||
65 | + | ||
66 | + with pytest.raises(ValueError): | ||
67 | + try: | ||
68 | + csv_generator._read_classes(csv_str('a,1\nb,c')) | ||
69 | + except ValueError as e: | ||
70 | + assert str(e).startswith('line 2: malformed class ID:') | ||
71 | + raise | ||
72 | + | ||
73 | + | ||
74 | +def test_read_classes_duplicate_name(): | ||
75 | + with pytest.raises(ValueError): | ||
76 | + try: | ||
77 | + csv_generator._read_classes(csv_str('a,1\nb,2\na,3')) | ||
78 | + except ValueError as e: | ||
79 | + assert str(e).startswith('line 3: duplicate class name') | ||
80 | + raise | ||
81 | + | ||
82 | + | ||
83 | +def test_read_annotations(): | ||
84 | + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} | ||
85 | + annotations = csv_generator._read_annotations(csv_str( | ||
86 | + 'a.png,0,1,2,3,a' '\n' | ||
87 | + 'b.png,4,5,6,7,b' '\n' | ||
88 | + 'c.png,8,9,10,11,c' '\n' | ||
89 | + 'd.png,12,13,14,15,d' '\n' | ||
90 | + ), classes) | ||
91 | + assert annotations == { | ||
92 | + 'a.png': [annotation( 0, 1, 2, 3, 'a')], | ||
93 | + 'b.png': [annotation( 4, 5, 6, 7, 'b')], | ||
94 | + 'c.png': [annotation( 8, 9, 10, 11, 'c')], | ||
95 | + 'd.png': [annotation(12, 13, 14, 15, 'd')], | ||
96 | + } | ||
97 | + | ||
98 | + | ||
99 | +def test_read_annotations_multiple(): | ||
100 | + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} | ||
101 | + annotations = csv_generator._read_annotations(csv_str( | ||
102 | + 'a.png,0,1,2,3,a' '\n' | ||
103 | + 'b.png,4,5,6,7,b' '\n' | ||
104 | + 'a.png,8,9,10,11,c' '\n' | ||
105 | + ), classes) | ||
106 | + assert annotations == { | ||
107 | + 'a.png': [ | ||
108 | + annotation(0, 1, 2, 3, 'a'), | ||
109 | + annotation(8, 9, 10, 11, 'c'), | ||
110 | + ], | ||
111 | + 'b.png': [annotation(4, 5, 6, 7, 'b')], | ||
112 | + } | ||
113 | + | ||
114 | + | ||
115 | +def test_read_annotations_wrong_format(): | ||
116 | + classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} | ||
117 | + with pytest.raises(ValueError): | ||
118 | + try: | ||
119 | + csv_generator._read_annotations(csv_str('a.png,1,2,3,a'), classes) | ||
120 | + except ValueError as e: | ||
121 | + assert str(e).startswith("line 1: format should be") | ||
122 | + raise | ||
123 | + | ||
124 | + with pytest.raises(ValueError): | ||
125 | + try: | ||
126 | + csv_generator._read_annotations(csv_str( | ||
127 | + 'a.png,0,1,2,3,a' '\n' | ||
128 | + 'a.png,1,2,3,a' '\n' | ||
129 | + ), classes) | ||
130 | + except ValueError as e: | ||
131 | + assert str(e).startswith("line 2: format should be") | ||
132 | + raise | ||
133 | + | ||
134 | + | ||
135 | +def test_read_annotations_wrong_x1(): | ||
136 | + with pytest.raises(ValueError): | ||
137 | + try: | ||
138 | + csv_generator._read_annotations(csv_str('a.png,a,0,1,2,a'), {'a': 1}) | ||
139 | + except ValueError as e: | ||
140 | + assert str(e).startswith("line 1: malformed x1:") | ||
141 | + raise | ||
142 | + | ||
143 | + | ||
144 | +def test_read_annotations_wrong_y1(): | ||
145 | + with pytest.raises(ValueError): | ||
146 | + try: | ||
147 | + csv_generator._read_annotations(csv_str('a.png,0,a,1,2,a'), {'a': 1}) | ||
148 | + except ValueError as e: | ||
149 | + assert str(e).startswith("line 1: malformed y1:") | ||
150 | + raise | ||
151 | + | ||
152 | + | ||
153 | +def test_read_annotations_wrong_x2(): | ||
154 | + with pytest.raises(ValueError): | ||
155 | + try: | ||
156 | + csv_generator._read_annotations(csv_str('a.png,0,1,a,2,a'), {'a': 1}) | ||
157 | + except ValueError as e: | ||
158 | + assert str(e).startswith("line 1: malformed x2:") | ||
159 | + raise | ||
160 | + | ||
161 | + | ||
162 | +def test_read_annotations_wrong_y2(): | ||
163 | + with pytest.raises(ValueError): | ||
164 | + try: | ||
165 | + csv_generator._read_annotations(csv_str('a.png,0,1,2,a,a'), {'a': 1}) | ||
166 | + except ValueError as e: | ||
167 | + assert str(e).startswith("line 1: malformed y2:") | ||
168 | + raise | ||
169 | + | ||
170 | + | ||
171 | +def test_read_annotations_wrong_class(): | ||
172 | + with pytest.raises(ValueError): | ||
173 | + try: | ||
174 | + csv_generator._read_annotations(csv_str('a.png,0,1,2,3,g'), {'a': 1}) | ||
175 | + except ValueError as e: | ||
176 | + assert str(e).startswith("line 1: unknown class name:") | ||
177 | + raise | ||
178 | + | ||
179 | + | ||
180 | +def test_read_annotations_invalid_bb_x(): | ||
181 | + with pytest.raises(ValueError): | ||
182 | + try: | ||
183 | + csv_generator._read_annotations(csv_str('a.png,1,2,1,3,g'), {'a': 1}) | ||
184 | + except ValueError as e: | ||
185 | + assert str(e).startswith("line 1: x2 (1) must be higher than x1 (1)") | ||
186 | + raise | ||
187 | + with pytest.raises(ValueError): | ||
188 | + try: | ||
189 | + csv_generator._read_annotations(csv_str('a.png,9,2,5,3,g'), {'a': 1}) | ||
190 | + except ValueError as e: | ||
191 | + assert str(e).startswith("line 1: x2 (5) must be higher than x1 (9)") | ||
192 | + raise | ||
193 | + | ||
194 | + | ||
195 | +def test_read_annotations_invalid_bb_y(): | ||
196 | + with pytest.raises(ValueError): | ||
197 | + try: | ||
198 | + csv_generator._read_annotations(csv_str('a.png,1,2,3,2,a'), {'a': 1}) | ||
199 | + except ValueError as e: | ||
200 | + assert str(e).startswith("line 1: y2 (2) must be higher than y1 (2)") | ||
201 | + raise | ||
202 | + with pytest.raises(ValueError): | ||
203 | + try: | ||
204 | + csv_generator._read_annotations(csv_str('a.png,1,8,3,5,a'), {'a': 1}) | ||
205 | + except ValueError as e: | ||
206 | + assert str(e).startswith("line 1: y2 (5) must be higher than y1 (8)") | ||
207 | + raise | ||
208 | + | ||
209 | + | ||
210 | +def test_read_annotations_empty_image(): | ||
211 | + # Check that images without annotations are parsed. | ||
212 | + assert csv_generator._read_annotations(csv_str('a.png,,,,,\nb.png,,,,,'), {'a': 1}) == {'a.png': [], 'b.png': []} | ||
213 | + | ||
214 | + # Check that lines without annotations don't clear earlier annotations. | ||
215 | + assert csv_generator._read_annotations(csv_str('a.png,0,1,2,3,a\na.png,,,,,'), {'a': 1}) == {'a.png': [annotation(0, 1, 2, 3, 'a')]} |
1 | +""" | ||
2 | +Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
3 | + | ||
4 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +you may not use this file except in compliance with the License. | ||
6 | +You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | +Unless required by applicable law or agreed to in writing, software | ||
11 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +See the License for the specific language governing permissions and | ||
14 | +limitations under the License. | ||
15 | +""" | ||
16 | + | ||
17 | +from keras_retinanet.preprocessing.generator import Generator | ||
18 | + | ||
19 | +import numpy as np | ||
20 | +import pytest | ||
21 | + | ||
22 | + | ||
23 | +class SimpleGenerator(Generator): | ||
24 | + def __init__(self, bboxes, labels, num_classes=0, image=None): | ||
25 | + assert(len(bboxes) == len(labels)) | ||
26 | + self.bboxes = bboxes | ||
27 | + self.labels = labels | ||
28 | + self.num_classes_ = num_classes | ||
29 | + self.image = image | ||
30 | + super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False) | ||
31 | + | ||
32 | + def num_classes(self): | ||
33 | + return self.num_classes_ | ||
34 | + | ||
35 | + def load_image(self, image_index): | ||
36 | + return self.image | ||
37 | + | ||
38 | + def image_path(self, image_index): | ||
39 | + return '' | ||
40 | + | ||
41 | + def size(self): | ||
42 | + return len(self.bboxes) | ||
43 | + | ||
44 | + def load_annotations(self, image_index): | ||
45 | + annotations = {'labels': self.labels[image_index], 'bboxes': self.bboxes[image_index]} | ||
46 | + return annotations | ||
47 | + | ||
48 | + | ||
49 | +class TestLoadAnnotationsGroup(object): | ||
50 | + def test_simple(self): | ||
51 | + input_bboxes_group = [ | ||
52 | + np.array([ | ||
53 | + [ 0, 0, 10, 10], | ||
54 | + [150, 150, 350, 350] | ||
55 | + ]), | ||
56 | + ] | ||
57 | + input_labels_group = [ | ||
58 | + np.array([ | ||
59 | + 1, | ||
60 | + 3 | ||
61 | + ]), | ||
62 | + ] | ||
63 | + expected_bboxes_group = input_bboxes_group | ||
64 | + expected_labels_group = input_labels_group | ||
65 | + | ||
66 | + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) | ||
67 | + annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) | ||
68 | + | ||
69 | + assert('bboxes' in annotations[0]) | ||
70 | + assert('labels' in annotations[0]) | ||
71 | + np.testing.assert_equal(expected_bboxes_group[0], annotations[0]['bboxes']) | ||
72 | + np.testing.assert_equal(expected_labels_group[0], annotations[0]['labels']) | ||
73 | + | ||
74 | + def test_multiple(self): | ||
75 | + input_bboxes_group = [ | ||
76 | + np.array([ | ||
77 | + [ 0, 0, 10, 10], | ||
78 | + [150, 150, 350, 350] | ||
79 | + ]), | ||
80 | + np.array([ | ||
81 | + [0, 0, 50, 50], | ||
82 | + ]), | ||
83 | + ] | ||
84 | + input_labels_group = [ | ||
85 | + np.array([ | ||
86 | + 1, | ||
87 | + 0 | ||
88 | + ]), | ||
89 | + np.array([ | ||
90 | + 3 | ||
91 | + ]) | ||
92 | + ] | ||
93 | + expected_bboxes_group = input_bboxes_group | ||
94 | + expected_labels_group = input_labels_group | ||
95 | + | ||
96 | + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) | ||
97 | + annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) | ||
98 | + annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) | ||
99 | + | ||
100 | + assert('bboxes' in annotations_group_0[0]) | ||
101 | + assert('bboxes' in annotations_group_1[0]) | ||
102 | + assert('labels' in annotations_group_0[0]) | ||
103 | + assert('labels' in annotations_group_1[0]) | ||
104 | + np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) | ||
105 | + np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) | ||
106 | + np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) | ||
107 | + np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) | ||
108 | + | ||
109 | + | ||
110 | +class TestFilterAnnotations(object): | ||
111 | + def test_simple_filter(self): | ||
112 | + input_bboxes_group = [ | ||
113 | + np.array([ | ||
114 | + [ 0, 0, 10, 10], | ||
115 | + [150, 150, 50, 50] | ||
116 | + ]), | ||
117 | + ] | ||
118 | + input_labels_group = [ | ||
119 | + np.array([ | ||
120 | + 3, | ||
121 | + 1 | ||
122 | + ]), | ||
123 | + ] | ||
124 | + | ||
125 | + input_image = np.zeros((500, 500, 3)) | ||
126 | + | ||
127 | + expected_bboxes_group = [ | ||
128 | + np.array([ | ||
129 | + [0, 0, 10, 10], | ||
130 | + ]), | ||
131 | + ] | ||
132 | + expected_labels_group = [ | ||
133 | + np.array([ | ||
134 | + 3, | ||
135 | + ]), | ||
136 | + ] | ||
137 | + | ||
138 | + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) | ||
139 | + annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) | ||
140 | + # expect a UserWarning | ||
141 | + with pytest.warns(UserWarning): | ||
142 | + image_group, annotations_group = simple_generator.filter_annotations([input_image], annotations, simple_generator.groups[0]) | ||
143 | + | ||
144 | + np.testing.assert_equal(expected_bboxes_group[0], annotations_group[0]['bboxes']) | ||
145 | + np.testing.assert_equal(expected_labels_group[0], annotations_group[0]['labels']) | ||
146 | + | ||
147 | + def test_multiple_filter(self): | ||
148 | + input_bboxes_group = [ | ||
149 | + np.array([ | ||
150 | + [ 0, 0, 10, 10], | ||
151 | + [150, 150, 50, 50], | ||
152 | + [150, 150, 350, 350], | ||
153 | + [350, 350, 150, 150], | ||
154 | + [ 1, 1, 2, 2], | ||
155 | + [ 2, 2, 1, 1] | ||
156 | + ]), | ||
157 | + np.array([ | ||
158 | + [0, 0, -1, -1] | ||
159 | + ]), | ||
160 | + np.array([ | ||
161 | + [-10, -10, 0, 0], | ||
162 | + [-10, -10, -100, -100], | ||
163 | + [ 10, 10, 100, 100] | ||
164 | + ]), | ||
165 | + np.array([ | ||
166 | + [ 10, 10, 100, 100], | ||
167 | + [ 10, 10, 600, 600] | ||
168 | + ]), | ||
169 | + ] | ||
170 | + | ||
171 | + input_labels_group = [ | ||
172 | + np.array([ | ||
173 | + 6, | ||
174 | + 5, | ||
175 | + 4, | ||
176 | + 3, | ||
177 | + 2, | ||
178 | + 1 | ||
179 | + ]), | ||
180 | + np.array([ | ||
181 | + 0 | ||
182 | + ]), | ||
183 | + np.array([ | ||
184 | + 10, | ||
185 | + 11, | ||
186 | + 12 | ||
187 | + ]), | ||
188 | + np.array([ | ||
189 | + 105, | ||
190 | + 107 | ||
191 | + ]), | ||
192 | + ] | ||
193 | + | ||
194 | + input_image = np.zeros((500, 500, 3)) | ||
195 | + | ||
196 | + expected_bboxes_group = [ | ||
197 | + np.array([ | ||
198 | + [ 0, 0, 10, 10], | ||
199 | + [150, 150, 350, 350], | ||
200 | + [ 1, 1, 2, 2] | ||
201 | + ]), | ||
202 | + np.zeros((0, 4)), | ||
203 | + np.array([ | ||
204 | + [10, 10, 100, 100] | ||
205 | + ]), | ||
206 | + np.array([ | ||
207 | + [ 10, 10, 100, 100] | ||
208 | + ]), | ||
209 | + ] | ||
210 | + expected_labels_group = [ | ||
211 | + np.array([ | ||
212 | + 6, | ||
213 | + 4, | ||
214 | + 2 | ||
215 | + ]), | ||
216 | + np.zeros((0,)), | ||
217 | + np.array([ | ||
218 | + 12 | ||
219 | + ]), | ||
220 | + np.array([ | ||
221 | + 105 | ||
222 | + ]), | ||
223 | + ] | ||
224 | + | ||
225 | + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) | ||
226 | + # expect a UserWarning | ||
227 | + annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) | ||
228 | + with pytest.warns(UserWarning): | ||
229 | + image_group, annotations_group_0 = simple_generator.filter_annotations([input_image], annotations_group_0, simple_generator.groups[0]) | ||
230 | + | ||
231 | + annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) | ||
232 | + with pytest.warns(UserWarning): | ||
233 | + image_group, annotations_group_1 = simple_generator.filter_annotations([input_image], annotations_group_1, simple_generator.groups[1]) | ||
234 | + | ||
235 | + annotations_group_2 = simple_generator.load_annotations_group(simple_generator.groups[2]) | ||
236 | + with pytest.warns(UserWarning): | ||
237 | + image_group, annotations_group_2 = simple_generator.filter_annotations([input_image], annotations_group_2, simple_generator.groups[2]) | ||
238 | + | ||
239 | + np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) | ||
240 | + np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) | ||
241 | + | ||
242 | + np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) | ||
243 | + np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) | ||
244 | + | ||
245 | + np.testing.assert_equal(expected_bboxes_group[2], annotations_group_2[0]['bboxes']) | ||
246 | + np.testing.assert_equal(expected_labels_group[2], annotations_group_2[0]['labels']) | ||
247 | + | ||
248 | + def test_complete(self): | ||
249 | + input_bboxes_group = [ | ||
250 | + np.array([ | ||
251 | + [ 0, 0, 50, 50], | ||
252 | + [150, 150, 50, 50], # invalid bbox | ||
253 | + ], dtype=float) | ||
254 | + ] | ||
255 | + | ||
256 | + input_labels_group = [ | ||
257 | + np.array([ | ||
258 | + 5, # one object of class 5 | ||
259 | + 3, # one object of class 3 with an invalid box | ||
260 | + ], dtype=float) | ||
261 | + ] | ||
262 | + | ||
263 | + input_image = np.zeros((500, 500, 3), dtype=np.uint8) | ||
264 | + | ||
265 | + simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6) | ||
266 | + # expect a UserWarning | ||
267 | + with pytest.warns(UserWarning): | ||
268 | + _, [_, labels_batch] = simple_generator[0] | ||
269 | + | ||
270 | + # test that only object with class 5 is present in labels_batch | ||
271 | + labels = np.unique(np.argmax(labels_batch == 5, axis=2)) | ||
272 | + assert(len(labels) == 1 and labels[0] == 0), 'Expected only class 0 to be present, but got classes {}'.format(labels) |
retinaNet/tests/preprocessing/test_image.py
0 → 100644
1 | +import os | ||
2 | +import pytest | ||
3 | +from PIL import Image | ||
4 | +from keras_retinanet.utils import image | ||
5 | +import numpy as np | ||
6 | + | ||
7 | +_STUB_IMG_FNAME = 'stub-image.jpg' | ||
8 | + | ||
9 | + | ||
10 | +@pytest.fixture(autouse=True) | ||
11 | +def run_around_tests(tmp_path): | ||
12 | + """Create a temp image for test""" | ||
13 | + rand_img = np.random.randint(0, 255, (3, 3, 3), dtype='uint8') | ||
14 | + Image.fromarray(rand_img).save(os.path.join(tmp_path, _STUB_IMG_FNAME)) | ||
15 | + yield | ||
16 | + | ||
17 | + | ||
18 | +def test_read_image_bgr(tmp_path): | ||
19 | + stub_image_path = os.path.join(tmp_path, _STUB_IMG_FNAME) | ||
20 | + | ||
21 | + original_img = np.asarray(Image.open( | ||
22 | + stub_image_path).convert('RGB'))[:, :, ::-1] | ||
23 | + loaded_image = image.read_image_bgr(stub_image_path) | ||
24 | + | ||
25 | + # Assert images are equal | ||
26 | + np.testing.assert_array_equal(original_img, loaded_image) |
retinaNet/tests/requirements.txt
0 → 100644
retinaNet/tests/test_losses.py
0 → 100644
1 | +import keras_retinanet.losses | ||
2 | +from tensorflow import keras | ||
3 | + | ||
4 | +import numpy as np | ||
5 | + | ||
6 | +import pytest | ||
7 | + | ||
8 | + | ||
9 | +def test_smooth_l1(): | ||
10 | + regression = np.array([ | ||
11 | + [ | ||
12 | + [0, 0, 0, 0], | ||
13 | + [0, 0, 0, 0], | ||
14 | + [0, 0, 0, 0], | ||
15 | + [0, 0, 0, 0], | ||
16 | + ] | ||
17 | + ], dtype=keras.backend.floatx()) | ||
18 | + regression = keras.backend.variable(regression) | ||
19 | + | ||
20 | + regression_target = np.array([ | ||
21 | + [ | ||
22 | + [0, 0, 0, 1, 1], | ||
23 | + [0, 0, 1, 0, 1], | ||
24 | + [0, 0, 0.05, 0, 1], | ||
25 | + [0, 0, 1, 0, 0], | ||
26 | + ] | ||
27 | + ], dtype=keras.backend.floatx()) | ||
28 | + regression_target = keras.backend.variable(regression_target) | ||
29 | + | ||
30 | + loss = keras_retinanet.losses.smooth_l1()(regression_target, regression) | ||
31 | + loss = keras.backend.eval(loss) | ||
32 | + | ||
33 | + assert loss == pytest.approx((((1 - 0.5 / 9) * 2 + (0.5 * 9 * 0.05 ** 2)) / 3)) |
retinaNet/tests/utils/__init__.py
0 → 100644
File mode changed
retinaNet/tests/utils/test_anchors.py
0 → 100644
1 | +import numpy as np | ||
2 | +import configparser | ||
3 | +from tensorflow import keras | ||
4 | + | ||
5 | +from keras_retinanet.utils.anchors import anchors_for_shape, AnchorParameters | ||
6 | +from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters | ||
7 | + | ||
8 | + | ||
9 | +def test_config_read(): | ||
10 | + config = read_config_file('tests/test-data/config/config.ini') | ||
11 | + assert 'anchor_parameters' in config | ||
12 | + assert 'sizes' in config['anchor_parameters'] | ||
13 | + assert 'strides' in config['anchor_parameters'] | ||
14 | + assert 'ratios' in config['anchor_parameters'] | ||
15 | + assert 'scales' in config['anchor_parameters'] | ||
16 | + assert config['anchor_parameters']['sizes'] == '32 64 128 256 512' | ||
17 | + assert config['anchor_parameters']['strides'] == '8 16 32 64 128' | ||
18 | + assert config['anchor_parameters']['ratios'] == '0.5 1 2 3' | ||
19 | + assert config['anchor_parameters']['scales'] == '1 1.2 1.6' | ||
20 | + | ||
21 | + | ||
22 | +def create_anchor_params_config(): | ||
23 | + config = configparser.ConfigParser() | ||
24 | + config['anchor_parameters'] = {} | ||
25 | + config['anchor_parameters']['sizes'] = '32 64 128 256 512' | ||
26 | + config['anchor_parameters']['strides'] = '8 16 32 64 128' | ||
27 | + config['anchor_parameters']['ratios'] = '0.5 1' | ||
28 | + config['anchor_parameters']['scales'] = '1 1.2 1.6' | ||
29 | + | ||
30 | + return config | ||
31 | + | ||
32 | + | ||
33 | +def test_parse_anchor_parameters(): | ||
34 | + config = create_anchor_params_config() | ||
35 | + anchor_params_parsed = parse_anchor_parameters(config) | ||
36 | + | ||
37 | + sizes = [32, 64, 128, 256, 512] | ||
38 | + strides = [8, 16, 32, 64, 128] | ||
39 | + ratios = np.array([0.5, 1], keras.backend.floatx()) | ||
40 | + scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) | ||
41 | + | ||
42 | + assert sizes == anchor_params_parsed.sizes | ||
43 | + assert strides == anchor_params_parsed.strides | ||
44 | + np.testing.assert_equal(ratios, anchor_params_parsed.ratios) | ||
45 | + np.testing.assert_equal(scales, anchor_params_parsed.scales) | ||
46 | + | ||
47 | + | ||
48 | +def test_anchors_for_shape_dimensions(): | ||
49 | + sizes = [32, 64, 128] | ||
50 | + strides = [8, 16, 32] | ||
51 | + ratios = np.array([0.5, 1, 2, 3], keras.backend.floatx()) | ||
52 | + scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) | ||
53 | + anchor_params = AnchorParameters(sizes, strides, ratios, scales) | ||
54 | + | ||
55 | + pyramid_levels = [3, 4, 5] | ||
56 | + image_shape = (64, 64) | ||
57 | + all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) | ||
58 | + | ||
59 | + assert all_anchors.shape == (1008, 4) | ||
60 | + | ||
61 | + | ||
62 | +def test_anchors_for_shape_values(): | ||
63 | + sizes = [12] | ||
64 | + strides = [8] | ||
65 | + ratios = np.array([1, 2], keras.backend.floatx()) | ||
66 | + scales = np.array([1, 2], keras.backend.floatx()) | ||
67 | + anchor_params = AnchorParameters(sizes, strides, ratios, scales) | ||
68 | + | ||
69 | + pyramid_levels = [3] | ||
70 | + image_shape = (16, 16) | ||
71 | + all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) | ||
72 | + | ||
73 | + # using almost_equal for floating point imprecisions | ||
74 | + np.testing.assert_almost_equal(all_anchors[0, :], [ | ||
75 | + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
76 | + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
77 | + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
78 | + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
79 | + ], decimal=6) | ||
80 | + np.testing.assert_almost_equal(all_anchors[1, :], [ | ||
81 | + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
82 | + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
83 | + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
84 | + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
85 | + ], decimal=6) | ||
86 | + np.testing.assert_almost_equal(all_anchors[2, :], [ | ||
87 | + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
88 | + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
89 | + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
90 | + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
91 | + ], decimal=6) | ||
92 | + np.testing.assert_almost_equal(all_anchors[3, :], [ | ||
93 | + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
94 | + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
95 | + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
96 | + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
97 | + ], decimal=6) | ||
98 | + np.testing.assert_almost_equal(all_anchors[4, :], [ | ||
99 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
100 | + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
101 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
102 | + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
103 | + ], decimal=6) | ||
104 | + np.testing.assert_almost_equal(all_anchors[5, :], [ | ||
105 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
106 | + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
107 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
108 | + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
109 | + ], decimal=6) | ||
110 | + np.testing.assert_almost_equal(all_anchors[6, :], [ | ||
111 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
112 | + strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
113 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
114 | + strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
115 | + ], decimal=6) | ||
116 | + np.testing.assert_almost_equal(all_anchors[7, :], [ | ||
117 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
118 | + strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
119 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
120 | + strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
121 | + ], decimal=6) | ||
122 | + np.testing.assert_almost_equal(all_anchors[8, :], [ | ||
123 | + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
124 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
125 | + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
126 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
127 | + ], decimal=6) | ||
128 | + np.testing.assert_almost_equal(all_anchors[9, :], [ | ||
129 | + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
130 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
131 | + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
132 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
133 | + ], decimal=6) | ||
134 | + np.testing.assert_almost_equal(all_anchors[10, :], [ | ||
135 | + strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
136 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
137 | + strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
138 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
139 | + ], decimal=6) | ||
140 | + np.testing.assert_almost_equal(all_anchors[11, :], [ | ||
141 | + strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
142 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
143 | + strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
144 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
145 | + ], decimal=6) | ||
146 | + np.testing.assert_almost_equal(all_anchors[12, :], [ | ||
147 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
148 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
149 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, | ||
150 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, | ||
151 | + ], decimal=6) | ||
152 | + np.testing.assert_almost_equal(all_anchors[13, :], [ | ||
153 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
154 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
155 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, | ||
156 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, | ||
157 | + ], decimal=6) | ||
158 | + np.testing.assert_almost_equal(all_anchors[14, :], [ | ||
159 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
160 | + strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
161 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, | ||
162 | + strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, | ||
163 | + ], decimal=6) | ||
164 | + np.testing.assert_almost_equal(all_anchors[15, :], [ | ||
165 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
166 | + strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
167 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, | ||
168 | + strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, | ||
169 | + ], decimal=6) |
retinaNet/tests/utils/test_transform.py
0 → 100644
1 | +import numpy as np | ||
2 | +from numpy.testing import assert_almost_equal | ||
3 | +from math import pi | ||
4 | + | ||
5 | +from keras_retinanet.utils.transform import ( | ||
6 | + colvec, | ||
7 | + transform_aabb, | ||
8 | + rotation, random_rotation, | ||
9 | + translation, random_translation, | ||
10 | + scaling, random_scaling, | ||
11 | + shear, random_shear, | ||
12 | + random_flip, | ||
13 | + random_transform, | ||
14 | + random_transform_generator, | ||
15 | + change_transform_origin, | ||
16 | +) | ||
17 | + | ||
18 | + | ||
19 | +def test_colvec(): | ||
20 | + assert np.array_equal(colvec(0), np.array([[0]])) | ||
21 | + assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]])) | ||
22 | + assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]])) | ||
23 | + | ||
24 | + | ||
25 | +def test_rotation(): | ||
26 | + assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1))) | ||
27 | + assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1))) | ||
28 | + assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1))) | ||
29 | + assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1))) | ||
30 | + assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1))) | ||
31 | + | ||
32 | + assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1))) | ||
33 | + assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1))) | ||
34 | + assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1))) | ||
35 | + assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1))) | ||
36 | + assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1))) | ||
37 | + | ||
38 | + | ||
39 | +def test_random_rotation(): | ||
40 | + prng = np.random.RandomState(0) | ||
41 | + for i in range(100): | ||
42 | + assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng))) | ||
43 | + | ||
44 | + | ||
45 | +def test_translation(): | ||
46 | + assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1))) | ||
47 | + assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1))) | ||
48 | + assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1))) | ||
49 | + | ||
50 | + | ||
51 | +def assert_is_translation(transform, min, max): | ||
52 | + assert transform.shape == (3, 3) | ||
53 | + assert np.array_equal(transform[:, 0:2], np.eye(3, 2)) | ||
54 | + assert transform[2, 2] == 1 | ||
55 | + assert np.greater_equal(transform[0:2, 2], min).all() | ||
56 | + assert np.less( transform[0:2, 2], max).all() | ||
57 | + | ||
58 | + | ||
59 | +def test_random_translation(): | ||
60 | + prng = np.random.RandomState(0) | ||
61 | + min = (-10, -20) | ||
62 | + max = (20, 10) | ||
63 | + for i in range(100): | ||
64 | + assert_is_translation(random_translation(min, max, prng), min, max) | ||
65 | + | ||
66 | + | ||
67 | +def test_shear(): | ||
68 | + assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1))) | ||
69 | + assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1))) | ||
70 | + assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1))) | ||
71 | + assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1))) | ||
72 | + assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1))) | ||
73 | + | ||
74 | + | ||
75 | +def assert_is_shear(transform): | ||
76 | + assert transform.shape == (3, 3) | ||
77 | + assert np.array_equal(transform[:, 0], [1, 0, 0]) | ||
78 | + assert np.array_equal(transform[:, 2], [0, 0, 1]) | ||
79 | + assert transform[2, 1] == 0 | ||
80 | + # sin^2 + cos^2 == 1 | ||
81 | + assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2) | ||
82 | + | ||
83 | + | ||
84 | +def test_random_shear(): | ||
85 | + prng = np.random.RandomState(0) | ||
86 | + for i in range(100): | ||
87 | + assert_is_shear(random_shear(-pi, pi, prng)) | ||
88 | + | ||
89 | + | ||
90 | +def test_scaling(): | ||
91 | + assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1))) | ||
92 | + assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1))) | ||
93 | + assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1))) | ||
94 | + assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1))) | ||
95 | + | ||
96 | + | ||
97 | +def assert_is_scaling(transform, min, max): | ||
98 | + assert transform.shape == (3, 3) | ||
99 | + assert np.array_equal(transform[2, :], [0, 0, 1]) | ||
100 | + assert np.array_equal(transform[:, 2], [0, 0, 1]) | ||
101 | + assert transform[1, 0] == 0 | ||
102 | + assert transform[0, 1] == 0 | ||
103 | + assert np.greater_equal(np.diagonal(transform)[:2], min).all() | ||
104 | + assert np.less( np.diagonal(transform)[:2], max).all() | ||
105 | + | ||
106 | + | ||
107 | +def test_random_scaling(): | ||
108 | + prng = np.random.RandomState(0) | ||
109 | + min = (0.1, 0.2) | ||
110 | + max = (20, 10) | ||
111 | + for i in range(100): | ||
112 | + assert_is_scaling(random_scaling(min, max, prng), min, max) | ||
113 | + | ||
114 | + | ||
115 | +def assert_is_flip(transform): | ||
116 | + assert transform.shape == (3, 3) | ||
117 | + assert np.array_equal(transform[2, :], [0, 0, 1]) | ||
118 | + assert np.array_equal(transform[:, 2], [0, 0, 1]) | ||
119 | + assert transform[1, 0] == 0 | ||
120 | + assert transform[0, 1] == 0 | ||
121 | + assert abs(transform[0, 0]) == 1 | ||
122 | + assert abs(transform[1, 1]) == 1 | ||
123 | + | ||
124 | + | ||
125 | +def test_random_flip(): | ||
126 | + prng = np.random.RandomState(0) | ||
127 | + for i in range(100): | ||
128 | + assert_is_flip(random_flip(0.5, 0.5, prng)) | ||
129 | + | ||
130 | + | ||
131 | +def test_random_transform(): | ||
132 | + prng = np.random.RandomState(0) | ||
133 | + for i in range(100): | ||
134 | + transform = random_transform(prng=prng) | ||
135 | + assert np.array_equal(transform, np.identity(3)) | ||
136 | + | ||
137 | + for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())): | ||
138 | + assert np.array_equal(transform, np.identity(3)) | ||
139 | + | ||
140 | + | ||
141 | +def test_transform_aabb(): | ||
142 | + assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4])) | ||
143 | + assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4])) | ||
144 | + assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4])) | ||
145 | + | ||
146 | + | ||
147 | +def test_change_transform_origin(): | ||
148 | + assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4])) | ||
149 | + assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1))) | ||
150 | + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1))) | ||
151 | + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1))) |
-
Please register or login to post a comment