在Android上部署TF目标检测模型

在移动设备上部署机器学习模型是ML即将开始的新阶段。

目标检测模型,已经与语音识别、图像分类等模型一起应用于移动设备。这些模型通常运行在支持GPU的计算机上,部署在移动设备上时也有大量用例。

为了演示如何将ML模型,特别是对象检测模型引入Android的端到端示例,我们将使用Victor Dibia的手检测模型进行演示,该模型来自victordibia/handtracking repo。

https://github.com/victordibia/handtracking

该模型可以从图像中检测人手,并使用TensorFlow对象检测API制作。我们将使用Victor Dibia的repo中经过训练的模型,并将其转换为TensorFlow Lite(TFLite)格式,该格式可用于在Android(甚至iOS、Raspberry Pi)上运行该模型。

接下来,我们将转到Android应用程序,创建运行模型所需的所有必要类/方法,并在实时摄像机上显示其预测(边界框)。

让我们开始吧!

目录

  • 将模型转换为TFLite

    • 1.设置TF目标检测API

    • 2.将检查点转换为图

    • 3.将图转换为TFLite缓冲区

  • 在Android中集成TFLite模型

    • 1.添加CameraX、Coroutines 和TF Lite的依赖项

    • 2.初始化CameraX和ImageAnalysis.Analyzer

    • 3.实现手部检测模型

    • 4.绘制边界框


将模型检查点转换为TFLite

我们的第一步是将Victor Dibia的repo中提供的经过训练的模型检查点转换为TensorFlow Lite格式。TensorFlow Lite提供了一个在Android、iOS和微控制器设备上运行TensorFlow模型的高效网关。为了运行转换脚本,我们需要在机器上设置TensorFlow对象检测API。你也可以使用这个Colab笔记本来执行所有转换。

建议你使用Colab笔记本(尤其是Windows)。

https://github.com/shubham0204/Google_Colab_Notebooks/blob/main/Hand_Tracking_Model_TFLite_Conversion.ipynb

1.设置TF目标检测API

TensorFlow对象检测API提供了许多预训练的对象检测模型,这些模型可以在自定义数据集上进行微调,并直接部署到移动、web或云中。我们只需要帮助我们将模型检查点转换为TF Lite缓冲区的转换脚本。

手部检测模型本身是使用TF OD API和TensorFlow 1.x制作的。所以,首先我们需要安装TensorFlow 1.x或TF 1.15.0(1.x系列中的最新版本),然后克隆包含TF OD API的tensorflow/models repo。

# Installing TF 1.15.0
!pip install tensorflow==1.15.0

# Cloning the tensorflow/models repo
!git clone https://github.com/tensorflow/models

# Installing the TF OD API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf1/setup.py .
python -m pip install .

此外,我们将克隆Victor Dibia仓库,以获得模型检查点

!git clone https://github.com/victordibia/handtracking

2.将检查点转换为图

现在在models/research/object_detection目录中,你将看到一个Python脚本export_tflite_ssd_graph.py,我们将使用它将模型检查点转换为与TFLite兼容的图。检查点可以在handtracking/model-checkpoint目录中找到。ssd代表‘Single Shot Detector’,这是手部检测模型的体系结构,而mobilenet代表mobilenet(v1或v2)的主干体系结构,这是一种专门用于移动设备的CNN体系结构。

d3eb1e0cc4182df03d1f883216395fbb.png

导出的TFLite图包含固定的输入和输出节点。我们可以在export_tflite_ssd_graph中找到这些节点(或张量)的名称和形状。使用该脚本,我们将把模型检查点转换为一个与TFLite兼容的图,给出三个参数,

  1. pipeline_config_path:指向的路径。包含所用SSD Lite模型配置的配置文件。

  2. trained_checkpoint_prefix:我们希望转换模型检查点的前缀。

  3. max_detections:要预测的边界框的数量。

!python models/research/object_detection/export_tflite_ssd_graph.py \
    --pipeline_config_path handtracking/model-checkpoint/ssdlitemobilenetv2/data_ssdlite.config \
    --trained_checkpoint_prefix handtracking/model-checkpoint/ssdlitemobilenetv2/out_model.ckpt-19040 \
    --output_directory outputs \
    --max_detections 10

脚本执行后,我们剩下两个文件,tflite_graph.pb和tflite_graph.pbtxt是与TFLite兼容的图形。

3.将图转换为TFLite缓冲区

现在,我们将使用第二个脚本(或者更准确地说,一个实用程序)将步骤2中生成的图转换为TFLite缓冲区(.tflite)。如TensorFlow 2.x排除了Session和Placeholder的使用,我们不能在这里将图转换为TFLite。这就是我们安装TensorFlow 1.x的原因之一。

我们将使用tflite_convert实用程序将图转换为tflite缓冲区。我们也可以使用tf.lite.TFLiteConverter API,但我们现在将继续使用命令行实用程序。

!tflite_convert \
--graph_def_file=/content/outputs/tflite_graph.pb \
--output_file=/content/outputs/model.tflite \
--output_format=TFLITE \
--input_arrays=normalized_input_image_tensor \
--input_shapes=1,300,300,3 \
--inference_type=FLOAT \
--output_arrays="TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3" \
--allow_custom_ops

执行完成后,你将看到一个outputs目录中的model.tflite文件。为了检查输入/输出形状,我们将使用tf.lite.Interpreter加载TFLite模型,调用.get_input_details()和.get_output_details()分别获取输入和输出详细信息。

提示:使用pprint可以获得漂亮的输出。

import tensorflow as tf
import pprint

interpreter = tf.lite.Interpreter( '/content/outputs/model.tflite' )
interpreter.allocate_tensors()

pprint.pprint( interpreter.get_input_details())
pprint.pprint( interpreter.get_output_details() )

在Android中集成TFLite模型

一旦我们有了TFLite模型及其输入和输出形状的所有细节,我们就可以在Android应用程序中运行它了。在Android Studio中创建一个新项目,或者可以自由地克隆GitHub repo!

1.添加CameraX、Coroutines 和TF Lite的依赖项

当我们要在实时摄像头上检测手的时候,我们需要在Android应用程序中添加CameraX依赖项。类似地,为了运行TFLite模型,我们需要tensorflow lite依赖项和Kotlin Coroutines依赖项,它们帮助我们异步运行模型。在应用程序级构建中.gradle文件,我们将添加以下依赖项,

plugins {
    ...
}

android {
    
    ...
    
    aaptOptions {
        noCompress "tflite"
    }
    
    ...
    
}

dependencies {
    ...

    // CameraX dependencies
    implementation "androidx.camera:camera-camera2:1.0.1"
    implementation "androidx.camera:camera-lifecycle:1.0.1"
    implementation "androidx.camera:camera-view:1.0.0-alpha28"
    implementation "androidx.camera:camera-extensions:1.0.0-alpha28"

    // TensorFlow Lite dependencies
    implementation 'org.tensorflow:tensorflow-lite:2.4.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.4.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'

    // Kotlin Coroutines
    implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.1'
    implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.4.1'

    ...
}

确保添加了aaptOptions{ noCompress "tflite" },这样系统就不会压缩模型,从而使应用程序的大小变小。现在,为了在我们的应用程序中放置TFLite模型,我们将在app/src/main下创建一个assets文件夹。将TFLite文件(.tflite)粘贴到此文件夹中。

99bba545b63e42d8f70d4b19edef8a4b.png

2.初始化CameraX和ImageAnalysis.Analyzer

我们将使用CameraX软件包中的PreviewView向用户显示实时摄像头提要。在它上面,我们将放置一个名为BoundingBoxOverlay的图,用于在摄影机提要上绘制边界框。我不会在这里讨论实现,但你可以从源代码中学到,

https://proandroiddev.com/realtime-selfie-segmentation-in-android-with-mlkit-38637c8502ba

当我们要预测实时帧数据的边界框时,我们还需要ImageAnalysis.Analyzer对象,该对象从实时摄影机提要返回每一帧。请参阅FrameAnalyzer中的这段代码。

// Image Analyser for performing hand detection on camera frames.
class FrameAnalyser(
    private val handDetectionModel: HandDetectionModel ,
    private val boundingBoxOverlay: BoundingBoxOverlay ) : ImageAnalysis.Analyzer {

    private var frameBitmap : Bitmap? = null
    private var isFrameProcessing = false


    override fun analyze(image: ImageProxy) {
        // If a frame is being processed, drop the current frame.
        if ( isFrameProcessing ) {
            image.close()
            return
        }
        isFrameProcessing = true

        // Get the `Bitmap` of the current frame ( with corrected rotation ).
        frameBitmap = BitmapUtils.imageToBitmap( image.image!! , image.imageInfo.rotationDegrees )
        image.close()

        // Configure frameHeight and frameWidth for output2overlay transformation matrix.
        if ( !boundingBoxOverlay.areDimsInit ) {
            Logger.logInfo( "Passing dims to overlay..." )
            boundingBoxOverlay.frameHeight = frameBitmap!!.height
            boundingBoxOverlay.frameWidth = frameBitmap!!.width
        }

        CoroutineScope( Dispatchers.Main ).launch {
            runModel( frameBitmap!! )
        }
    }


    private suspend fun runModel( inputImage : Bitmap ) = withContext( Dispatchers.Default ) {
       ...
    }

}

BitmapUtils包含一些有用的静态方法来操作Bitmap。isFrameProcessing是一个布尔变量,用于确定是否必须删除传入帧或将其传递给模型。正如你所观察到的,我们在一个CoroutineScope中运行模型,因此当模型产生推断时,你不会观察到任何滞后。

3.实现手部检测模型

接下来,我们将创建一个名为HandDetectionModel的类,该类将处理所有TFLite操作,并返回给定图像(作为位图)的预测。

// Helper class for Hand detection TFLite model
class HandDetectionModel( context: Context ) {

    // I/O details for the hand detection model.
    // Refer to the comments of this script ->
    // https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py
    // For quantization, use the tflite_convert utility as described in the conversion notebook ( README ).
    private val modelInputImageDim = 300
    private val isQuantized = false
    private val maxDetections = 10
    private val boundingBoxesTensorShape = intArrayOf( 1 , maxDetections , 4 ) // [ 1 , 10 , 4 ]
    private val confidenceScoresTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
    private val classesTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
    private val numBoxesTensorShape = intArrayOf( 1 ) // [ 1 , ]
    // Input tensor processor for quantized and non-quantized versions of the model.
    private val inputImageProcessorQuantized = ImageProcessor.Builder()
        .add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
        .add( CastOp( DataType.FLOAT32 ) )
        .build()
    private val inputImageProcessorNonQuantized = ImageProcessor.Builder()
        .add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
        .add( NormalizeOp( 128.5f , 128.5f ) )
        .build()

    // See app/src/main/assets for the TFLite model.
    private val modelName = "model.tflite"
    private val numThreads = 4
    private var interpreter : Interpreter
    // Confidence threshold for NMS
    private val outputConfidenceThreshold = 0.9f
  
  ...

我们将在上述片段中分别理解每个术语,

  1. modelImageInputDim是我们模型的输入图像的大小。我们的模型可以拍摄300*300的图像。

  2. maxDetections代表我们的模型所做预测的最大数量。它决定了boundingBoxesTensorShape、confidenceScoresTensorShape、classesTensorShape和numTensorShape的形状。

  3. outputConfidenceThreshold用于过滤模型做出的预测。这不是NMS,但我们只接受分数大于此阈值的框。

  4. inputImageProcessorQuantized和inputImageProcessorNonQuantized是TensorOperator的实例,它将给定图像的大小调整为modelImageInputDim*modelInputImageDim。对于量化模型,我们用平均值和标准偏差都等于127.5的标准化给定图像。

现在,我们将实现一个run()方法,它将获取位图图像,并以List的形式输出边界框。Prediction是一个包含预测数据的类,如置信度得分和边界框坐标。

// Store the width and height of the input frames as they will be used for future transformations.
inputFrameWidth = inputImage.width
inputFrameHeight = inputImage.height

var tensorImage = TensorImage.fromBitmap( inputImage )
tensorImage = if ( isQuantized ) {
    inputImageProcessorQuantized.process( tensorImage )
}
else {
    inputImageProcessorNonQuantized.process( tensorImage )
}

val confidenceScores = TensorBuffer.createFixedSize( confidenceScoresTensorShape , DataType.FLOAT32 )
val boundingBoxes = TensorBuffer.createFixedSize( boundingBoxesTensorShape , DataType.FLOAT32 )
val classes = TensorBuffer.createFixedSize( classesTensorShape , DataType.FLOAT32 )
val numBoxes = TensorBuffer.createFixedSize( numBoxesTensorShape , DataType.FLOAT32 )
val outputs = mapOf(
    0 to boundingBoxes.buffer ,
    1 to classes.buffer ,
    2 to confidenceScores.buffer ,
    3 to numBoxes.buffer
)

val t1 = System.currentTimeMillis()
interpreter.runForMultipleInputsOutputs( arrayOf(tensorImage.buffer), outputs )
Logger.logInfo( "Model inference time -> ${System.currentTimeMillis() - t1} ms." )

return processOutputs( confidenceScores , boundingBoxes )

confidenceScores , boundingBoxes , classesnumBoxes是保存模型输出的四个张量。processOutputs方法将过滤边界框,并仅返回置信度得分大于阈值的框。

private fun processOutputs( scores : TensorBuffer ,
                            boundingBoxes : TensorBuffer ) : List<Prediction> {
    // Flattened version of array of shape [ 1 , maxDetections ] ( size = maxDetections )
    val scoresFloatArray = scores.floatArray
    // Flattened version of array of shape [ 1 , maxDetections , 4 ] ( size = maxDetections * 4 )
    val boxesFloatArray = boundingBoxes.floatArray
    val predictions = ArrayList<Prediction>()
    for ( i in boxesFloatArray.indices step 4 ) {
        // Store predictions which have a confidence > threshold
        if ( scoresFloatArray[ i / 4 ] >= filterThreshold ) {
            predictions.add(
                Prediction(
                    getRect( boxesFloatArray.sliceArray( i..i+3 )) ,
                    scoresFloatArray[ i / 4 ]
                )
            )
        }
    }
    return predictions.toList()
}


// Transform the normalized bounding box coordinates relative to the input frames.
private fun getRect( coordinates : FloatArray ) : Rect {
    return Rect(
        max( (coordinates[ 1 ] * inputFrameWidth).toInt() , 1 ),
        max( (coordinates[ 0 ] * inputFrameHeight).toInt() , 1 ),
        min( (coordinates[ 3 ] * inputFrameWidth).toInt() , inputFrameWidth ),
        min( (coordinates[ 2 ] * inputFrameHeight).toInt() , inputFrameHeight )
    )
}

4.绘制边界框

一旦我们收到了边界框,我们会像使用OpenCV一样,将它们绘制。我们将创建一个新的类BoundingBoxOverlay,并将其添加到activity_main.xml。

class BoundingBoxOverlay(context : Context, attributeSet : AttributeSet)
    : SurfaceView( context , attributeSet ) , SurfaceHolder.Callback {

    // Variables used to compute output2overlay transformation matrix
    // These are assigned in FrameAnalyser.kt
    var areDimsInit = false
    var frameHeight = 0
    var frameWidth = 0

    // This var is assigned in FrameAnalyser.kt
    var handBoundingBoxes: List<Prediction>? = null

    // This var is assigned in MainActivity.kt
    var isFrontCameraOn = false

    private var output2OverlayTransform: Matrix = Matrix()
    private val boxPaint = Paint().apply {
        color = Color.YELLOW
        style = Paint.Style.STROKE
        strokeWidth = 16f
    }
    private val textPaint = Paint().apply {
        strokeWidth = 2.0f
        textSize = 32f
        color = Color.YELLOW
    }

    private val displayMetrics = DisplayMetrics()

    ...

    override fun onDraw(canvas: Canvas?) {
        if ( handBoundingBoxes == null ) {
            return
        }
        if (!areDimsInit) {
            ...
        }
        else {
            for ( prediction in handBoundingBoxes!! ) {
                val rect = prediction.boundingBox.toRectF()
                output2OverlayTransform.mapRect( rect )
                canvas?.drawRoundRect( rect , 16f, 16f, boxPaint )
                canvas?.drawText(
                    prediction.confidence.toString(),
                    rect.centerX(),
                    rect.centerY(),
                    textPaint
                )
            }
        }
    }
    
}

就这些!我们刚刚在Android应用程序中实现了一个手检测器!你可以在查看所有代码后运行该应用程序。

2e48e4a0e576c1678fe42df5f459e90e.png

感谢阅读!

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

5c53263414e2b29b881005b2c1774b97.png

版权声明:本文为CSDN博主「woshicver」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/woshicver/article/details/123143988

woshicver

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐