본 글은 Android - java 기본지식과 tensorflow 기본지식이 있음을 전제로 합니다.
텐서플로우 모델을 안드로이드에서 사용하기 위해서는
0. (모델구현 및 가중치파일 생성)
1. 모델 가중치 변환
2. 안드로이드 설정
3. 안드로이드 텐서플로우 코드 구현
의 과정이 필요합니다.
0번째의 과정은 진행이 되었다 하고 설명을 시작합니다.
<최종 tflite 의 tensorflow 버전은 2.0.0 임을 전제합니다>
1. 텐서플로우 가중치를 tflite 형식으로 변환해야합니다. 변환 방법은 몇가지가 있으며 아래의 그림과 같습니다.
각각의 변환 코드는 https://wikidocs.net/37704 위키글을 참조해주세요. (모델 구현부터 시작하신다면 tensorflow>=2.0.0 으로 구현하시길 추천합니다)
2. 안드로이드의 tensorflow 사용설정은 Moudle의 gradle에 다음처럼 코드를 추가 해주시면 됩니다.
3. 안드로이드 코드 구현은 몇가지가 있습니다 차례대로 설명하면
- 모델 로드 함수 구현(seriealize 된 tflite 파일을 읽어오기 위함)
- 모델 인터프리터 함수 구현(inference를 하기위함)
- 카메라 사용 및 이미지 처리 함수구현(본 글에서는 cats_and_dogs 모델을 기준으로 합니다
cats and dogs?
자세한 설명을 그림으로 첨부합니다.
MainActivity.java
변수선언 & button onClick함수작성, 카메라실행 코드 작성

MainActivity.loadModelFile
MainActivity.getTfliteInterpreter
MainActivity.getInputImage
MainActivity.onActivityResult (override method)
Code Details
MainActivity.java
package com.example.cats_dogs_classifier;
import androidx.appcompat.app.AppCompatActivity;
import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.ThumbnailUtils;
import android.os.Bundle;
import android.content.Intent;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.Toast;
import org.tensorflow.lite.Interpreter;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
public class MainActivity extends AppCompatActivity {
private Button btnCapture;
private ImageView imgCapture;
private static final int Image_Capture_Code = 1;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
btnCapture =(Button)findViewById(R.id.btnTakePicture);
imgCapture = (ImageView) findViewById(R.id.capturedImage);
btnCapture.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
Intent cInt = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(cInt,Image_Capture_Code);
}
});
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
if (requestCode == Image_Capture_Code) {
if (resultCode == RESULT_OK) {
Bitmap bp = (Bitmap) data.getExtras().get("data");
bp = RotateBitmap(bp, 90);
int cx = 128, cy = 128;
bp = Bitmap.createScaledBitmap(bp, cx, cy, false);
int[] pixels = new int[cx * cy];
bp.getPixels(pixels, 0, cx, 0, 0, cx, cy);
ByteBuffer input_img = getInputImage_2(pixels, cx, cy);
Interpreter tf_lite = getTfliteInterpreter("cats_and_dogs.tflite");
float[][] pred = new float[1][2];
tf_lite.run(input_img, pred);
final String predText = String.format("%f", pred[0][0]) + String.format("%f", pred[0][1]);
if(pred[0][0] > pred[0][1]){
Toast toast = Toast.makeText(getApplicationContext(), String.format("고양이 확률 : %f", pred[0][0]), Toast.LENGTH_LONG); toast.show();
}else{
Toast toast = Toast.makeText(getApplicationContext(), String.format("강아지 확률 : %f", pred[0][1]), Toast.LENGTH_LONG); toast.show();
}
Log.d("prediction", predText);
imgCapture.setImageBitmap(bp);
} else if (resultCode == RESULT_CANCELED) {
Toast.makeText(this, "Cancelled", Toast.LENGTH_LONG).show();
}
}
}
public static Bitmap RotateBitmap(Bitmap source, float angle)
{
Matrix matrix = new Matrix();
matrix.postRotate(angle);
Bitmap temp = Bitmap.createBitmap(source, 0, 0, source.getWidth(), source.getHeight(), matrix, true);
temp = ThumbnailUtils.extractThumbnail(temp, 1080, 1080);
return temp;
}
// 다루기 편한 1차원 배열 사용
private ByteBuffer getInputImage_2(int[] pixels, int cx, int cy) {
ByteBuffer input_img = ByteBuffer.allocateDirect(cx * cy * 3 * 4);
input_img.order(ByteOrder.nativeOrder());
for (int i = 0; i < cx * cy; i++) {
int pixel = pixels[i]; // ARGB : ff4e2a2a
input_img.putFloat(((pixel >> 16) & 0xff) / (float) 255);
input_img.putFloat(((pixel >> 8) & 0xff) / (float) 255);
input_img.putFloat(((pixel >> 0) & 0xff) / (float) 255);
}
return input_img;
}
// 모델 파일 인터프리터를 생성하는 공통 함수
// loadModelFile 함수에 예외가 포함되어 있기 때문에 반드시 try, catch 블록이 필요하다.
private Interpreter getTfliteInterpreter(String modelPath) {
try {
return new Interpreter(loadModelFile(MainActivity.this, modelPath));
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
// 모델을 읽어오는 함수로, 텐서플로 라이트 홈페이지에 있다.
// MappedByteBuffer 바이트 버퍼를 Interpreter 객체에 전달하면 모델 해석을 할 수 있다.
private MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
activity_main.xml
<?xml version="1.0" encoding="utf-8"?> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" android:layout_width="match_parent" android:layout_height="match_parent" android:paddingLeft="10dp" android:orientation="vertical" android:paddingRight="10dp"> <Button android:id="@+id/btnTakePicture" android:layout_width="wrap_content" android:layout_height="wrap_content" android:text="Take a Photo" android:textStyle="bold" android:layout_centerHorizontal="true" android:layout_alignParentBottom="true" /> <ImageView android:layout_width="fill_parent" android:layout_height="fill_parent" android:id="@+id/capturedImage" android:layout_above="@+id/btnTakePicture"/> </RelativeLayout>







댓글
댓글 쓰기