본 글은 Android - java 기본지식과 tensorflow 기본지식이 있음을 전제로 합니다.
텐서플로우 모델을 안드로이드에서 사용하기 위해서는
0. (모델구현 및 가중치파일 생성)
1. 모델 가중치 변환
2. 안드로이드 설정
3. 안드로이드 텐서플로우 코드 구현
의 과정이 필요합니다.
0번째의 과정은 진행이 되었다 하고 설명을 시작합니다.
<최종 tflite 의 tensorflow 버전은 2.0.0 임을 전제합니다>
1. 텐서플로우 가중치를 tflite 형식으로 변환해야합니다. 변환 방법은 몇가지가 있으며 아래의 그림과 같습니다.
각각의 변환 코드는 https://wikidocs.net/37704 위키글을 참조해주세요. (모델 구현부터 시작하신다면 tensorflow>=2.0.0 으로 구현하시길 추천합니다)
2. 안드로이드의 tensorflow 사용설정은 Moudle의 gradle에 다음처럼 코드를 추가 해주시면 됩니다.
3. 안드로이드 코드 구현은 몇가지가 있습니다 차례대로 설명하면
- 모델 로드 함수 구현(seriealize 된 tflite 파일을 읽어오기 위함)
- 모델 인터프리터 함수 구현(inference를 하기위함)
- 카메라 사용 및 이미지 처리 함수구현(본 글에서는 cats_and_dogs 모델을 기준으로 합니다
cats and dogs?
자세한 설명을 그림으로 첨부합니다.
MainActivity.java
변수선언 & button onClick함수작성, 카메라실행 코드 작성
MainActivity.loadModelFile
MainActivity.getTfliteInterpreter
MainActivity.getInputImage
MainActivity.onActivityResult (override method)
Code Details
MainActivity.java
package com.example.cats_dogs_classifier; import androidx.appcompat.app.AppCompatActivity; import android.app.Activity; import android.content.res.AssetFileDescriptor; import android.graphics.Bitmap; import android.graphics.Matrix; import android.media.ThumbnailUtils; import android.os.Bundle; import android.content.Intent; import android.provider.MediaStore; import android.util.Log; import android.view.View; import android.widget.Button; import android.widget.ImageView; import android.widget.Toast; import org.tensorflow.lite.Interpreter; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; public class MainActivity extends AppCompatActivity { private Button btnCapture; private ImageView imgCapture; private static final int Image_Capture_Code = 1; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); btnCapture =(Button)findViewById(R.id.btnTakePicture); imgCapture = (ImageView) findViewById(R.id.capturedImage); btnCapture.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { Intent cInt = new Intent(MediaStore.ACTION_IMAGE_CAPTURE); startActivityForResult(cInt,Image_Capture_Code); } }); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { if (requestCode == Image_Capture_Code) { if (resultCode == RESULT_OK) { Bitmap bp = (Bitmap) data.getExtras().get("data"); bp = RotateBitmap(bp, 90); int cx = 128, cy = 128; bp = Bitmap.createScaledBitmap(bp, cx, cy, false); int[] pixels = new int[cx * cy]; bp.getPixels(pixels, 0, cx, 0, 0, cx, cy); ByteBuffer input_img = getInputImage_2(pixels, cx, cy); Interpreter tf_lite = getTfliteInterpreter("cats_and_dogs.tflite"); float[][] pred = new float[1][2]; tf_lite.run(input_img, pred); final String predText = String.format("%f", pred[0][0]) + String.format("%f", pred[0][1]); if(pred[0][0] > pred[0][1]){ Toast toast = Toast.makeText(getApplicationContext(), String.format("고양이 확률 : %f", pred[0][0]), Toast.LENGTH_LONG); toast.show(); }else{ Toast toast = Toast.makeText(getApplicationContext(), String.format("강아지 확률 : %f", pred[0][1]), Toast.LENGTH_LONG); toast.show(); } Log.d("prediction", predText); imgCapture.setImageBitmap(bp); } else if (resultCode == RESULT_CANCELED) { Toast.makeText(this, "Cancelled", Toast.LENGTH_LONG).show(); } } } public static Bitmap RotateBitmap(Bitmap source, float angle) { Matrix matrix = new Matrix(); matrix.postRotate(angle); Bitmap temp = Bitmap.createBitmap(source, 0, 0, source.getWidth(), source.getHeight(), matrix, true); temp = ThumbnailUtils.extractThumbnail(temp, 1080, 1080); return temp; } // 다루기 편한 1차원 배열 사용 private ByteBuffer getInputImage_2(int[] pixels, int cx, int cy) { ByteBuffer input_img = ByteBuffer.allocateDirect(cx * cy * 3 * 4); input_img.order(ByteOrder.nativeOrder()); for (int i = 0; i < cx * cy; i++) { int pixel = pixels[i]; // ARGB : ff4e2a2a input_img.putFloat(((pixel >> 16) & 0xff) / (float) 255); input_img.putFloat(((pixel >> 8) & 0xff) / (float) 255); input_img.putFloat(((pixel >> 0) & 0xff) / (float) 255); } return input_img; } // 모델 파일 인터프리터를 생성하는 공통 함수 // loadModelFile 함수에 예외가 포함되어 있기 때문에 반드시 try, catch 블록이 필요하다. private Interpreter getTfliteInterpreter(String modelPath) { try { return new Interpreter(loadModelFile(MainActivity.this, modelPath)); } catch (Exception e) { e.printStackTrace(); } return null; } // 모델을 읽어오는 함수로, 텐서플로 라이트 홈페이지에 있다. // MappedByteBuffer 바이트 버퍼를 Interpreter 객체에 전달하면 모델 해석을 할 수 있다. private MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException { AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } }
activity_main.xml
<?xml version="1.0" encoding="utf-8"?> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" android:layout_width="match_parent" android:layout_height="match_parent" android:paddingLeft="10dp" android:orientation="vertical" android:paddingRight="10dp"> <Button android:id="@+id/btnTakePicture" android:layout_width="wrap_content" android:layout_height="wrap_content" android:text="Take a Photo" android:textStyle="bold" android:layout_centerHorizontal="true" android:layout_alignParentBottom="true" /> <ImageView android:layout_width="fill_parent" android:layout_height="fill_parent" android:id="@+id/capturedImage" android:layout_above="@+id/btnTakePicture"/> </RelativeLayout>
댓글
댓글 쓰기