awebow
Committed by Ma Suhyeon

Key points detection 구현

apply plugin: 'com.android.application'
apply plugin: 'kotlin-android-extensions'
apply plugin: 'kotlin-android'
android {
compileSdkVersion 29
......@@ -17,6 +19,9 @@ android {
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
}
dependencies {
......@@ -26,4 +31,11 @@ dependencies {
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
compile "androidx.core:core-ktx:+"
implementation 'org.tensorflow:tensorflow-lite:2.2.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.2.0'
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
}
repositories {
mavenCentral()
}
......
......@@ -7,25 +7,39 @@ import androidx.core.app.ActivityCompat;
import android.Manifest;
import android.content.Context;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Rect;
import android.hardware.camera2.CameraAccessException;
import android.hardware.camera2.CameraCaptureSession;
import android.hardware.camera2.CameraDevice;
import android.hardware.camera2.CameraManager;
import android.hardware.camera2.CaptureRequest;
import android.media.Image;
import android.media.ImageReader;
import android.os.Bundle;
import android.util.Log;
import android.view.SurfaceHolder;
import android.view.SurfaceView;
import android.widget.Toast;
import java.nio.ByteBuffer;
import java.util.Arrays;
public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback {
public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder.Callback, ImageReader.OnImageAvailableListener {
private static final int REQUEST_CAMERA = 1000;
private SurfaceView surfaceView;
private CameraDevice camera;
private CaptureRequest.Builder previewBuilder;
private Posenet posenet;
private ImageReader imageReader;
private byte[][] yuvBytes = new byte[3][];
@Override
protected void onCreate(Bundle savedInstanceState) {
......@@ -36,6 +50,18 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
surfaceView.getHolder().addCallback(this);
}
@Override
protected void onStart() {
super.onStart();
posenet = new Posenet(this, "posenet_model.tflite", Device.GPU);
}
@Override
protected void onDestroy() {
super.onDestroy();
posenet.close();
}
// 카메라 활성화
private void openCamera() {
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
......@@ -99,10 +125,13 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
// 카메라 Capture 시작
private void startCapture() {
try {
imageReader = ImageReader.newInstance(640, 480, ImageFormat.YUV_420_888, 2);
imageReader.setOnImageAvailableListener(this, null);
previewBuilder = camera.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
previewBuilder.addTarget(surfaceView.getHolder().getSurface());
previewBuilder.addTarget(imageReader.getSurface());
camera.createCaptureSession(Arrays.asList(surfaceView.getHolder().getSurface()), new CameraCaptureSession.StateCallback() {
camera.createCaptureSession(Arrays.asList(imageReader.getSurface()), new CameraCaptureSession.StateCallback() {
@Override
public void onConfigured(@NonNull CameraCaptureSession session) {
try {
......@@ -125,4 +154,110 @@ public class ExerciseActivity extends AppCompatActivity implements SurfaceHolder
e.printStackTrace();
}
}
@Override
public void onImageAvailable(ImageReader reader) {
Image image = reader.acquireLatestImage();
if(image == null)
return;
fillBytes(image.getPlanes(), yuvBytes);
int[] rgbBytes = new int[640 * 480];
ImageUtils.INSTANCE.convertYUV420ToARGB8888(yuvBytes[0], yuvBytes[1], yuvBytes[2],
640, 480,
image.getPlanes()[0].getRowStride(),
image.getPlanes()[1].getRowStride(),
image.getPlanes()[1].getPixelStride(), rgbBytes);
Bitmap imageBitmap = Bitmap.createBitmap(rgbBytes, 640, 480, Bitmap.Config.ARGB_8888);
Matrix rotateMatrix = new Matrix();
rotateMatrix.postRotate(90);
Bitmap rotatedBitmap = Bitmap.createBitmap(imageBitmap,
0, 0, 640, 480, rotateMatrix, true);
image.close();
processImage(rotatedBitmap);
}
private void fillBytes(Image.Plane[] planes, byte[][] yuvBytes) {
// Row stride is the total number of bytes occupied in memory by a row of an image.
// Because of the variable row stride it's not possible to know in
// advance the actual necessary dimensions of the yuv planes.
for (int i = 0; i < planes.length; i++) {
ByteBuffer buffer = planes[i].getBuffer();
if (yuvBytes[i] == null) {
yuvBytes[i] = new byte[buffer.capacity()];
}
buffer.get(yuvBytes[i]);
}
}
private Bitmap cropBitmap(Bitmap bitmap) {
float bitmapRatio = (float) bitmap.getHeight() / bitmap.getWidth();
float modelInputRatio = 257.0f / 257.0f;
Bitmap croppedBitmap = bitmap;
// Acceptable difference between the modelInputRatio and bitmapRatio to skip cropping.
double maxDifference = 1e-5;
// Checks if the bitmap has similar aspect ratio as the required model input.
if(Math.abs(modelInputRatio - bitmapRatio) < maxDifference)
return croppedBitmap;
if(modelInputRatio < bitmapRatio) {
// New image is taller so we are height constrained.
float cropHeight = bitmap.getHeight() - bitmap.getWidth() / modelInputRatio;
croppedBitmap = Bitmap.createBitmap(
bitmap,
0,
(int) cropHeight / 2,
bitmap.getWidth(),
(int) (bitmap.getHeight() - cropHeight)
);
}
else {
float cropWidth = bitmap.getWidth() - bitmap.getHeight() * modelInputRatio;
croppedBitmap = Bitmap.createBitmap(
bitmap,
(int) (cropWidth / 2),
0,
(int) (bitmap.getWidth() - cropWidth),
bitmap.getHeight()
);
}
return croppedBitmap;
}
private void processImage(Bitmap bitmap) {
Log.d("Capture", "Process");
// Crop bitmap.
Bitmap croppedBitmap = cropBitmap(bitmap);
// Created scaled version of bitmap for model input.
Bitmap scaledBitmap = Bitmap.createScaledBitmap(croppedBitmap, 257, 257, true);
// Perform inference.
Person person = posenet.estimateSinglePose(scaledBitmap);
Paint paint = new Paint();
Canvas canvas = surfaceView.getHolder().lockCanvas();
// 이미지 그리기
canvas.drawBitmap(croppedBitmap, new Rect(0, 0, croppedBitmap.getWidth(), croppedBitmap.getHeight()), new Rect(0, 0, canvas.getWidth(), canvas.getWidth()), paint);
// Key points 그리기
paint.setColor(Color.RED);
for(KeyPoint keyPoint : person.getKeyPoints()) {
if(keyPoint.getScore() < 0.7)
continue;
canvas.drawCircle((float) keyPoint.getPosition().getX() / scaledBitmap.getWidth() * canvas.getWidth(), (float) keyPoint.getPosition().getY() / scaledBitmap.getWidth() * canvas.getWidth(), 5, paint);
}
surfaceView.getHolder().unlockCanvasAndPost(canvas);
}
}
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.khuhacker.pocketgym
/** Utility class for manipulating images. */
object ImageUtils {
// This value is 2 ^ 18 - 1, and is used to hold the RGB values together before their ranges
// are normalized to eight bits.
private const val MAX_CHANNEL_VALUE = 262143
/** Helper function to convert y,u,v integer values to RGB format */
private fun convertYUVToRGB(y: Int, u: Int, v: Int): Int {
// Adjust and check YUV values
val yNew = if (y - 16 < 0) 0 else y - 16
val uNew = u - 128
val vNew = v - 128
val expandY = 1192 * yNew
var r = expandY + 1634 * vNew
var g = expandY - 833 * vNew - 400 * uNew
var b = expandY + 2066 * uNew
// Clipping RGB values to be inside boundaries [ 0 , MAX_CHANNEL_VALUE ]
val checkBoundaries = { x: Int ->
when {
x > MAX_CHANNEL_VALUE -> MAX_CHANNEL_VALUE
x < 0 -> 0
else -> x
}
}
r = checkBoundaries(r)
g = checkBoundaries(g)
b = checkBoundaries(b)
return -0x1000000 or (r shl 6 and 0xff0000) or (g shr 2 and 0xff00) or (b shr 10 and 0xff)
}
/** Converts YUV420 format image data (ByteArray) into ARGB8888 format with IntArray as output. */
fun convertYUV420ToARGB8888(
yData: ByteArray,
uData: ByteArray,
vData: ByteArray,
width: Int,
height: Int,
yRowStride: Int,
uvRowStride: Int,
uvPixelStride: Int,
out: IntArray
) {
var outputIndex = 0
for (j in 0 until height) {
val positionY = yRowStride * j
val positionUV = uvRowStride * (j shr 1)
for (i in 0 until width) {
val uvOffset = positionUV + (i shr 1) * uvPixelStride
// "0xff and" is used to cut off bits from following value that are higher than
// the low 8 bits
out[outputIndex] = convertYUVToRGB(
0xff and yData[positionY + i].toInt(), 0xff and uData[uvOffset].toInt(),
0xff and vData[uvOffset].toInt()
)
outputIndex += 1
}
}
}
}
/*
* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.khuhacker.pocketgym
import android.content.Context
import android.graphics.Bitmap
import android.os.SystemClock
import android.util.Log
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import kotlin.math.exp
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.GpuDelegate
enum class BodyPart {
NOSE,
LEFT_EYE,
RIGHT_EYE,
LEFT_EAR,
RIGHT_EAR,
LEFT_SHOULDER,
RIGHT_SHOULDER,
LEFT_ELBOW,
RIGHT_ELBOW,
LEFT_WRIST,
RIGHT_WRIST,
LEFT_HIP,
RIGHT_HIP,
LEFT_KNEE,
RIGHT_KNEE,
LEFT_ANKLE,
RIGHT_ANKLE
}
class Position {
var x: Int = 0
var y: Int = 0
}
class KeyPoint {
var bodyPart: BodyPart = BodyPart.NOSE
var position: Position = Position()
var score: Float = 0.0f
}
class Person {
var keyPoints = listOf<KeyPoint>()
var score: Float = 0.0f
}
enum class Device {
CPU,
NNAPI,
GPU
}
class Posenet(
val context: Context,
val filename: String = "posenet_model.tflite",
val device: Device = Device.GPU
) : AutoCloseable {
var lastInferenceTimeNanos: Long = -1
private set
/** An Interpreter for the TFLite model. */
private var interpreter: Interpreter? = null
private var gpuDelegate: GpuDelegate? = null
private val NUM_LITE_THREADS = 4
private fun getInterpreter(): Interpreter {
if (interpreter != null) {
return interpreter!!
}
val options = Interpreter.Options()
options.setNumThreads(NUM_LITE_THREADS)
when (device) {
Device.CPU -> { }
Device.GPU -> {
gpuDelegate = GpuDelegate()
options.addDelegate(gpuDelegate)
}
Device.NNAPI -> options.setUseNNAPI(true)
}
interpreter = Interpreter(loadModelFile(filename, context), options)
return interpreter!!
}
override fun close() {
interpreter?.close()
interpreter = null
gpuDelegate?.close()
gpuDelegate = null
}
/** Returns value within [0,1]. */
private fun sigmoid(x: Float): Float {
return (1.0f / (1.0f + exp(-x)))
}
/**
* Scale the image to a byteBuffer of [-1,1] values.
*/
private fun initInputArray(bitmap: Bitmap): ByteBuffer {
val bytesPerChannel = 4
val inputChannels = 3
val batchSize = 1
val inputBuffer = ByteBuffer.allocateDirect(
batchSize * bytesPerChannel * bitmap.height * bitmap.width * inputChannels
)
inputBuffer.order(ByteOrder.nativeOrder())
inputBuffer.rewind()
val mean = 128.0f
val std = 128.0f
for (row in 0 until bitmap.height) {
for (col in 0 until bitmap.width) {
val pixelValue = bitmap.getPixel(col, row)
inputBuffer.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
inputBuffer.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
inputBuffer.putFloat(((pixelValue and 0xFF) - mean) / std)
}
}
return inputBuffer
}
/** Preload and memory map the model file, returning a MappedByteBuffer containing the model. */
private fun loadModelFile(path: String, context: Context): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(path)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
return inputStream.channel.map(
FileChannel.MapMode.READ_ONLY, fileDescriptor.startOffset, fileDescriptor.declaredLength
)
}
/**
* Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
*/
private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
val outputMap = HashMap<Int, Any>()
// 1 * 9 * 9 * 17 contains heatmaps
val heatmapsShape = interpreter.getOutputTensor(0).shape()
outputMap[0] = Array(heatmapsShape[0]) {
Array(heatmapsShape[1]) {
Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) }
}
}
// 1 * 9 * 9 * 34 contains offsets
val offsetsShape = interpreter.getOutputTensor(1).shape()
outputMap[1] = Array(offsetsShape[0]) {
Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } }
}
// 1 * 9 * 9 * 32 contains forward displacements
val displacementsFwdShape = interpreter.getOutputTensor(2).shape()
outputMap[2] = Array(offsetsShape[0]) {
Array(displacementsFwdShape[1]) {
Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) }
}
}
// 1 * 9 * 9 * 32 contains backward displacements
val displacementsBwdShape = interpreter.getOutputTensor(3).shape()
outputMap[3] = Array(displacementsBwdShape[0]) {
Array(displacementsBwdShape[1]) {
Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) }
}
}
return outputMap
}
/**
* Estimates the pose for a single person.
* args:
* bitmap: image bitmap of frame that should be processed
* returns:
* person: a Person object containing data about keypoint locations and confidence scores
*/
fun estimateSinglePose(bitmap: Bitmap): Person {
val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
val inputArray = arrayOf(initInputArray(bitmap))
Log.i(
"posenet",
String.format(
"Scaling to [-1,1] took %.2f ms",
1.0f * (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000
)
)
val outputMap = initOutputMap(getInterpreter())
val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
getInterpreter().runForMultipleInputsOutputs(inputArray, outputMap)
lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
Log.i(
"posenet",
String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
)
val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>
val height = heatmaps[0].size
val width = heatmaps[0][0].size
val numKeypoints = heatmaps[0][0][0].size
// Finds the (row, col) locations of where the keypoints are most likely to be.
val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
for (keypoint in 0 until numKeypoints) {
var maxVal = heatmaps[0][0][0][keypoint]
var maxRow = 0
var maxCol = 0
for (row in 0 until height) {
for (col in 0 until width) {
if (heatmaps[0][row][col][keypoint] > maxVal) {
maxVal = heatmaps[0][row][col][keypoint]
maxRow = row
maxCol = col
}
}
}
keypointPositions[keypoint] = Pair(maxRow, maxCol)
}
// Calculating the x and y coordinates of the keypoints with offset adjustment.
val xCoords = IntArray(numKeypoints)
val yCoords = IntArray(numKeypoints)
val confidenceScores = FloatArray(numKeypoints)
keypointPositions.forEachIndexed { idx, position ->
val positionY = keypointPositions[idx].first
val positionX = keypointPositions[idx].second
yCoords[idx] = (
position.first / (height - 1).toFloat() * bitmap.height +
offsets[0][positionY][positionX][idx]
).toInt()
xCoords[idx] = (
position.second / (width - 1).toFloat() * bitmap.width +
offsets[0][positionY]
[positionX][idx + numKeypoints]
).toInt()
confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
}
val person = Person()
val keypointList = Array(numKeypoints) { KeyPoint() }
var totalScore = 0.0f
enumValues<BodyPart>().forEachIndexed { idx, it ->
keypointList[idx].bodyPart = it
keypointList[idx].position.x = xCoords[idx]
keypointList[idx].position.y = yCoords[idx]
keypointList[idx].score = confidenceScores[idx]
totalScore += confidenceScores[idx]
}
person.keyPoints = keypointList.toList()
person.score = totalScore / numKeypoints
return person
}
}
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
ext.kotlin_version = '1.3.72'
repositories {
google()
jcenter()
......@@ -8,7 +9,8 @@ buildscript {
}
dependencies {
classpath 'com.android.tools.build:gradle:3.5.1'
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
......