Showing
6 changed files
with
509 additions
and
3 deletions
1 | apply plugin: 'com.android.application' | 1 | apply plugin: 'com.android.application' |
2 | +apply plugin: 'kotlin-android-extensions' | ||
3 | +apply plugin: 'kotlin-android' | ||
2 | 4 | ||
3 | android { | 5 | android { |
4 | compileSdkVersion 29 | 6 | compileSdkVersion 29 |
... | @@ -17,6 +19,9 @@ android { | ... | @@ -17,6 +19,9 @@ android { |
17 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' | 19 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' |
18 | } | 20 | } |
19 | } | 21 | } |
22 | + aaptOptions { | ||
23 | + noCompress "tflite" | ||
24 | + } | ||
20 | } | 25 | } |
21 | 26 | ||
22 | dependencies { | 27 | dependencies { |
... | @@ -26,4 +31,11 @@ dependencies { | ... | @@ -26,4 +31,11 @@ dependencies { |
26 | testImplementation 'junit:junit:4.12' | 31 | testImplementation 'junit:junit:4.12' |
27 | androidTestImplementation 'androidx.test.ext:junit:1.1.0' | 32 | androidTestImplementation 'androidx.test.ext:junit:1.1.0' |
28 | androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' | 33 | androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' |
34 | + compile "androidx.core:core-ktx:+" | ||
35 | + implementation 'org.tensorflow:tensorflow-lite:2.2.0' | ||
36 | + implementation 'org.tensorflow:tensorflow-lite-gpu:2.2.0' | ||
37 | + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" | ||
38 | +} | ||
39 | +repositories { | ||
40 | + mavenCentral() | ||
29 | } | 41 | } | ... | ... |
This file is too large to display.
... | @@ -7,25 +7,39 @@ import androidx.core.app.ActivityCompat; | ... | @@ -7,25 +7,39 @@ import androidx.core.app.ActivityCompat; |
7 | import android.Manifest; | 7 | import android.Manifest; |
8 | import android.content.Context; | 8 | import android.content.Context; |
9 | import android.content.pm.PackageManager; | 9 | import android.content.pm.PackageManager; |
10 | +import android.graphics.Bitmap; | ||
11 | +import android.graphics.Canvas; | ||
12 | +import android.graphics.Color; | ||
13 | +import android.graphics.ImageFormat; | ||
14 | +import android.graphics.Matrix; | ||
15 | +import android.graphics.Paint; | ||
16 | +import android.graphics.Rect; | ||
10 | import android.hardware.camera2.CameraAccessException; | 17 | import android.hardware.camera2.CameraAccessException; |
11 | import android.hardware.camera2.CameraCaptureSession; | 18 | import android.hardware.camera2.CameraCaptureSession; |
12 | import android.hardware.camera2.CameraDevice; | 19 | import android.hardware.camera2.CameraDevice; |
13 | import android.hardware.camera2.CameraManager; | 20 | import android.hardware.camera2.CameraManager; |
14 | import android.hardware.camera2.CaptureRequest; | 21 | import android.hardware.camera2.CaptureRequest; |
22 | +import android.media.Image; | ||
23 | +import android.media.ImageReader; | ||
15 | import android.os.Bundle; | 24 | import android.os.Bundle; |
25 | +import android.util.Log; | ||
16 | import android.view.SurfaceHolder; | 26 | import android.view.SurfaceHolder; |
17 | import android.view.SurfaceView; | 27 | import android.view.SurfaceView; |
18 | import android.widget.Toast; | 28 | import android.widget.Toast; |
19 | 29 | ||
30 | +import java.nio.ByteBuffer; | ||
20 | import java.util.Arrays; | 31 | import java.util.Arrays; |
21 | 32 | ||
22 | -public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback { | 33 | +public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback, ImageReader.OnImageAvailableListener { |
23 | 34 | ||
24 | private static final int REQUEST_CAMERA = 1000; | 35 | private static final int REQUEST_CAMERA = 1000; |
25 | 36 | ||
26 | private SurfaceView surfaceView; | 37 | private SurfaceView surfaceView; |
27 | private CameraDevice camera; | 38 | private CameraDevice camera; |
28 | private CaptureRequest.Builder previewBuilder; | 39 | private CaptureRequest.Builder previewBuilder; |
40 | + private Posenet posenet; | ||
41 | + private ImageReader imageReader; | ||
42 | + private byte[][] yuvBytes = new byte[3][]; | ||
29 | 43 | ||
30 | @Override | 44 | @Override |
31 | protected void onCreate(Bundle savedInstanceState) { | 45 | protected void onCreate(Bundle savedInstanceState) { |
... | @@ -36,6 +50,18 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder | ... | @@ -36,6 +50,18 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder |
36 | surfaceView.getHolder().addCallback(this); | 50 | surfaceView.getHolder().addCallback(this); |
37 | } | 51 | } |
38 | 52 | ||
53 | + @Override | ||
54 | + protected void onStart() { | ||
55 | + super.onStart(); | ||
56 | + posenet = new Posenet(this, "posenet_model.tflite", Device.GPU); | ||
57 | + } | ||
58 | + | ||
59 | + @Override | ||
60 | + protected void onDestroy() { | ||
61 | + super.onDestroy(); | ||
62 | + posenet.close(); | ||
63 | + } | ||
64 | + | ||
39 | // 카메라 활성화 | 65 | // 카메라 활성화 |
40 | private void openCamera() { | 66 | private void openCamera() { |
41 | if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { | 67 | if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { |
... | @@ -99,10 +125,13 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder | ... | @@ -99,10 +125,13 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder |
99 | // 카메라 Capture 시작 | 125 | // 카메라 Capture 시작 |
100 | private void startCapture() { | 126 | private void startCapture() { |
101 | try { | 127 | try { |
128 | + imageReader = ImageReader.newInstance(640, 480, ImageFormat.YUV_420_888, 2); | ||
129 | + imageReader.setOnImageAvailableListener(this, null); | ||
130 | + | ||
102 | previewBuilder = camera.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); | 131 | previewBuilder = camera.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); |
103 | - previewBuilder.addTarget(surfaceView.getHolder().getSurface()); | 132 | + previewBuilder.addTarget(imageReader.getSurface()); |
104 | 133 | ||
105 | - camera.createCaptureSession(Arrays.asList(surfaceView.getHolder().getSurface()), new CameraCaptureSession.StateCallback() { | 134 | + camera.createCaptureSession(Arrays.asList(imageReader.getSurface()), new CameraCaptureSession.StateCallback() { |
106 | @Override | 135 | @Override |
107 | public void onConfigured(@NonNull CameraCaptureSession session) { | 136 | public void onConfigured(@NonNull CameraCaptureSession session) { |
108 | try { | 137 | try { |
... | @@ -125,4 +154,110 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder | ... | @@ -125,4 +154,110 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder |
125 | e.printStackTrace(); | 154 | e.printStackTrace(); |
126 | } | 155 | } |
127 | } | 156 | } |
157 | + | ||
158 | + @Override | ||
159 | + public void onImageAvailable(ImageReader reader) { | ||
160 | + Image image = reader.acquireLatestImage(); | ||
161 | + if(image == null) | ||
162 | + return; | ||
163 | + | ||
164 | + fillBytes(image.getPlanes(), yuvBytes); | ||
165 | + int[] rgbBytes = new int[640 * 480]; | ||
166 | + | ||
167 | + ImageUtils.INSTANCE.convertYUV420ToARGB8888(yuvBytes[0], yuvBytes[1], yuvBytes[2], | ||
168 | + 640, 480, | ||
169 | + image.getPlanes()[0].getRowStride(), | ||
170 | + image.getPlanes()[1].getRowStride(), | ||
171 | + image.getPlanes()[1].getPixelStride(), rgbBytes); | ||
172 | + | ||
173 | + Bitmap imageBitmap = Bitmap.createBitmap(rgbBytes, 640, 480, Bitmap.Config.ARGB_8888); | ||
174 | + | ||
175 | + Matrix rotateMatrix = new Matrix(); | ||
176 | + rotateMatrix.postRotate(90); | ||
177 | + | ||
178 | + Bitmap rotatedBitmap = Bitmap.createBitmap(imageBitmap, | ||
179 | + 0, 0, 640, 480, rotateMatrix, true); | ||
180 | + image.close(); | ||
181 | + | ||
182 | + processImage(rotatedBitmap); | ||
183 | + } | ||
184 | + | ||
185 | + private void fillBytes(Image.Plane[] planes, byte[][] yuvBytes) { | ||
186 | + // Row stride is the total number of bytes occupied in memory by a row of an image. | ||
187 | + // Because of the variable row stride it's not possible to know in | ||
188 | + // advance the actual necessary dimensions of the yuv planes. | ||
189 | + for (int i = 0; i < planes.length; i++) { | ||
190 | + ByteBuffer buffer = planes[i].getBuffer(); | ||
191 | + if (yuvBytes[i] == null) { | ||
192 | + yuvBytes[i] = new byte[buffer.capacity()]; | ||
193 | + } | ||
194 | + buffer.get(yuvBytes[i]); | ||
195 | + } | ||
196 | + } | ||
197 | + | ||
198 | + private Bitmap cropBitmap(Bitmap bitmap) { | ||
199 | + float bitmapRatio = (float) bitmap.getHeight() / bitmap.getWidth(); | ||
200 | + float modelInputRatio = 257.0f / 257.0f; | ||
201 | + Bitmap croppedBitmap = bitmap; | ||
202 | + | ||
203 | + // Acceptable difference between the modelInputRatio and bitmapRatio to skip cropping. | ||
204 | + double maxDifference = 1e-5; | ||
205 | + | ||
206 | + // Checks if the bitmap has similar aspect ratio as the required model input. | ||
207 | + if(Math.abs(modelInputRatio - bitmapRatio) < maxDifference) | ||
208 | + return croppedBitmap; | ||
209 | + | ||
210 | + if(modelInputRatio < bitmapRatio) { | ||
211 | + // New image is taller so we are height constrained. | ||
212 | + float cropHeight = bitmap.getHeight() - bitmap.getWidth() / modelInputRatio; | ||
213 | + croppedBitmap = Bitmap.createBitmap( | ||
214 | + bitmap, | ||
215 | + 0, | ||
216 | + (int) cropHeight / 2, | ||
217 | + bitmap.getWidth(), | ||
218 | + (int) (bitmap.getHeight() - cropHeight) | ||
219 | + ); | ||
220 | + } | ||
221 | + else { | ||
222 | + float cropWidth = bitmap.getWidth() - bitmap.getHeight() * modelInputRatio; | ||
223 | + croppedBitmap = Bitmap.createBitmap( | ||
224 | + bitmap, | ||
225 | + (int) (cropWidth / 2), | ||
226 | + 0, | ||
227 | + (int) (bitmap.getWidth() - cropWidth), | ||
228 | + bitmap.getHeight() | ||
229 | + ); | ||
230 | + } | ||
231 | + return croppedBitmap; | ||
232 | + } | ||
233 | + | ||
234 | + private void processImage(Bitmap bitmap) { | ||
235 | + Log.d("Capture", "Process"); | ||
236 | + | ||
237 | + // Crop bitmap. | ||
238 | + Bitmap croppedBitmap = cropBitmap(bitmap); | ||
239 | + | ||
240 | + // Created scaled version of bitmap for model input. | ||
241 | + Bitmap scaledBitmap = Bitmap.createScaledBitmap(croppedBitmap, 257, 257, true); | ||
242 | + | ||
243 | + // Perform inference. | ||
244 | + Person person = posenet.estimateSinglePose(scaledBitmap); | ||
245 | + | ||
246 | + Paint paint = new Paint(); | ||
247 | + Canvas canvas = surfaceView.getHolder().lockCanvas(); | ||
248 | + | ||
249 | + // 이미지 그리기 | ||
250 | + canvas.drawBitmap(croppedBitmap, new Rect(0, 0, croppedBitmap.getWidth(), croppedBitmap.getHeight()), new Rect(0, 0, canvas.getWidth(), canvas.getWidth()), paint); | ||
251 | + | ||
252 | + // Key points 그리기 | ||
253 | + paint.setColor(Color.RED); | ||
254 | + for(KeyPoint keyPoint : person.getKeyPoints()) { | ||
255 | + if(keyPoint.getScore() < 0.7) | ||
256 | + continue; | ||
257 | + | ||
258 | + canvas.drawCircle((float) keyPoint.getPosition().getX() / scaledBitmap.getWidth() * canvas.getWidth(), (float) keyPoint.getPosition().getY() / scaledBitmap.getWidth() * canvas.getWidth(), 5, paint); | ||
259 | + } | ||
260 | + | ||
261 | + surfaceView.getHolder().unlockCanvasAndPost(canvas); | ||
262 | + } | ||
128 | } | 263 | } | ... | ... |
1 | +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package com.khuhacker.pocketgym | ||
17 | + | ||
18 | +/** Utility class for manipulating images. */ | ||
19 | +object ImageUtils { | ||
20 | + // This value is 2 ^ 18 - 1, and is used to hold the RGB values together before their ranges | ||
21 | + // are normalized to eight bits. | ||
22 | + private const val MAX_CHANNEL_VALUE = 262143 | ||
23 | + | ||
24 | + /** Helper function to convert y,u,v integer values to RGB format */ | ||
25 | + private fun convertYUVToRGB(y: Int, u: Int, v: Int): Int { | ||
26 | + // Adjust and check YUV values | ||
27 | + val yNew = if (y - 16 < 0) 0 else y - 16 | ||
28 | + val uNew = u - 128 | ||
29 | + val vNew = v - 128 | ||
30 | + val expandY = 1192 * yNew | ||
31 | + var r = expandY + 1634 * vNew | ||
32 | + var g = expandY - 833 * vNew - 400 * uNew | ||
33 | + var b = expandY + 2066 * uNew | ||
34 | + | ||
35 | + // Clipping RGB values to be inside boundaries [ 0 , MAX_CHANNEL_VALUE ] | ||
36 | + val checkBoundaries = { x: Int -> | ||
37 | + when { | ||
38 | + x > MAX_CHANNEL_VALUE -> MAX_CHANNEL_VALUE | ||
39 | + x < 0 -> 0 | ||
40 | + else -> x | ||
41 | + } | ||
42 | + } | ||
43 | + r = checkBoundaries(r) | ||
44 | + g = checkBoundaries(g) | ||
45 | + b = checkBoundaries(b) | ||
46 | + return -0x1000000 or (r shl 6 and 0xff0000) or (g shr 2 and 0xff00) or (b shr 10 and 0xff) | ||
47 | + } | ||
48 | + | ||
49 | + /** Converts YUV420 format image data (ByteArray) into ARGB8888 format with IntArray as output. */ | ||
50 | + fun convertYUV420ToARGB8888( | ||
51 | + yData: ByteArray, | ||
52 | + uData: ByteArray, | ||
53 | + vData: ByteArray, | ||
54 | + width: Int, | ||
55 | + height: Int, | ||
56 | + yRowStride: Int, | ||
57 | + uvRowStride: Int, | ||
58 | + uvPixelStride: Int, | ||
59 | + out: IntArray | ||
60 | + ) { | ||
61 | + var outputIndex = 0 | ||
62 | + for (j in 0 until height) { | ||
63 | + val positionY = yRowStride * j | ||
64 | + val positionUV = uvRowStride * (j shr 1) | ||
65 | + | ||
66 | + for (i in 0 until width) { | ||
67 | + val uvOffset = positionUV + (i shr 1) * uvPixelStride | ||
68 | + | ||
69 | + // "0xff and" is used to cut off bits from following value that are higher than | ||
70 | + // the low 8 bits | ||
71 | + out[outputIndex] = convertYUVToRGB( | ||
72 | + 0xff and yData[positionY + i].toInt(), 0xff and uData[uvOffset].toInt(), | ||
73 | + 0xff and vData[uvOffset].toInt() | ||
74 | + ) | ||
75 | + outputIndex += 1 | ||
76 | + } | ||
77 | + } | ||
78 | + } | ||
79 | +} |
1 | +/* | ||
2 | + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | +package com.khuhacker.pocketgym | ||
17 | + | ||
18 | +import android.content.Context | ||
19 | +import android.graphics.Bitmap | ||
20 | +import android.os.SystemClock | ||
21 | +import android.util.Log | ||
22 | +import java.io.FileInputStream | ||
23 | +import java.nio.ByteBuffer | ||
24 | +import java.nio.ByteOrder | ||
25 | +import java.nio.MappedByteBuffer | ||
26 | +import java.nio.channels.FileChannel | ||
27 | +import kotlin.math.exp | ||
28 | +import org.tensorflow.lite.Interpreter | ||
29 | +import org.tensorflow.lite.gpu.GpuDelegate | ||
30 | + | ||
31 | +enum class BodyPart { | ||
32 | + NOSE, | ||
33 | + LEFT_EYE, | ||
34 | + RIGHT_EYE, | ||
35 | + LEFT_EAR, | ||
36 | + RIGHT_EAR, | ||
37 | + LEFT_SHOULDER, | ||
38 | + RIGHT_SHOULDER, | ||
39 | + LEFT_ELBOW, | ||
40 | + RIGHT_ELBOW, | ||
41 | + LEFT_WRIST, | ||
42 | + RIGHT_WRIST, | ||
43 | + LEFT_HIP, | ||
44 | + RIGHT_HIP, | ||
45 | + LEFT_KNEE, | ||
46 | + RIGHT_KNEE, | ||
47 | + LEFT_ANKLE, | ||
48 | + RIGHT_ANKLE | ||
49 | +} | ||
50 | + | ||
51 | +class Position { | ||
52 | + var x: Int = 0 | ||
53 | + var y: Int = 0 | ||
54 | +} | ||
55 | + | ||
56 | +class KeyPoint { | ||
57 | + var bodyPart: BodyPart = BodyPart.NOSE | ||
58 | + var position: Position = Position() | ||
59 | + var score: Float = 0.0f | ||
60 | +} | ||
61 | + | ||
62 | +class Person { | ||
63 | + var keyPoints = listOf<KeyPoint>() | ||
64 | + var score: Float = 0.0f | ||
65 | +} | ||
66 | + | ||
67 | +enum class Device { | ||
68 | + CPU, | ||
69 | + NNAPI, | ||
70 | + GPU | ||
71 | +} | ||
72 | + | ||
73 | +class Posenet( | ||
74 | + val context: Context, | ||
75 | + val filename: String = "posenet_model.tflite", | ||
76 | + val device: Device = Device.GPU | ||
77 | +) : AutoCloseable { | ||
78 | + var lastInferenceTimeNanos: Long = -1 | ||
79 | + private set | ||
80 | + | ||
81 | + /** An Interpreter for the TFLite model. */ | ||
82 | + private var interpreter: Interpreter? = null | ||
83 | + private var gpuDelegate: GpuDelegate? = null | ||
84 | + private val NUM_LITE_THREADS = 4 | ||
85 | + | ||
86 | + private fun getInterpreter(): Interpreter { | ||
87 | + if (interpreter != null) { | ||
88 | + return interpreter!! | ||
89 | + } | ||
90 | + val options = Interpreter.Options() | ||
91 | + options.setNumThreads(NUM_LITE_THREADS) | ||
92 | + when (device) { | ||
93 | + Device.CPU -> { } | ||
94 | + Device.GPU -> { | ||
95 | + gpuDelegate = GpuDelegate() | ||
96 | + options.addDelegate(gpuDelegate) | ||
97 | + } | ||
98 | + Device.NNAPI -> options.setUseNNAPI(true) | ||
99 | + } | ||
100 | + interpreter = Interpreter(loadModelFile(filename, context), options) | ||
101 | + return interpreter!! | ||
102 | + } | ||
103 | + | ||
104 | + override fun close() { | ||
105 | + interpreter?.close() | ||
106 | + interpreter = null | ||
107 | + gpuDelegate?.close() | ||
108 | + gpuDelegate = null | ||
109 | + } | ||
110 | + | ||
111 | + /** Returns value within [0,1]. */ | ||
112 | + private fun sigmoid(x: Float): Float { | ||
113 | + return (1.0f / (1.0f + exp(-x))) | ||
114 | + } | ||
115 | + | ||
116 | + /** | ||
117 | + * Scale the image to a byteBuffer of [-1,1] values. | ||
118 | + */ | ||
119 | + private fun initInputArray(bitmap: Bitmap): ByteBuffer { | ||
120 | + val bytesPerChannel = 4 | ||
121 | + val inputChannels = 3 | ||
122 | + val batchSize = 1 | ||
123 | + val inputBuffer = ByteBuffer.allocateDirect( | ||
124 | + batchSize * bytesPerChannel * bitmap.height * bitmap.width * inputChannels | ||
125 | + ) | ||
126 | + inputBuffer.order(ByteOrder.nativeOrder()) | ||
127 | + inputBuffer.rewind() | ||
128 | + | ||
129 | + val mean = 128.0f | ||
130 | + val std = 128.0f | ||
131 | + for (row in 0 until bitmap.height) { | ||
132 | + for (col in 0 until bitmap.width) { | ||
133 | + val pixelValue = bitmap.getPixel(col, row) | ||
134 | + inputBuffer.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std) | ||
135 | + inputBuffer.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std) | ||
136 | + inputBuffer.putFloat(((pixelValue and 0xFF) - mean) / std) | ||
137 | + } | ||
138 | + } | ||
139 | + return inputBuffer | ||
140 | + } | ||
141 | + | ||
142 | + /** Preload and memory map the model file, returning a MappedByteBuffer containing the model. */ | ||
143 | + private fun loadModelFile(path: String, context: Context): MappedByteBuffer { | ||
144 | + val fileDescriptor = context.assets.openFd(path) | ||
145 | + val inputStream = FileInputStream(fileDescriptor.fileDescriptor) | ||
146 | + return inputStream.channel.map( | ||
147 | + FileChannel.MapMode.READ_ONLY, fileDescriptor.startOffset, fileDescriptor.declaredLength | ||
148 | + ) | ||
149 | + } | ||
150 | + | ||
151 | + /** | ||
152 | + * Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate. | ||
153 | + */ | ||
154 | + private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> { | ||
155 | + val outputMap = HashMap<Int, Any>() | ||
156 | + | ||
157 | + // 1 * 9 * 9 * 17 contains heatmaps | ||
158 | + val heatmapsShape = interpreter.getOutputTensor(0).shape() | ||
159 | + outputMap[0] = Array(heatmapsShape[0]) { | ||
160 | + Array(heatmapsShape[1]) { | ||
161 | + Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) } | ||
162 | + } | ||
163 | + } | ||
164 | + | ||
165 | + // 1 * 9 * 9 * 34 contains offsets | ||
166 | + val offsetsShape = interpreter.getOutputTensor(1).shape() | ||
167 | + outputMap[1] = Array(offsetsShape[0]) { | ||
168 | + Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } } | ||
169 | + } | ||
170 | + | ||
171 | + // 1 * 9 * 9 * 32 contains forward displacements | ||
172 | + val displacementsFwdShape = interpreter.getOutputTensor(2).shape() | ||
173 | + outputMap[2] = Array(offsetsShape[0]) { | ||
174 | + Array(displacementsFwdShape[1]) { | ||
175 | + Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) } | ||
176 | + } | ||
177 | + } | ||
178 | + | ||
179 | + // 1 * 9 * 9 * 32 contains backward displacements | ||
180 | + val displacementsBwdShape = interpreter.getOutputTensor(3).shape() | ||
181 | + outputMap[3] = Array(displacementsBwdShape[0]) { | ||
182 | + Array(displacementsBwdShape[1]) { | ||
183 | + Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) } | ||
184 | + } | ||
185 | + } | ||
186 | + | ||
187 | + return outputMap | ||
188 | + } | ||
189 | + | ||
190 | + /** | ||
191 | + * Estimates the pose for a single person. | ||
192 | + * args: | ||
193 | + * bitmap: image bitmap of frame that should be processed | ||
194 | + * returns: | ||
195 | + * person: a Person object containing data about keypoint locations and confidence scores | ||
196 | + */ | ||
197 | + fun estimateSinglePose(bitmap: Bitmap): Person { | ||
198 | + val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos() | ||
199 | + val inputArray = arrayOf(initInputArray(bitmap)) | ||
200 | + Log.i( | ||
201 | + "posenet", | ||
202 | + String.format( | ||
203 | + "Scaling to [-1,1] took %.2f ms", | ||
204 | + 1.0f * (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000 | ||
205 | + ) | ||
206 | + ) | ||
207 | + | ||
208 | + val outputMap = initOutputMap(getInterpreter()) | ||
209 | + | ||
210 | + val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() | ||
211 | + getInterpreter().runForMultipleInputsOutputs(inputArray, outputMap) | ||
212 | + lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos | ||
213 | + Log.i( | ||
214 | + "posenet", | ||
215 | + String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000) | ||
216 | + ) | ||
217 | + | ||
218 | + val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>> | ||
219 | + val offsets = outputMap[1] as Array<Array<Array<FloatArray>>> | ||
220 | + | ||
221 | + val height = heatmaps[0].size | ||
222 | + val width = heatmaps[0][0].size | ||
223 | + val numKeypoints = heatmaps[0][0][0].size | ||
224 | + | ||
225 | + // Finds the (row, col) locations of where the keypoints are most likely to be. | ||
226 | + val keypointPositions = Array(numKeypoints) { Pair(0, 0) } | ||
227 | + for (keypoint in 0 until numKeypoints) { | ||
228 | + var maxVal = heatmaps[0][0][0][keypoint] | ||
229 | + var maxRow = 0 | ||
230 | + var maxCol = 0 | ||
231 | + for (row in 0 until height) { | ||
232 | + for (col in 0 until width) { | ||
233 | + if (heatmaps[0][row][col][keypoint] > maxVal) { | ||
234 | + maxVal = heatmaps[0][row][col][keypoint] | ||
235 | + maxRow = row | ||
236 | + maxCol = col | ||
237 | + } | ||
238 | + } | ||
239 | + } | ||
240 | + keypointPositions[keypoint] = Pair(maxRow, maxCol) | ||
241 | + } | ||
242 | + | ||
243 | + // Calculating the x and y coordinates of the keypoints with offset adjustment. | ||
244 | + val xCoords = IntArray(numKeypoints) | ||
245 | + val yCoords = IntArray(numKeypoints) | ||
246 | + val confidenceScores = FloatArray(numKeypoints) | ||
247 | + keypointPositions.forEachIndexed { idx, position -> | ||
248 | + val positionY = keypointPositions[idx].first | ||
249 | + val positionX = keypointPositions[idx].second | ||
250 | + yCoords[idx] = ( | ||
251 | + position.first / (height - 1).toFloat() * bitmap.height + | ||
252 | + offsets[0][positionY][positionX][idx] | ||
253 | + ).toInt() | ||
254 | + xCoords[idx] = ( | ||
255 | + position.second / (width - 1).toFloat() * bitmap.width + | ||
256 | + offsets[0][positionY] | ||
257 | + [positionX][idx + numKeypoints] | ||
258 | + ).toInt() | ||
259 | + confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx]) | ||
260 | + } | ||
261 | + | ||
262 | + val person = Person() | ||
263 | + val keypointList = Array(numKeypoints) { KeyPoint() } | ||
264 | + var totalScore = 0.0f | ||
265 | + enumValues<BodyPart>().forEachIndexed { idx, it -> | ||
266 | + keypointList[idx].bodyPart = it | ||
267 | + keypointList[idx].position.x = xCoords[idx] | ||
268 | + keypointList[idx].position.y = yCoords[idx] | ||
269 | + keypointList[idx].score = confidenceScores[idx] | ||
270 | + totalScore += confidenceScores[idx] | ||
271 | + } | ||
272 | + | ||
273 | + person.keyPoints = keypointList.toList() | ||
274 | + person.score = totalScore / numKeypoints | ||
275 | + | ||
276 | + return person | ||
277 | + } | ||
278 | +} |
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules. | 1 | // Top-level build file where you can add configuration options common to all sub-projects/modules. |
2 | 2 | ||
3 | buildscript { | 3 | buildscript { |
4 | + ext.kotlin_version = '1.3.72' | ||
4 | repositories { | 5 | repositories { |
5 | google() | 6 | google() |
6 | jcenter() | 7 | jcenter() |
... | @@ -8,6 +9,7 @@ buildscript { | ... | @@ -8,6 +9,7 @@ buildscript { |
8 | } | 9 | } |
9 | dependencies { | 10 | dependencies { |
10 | classpath 'com.android.tools.build:gradle:3.5.1' | 11 | classpath 'com.android.tools.build:gradle:3.5.1' |
12 | + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" | ||
11 | 13 | ||
12 | // NOTE: Do not place your application dependencies here; they belong | 14 | // NOTE: Do not place your application dependencies here; they belong |
13 | // in the individual module build.gradle files | 15 | // in the individual module build.gradle files | ... | ... |
-
Please register or login to post a comment