안드로이드 텐서플로우 구현하기! (카메라 이용)


본 글은 Android - java 기본지식과 tensorflow 기본지식이 있음을 전제로 합니다.




텐서플로우 모델을 안드로이드에서 사용하기 위해서는

0. (모델구현 및 가중치파일 생성)
1.  모델 가중치 변환
2. 안드로이드 설정
3. 안드로이드 텐서플로우 코드 구현
의 과정이 필요합니다.
0번째의 과정은 진행이 되었다 하고 설명을 시작합니다.

<최종 tflite 의 tensorflow 버전은 2.0.0 임을 전제합니다>



1. 텐서플로우 가중치를 tflite 형식으로 변환해야합니다. 변환 방법은 몇가지가 있으며 아래의 그림과 같습니다.

각각의 변환 코드는 https://wikidocs.net/37704 위키글을 참조해주세요. (모델 구현부터 시작하신다면 tensorflow>=2.0.0 으로 구현하시길 추천합니다)



2. 안드로이드의 tensorflow 사용설정은 Moudle의 gradle에 다음처럼 코드를 추가 해주시면 됩니다.



3. 안드로이드 코드 구현은 몇가지가 있습니다 차례대로 설명하면
- 모델 로드 함수 구현(seriealize 된 tflite 파일을 읽어오기 위함)
- 모델 인터프리터 함수 구현(inference를 하기위함)
- 카메라 사용 및 이미지 처리 함수구현(본 글에서는 cats_and_dogs 모델을 기준으로 합니다
  cats and dogs?


자세한 설명을 그림으로 첨부합니다.
MainActivity.java

변수선언  &  button onClick함수작성, 카메라실행 코드 작성


MainActivity.loadModelFile

MainActivity.getTfliteInterpreter

MainActivity.getInputImage

MainActivity.onActivityResult (override method)





Code Details




MainActivity.java
package com.example.cats_dogs_classifier;

import androidx.appcompat.app.AppCompatActivity;

import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.ThumbnailUtils;
import android.os.Bundle;

import android.content.Intent;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.Toast;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

public class MainActivity extends AppCompatActivity {
    private Button btnCapture;
    private ImageView imgCapture;
    private static final int Image_Capture_Code = 1;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        btnCapture =(Button)findViewById(R.id.btnTakePicture);
        imgCapture = (ImageView) findViewById(R.id.capturedImage);
        btnCapture.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                Intent cInt = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
                startActivityForResult(cInt,Image_Capture_Code);
            }
        });
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        if (requestCode == Image_Capture_Code) {
            if (resultCode == RESULT_OK) {
                Bitmap bp = (Bitmap) data.getExtras().get("data");
                bp = RotateBitmap(bp, 90);
                int cx = 128, cy = 128;
                bp = Bitmap.createScaledBitmap(bp, cx, cy, false);
                int[] pixels = new int[cx * cy];
                bp.getPixels(pixels, 0, cx, 0, 0, cx, cy);

                ByteBuffer input_img = getInputImage_2(pixels, cx, cy);
                Interpreter tf_lite = getTfliteInterpreter("cats_and_dogs.tflite");

                float[][] pred = new float[1][2];
                tf_lite.run(input_img, pred);

                final String predText = String.format("%f", pred[0][0]) + String.format("%f", pred[0][1]);
                if(pred[0][0] > pred[0][1]){
                    Toast toast = Toast.makeText(getApplicationContext(), String.format("고양이 확률 : %f", pred[0][0]), Toast.LENGTH_LONG); toast.show();
                }else{
                    Toast toast = Toast.makeText(getApplicationContext(), String.format("강아지 확률 : %f", pred[0][1]), Toast.LENGTH_LONG); toast.show();
                }



                Log.d("prediction", predText);


                imgCapture.setImageBitmap(bp);



            } else if (resultCode == RESULT_CANCELED) {
                Toast.makeText(this, "Cancelled", Toast.LENGTH_LONG).show();
            }
        }
    }

    public static Bitmap RotateBitmap(Bitmap source, float angle)
    {
        Matrix matrix = new Matrix();
        matrix.postRotate(angle);
        Bitmap temp = Bitmap.createBitmap(source, 0, 0, source.getWidth(), source.getHeight(), matrix, true);
        temp = ThumbnailUtils.extractThumbnail(temp, 1080, 1080);

        return temp;
    }
    // 다루기 편한 1차원 배열 사용
    private ByteBuffer getInputImage_2(int[] pixels, int cx, int cy) {
        ByteBuffer input_img = ByteBuffer.allocateDirect(cx * cy * 3 * 4);
        input_img.order(ByteOrder.nativeOrder());

        for (int i = 0; i < cx * cy; i++) {
            int pixel = pixels[i];        // ARGB : ff4e2a2a

            input_img.putFloat(((pixel >> 16) & 0xff) / (float) 255);
            input_img.putFloat(((pixel >> 8) & 0xff) / (float) 255);
            input_img.putFloat(((pixel >> 0) & 0xff) / (float) 255);
        }

        return input_img;
    }


    // 모델 파일 인터프리터를 생성하는 공통 함수
    // loadModelFile 함수에 예외가 포함되어 있기 때문에 반드시 try, catch 블록이 필요하다.
    private Interpreter getTfliteInterpreter(String modelPath) {
        try {
            return new Interpreter(loadModelFile(MainActivity.this, modelPath));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    // 모델을 읽어오는 함수로, 텐서플로 라이트 홈페이지에 있다.
    // MappedByteBuffer 바이트 버퍼를 Interpreter 객체에 전달하면 모델 해석을 할 수 있다.
    private MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }


}








activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:paddingLeft="10dp"
    android:orientation="vertical"
    android:paddingRight="10dp">
    <Button
        android:id="@+id/btnTakePicture"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="Take a Photo"
        android:textStyle="bold"
        android:layout_centerHorizontal="true"
        android:layout_alignParentBottom="true" />
    <ImageView
        android:layout_width="fill_parent"
        android:layout_height="fill_parent"
        android:id="@+id/capturedImage"
        android:layout_above="@+id/btnTakePicture"/>
</RelativeLayout>

댓글