在 iOS 上使用 TensorFlow Lite 识别花卉

1. 简介

657431be3173fa86

TensorFlow 是一个多用途机器学习框架。TensorFlow 有很多用途,例如跨云端集群训练大型模型,以及在手机等嵌入式系统本地运行模型。

此 Codelab 使用 TensorFlow Lite 在 iOS 设备上运行图片识别模型。

学习内容

  • 如何使用 TFLite 转换器优化您的模型。
  • 如何使用 TFLite 解释器在预制的 iOS 应用中运行此测试。

构建内容

一个简单的相机应用,该应用会运行 TensorFlow 图像识别程序来识别花卉。

前提条件

如果您要在自己的硬件上完成此 Codelab,请确保安装了以下软件:

  • Xcode 10 或更高版本
  • CocoaPods 1.8.0 或更高版本

c45ecd122998622e.png

许可:可随意使用

2. 使用 Colab 训练花卉识别器

此 Codelab 将使用 Colaboratory 和 Xcode。

打开 Colab,它使用 TensorFlow Lite Model Maker 训练分类器,以使用迁移学习识别花卉,并导出 TFLite 模型以便在移动应用中使用。

3. 设置工作目录

克隆 Git 代码库

以下命令将克隆包含此 Codelab 所需文件的 Git 代码库:

git clone https://github.com/tensorflow/examples.git

现在,使用 cd 进入您刚刚创建的克隆的 Xcode 项目根目录。您将在此基础上完成此 Codelab 的其余部分:

cd examples/lite/examples/image_classification/ios

4. 设置 iOS 应用

安装依赖项

使用 CocoaPods 安装 iOS 应用的依赖项(包括 TensorFlow Lite)。安装命令完成后,打开 ImageClassification.xcworkspace 以在 Xcode 中打开项目。

pod install --repo-update
open ImageClassification.xcworkspace

5. 试运行应用

要使用相机,应用必须在真实设备上运行,因为 iOS 模拟器无法访问 Mac 的相机。要在 iOS 设备上构建应用,您必须加入 Apple Developer Program 或能够使用他人为您配置的设备。

如果您想在模拟器中运行此 Codelab,则需要在模拟器本身中将 Safari 中的图片复制并粘贴到粘贴板中。以下是在模拟器中处理图片的步骤:

  1. 针对您选择的模拟器目标构建应用。
  2. 在 iOS 模拟器中,按 Cmd + Shift + H 可将应用最小化。
  3. 点按主屏幕底部的 Safari,然后搜索图片。
  4. 在 Google 图片搜索结果中,点按相应结果并长按相应图片。在弹出的对话框中,选择“复制”。
  5. 返回 TFL Classify 应用。复制的图片应与推理结果一起自动显示。如果没有,请确保您复制的是图片数据本身,而不是图片的网址。

测试 build 并安装应用

在对应用进行任何更改之前,我们先运行与代码库一起提供的版本。从左上角的下拉菜单中选择您的 iOS 设备:

275753d3a77a0df3.png

然后按 Cmd+R 或按 Xcode 中的“Play”(播放)f96cf117245c0fa6.png 按钮,将该应用构建到您的设备。应用安装至您的设备后,应该就会自动启动。

此版本的应用使用基于 1000 个 ImageNet 类别进行预训练的标准 MobileNet。代码应如下所示:

d11436f0bb5a75db.jpeg

6. 运行自定义应用

默认应用设置使用标准 MobileNet 将图片归为 1000 个 ImageNet 类别中的一个。

现在,我们来修改应用,以便应用将重新训练的模型用于 Colab 中训练的自定义图片类别。

7. 转换应用以运行模型

将模型文件添加到项目中

项目的模型资源位于 Xcode 项目导航器的 ImageClassification > Model 中。如需替换它们,请先删除 Model 组中的两个现有文件。出现提示时,选择“移到废纸篓”:

cf2f7fefb2e5075f.png

然后,将您从 Colab 下载的 model.tflitelabels.txt 文件拖动到“模型”组中。出现提示时,请确保同时选中 Copy items if neededAdd to targets

281d7eb72635bb5f

修改应用的代码

为了让应用正常运行,我们需要更新模型加载逻辑的路径,使其指向我们添加的新模型。

打开 ModelDataHandler.swift(Xcode 导航器路径:ImageClassification -> ModelDataHandler -> ModelDataHandler.swift),将 line 36 更改为

// before
static let modelInfo: FileInfo = (name: "mobilenet_quant_v1_224", extension: "tflite")

// after
static let modelInfo: FileInfo = (name: "model", extension: "tflite")

请务必保存所有更改。

8. 运行自定义应用

按 Cmd+B 或按 Xcode 中的 Play f96cf117245c0fa6.png 按钮,在您的设备上构建应用程序。应用启动后,显示的内容应如下所示:

c45ecd122998622e.png

您可以同时按住电源按钮和音量调高按钮来截屏。

现在试着在网上搜索花卉,将相机对准计算机屏幕,看看这些照片分类是否正确。

或者让朋友帮忙拍张照片,看看您是哪类 TensorFlower \\uf339 \\uf33b \\uf337

9. 它是如何运作的?

现在,您已经运行了该应用,下面我们来看看特定于 TensorFlow Lite 的代码。

TensorFlowLiteSwift

此应用通过 CocoaPods 使用 TensorFlowLite Swift 库。Swift 库是 TFLite C API 的瘦封装容器,而 TFLite C API 本身是 TFLite C++ 库的封装容器。

模块的 Podfile 文件中的以下行会将最新版本的 Pod 全局 CocoaPods 规范代码库提取到项目中。

Podfile

target 'ImageClassification' do
  use_frameworks!

  # Pods for ImageClassification
   pod 'TensorFlowLiteSwift'
end

使用 TensorFlow Lite Swift API

与 TensorFlow Lite 交互的代码全部包含在 ModelDataHandler.swift 中。

设置

涉及的第一个块是 ModelDataHandler 的初始化程序:

ModelDataHandler.swift

/// A failable initializer for `ModelDataHandler`. A new instance is created if the model and
/// labels files are successfully loaded from the app's main bundle. Default `threadCount` is 1.
init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int = 1) {
  let modelFilename = modelFileInfo.name

  // Construct the path to the model file.
  guard let modelPath = Bundle.main.path(
    forResource: modelFilename,
    ofType: modelFileInfo.extension
  ) else {
    print("Failed to load the model file with name: \(modelFilename).")
    return nil
  }

  // Specify the options for the `Interpreter`.
  self.threadCount = threadCount
  var options = InterpreterOptions()
  options.threadCount = threadCount
  do {
    // Create the `Interpreter`.
    interpreter = try Interpreter(modelPath: modelPath, options: options)
    // Allocate memory for the model's input `Tensor`s.
    try interpreter.allocateTensors()
  } catch let error {
    print("Failed to create the interpreter with error: \(error.localizedDescription)")
    return nil
  }
  // Load the classes listed in the labels file.
  loadLabels(fileInfo: labelsFileInfo)
}

以下几行代码应该更详细地加以讨论。

以下行会创建 TFLite 解释器:

ModelDataHandler.swift

interpreter = try Interpreter(modelPath: modelPath, options: options)

解释器负责通过 TensorFlow 图传递原始数据输入。我们将模型在磁盘上的路径传递给该解释器,然后该解释器将其加载为 FlatBufferModel

最后一行用于加载标签列表:

loadLabels(fileInfo: labelsFileInfo)

上述操作只是将文本文件中的字符串加载到内存中。

运行模型

涉及的第二个代码块是 runModel 方法。它接受 CVPixelBuffer 作为输入,运行解释器并返回要在应用中输出的文本。

ModelDataHandler.swift

try interpreter.copy(rgbData, toInputAt: 0)
// ...
try interpreter.invoke()
// ...
outputTensor = try interpreter.output(at: 0)

10. 后续步骤

如需了解详情,请参阅以下链接: