Compila una app para Android de clasificación de dígitos escritos a mano con MediaPipe Tasks

1. Introducción

¿Qué es MediaPipe?

Las MediaPipe Solutions te permiten aplicar soluciones de aprendizaje automático (AA) en tus apps. Proporcionan un framework para que configures canalizaciones de procesamiento compiladas previamente que arrojan resultados inmediatos, pertinentes y útiles para los usuarios. Incluso puedes personalizar estas soluciones con MediaPipe Model Maker para actualizar los modelos predeterminados.

La clasificación de imágenes es una de las muchas tareas de visión de AA que ofrecen las MediaPipe Solutions. MediaPipe Tasks está disponible para Android, iOS, Python (incluida Raspberry Pi) y la Web.

En este codelab, comenzarás con una app para Android que te permite dibujar dígitos numéricos en la pantalla. Luego, agregarás una función que clasifica esos dígitos dibujados como un solo valor de 0 a 9.

Qué aprenderás

  • Cómo incorporar una tarea de clasificación de imágenes en una app para Android con MediaPipe Tasks.

Requisitos

  • Una versión instalada de Android Studio (este codelab se escribió y probó con Android Studio Giraffe)
  • Un dispositivo o emulador de Android para ejecutar la app
  • Conocimientos básicos sobre el desarrollo de Android (no es "Hola, mundo", pero no está muy lejos)

2. Agrega MediaPipe Tasks a la app para Android

Descarga la app de inicio para Android

Este codelab comenzará con una muestra precompilada que te permite dibujar en la pantalla. Puedes encontrar esa app de inicio en el repositorio oficial de muestras de MediaPipe aquí. Para clonar el repositorio o descargar el archivo zip, haz clic en Code > Download ZIP.

Importa la app a Android Studio

  1. Abre Android Studio.
  2. En la pantalla Welcome to Android Studio, selecciona Open en la esquina superior derecha.

a0b5b070b802e4ea.png

  1. Navega hasta donde clonaste o descargaste el repositorio y abre el directorio codelabs/digitclassifier/android/start.
  2. Para verificar que todo se haya abierto correctamente, haz clic en la flecha verde run ( 7e15a9c9e1620fe7.png) en la parte superior derecha de Android Studio.
  3. Deberías ver que la app se abre con una pantalla negra en la que puedes dibujar, así como un botón Clear para restablecer esa pantalla. Si bien puedes dibujar en esa pantalla, no hace mucho más, por lo que comenzaremos a solucionarlo ahora.

11a0f6fe021fdc92.jpeg

Modelo

Cuando ejecutes la app por primera vez, es posible que observes que se descarga un archivo llamado mnist.tflite y se almacena en el directorio assets de la app. Para simplificar, ya tomamos un modelo conocido, MNIST, que clasifica dígitos, y lo agregamos a la app mediante el uso de la secuencia de comandos download_models.gradle en el proyecto. Si decides entrenar tu propio modelo personalizado, como uno para letras escritas a mano, quitarás el archivo download_models.gradle , borrarás la referencia a él en el archivo build.gradle a nivel de la app y cambiarás el nombre del modelo más adelante en el código (específicamente en el archivo DigitClassifierHelper.kt).

Actualiza build.gradle

Antes de comenzar a usar MediaPipe Tasks, debes importar la biblioteca.

  1. Abre el archivo build.gradle que se encuentra en el módulo app y, luego, desplázate hacia abajo hasta el bloque dependencies.
  2. Deberías ver un comentario en la parte inferior de ese bloque que dice // STEP 1 Dependency Import.
  3. Reemplaza esa línea por la siguiente implementación.
implementation("com.google.mediapipe:tasks-vision:latest.release")
  1. Haz clic en el botón Sync Now que aparece en el banner en la parte superior de Android Studio para descargar esta dependencia.

3. Crea un auxiliar de clasificación de dígitos de MediaPipe Tasks

En el siguiente paso, completarás una clase que hará el trabajo pesado para la clasificación de aprendizaje automático. Abre DigitClassifierHelper.kt y comencemos.

  1. Busca el comentario en la parte superior de la clase que dice // STEP 2 Create listener.
  2. Reemplaza esa línea por el siguiente código. Esto creará un objeto de escucha que se usará para pasar los resultados de la clase DigitClassifierHelper a donde sea que se escuchen esos resultados (en este caso, será tu clase DigitCanvasFragment , pero llegaremos allí pronto).
// STEP 2 Create listener

interface DigitClassifierListener {
    fun onError(error: String)
    fun onResults(
        results: ImageClassifierResult,
        inferenceTime: Long
    )
}
  1. También deberás aceptar un DigitClassifierListener como parámetro opcional para la clase:
class DigitClassifierHelper(
    val context: Context,
    val digitClassifierListener: DigitClassifierListener?
) {
  1. Si bajas a la línea que dice // STEP 3 define classifier, agrega la siguiente línea para crear un marcador de posición para el ImageClassifier que se usará para esta app:

// STEP 3 define classifier

private var digitClassifier: ImageClassifier? = null
  1. Agrega la siguiente función donde veas el comentario // STEP 4 set up classifier:
// STEP 4 set up classifier
private fun setupDigitClassifier() {

    val baseOptionsBuilder = BaseOptions.builder()
        .setModelAssetPath("mnist.tflite")

    // Describe additional options
    val optionsBuilder = ImageClassifierOptions.builder()
        .setRunningMode(RunningMode.IMAGE)
        .setBaseOptions(baseOptionsBuilder.build())

    try {
        digitClassifier =
            ImageClassifier.createFromOptions(
                context,
                optionsBuilder.build()
            )
    } catch (e: IllegalStateException) {
        digitClassifierListener?.onError(
            "Image classifier failed to initialize. See error logs for " +
                    "details"
        )
        Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
    }
}

Hay algunas cosas que suceden en la sección anterior, así que veamos partes más pequeñas para comprender realmente lo que sucede.

val baseOptionsBuilder = BaseOptions.builder()
    .setModelAssetPath("mnist.tflite")

// Describe additional options
val optionsBuilder = ImageClassifierOptions.builder()
    .setRunningMode(RunningMode.IMAGE)
    .setBaseOptions(baseOptionsBuilder.build())

Este bloque definirá los parámetros que usa ImageClassifier. Esto incluye el modelo almacenado en tu app (mnist.tflite) en BaseOptions y el RunningMode en ImageClassifierOptions, que en este caso es IMAGE, pero VIDEO y LIVE_STREAM son opciones adicionales disponibles. Otros parámetros disponibles son MaxResults, que limita el modelo para que muestre una cantidad máxima de resultados, y ScoreThreshold, que establece la confianza mínima que debe tener el modelo en un resultado antes de mostrarlo.

try {
    digitClassifier =
        ImageClassifier.createFromOptions(
            context,
            optionsBuilder.build()
        )
} catch (e: IllegalStateException) {
    digitClassifierListener?.onError(
        "Image classifier failed to initialize. See error logs for " +
                "details"
    )
    Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
}

Después de crear las opciones de configuración, puedes crear tu nuevo ImageClassifier pasando un contexto y las opciones. Si algo sale mal con ese proceso de inicialización, se mostrará un error a través de tu DigitClassifierListener.

  1. Como queremos inicializar ImageClassifier antes de que se use, puedes agregar un bloque init para llamar a setupDigitClassifier().
init {
    setupDigitClassifier()
}
  1. Por último, desplázate hacia abajo hasta el comentario que dice // STEP 5 create classify function y agrega el siguiente código. Esta función aceptará un Bitmap, que en este caso es el dígito dibujado, lo convertirá en un objeto de imagen de MediaPipe (MPImage) y, luego, clasificará esa imagen con ImageClassifier, además de registrar cuánto tiempo lleva la inferencia, antes de mostrar esos resultados a través de DigitClassifierListener.
// STEP 5 create classify function
fun classify(image: Bitmap) {
    if (digitClassifier == null) {
        setupDigitClassifier()
    }

    // Convert the input Bitmap object to an MPImage object to run inference.
    // Rotating shouldn't be necessary because the text is being extracted from
    // a view that should always be correctly positioned.
    val mpImage = BitmapImageBuilder(image).build()

    // Inference time is the difference between the system time at the start and finish of the
    // process
    val startTime = SystemClock.uptimeMillis()

    // Run image classification using MediaPipe Image Classifier API
    digitClassifier?.classify(mpImage)?.also { classificationResults ->
        val inferenceTimeMs = SystemClock.uptimeMillis() - startTime
        digitClassifierListener?.onResults(classificationResults, inferenceTimeMs)
    }
}

Y eso es todo para el archivo auxiliar. En la siguiente sección, completarás los pasos finales para comenzar a clasificar los números dibujados.

4. Ejecuta la inferencia con MediaPipe Tasks

Para comenzar esta sección, abre la clase DigitCanvasFragment en Android Studio, que es donde se realizará todo el trabajo.

  1. En la parte inferior de este archivo, deberías ver un comentario que dice // STEP 6 Set up listener. Aquí agregarás las funciones onResults() y onError() asociadas con el objeto de escucha.
// STEP 6 Set up listener
override fun onError(error: String) {
    activity?.runOnUiThread {
        Toast.makeText(requireActivity(), error, Toast.LENGTH_SHORT).show()
        fragmentDigitCanvasBinding.tvResults.text = ""
    }
}

override fun onResults(
    results: ImageClassifierResult,
    inferenceTime: Long
) {
    activity?.runOnUiThread {
        fragmentDigitCanvasBinding.tvResults.text = results
            .classificationResult()
            .classifications().get(0)
            .categories().get(0)
            .categoryName()

        fragmentDigitCanvasBinding.tvInferenceTime.text = requireActivity()
            .getString(R.string.inference_time, inferenceTime.toString())
    }
}

onResults() es particularmente importante, ya que mostrará los resultados recibidos de ImageClassifier. Como esta devolución de llamada se activa desde un subproceso en segundo plano, también deberás ejecutar las actualizaciones de la IU en el subproceso de IU de Android.

  1. A medida que agregas funciones nuevas desde una interfaz en el paso anterior, también deberás agregar la declaración de implementación en la parte superior de la clase.
class DigitCanvasFragment : Fragment(), DigitClassifierHelper.DigitClassifierListener
  1. Hacia la parte superior de la clase, deberías ver un comentario que dice // STEP 7a Initialize classifier. Aquí es donde colocarás la declaración para DigitClassifierHelper.
// STEP 7a Initialize classifier.
private lateinit var digitClassifierHelper: DigitClassifierHelper
  1. Si bajas a // STEP 7b Initialize classifier, puedes inicializar digitClassifierHelper dentro de la función onViewCreated().
// STEP 7b Initialize classifier
// Initialize the digit classifier helper, which does all of the
// ML work. This uses the default values for the classifier.
digitClassifierHelper = DigitClassifierHelper(
    context = requireContext(), digitClassifierListener = this
)
  1. Para los últimos pasos, busca el comentario // STEP 8a*: classify* y agrega el siguiente código para llamar a una función nueva que agregarás en un momento. Este bloque de código activará la clasificación cuando levantes el dedo del área de dibujo de la app.
// STEP 8a: classify
classifyDrawing()
  1. Por último, busca el comentario // STEP 8b classify para agregar la nueva función classifyDrawing(). Esto extraerá un mapa de bits del lienzo y, luego, lo pasará a DigitClassifierHelper para realizar la clasificación y recibir los resultados en la función de interfaz onResults().
// STEP 8b classify
private fun classifyDrawing() {
    val bitmap = fragmentDigitCanvasBinding.digitCanvas.getBitmap()
    digitClassifierHelper.classify(bitmap)
}

5. Implementa y prueba la app

Después de todo eso, deberías tener una app en funcionamiento que pueda clasificar los dígitos dibujados en la pantalla. Continúa y, luego, implementa la app en Android Emulator o en un dispositivo Android físico para probarla.

  1. Haz clic en Run ( 7e15a9c9e1620fe7.png) en la barra de herramientas de Android Studio para ejecutar la app.
  2. Dibuja cualquier dígito en el panel de dibujo y verifica si la app puede reconocerlo. Debería mostrar el dígito que el modelo cree que se dibujó, así como el tiempo que tardó en predecir ese dígito.

7f37187f8f919638.gif

6. ¡Felicitaciones!

¡Lo lograste! En este codelab, aprendiste a agregar clasificación de imágenes a una app para Android y, específicamente, a clasificar dígitos dibujados a mano con el modelo MNIST.

Próximos pasos

  • Ahora que puedes clasificar dígitos, es posible que quieras entrenar tu propio modelo para clasificar letras dibujadas, animales o una cantidad infinita de otros elementos. Puedes encontrar la documentación para entrenar un nuevo modelo de clasificación de imágenes con MediaPipe Model Maker en la página developers.google.com/mediapipe.
  • Obtén información sobre las otras MediaPipe Tasks que están disponibles para Android, incluidas la detección de puntos de referencia faciales, el reconocimiento de gestos y la clasificación de audio.

Estamos ansiosos por ver todas las cosas interesantes que crees.