Showing
16 changed files
with
243 additions
and
190 deletions
... | @@ -13,3 +13,62 @@ | ... | @@ -13,3 +13,62 @@ |
13 | - 2015104213 장수창 | 13 | - 2015104213 장수창 |
14 | 14 | ||
15 | 15 | ||
16 | + | ||
17 | +### 준비 사항 | ||
18 | +tensorflow-android 라이브러리의 최신 버전이 (2020.06.01 기준) 1.13.1입니다. | ||
19 | +따라서 android implementation까지 구현하는 경우에는 | ||
20 | +상위 버전과 호환이 되도록 라이브러리를 빌드하거나, 학습 혹은 pb 파일 생성 또한 tensorflow v1.13.1 이하로 진행하셔야 합니다. | ||
21 | + | ||
22 | +annotation에는 labelImg 툴을 이용하여 xml을 생성하였습니다. | ||
23 | + | ||
24 | +학습에는 TFrecord 형태로 저장된 파일을 사용합니다. | ||
25 | +데이터 하나의 형식은 {data index, image binary, image width, image height, boxes}이며 | ||
26 | +boxes의 형식은 {label1, xmin, ymin, xmax, ymax, label2, xmin, ...}입니다. | ||
27 | +TFRecord 파일 작성은 code/tfrecord_writer.py를 참고하시기 바랍니다. | ||
28 | + | ||
29 | +tfrecord_writer.py에서 입력으로 받는 txt 파일은 | ||
30 | +각 라인마다 {data index, image path, image width, image height, boxes} 형태로 저장되어 있습니다. | ||
31 | +txt 파일 생성은 code/annotation_xml_parser.py를 참고하시기 바랍니다. | ||
32 | + | ||
33 | +이 학습에서는 train/eval/test 데이터셋을 구분하여 사용합니다. | ||
34 | +txt 파일에 대한 데이터셋 분리는 code/dataset_splitter.py를 참고하기시 바랍니다. | ||
35 | + | ||
36 | +annotation_xml_parser.py에서 입력으로 받는 xml 파일은 | ||
37 | +labelImg 툴로 생성된 Pascal VOC format XML 파일을 기준으로 합니다. | ||
38 | + | ||
39 | +학습을 위해서 anchor 파일이 필요합니다. | ||
40 | +anchor 파일 생성에는 code/yolov3/get_kmeans.py를 참고하시기 바랍니다. | ||
41 | +출력된 anchor를 code/yolov3/args.py의 anchor_path에 맞는 위치에 저장하시면 됩니다. | ||
42 | + | ||
43 | +이 학습에서는 pretrained model을 불러와 fine tuning을 이용합니다. | ||
44 | +따라서 pretrained model 파일을 준비해야 합니다. | ||
45 | +pretrained model은 [링크](https://pjreddie.com/media/files/yolov3.weights)에서 다운로드할 수 있습니다. | ||
46 | +이 파일은 darknet weights 파일이므로, tensorflow model로 변환하려면 code/yolov3/convert_weights.py를 참고하시기 바랍니다. | ||
47 | +(git에는 이미 변환된 yolov3.ckpt만이 업로드되어 있습니다. 다른 데이터셋 혹은 다른 용도로 학습을 진행하려면 새로 생성하셔야 합니다.) | ||
48 | + | ||
49 | +학습에는 train.py (train/eval dataset)를, 평가에는 eval.py (test dataset)를 사용하시면 됩니다. | ||
50 | +학습에 사용하는 파일의 경로 및 hyper parameter 설정은 args.py를 참고하시기 바랍니다. | ||
51 | +평가에 대한 경로 설정은 eval.py에서 할 수 있습니다. | ||
52 | + | ||
53 | +data/trained에 임시 테스트용 trained model 파일이 업로드되어 있습니다. | ||
54 | + | ||
55 | + | ||
56 | +android implementation을 하는 경우에는 학습된 모델에 대한 pb 파일을 생성해야 합니다. | ||
57 | +code/pb/pbCreator.py를 참고하시기 바랍니다. (code/yolov3/test_single_image.py를 약간 수정한 파일입니다) | ||
58 | + | ||
59 | +android에서는 freeze된 model만 사용할 수 있습니다. | ||
60 | +code/pb/freeze_pb.py를 참고하시기 바랍니다. | ||
61 | + | ||
62 | +android_App/assets에 pb file을 저장한 후, DetectorActivity.java에서 YOLO_MODEL_FILE의 값을 알맞게 수정하시면 됩니다. | ||
63 | + | ||
64 | +이 학습 코드로 생성된 모델의 input, output node name은 | ||
65 | +각각 input_data, {yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3} 입니다. | ||
66 | +모델의 node name 참고에는 Netron 프로그램을 사용하였습니다. | ||
67 | + | ||
68 | + | ||
69 | +#### Reference | ||
70 | +학습 코드는 [링크](https://github.com/wizyoung/YOLOv3_TensorFlow)를 기반으로 작셩하였습니다. | ||
71 | +변경점은 code/yolov3/changes.txt를 참고하시기 바랍니다. | ||
72 | + | ||
73 | +android 코드는 [링크](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android)를 기반으로 작성하였습니다. | ||
74 | + | ... | ... |
No preview for this file type
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<project version="4"> | ||
3 | + <component name="RemoteRepositoriesConfiguration"> | ||
4 | + <remote-repository> | ||
5 | + <option name="id" value="central" /> | ||
6 | + <option name="name" value="Maven Central repository" /> | ||
7 | + <option name="url" value="https://repo1.maven.org/maven2" /> | ||
8 | + </remote-repository> | ||
9 | + <remote-repository> | ||
10 | + <option name="id" value="jboss.community" /> | ||
11 | + <option name="name" value="JBoss Community repository" /> | ||
12 | + <option name="url" value="https://repository.jboss.org/nexus/content/repositories/public/" /> | ||
13 | + </remote-repository> | ||
14 | + <remote-repository> | ||
15 | + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\android\m2repository" /> | ||
16 | + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\android\m2repository" /> | ||
17 | + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/android/m2repository" /> | ||
18 | + </remote-repository> | ||
19 | + <remote-repository> | ||
20 | + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\m2repository" /> | ||
21 | + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\m2repository" /> | ||
22 | + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/m2repository" /> | ||
23 | + </remote-repository> | ||
24 | + <remote-repository> | ||
25 | + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\google\m2repository" /> | ||
26 | + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\google\m2repository" /> | ||
27 | + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/google/m2repository" /> | ||
28 | + </remote-repository> | ||
29 | + <remote-repository> | ||
30 | + <option name="id" value="BintrayJCenter" /> | ||
31 | + <option name="name" value="BintrayJCenter" /> | ||
32 | + <option name="url" value="https://jcenter.bintray.com/" /> | ||
33 | + </remote-repository> | ||
34 | + <remote-repository> | ||
35 | + <option name="id" value="Google" /> | ||
36 | + <option name="name" value="Google" /> | ||
37 | + <option name="url" value="https://dl.google.com/dl/android/maven2/" /> | ||
38 | + </remote-repository> | ||
39 | + </component> | ||
40 | +</project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -25,21 +25,10 @@ | ... | @@ -25,21 +25,10 @@ |
25 | <uses-permission android:name="android.permission.RECORD_AUDIO" /> | 25 | <uses-permission android:name="android.permission.RECORD_AUDIO" /> |
26 | 26 | ||
27 | <application android:allowBackup="true" | 27 | <application android:allowBackup="true" |
28 | - android:debuggable="true" | ||
29 | android:label="@string/app_name" | 28 | android:label="@string/app_name" |
30 | android:icon="@drawable/ic_launcher" | 29 | android:icon="@drawable/ic_launcher" |
31 | android:theme="@style/MaterialTheme"> | 30 | android:theme="@style/MaterialTheme"> |
32 | 31 | ||
33 | -<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"--> | ||
34 | -<!-- android:screenOrientation="portrait"--> | ||
35 | -<!-- android:label="@string/activity_name_classification">--> | ||
36 | -<!-- <intent-filter>--> | ||
37 | -<!-- <action android:name="android.intent.action.MAIN" />--> | ||
38 | -<!-- <category android:name="android.intent.category.LAUNCHER" />--> | ||
39 | -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | ||
40 | -<!-- </intent-filter>--> | ||
41 | -<!-- </activity>--> | ||
42 | - | ||
43 | <activity android:name="org.tensorflow.demo.DetectorActivity" | 32 | <activity android:name="org.tensorflow.demo.DetectorActivity" |
44 | android:screenOrientation="portrait" | 33 | android:screenOrientation="portrait" |
45 | android:label="@string/activity_name_detection"> | 34 | android:label="@string/activity_name_detection"> |
... | @@ -50,25 +39,38 @@ | ... | @@ -50,25 +39,38 @@ |
50 | </intent-filter> | 39 | </intent-filter> |
51 | </activity> | 40 | </activity> |
52 | 41 | ||
53 | -<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"--> | 42 | + <!-- |
54 | -<!-- android:screenOrientation="portrait"--> | 43 | + <activity android:name="org.tensorflow.demo.ClassifierActivity" |
55 | -<!-- android:label="@string/activity_name_stylize">--> | 44 | + android:screenOrientation="portrait" |
56 | -<!-- <intent-filter>--> | 45 | + android:label="@string/activity_name_classification"> |
57 | -<!-- <action android:name="android.intent.action.MAIN" />--> | 46 | + <intent-filter> |
58 | -<!-- <category android:name="android.intent.category.LAUNCHER" />--> | 47 | + <action android:name="android.intent.action.MAIN" /> |
59 | -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | 48 | + <category android:name="android.intent.category.LAUNCHER" /> |
60 | -<!-- </intent-filter>--> | 49 | + <category android:name="android.intent.category.LEANBACK_LAUNCHER" /> |
61 | -<!-- </activity>--> | 50 | + </intent-filter> |
51 | + </activity> | ||
52 | + | ||
53 | + <activity android:name="org.tensorflow.demo.StylizeActivity" | ||
54 | + android:screenOrientation="portrait" | ||
55 | + android:label="@string/activity_name_stylize"> | ||
56 | + <intent-filter> | ||
57 | + <action android:name="android.intent.action.MAIN" /> | ||
58 | + <category android:name="android.intent.category.LAUNCHER" /> | ||
59 | + <category android:name="android.intent.category.LEANBACK_LAUNCHER" /> | ||
60 | + </intent-filter> | ||
61 | + </activity> | ||
62 | + | ||
63 | + <activity android:name="org.tensorflow.demo.SpeechActivity" | ||
64 | + android:screenOrientation="portrait" | ||
65 | + android:label="@string/activity_name_speech"> | ||
66 | + <intent-filter> | ||
67 | + <action android:name="android.intent.action.MAIN" /> | ||
68 | + <category android:name="android.intent.category.LAUNCHER" /> | ||
69 | + <category android:name="android.intent.category.LEANBACK_LAUNCHER" /> | ||
70 | + </intent-filter> | ||
71 | + </activity> | ||
72 | + --> | ||
62 | 73 | ||
63 | -<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"--> | ||
64 | -<!-- android:screenOrientation="portrait"--> | ||
65 | -<!-- android:label="@string/activity_name_speech">--> | ||
66 | -<!-- <intent-filter>--> | ||
67 | -<!-- <action android:name="android.intent.action.MAIN" />--> | ||
68 | -<!-- <category android:name="android.intent.category.LAUNCHER" />--> | ||
69 | -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | ||
70 | -<!-- </intent-filter>--> | ||
71 | -<!-- </activity>--> | ||
72 | </application> | 74 | </application> |
73 | 75 | ||
74 | </manifest> | 76 | </manifest> | ... | ... |
android/android_App/assets/BUILD
deleted
100644 → 0
1 | -package( | ||
2 | - default_visibility = ["//visibility:public"], | ||
3 | - licenses = ["notice"], # Apache 2.0 | ||
4 | -) | ||
5 | - | ||
6 | -# It is necessary to use this filegroup rather than globbing the files in this | ||
7 | -# folder directly the examples/android:tensorflow_demo target due to the fact | ||
8 | -# that assets_dir is necessarily set to "" there (to allow using other | ||
9 | -# arbitrary targets as assets). | ||
10 | -filegroup( | ||
11 | - name = "asset_files", | ||
12 | - srcs = glob( | ||
13 | - ["**/*"], | ||
14 | - exclude = ["BUILD"], | ||
15 | - ), | ||
16 | -) |
android/android_App/assets/yolov3.pb
deleted
100644 → 0
This file is too large to display.
... | @@ -42,7 +42,7 @@ allprojects { | ... | @@ -42,7 +42,7 @@ allprojects { |
42 | } | 42 | } |
43 | 43 | ||
44 | // set to 'bazel', 'cmake', 'makefile', 'none' | 44 | // set to 'bazel', 'cmake', 'makefile', 'none' |
45 | -def nativeBuildSystem = 'none' | 45 | +def nativeBuildSystem = 'cmake' |
46 | 46 | ||
47 | // Controls output directory in APK and CPU type for Bazel builds. | 47 | // Controls output directory in APK and CPU type for Bazel builds. |
48 | // NOTE: Does not affect the Makefile build target API (yet), which currently | 48 | // NOTE: Does not affect the Makefile build target API (yet), which currently | ... | ... |
android/android_App/gradle.properties
0 → 100644
1 | +org.gradle.jvmargs=-Xmx2048m | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | -#Sat Nov 18 15:06:47 CET 2017 | 1 | +#Sat May 30 18:49:07 KST 2020 |
2 | distributionBase=GRADLE_USER_HOME | 2 | distributionBase=GRADLE_USER_HOME |
3 | distributionPath=wrapper/dists | 3 | distributionPath=wrapper/dists |
4 | zipStoreBase=GRADLE_USER_HOME | 4 | zipStoreBase=GRADLE_USER_HOME |
5 | zipStorePath=wrapper/dists | 5 | zipStorePath=wrapper/dists |
6 | -distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-all.zip | 6 | +distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.1-all.zip | ... | ... |
... | @@ -71,11 +71,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable | ... | @@ -71,11 +71,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable |
71 | // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via | 71 | // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via |
72 | // DarkFlow (https://github.com/thtrieu/darkflow). Sample command: | 72 | // DarkFlow (https://github.com/thtrieu/darkflow). Sample command: |
73 | // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise | 73 | // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise |
74 | - private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb"; | 74 | + private static final String YOLO_MODEL_FILE = "file:///android_asset/test_freeze_13.pb"; |
75 | private static final int YOLO_INPUT_SIZE = 416; | 75 | private static final int YOLO_INPUT_SIZE = 416; |
76 | - private static final String YOLO_INPUT_NAME = "input"; | 76 | + private static final String YOLO_INPUT_NAME = "input_data"; |
77 | - private static final String YOLO_OUTPUT_NAMES = "output"; | 77 | + private static final String YOLO_OUTPUT_NAMES = "yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3"; |
78 | - private static final int YOLO_BLOCK_SIZE = 32; | 78 | + private static final int YOLO_BLOCK_SIZE = 16; |
79 | 79 | ||
80 | // Which detection model to use: by default uses Tensorflow Object Detection API frozen | 80 | // Which detection model to use: by default uses Tensorflow Object Detection API frozen |
81 | // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) | 81 | // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) |
... | @@ -131,6 +131,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable | ... | @@ -131,6 +131,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable |
131 | int cropSize = TF_OD_API_INPUT_SIZE; | 131 | int cropSize = TF_OD_API_INPUT_SIZE; |
132 | if (MODE == DetectorMode.YOLO) { | 132 | if (MODE == DetectorMode.YOLO) { |
133 | detector = | 133 | detector = |
134 | + | ||
134 | TensorFlowYoloDetector.create( | 135 | TensorFlowYoloDetector.create( |
135 | getAssets(), | 136 | getAssets(), |
136 | YOLO_MODEL_FILE, | 137 | YOLO_MODEL_FILE, | ... | ... |
... | @@ -32,7 +32,7 @@ public class TensorFlowYoloDetector implements Classifier { | ... | @@ -32,7 +32,7 @@ public class TensorFlowYoloDetector implements Classifier { |
32 | private static final Logger LOGGER = new Logger(); | 32 | private static final Logger LOGGER = new Logger(); |
33 | 33 | ||
34 | // Only return this many results with at least this confidence. | 34 | // Only return this many results with at least this confidence. |
35 | - private static final int MAX_RESULTS = 5; | 35 | + private static final int MAX_RESULTS = 10; |
36 | 36 | ||
37 | private static final int NUM_CLASSES = 1; | 37 | private static final int NUM_CLASSES = 1; |
38 | 38 | ||
... | @@ -41,17 +41,14 @@ public class TensorFlowYoloDetector implements Classifier { | ... | @@ -41,17 +41,14 @@ public class TensorFlowYoloDetector implements Classifier { |
41 | // TODO(andrewharp): allow loading anchors and classes | 41 | // TODO(andrewharp): allow loading anchors and classes |
42 | // from files. | 42 | // from files. |
43 | private static final double[] ANCHORS = { | 43 | private static final double[] ANCHORS = { |
44 | - 1.08, 1.19, | 44 | + 35,37, 75,48, 57,87, 116,73, 83,138, 119,110, 154,184, 250,216, 317,362 |
45 | - 3.42, 4.41, | ||
46 | - 6.63, 11.38, | ||
47 | - 9.42, 5.11, | ||
48 | - 16.62, 10.52 | ||
49 | }; | 45 | }; |
50 | 46 | ||
51 | private static final String[] LABELS = { | 47 | private static final String[] LABELS = { |
52 | "dog" | 48 | "dog" |
53 | }; | 49 | }; |
54 | 50 | ||
51 | + | ||
55 | // Config values. | 52 | // Config values. |
56 | private String inputName; | 53 | private String inputName; |
57 | private int inputSize; | 54 | private int inputSize; | ... | ... |
android/freeze_graph.py
deleted
100644 → 0
1 | -# Copyright 2015 Google Inc. All Rights Reserved. | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | -# ============================================================================== | ||
15 | -"""Converts checkpoint variables into Const ops in a standalone GraphDef file. | ||
16 | -This script is designed to take a GraphDef proto, a SaverDef proto, and a set of | ||
17 | -variable values stored in a checkpoint file, and output a GraphDef with all of | ||
18 | -the variable ops converted into const ops containing the values of the | ||
19 | -variables. | ||
20 | -It's useful to do this when we need to load a single file in C++, especially in | ||
21 | -environments like mobile or embedded where we may not have access to the | ||
22 | -RestoreTensor ops and file loading calls that they rely on. | ||
23 | -An example of command-line usage is: | ||
24 | -bazel build tensorflow/python/tools:freeze_graph && \ | ||
25 | -bazel-bin/tensorflow/python/tools/freeze_graph \ | ||
26 | ---input_graph=some_graph_def.pb \ | ||
27 | ---input_checkpoint=model.ckpt-8361242 \ | ||
28 | ---output_graph=/tmp/frozen_graph.pb --output_node_names=softmax | ||
29 | -You can also look at freeze_graph_test.py for an example of how to use it. | ||
30 | -""" | ||
31 | -from __future__ import absolute_import | ||
32 | -from __future__ import division | ||
33 | -from __future__ import print_function | ||
34 | - | ||
35 | -import tensorflow as tf | ||
36 | - | ||
37 | -from google.protobuf import text_format | ||
38 | -from tensorflow.python.framework import graph_util | ||
39 | - | ||
40 | - | ||
41 | -FLAGS = tf.app.flags.FLAGS | ||
42 | - | ||
43 | -tf.app.flags.DEFINE_string("input_graph", "", | ||
44 | - """TensorFlow 'GraphDef' file to load.""") | ||
45 | -tf.app.flags.DEFINE_string("input_saver", "", | ||
46 | - """TensorFlow saver file to load.""") | ||
47 | -tf.app.flags.DEFINE_string("input_checkpoint", "", | ||
48 | - """TensorFlow variables file to load.""") | ||
49 | -tf.app.flags.DEFINE_string("output_graph", "", | ||
50 | - """Output 'GraphDef' file name.""") | ||
51 | -tf.app.flags.DEFINE_boolean("input_binary", False, | ||
52 | - """Whether the input files are in binary format.""") | ||
53 | -tf.app.flags.DEFINE_string("output_node_names", "", | ||
54 | - """The name of the output nodes, comma separated.""") | ||
55 | -tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all", | ||
56 | - """The name of the master restore operator.""") | ||
57 | -tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0", | ||
58 | - """The name of the tensor holding the save path.""") | ||
59 | -tf.app.flags.DEFINE_boolean("clear_devices", True, | ||
60 | - """Whether to remove device specifications.""") | ||
61 | -tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of " | ||
62 | - "initializer nodes to run before freezing.") | ||
63 | - | ||
64 | - | ||
65 | -def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, | ||
66 | - output_node_names, restore_op_name, filename_tensor_name, | ||
67 | - output_graph, clear_devices, initializer_nodes): | ||
68 | - """Converts all variables in a graph and checkpoint into constants.""" | ||
69 | - | ||
70 | - if not tf.gfile.Exists(input_graph): | ||
71 | - print("Input graph file '" + input_graph + "' does not exist!") | ||
72 | - return -1 | ||
73 | - | ||
74 | - if input_saver and not tf.gfile.Exists(input_saver): | ||
75 | - print("Input saver file '" + input_saver + "' does not exist!") | ||
76 | - return -1 | ||
77 | - | ||
78 | - if not tf.gfile.Glob(input_checkpoint): | ||
79 | - print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") | ||
80 | - return -1 | ||
81 | - | ||
82 | - if not output_node_names: | ||
83 | - print("You need to supply the name of a node to --output_node_names.") | ||
84 | - return -1 | ||
85 | - | ||
86 | - input_graph_def = tf.GraphDef() | ||
87 | - mode = "rb" if input_binary else "r" | ||
88 | - with tf.gfile.FastGFile(input_graph, mode) as f: | ||
89 | - if input_binary: | ||
90 | - input_graph_def.ParseFromString(f.read()) | ||
91 | - else: | ||
92 | - text_format.Merge(f.read(), input_graph_def) | ||
93 | - # Remove all the explicit device specifications for this node. This helps to | ||
94 | - # make the graph more portable. | ||
95 | - if clear_devices: | ||
96 | - for node in input_graph_def.node: | ||
97 | - node.device = "" | ||
98 | - _ = tf.import_graph_def(input_graph_def, name="") | ||
99 | - | ||
100 | - with tf.Session() as sess: | ||
101 | - if input_saver: | ||
102 | - with tf.gfile.FastGFile(input_saver, mode) as f: | ||
103 | - saver_def = tf.train.SaverDef() | ||
104 | - if input_binary: | ||
105 | - saver_def.ParseFromString(f.read()) | ||
106 | - else: | ||
107 | - text_format.Merge(f.read(), saver_def) | ||
108 | - saver = tf.train.Saver(saver_def=saver_def) | ||
109 | - saver.restore(sess, input_checkpoint) | ||
110 | - else: | ||
111 | - sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) | ||
112 | - if initializer_nodes: | ||
113 | - sess.run(initializer_nodes) | ||
114 | - output_graph_def = graph_util.convert_variables_to_constants( | ||
115 | - sess, input_graph_def, output_node_names.split(",")) | ||
116 | - | ||
117 | - with tf.gfile.GFile(output_graph, "wb") as f: | ||
118 | - f.write(output_graph_def.SerializeToString()) | ||
119 | - print("%d ops in the final graph." % len(output_graph_def.node)) | ||
120 | - | ||
121 | - | ||
122 | -def main(unused_args): | ||
123 | - freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, | ||
124 | - FLAGS.input_checkpoint, FLAGS.output_node_names, | ||
125 | - FLAGS.restore_op_name, FLAGS.filename_tensor_name, | ||
126 | - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) | ||
127 | - | ||
128 | -if __name__ == "__main__": | ||
129 | - tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pb/freeze_pb.py
0 → 100644
1 | +from tensorflow.python.tools import freeze_graph | ||
2 | + | ||
3 | +ckpt_filepath = '../../data/pb/pb.ckpt' | ||
4 | +pbtxt_filename = 'model.pbtxt' | ||
5 | +pbtxt_filepath = '../../data/pb/model.pbtxt' | ||
6 | +pb_filepath = '../../data/pb/freeze.pb' | ||
7 | + | ||
8 | +freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names='yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='') |
code/pb/pbCreator.py
0 → 100644
1 | +from __future__ import division, print_function | ||
2 | + | ||
3 | +import tensorflow as tf | ||
4 | +import numpy as np | ||
5 | +import argparse | ||
6 | +import cv2 | ||
7 | + | ||
8 | +from misc_utils import parse_anchors, read_class_names | ||
9 | +from nms_utils import gpu_nms | ||
10 | +from plot_utils import get_color_table, plot_one_box | ||
11 | +from data_utils import letterbox_resize | ||
12 | + | ||
13 | +from model import yolov3 | ||
14 | + | ||
15 | +parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.") | ||
16 | +parser.add_argument("input_image", type=str, | ||
17 | + help="The path of the input image.") | ||
18 | +parser.add_argument("--anchor_path", type=str, default="../../data/yolo_anchors.txt", | ||
19 | + help="The path of the anchor txt file.") | ||
20 | +parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], | ||
21 | + help="Resize the input image with `new_size`, size format: [width, height]") | ||
22 | +parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, | ||
23 | + help="Whether to use the letterbox resize.") | ||
24 | +parser.add_argument("--class_name_path", type=str, default="../../data/classes.txt", | ||
25 | + help="The path of the class names.") | ||
26 | +parser.add_argument("--restore_path", type=str, default="../../data/darknet_weights/yolov3.ckpt", | ||
27 | + help="The path of the weights to restore.") | ||
28 | +parser.add_argument("--pb_path", type=str, default="../../data/pb", | ||
29 | + help="The directory of pb files") | ||
30 | +args = parser.parse_args() | ||
31 | + | ||
32 | +args.anchors = parse_anchors(args.anchor_path) | ||
33 | +args.classes = read_class_names(args.class_name_path) | ||
34 | +args.num_class = len(args.classes) | ||
35 | + | ||
36 | +color_table = get_color_table(args.num_class) | ||
37 | + | ||
38 | +img_ori = cv2.imread(args.input_image) | ||
39 | +if args.letterbox_resize: | ||
40 | + img, resize_ratio, dw, dh = letterbox_resize(img_ori, args.new_size[0], args.new_size[1]) | ||
41 | +else: | ||
42 | + height_ori, width_ori = img_ori.shape[:2] | ||
43 | + img = cv2.resize(img_ori, tuple(args.new_size)) | ||
44 | +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
45 | +img = np.asarray(img, np.float32) | ||
46 | +img = img[np.newaxis, :] / 255. | ||
47 | + | ||
48 | +graph = tf.Graph() | ||
49 | +with tf.Session(graph=graph) as sess: | ||
50 | + input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data') | ||
51 | + yolo_model = yolov3(args.num_class, args.anchors) | ||
52 | + with tf.variable_scope('yolov3'): | ||
53 | + pred_feature_maps = yolo_model.forward(input_data, False) | ||
54 | + pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps) | ||
55 | + | ||
56 | + pred_scores = pred_confs * pred_probs | ||
57 | + | ||
58 | + boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=200, score_thresh=0.3, nms_thresh=0.45) | ||
59 | + | ||
60 | + saver = tf.train.Saver() | ||
61 | + saver.restore(sess, args.restore_path) | ||
62 | + | ||
63 | + boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img}) | ||
64 | + | ||
65 | + if args.letterbox_resize: | ||
66 | + boxes_[:, [0, 2]] = (boxes_[:, [0, 2]] - dw) / resize_ratio | ||
67 | + boxes_[:, [1, 3]] = (boxes_[:, [1, 3]] - dh) / resize_ratio | ||
68 | + else: | ||
69 | + boxes_[:, [0, 2]] *= (width_ori/float(args.new_size[0])) | ||
70 | + boxes_[:, [1, 3]] *= (height_ori/float(args.new_size[1])) | ||
71 | + | ||
72 | + print("box coords:") | ||
73 | + print(boxes_) | ||
74 | + print('*' * 30) | ||
75 | + print("scores:") | ||
76 | + print(scores_) | ||
77 | + print('*' * 30) | ||
78 | + print("labels:") | ||
79 | + print(labels_) | ||
80 | + | ||
81 | + for i in range(len(boxes_)): | ||
82 | + x0, y0, x1, y1 = boxes_[i] | ||
83 | + plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]] + ', {:.2f}%'.format(scores_[i] * 100), color=color_table[labels_[i]]) | ||
84 | + cv2.imshow('Detection result', img_ori) | ||
85 | + cv2.imwrite('detection_result.jpg', img_ori) | ||
86 | + cv2.waitKey(0) | ||
87 | + | ||
88 | + saver.save(sess, args.pb_path+'/pb.ckpt') | ||
89 | + tf.io.write_graph(sess.graph_def, args.pb_path, 'model.pb', as_text=False) | ||
90 | + tf.io.write_graph(sess.graph_def, args.pb_path, 'model.pbtxt', as_text=True) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -15,15 +15,15 @@ from model import yolov3 | ... | @@ -15,15 +15,15 @@ from model import yolov3 |
15 | parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.") | 15 | parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.") |
16 | parser.add_argument("input_image", type=str, | 16 | parser.add_argument("input_image", type=str, |
17 | help="The path of the input image.") | 17 | help="The path of the input image.") |
18 | -parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", | 18 | +parser.add_argument("--anchor_path", type=str, default="../../data/yolo_anchors.txt", |
19 | help="The path of the anchor txt file.") | 19 | help="The path of the anchor txt file.") |
20 | parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], | 20 | parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], |
21 | help="Resize the input image with `new_size`, size format: [width, height]") | 21 | help="Resize the input image with `new_size`, size format: [width, height]") |
22 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, | 22 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, |
23 | help="Whether to use the letterbox resize.") | 23 | help="Whether to use the letterbox resize.") |
24 | -parser.add_argument("--class_name_path", type=str, default="./data/coco.names", | 24 | +parser.add_argument("--class_name_path", type=str, default="../../data/classes.txt", |
25 | help="The path of the class names.") | 25 | help="The path of the class names.") |
26 | -parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt", | 26 | +parser.add_argument("--restore_path", type=str, default="../../data/darknet_weights/yolov3.ckpt", |
27 | help="The path of the weights to restore.") | 27 | help="The path of the weights to restore.") |
28 | args = parser.parse_args() | 28 | args = parser.parse_args() |
29 | 29 | ... | ... |
data/trained/yolov3.pb
deleted
100644 → 0
This file is too large to display.
-
Please register or login to post a comment