awebow
Committed by Ma Suhyeon

Key points detection 구현

1 apply plugin: 'com.android.application' 1 apply plugin: 'com.android.application'
2 +apply plugin: 'kotlin-android-extensions'
3 +apply plugin: 'kotlin-android'
2 4
3 android { 5 android {
4 compileSdkVersion 29 6 compileSdkVersion 29
...@@ -17,6 +19,9 @@ android { ...@@ -17,6 +19,9 @@ android {
17 proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 19 proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
18 } 20 }
19 } 21 }
22 + aaptOptions {
23 + noCompress "tflite"
24 + }
20 } 25 }
21 26
22 dependencies { 27 dependencies {
...@@ -26,4 +31,11 @@ dependencies { ...@@ -26,4 +31,11 @@ dependencies {
26 testImplementation 'junit:junit:4.12' 31 testImplementation 'junit:junit:4.12'
27 androidTestImplementation 'androidx.test.ext:junit:1.1.0' 32 androidTestImplementation 'androidx.test.ext:junit:1.1.0'
28 androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1' 33 androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
34 + compile "androidx.core:core-ktx:+"
35 + implementation 'org.tensorflow:tensorflow-lite:2.2.0'
36 + implementation 'org.tensorflow:tensorflow-lite-gpu:2.2.0'
37 + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
38 +}
39 +repositories {
40 + mavenCentral()
29 } 41 }
......
...@@ -7,25 +7,39 @@ import androidx.core.app.ActivityCompat; ...@@ -7,25 +7,39 @@ import androidx.core.app.ActivityCompat;
7 import android.Manifest; 7 import android.Manifest;
8 import android.content.Context; 8 import android.content.Context;
9 import android.content.pm.PackageManager; 9 import android.content.pm.PackageManager;
10 +import android.graphics.Bitmap;
11 +import android.graphics.Canvas;
12 +import android.graphics.Color;
13 +import android.graphics.ImageFormat;
14 +import android.graphics.Matrix;
15 +import android.graphics.Paint;
16 +import android.graphics.Rect;
10 import android.hardware.camera2.CameraAccessException; 17 import android.hardware.camera2.CameraAccessException;
11 import android.hardware.camera2.CameraCaptureSession; 18 import android.hardware.camera2.CameraCaptureSession;
12 import android.hardware.camera2.CameraDevice; 19 import android.hardware.camera2.CameraDevice;
13 import android.hardware.camera2.CameraManager; 20 import android.hardware.camera2.CameraManager;
14 import android.hardware.camera2.CaptureRequest; 21 import android.hardware.camera2.CaptureRequest;
22 +import android.media.Image;
23 +import android.media.ImageReader;
15 import android.os.Bundle; 24 import android.os.Bundle;
25 +import android.util.Log;
16 import android.view.SurfaceHolder; 26 import android.view.SurfaceHolder;
17 import android.view.SurfaceView; 27 import android.view.SurfaceView;
18 import android.widget.Toast; 28 import android.widget.Toast;
19 29
30 +import java.nio.ByteBuffer;
20 import java.util.Arrays; 31 import java.util.Arrays;
21 32
22 -public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback { 33 +public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback, ImageReader.OnImageAvailableListener {
23 34
24 private static final int REQUEST_CAMERA = 1000; 35 private static final int REQUEST_CAMERA = 1000;
25 36
26 private SurfaceView surfaceView; 37 private SurfaceView surfaceView;
27 private CameraDevice camera; 38 private CameraDevice camera;
28 private CaptureRequest.Builder previewBuilder; 39 private CaptureRequest.Builder previewBuilder;
40 + private Posenet posenet;
41 + private ImageReader imageReader;
42 + private byte[][] yuvBytes = new byte[3][];
29 43
30 @Override 44 @Override
31 protected void onCreate(Bundle savedInstanceState) { 45 protected void onCreate(Bundle savedInstanceState) {
...@@ -36,6 +50,18 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder ...@@ -36,6 +50,18 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
36 surfaceView.getHolder().addCallback(this); 50 surfaceView.getHolder().addCallback(this);
37 } 51 }
38 52
53 + @Override
54 + protected void onStart() {
55 + super.onStart();
56 + posenet = new Posenet(this, "posenet_model.tflite", Device.GPU);
57 + }
58 +
59 + @Override
60 + protected void onDestroy() {
61 + super.onDestroy();
62 + posenet.close();
63 + }
64 +
39 // 카메라 활성화 65 // 카메라 활성화
40 private void openCamera() { 66 private void openCamera() {
41 if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { 67 if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
...@@ -99,10 +125,13 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder ...@@ -99,10 +125,13 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
99 // 카메라 Capture 시작 125 // 카메라 Capture 시작
100 private void startCapture() { 126 private void startCapture() {
101 try { 127 try {
128 + imageReader = ImageReader.newInstance(640, 480, ImageFormat.YUV_420_888, 2);
129 + imageReader.setOnImageAvailableListener(this, null);
130 +
102 previewBuilder = camera.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); 131 previewBuilder = camera.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
103 - previewBuilder.addTarget(surfaceView.getHolder().getSurface()); 132 + previewBuilder.addTarget(imageReader.getSurface());
104 133
105 - camera.createCaptureSession(Arrays.asList(surfaceView.getHolder().getSurface()), new CameraCaptureSession.StateCallback() { 134 + camera.createCaptureSession(Arrays.asList(imageReader.getSurface()), new CameraCaptureSession.StateCallback() {
106 @Override 135 @Override
107 public void onConfigured(@NonNull CameraCaptureSession session) { 136 public void onConfigured(@NonNull CameraCaptureSession session) {
108 try { 137 try {
...@@ -125,4 +154,110 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder ...@@ -125,4 +154,110 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
125 e.printStackTrace(); 154 e.printStackTrace();
126 } 155 }
127 } 156 }
157 +
158 + @Override
159 + public void onImageAvailable(ImageReader reader) {
160 + Image image = reader.acquireLatestImage();
161 + if(image == null)
162 + return;
163 +
164 + fillBytes(image.getPlanes(), yuvBytes);
165 + int[] rgbBytes = new int[640 * 480];
166 +
167 + ImageUtils.INSTANCE.convertYUV420ToARGB8888(yuvBytes[0], yuvBytes[1], yuvBytes[2],
168 + 640, 480,
169 + image.getPlanes()[0].getRowStride(),
170 + image.getPlanes()[1].getRowStride(),
171 + image.getPlanes()[1].getPixelStride(), rgbBytes);
172 +
173 + Bitmap imageBitmap = Bitmap.createBitmap(rgbBytes, 640, 480, Bitmap.Config.ARGB_8888);
174 +
175 + Matrix rotateMatrix = new Matrix();
176 + rotateMatrix.postRotate(90);
177 +
178 + Bitmap rotatedBitmap = Bitmap.createBitmap(imageBitmap,
179 + 0, 0, 640, 480, rotateMatrix, true);
180 + image.close();
181 +
182 + processImage(rotatedBitmap);
183 + }
184 +
185 + private void fillBytes(Image.Plane[] planes, byte[][] yuvBytes) {
186 + // Row stride is the total number of bytes occupied in memory by a row of an image.
187 + // Because of the variable row stride it's not possible to know in
188 + // advance the actual necessary dimensions of the yuv planes.
189 + for (int i = 0; i < planes.length; i++) {
190 + ByteBuffer buffer = planes[i].getBuffer();
191 + if (yuvBytes[i] == null) {
192 + yuvBytes[i] = new byte[buffer.capacity()];
193 + }
194 + buffer.get(yuvBytes[i]);
195 + }
196 + }
197 +
198 + private Bitmap cropBitmap(Bitmap bitmap) {
199 + float bitmapRatio = (float) bitmap.getHeight() / bitmap.getWidth();
200 + float modelInputRatio = 257.0f / 257.0f;
201 + Bitmap croppedBitmap = bitmap;
202 +
203 + // Acceptable difference between the modelInputRatio and bitmapRatio to skip cropping.
204 + double maxDifference = 1e-5;
205 +
206 + // Checks if the bitmap has similar aspect ratio as the required model input.
207 + if(Math.abs(modelInputRatio - bitmapRatio) < maxDifference)
208 + return croppedBitmap;
209 +
210 + if(modelInputRatio < bitmapRatio) {
211 + // New image is taller so we are height constrained.
212 + float cropHeight = bitmap.getHeight() - bitmap.getWidth() / modelInputRatio;
213 + croppedBitmap = Bitmap.createBitmap(
214 + bitmap,
215 + 0,
216 + (int) cropHeight / 2,
217 + bitmap.getWidth(),
218 + (int) (bitmap.getHeight() - cropHeight)
219 + );
220 + }
221 + else {
222 + float cropWidth = bitmap.getWidth() - bitmap.getHeight() * modelInputRatio;
223 + croppedBitmap = Bitmap.createBitmap(
224 + bitmap,
225 + (int) (cropWidth / 2),
226 + 0,
227 + (int) (bitmap.getWidth() - cropWidth),
228 + bitmap.getHeight()
229 + );
230 + }
231 + return croppedBitmap;
232 + }
233 +
234 + private void processImage(Bitmap bitmap) {
235 + Log.d("Capture", "Process");
236 +
237 + // Crop bitmap.
238 + Bitmap croppedBitmap = cropBitmap(bitmap);
239 +
240 + // Created scaled version of bitmap for model input.
241 + Bitmap scaledBitmap = Bitmap.createScaledBitmap(croppedBitmap, 257, 257, true);
242 +
243 + // Perform inference.
244 + Person person = posenet.estimateSinglePose(scaledBitmap);
245 +
246 + Paint paint = new Paint();
247 + Canvas canvas = surfaceView.getHolder().lockCanvas();
248 +
249 + // 이미지 그리기
250 + canvas.drawBitmap(croppedBitmap, new Rect(0, 0, croppedBitmap.getWidth(), croppedBitmap.getHeight()), new Rect(0, 0, canvas.getWidth(), canvas.getWidth()), paint);
251 +
252 + // Key points 그리기
253 + paint.setColor(Color.RED);
254 + for(KeyPoint keyPoint : person.getKeyPoints()) {
255 + if(keyPoint.getScore() < 0.7)
256 + continue;
257 +
258 + canvas.drawCircle((float) keyPoint.getPosition().getX() / scaledBitmap.getWidth() * canvas.getWidth(), (float) keyPoint.getPosition().getY() / scaledBitmap.getWidth() * canvas.getWidth(), 5, paint);
259 + }
260 +
261 + surfaceView.getHolder().unlockCanvasAndPost(canvas);
262 + }
128 } 263 }
......
1 +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package com.khuhacker.pocketgym
17 +
18 +/** Utility class for manipulating images. */
19 +object ImageUtils {
20 + // This value is 2 ^ 18 - 1, and is used to hold the RGB values together before their ranges
21 + // are normalized to eight bits.
22 + private const val MAX_CHANNEL_VALUE = 262143
23 +
24 + /** Helper function to convert y,u,v integer values to RGB format */
25 + private fun convertYUVToRGB(y: Int, u: Int, v: Int): Int {
26 + // Adjust and check YUV values
27 + val yNew = if (y - 16 < 0) 0 else y - 16
28 + val uNew = u - 128
29 + val vNew = v - 128
30 + val expandY = 1192 * yNew
31 + var r = expandY + 1634 * vNew
32 + var g = expandY - 833 * vNew - 400 * uNew
33 + var b = expandY + 2066 * uNew
34 +
35 + // Clipping RGB values to be inside boundaries [ 0 , MAX_CHANNEL_VALUE ]
36 + val checkBoundaries = { x: Int ->
37 + when {
38 + x > MAX_CHANNEL_VALUE -> MAX_CHANNEL_VALUE
39 + x < 0 -> 0
40 + else -> x
41 + }
42 + }
43 + r = checkBoundaries(r)
44 + g = checkBoundaries(g)
45 + b = checkBoundaries(b)
46 + return -0x1000000 or (r shl 6 and 0xff0000) or (g shr 2 and 0xff00) or (b shr 10 and 0xff)
47 + }
48 +
49 + /** Converts YUV420 format image data (ByteArray) into ARGB8888 format with IntArray as output. */
50 + fun convertYUV420ToARGB8888(
51 + yData: ByteArray,
52 + uData: ByteArray,
53 + vData: ByteArray,
54 + width: Int,
55 + height: Int,
56 + yRowStride: Int,
57 + uvRowStride: Int,
58 + uvPixelStride: Int,
59 + out: IntArray
60 + ) {
61 + var outputIndex = 0
62 + for (j in 0 until height) {
63 + val positionY = yRowStride * j
64 + val positionUV = uvRowStride * (j shr 1)
65 +
66 + for (i in 0 until width) {
67 + val uvOffset = positionUV + (i shr 1) * uvPixelStride
68 +
69 + // "0xff and" is used to cut off bits from following value that are higher than
70 + // the low 8 bits
71 + out[outputIndex] = convertYUVToRGB(
72 + 0xff and yData[positionY + i].toInt(), 0xff and uData[uvOffset].toInt(),
73 + 0xff and vData[uvOffset].toInt()
74 + )
75 + outputIndex += 1
76 + }
77 + }
78 + }
79 +}
1 +/*
2 + * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +package com.khuhacker.pocketgym
17 +
18 +import android.content.Context
19 +import android.graphics.Bitmap
20 +import android.os.SystemClock
21 +import android.util.Log
22 +import java.io.FileInputStream
23 +import java.nio.ByteBuffer
24 +import java.nio.ByteOrder
25 +import java.nio.MappedByteBuffer
26 +import java.nio.channels.FileChannel
27 +import kotlin.math.exp
28 +import org.tensorflow.lite.Interpreter
29 +import org.tensorflow.lite.gpu.GpuDelegate
30 +
31 +enum class BodyPart {
32 + NOSE,
33 + LEFT_EYE,
34 + RIGHT_EYE,
35 + LEFT_EAR,
36 + RIGHT_EAR,
37 + LEFT_SHOULDER,
38 + RIGHT_SHOULDER,
39 + LEFT_ELBOW,
40 + RIGHT_ELBOW,
41 + LEFT_WRIST,
42 + RIGHT_WRIST,
43 + LEFT_HIP,
44 + RIGHT_HIP,
45 + LEFT_KNEE,
46 + RIGHT_KNEE,
47 + LEFT_ANKLE,
48 + RIGHT_ANKLE
49 +}
50 +
51 +class Position {
52 + var x: Int = 0
53 + var y: Int = 0
54 +}
55 +
56 +class KeyPoint {
57 + var bodyPart: BodyPart = BodyPart.NOSE
58 + var position: Position = Position()
59 + var score: Float = 0.0f
60 +}
61 +
62 +class Person {
63 + var keyPoints = listOf<KeyPoint>()
64 + var score: Float = 0.0f
65 +}
66 +
67 +enum class Device {
68 + CPU,
69 + NNAPI,
70 + GPU
71 +}
72 +
73 +class Posenet(
74 + val context: Context,
75 + val filename: String = "posenet_model.tflite",
76 + val device: Device = Device.GPU
77 +) : AutoCloseable {
78 + var lastInferenceTimeNanos: Long = -1
79 + private set
80 +
81 + /** An Interpreter for the TFLite model. */
82 + private var interpreter: Interpreter? = null
83 + private var gpuDelegate: GpuDelegate? = null
84 + private val NUM_LITE_THREADS = 4
85 +
86 + private fun getInterpreter(): Interpreter {
87 + if (interpreter != null) {
88 + return interpreter!!
89 + }
90 + val options = Interpreter.Options()
91 + options.setNumThreads(NUM_LITE_THREADS)
92 + when (device) {
93 + Device.CPU -> { }
94 + Device.GPU -> {
95 + gpuDelegate = GpuDelegate()
96 + options.addDelegate(gpuDelegate)
97 + }
98 + Device.NNAPI -> options.setUseNNAPI(true)
99 + }
100 + interpreter = Interpreter(loadModelFile(filename, context), options)
101 + return interpreter!!
102 + }
103 +
104 + override fun close() {
105 + interpreter?.close()
106 + interpreter = null
107 + gpuDelegate?.close()
108 + gpuDelegate = null
109 + }
110 +
111 + /** Returns value within [0,1]. */
112 + private fun sigmoid(x: Float): Float {
113 + return (1.0f / (1.0f + exp(-x)))
114 + }
115 +
116 + /**
117 + * Scale the image to a byteBuffer of [-1,1] values.
118 + */
119 + private fun initInputArray(bitmap: Bitmap): ByteBuffer {
120 + val bytesPerChannel = 4
121 + val inputChannels = 3
122 + val batchSize = 1
123 + val inputBuffer = ByteBuffer.allocateDirect(
124 + batchSize * bytesPerChannel * bitmap.height * bitmap.width * inputChannels
125 + )
126 + inputBuffer.order(ByteOrder.nativeOrder())
127 + inputBuffer.rewind()
128 +
129 + val mean = 128.0f
130 + val std = 128.0f
131 + for (row in 0 until bitmap.height) {
132 + for (col in 0 until bitmap.width) {
133 + val pixelValue = bitmap.getPixel(col, row)
134 + inputBuffer.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
135 + inputBuffer.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
136 + inputBuffer.putFloat(((pixelValue and 0xFF) - mean) / std)
137 + }
138 + }
139 + return inputBuffer
140 + }
141 +
142 + /** Preload and memory map the model file, returning a MappedByteBuffer containing the model. */
143 + private fun loadModelFile(path: String, context: Context): MappedByteBuffer {
144 + val fileDescriptor = context.assets.openFd(path)
145 + val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
146 + return inputStream.channel.map(
147 + FileChannel.MapMode.READ_ONLY, fileDescriptor.startOffset, fileDescriptor.declaredLength
148 + )
149 + }
150 +
151 + /**
152 + * Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
153 + */
154 + private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
155 + val outputMap = HashMap<Int, Any>()
156 +
157 + // 1 * 9 * 9 * 17 contains heatmaps
158 + val heatmapsShape = interpreter.getOutputTensor(0).shape()
159 + outputMap[0] = Array(heatmapsShape[0]) {
160 + Array(heatmapsShape[1]) {
161 + Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) }
162 + }
163 + }
164 +
165 + // 1 * 9 * 9 * 34 contains offsets
166 + val offsetsShape = interpreter.getOutputTensor(1).shape()
167 + outputMap[1] = Array(offsetsShape[0]) {
168 + Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } }
169 + }
170 +
171 + // 1 * 9 * 9 * 32 contains forward displacements
172 + val displacementsFwdShape = interpreter.getOutputTensor(2).shape()
173 + outputMap[2] = Array(offsetsShape[0]) {
174 + Array(displacementsFwdShape[1]) {
175 + Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) }
176 + }
177 + }
178 +
179 + // 1 * 9 * 9 * 32 contains backward displacements
180 + val displacementsBwdShape = interpreter.getOutputTensor(3).shape()
181 + outputMap[3] = Array(displacementsBwdShape[0]) {
182 + Array(displacementsBwdShape[1]) {
183 + Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) }
184 + }
185 + }
186 +
187 + return outputMap
188 + }
189 +
190 + /**
191 + * Estimates the pose for a single person.
192 + * args:
193 + * bitmap: image bitmap of frame that should be processed
194 + * returns:
195 + * person: a Person object containing data about keypoint locations and confidence scores
196 + */
197 + fun estimateSinglePose(bitmap: Bitmap): Person {
198 + val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
199 + val inputArray = arrayOf(initInputArray(bitmap))
200 + Log.i(
201 + "posenet",
202 + String.format(
203 + "Scaling to [-1,1] took %.2f ms",
204 + 1.0f * (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000
205 + )
206 + )
207 +
208 + val outputMap = initOutputMap(getInterpreter())
209 +
210 + val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
211 + getInterpreter().runForMultipleInputsOutputs(inputArray, outputMap)
212 + lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
213 + Log.i(
214 + "posenet",
215 + String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
216 + )
217 +
218 + val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
219 + val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>
220 +
221 + val height = heatmaps[0].size
222 + val width = heatmaps[0][0].size
223 + val numKeypoints = heatmaps[0][0][0].size
224 +
225 + // Finds the (row, col) locations of where the keypoints are most likely to be.
226 + val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
227 + for (keypoint in 0 until numKeypoints) {
228 + var maxVal = heatmaps[0][0][0][keypoint]
229 + var maxRow = 0
230 + var maxCol = 0
231 + for (row in 0 until height) {
232 + for (col in 0 until width) {
233 + if (heatmaps[0][row][col][keypoint] > maxVal) {
234 + maxVal = heatmaps[0][row][col][keypoint]
235 + maxRow = row
236 + maxCol = col
237 + }
238 + }
239 + }
240 + keypointPositions[keypoint] = Pair(maxRow, maxCol)
241 + }
242 +
243 + // Calculating the x and y coordinates of the keypoints with offset adjustment.
244 + val xCoords = IntArray(numKeypoints)
245 + val yCoords = IntArray(numKeypoints)
246 + val confidenceScores = FloatArray(numKeypoints)
247 + keypointPositions.forEachIndexed { idx, position ->
248 + val positionY = keypointPositions[idx].first
249 + val positionX = keypointPositions[idx].second
250 + yCoords[idx] = (
251 + position.first / (height - 1).toFloat() * bitmap.height +
252 + offsets[0][positionY][positionX][idx]
253 + ).toInt()
254 + xCoords[idx] = (
255 + position.second / (width - 1).toFloat() * bitmap.width +
256 + offsets[0][positionY]
257 + [positionX][idx + numKeypoints]
258 + ).toInt()
259 + confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
260 + }
261 +
262 + val person = Person()
263 + val keypointList = Array(numKeypoints) { KeyPoint() }
264 + var totalScore = 0.0f
265 + enumValues<BodyPart>().forEachIndexed { idx, it ->
266 + keypointList[idx].bodyPart = it
267 + keypointList[idx].position.x = xCoords[idx]
268 + keypointList[idx].position.y = yCoords[idx]
269 + keypointList[idx].score = confidenceScores[idx]
270 + totalScore += confidenceScores[idx]
271 + }
272 +
273 + person.keyPoints = keypointList.toList()
274 + person.score = totalScore / numKeypoints
275 +
276 + return person
277 + }
278 +}
1 // Top-level build file where you can add configuration options common to all sub-projects/modules. 1 // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 2
3 buildscript { 3 buildscript {
4 + ext.kotlin_version = '1.3.72'
4 repositories { 5 repositories {
5 google() 6 google()
6 jcenter() 7 jcenter()
...@@ -8,6 +9,7 @@ buildscript { ...@@ -8,6 +9,7 @@ buildscript {
8 } 9 }
9 dependencies { 10 dependencies {
10 classpath 'com.android.tools.build:gradle:3.5.1' 11 classpath 'com.android.tools.build:gradle:3.5.1'
12 + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
11 13
12 // NOTE: Do not place your application dependencies here; they belong 14 // NOTE: Do not place your application dependencies here; they belong
13 // in the individual module build.gradle files 15 // in the individual module build.gradle files
......