在 iOS 上使用 TensorFlow Lite 辨識 Flowers

1. 簡介

657431be3173fa86.png

TensorFlow 是一種多用途機器學習架構,在任何位置都能使用 TensorFlow,無論是在雲端的叢集上訓練大型模型,還是在手機等嵌入式系統上執行模型,都能使用 TensorFlow。

本程式碼研究室使用 TensorFlow Lite,在 iOS 裝置上執行圖片辨識模型。

課程內容

  • 如何使用 TFLite 轉換工具改善模型。
  • 如何使用 TFLite 解譯器,在預先建立的 iOS 應用程式中執行此 API。

建構目標

簡易的相機應用程式,可執行 TensorFlow 圖片辨識程式來識別花朵。

必要條件

如果您是在自己的硬體上執行這個程式碼研究室,請確認您已安裝以下內容:

  • Xcode 10 以上版本
  • CocoaPods 1.8.0 以上版本

c45ecd122998622e.png

授權:可免費使用

2. 使用 Colab 訓練花朵辨識工具

本程式碼研究室將使用 Colaboratory 和 Xcode。

開啟 Colab,使用 TensorFlow Lite Model Maker 訓練分類器,透過遷移學習來辨識花朵,並匯出要在行動應用程式中使用的 TFLite 模型。

3. 設定工作目錄

複製 Git 存放區

下列指令會複製 Git 存放區,其中包含本程式碼研究室所需的檔案:

git clone https://github.com/tensorflow/examples.git

現在,將 cd 放入剛建立本機副本的 Xcode 專案根目錄。您會在這程式碼研究室的後續部分工作:

cd examples/lite/examples/image_classification/ios

4. 設定 iOS 應用程式

安裝依附元件

使用 CocoaPods 安裝 iOS 應用程式的依附元件 (包括 TensorFlow Lite)。安裝指令完成後,開啟 ImageClassification.xcworkspace 即可在 Xcode 中開啟專案。

pod install --repo-update
open ImageClassification.xcworkspace

5. 執行應用程式測試

由於 iOS 模擬器無法存取 Mac 的相機,因此應用程式必須在實體裝置上執行,才能使用相機。如要打造 iOS 裝置,您必須註冊 Apple Developer Program,或擁有他人為您佈建的裝置。

如果您希望在模擬器中執行本程式碼研究室,就需要在模擬器中複製圖像,才能貼上模擬器。在模擬工具中處理圖片的步驟如下:

  1. 針對您選擇的模擬工具目標建構應用程式。
  2. 在 iOS 模擬工具中,按下 Cmd+Shift+H 鍵,將應用程式最小化。
  3. 輕觸主畫面底部的 [Safari],然後搜尋所需圖片。
  4. 在 Google 圖片搜尋結果中,輕觸任一搜尋結果並長按該圖片。在彈出的對話方塊中,選取「複製」。
  5. 返回 TFL Classify 應用程式。複製的圖片應會和推論結果一併自動顯示。如果不是,請確認您已複製圖片資料,而非圖片的網址。

測試建構並安裝應用程式

請先執行與存放區隨附的版本,再對應用程式進行變更。從左上方的下拉式選單中選取您的 iOS 裝置:

275753d3a77a0df3.png

然後按 Cmd+R 鍵,或點選 Xcode 中的「Play」f96cf117245c0fa6.png按鈕,在裝置中建構應用程式。應用程式安裝到裝置後,應會自動啟動。

此版本的應用程式使用標準 MobileNet,針對 1000 個 ImageNet 類別預先訓練。如下所示:

d11436f0bb5a75db.jpeg

6. 執行自訂應用程式

預設的應用程式設定會使用標準 MobileNet,將圖片分類為 1000 個 ImageNet 類別之一。

現在,請修改應用程式,讓應用程式使用經過重新訓練的模型,處理 Colab 訓練的自訂圖片類別。

7. 轉換應用程式以執行模型

將模型檔案新增至專案

專案的模型資源位於 Xcode 專案導覽器中的 ImageClassification > Model。如要取代這些項目,請先刪除 Model 群組中的兩個現有檔案。畫面上出現提示時,請選取「丟到垃圾桶」:

cf2f7fefb2e5075f.png

然後將從 Colab 下載的 model.tflitelabels.txt 檔案拖曳至 Model 群組。系統提示時,請確認已選取 Copy items if neededAdd to targets

281d7eb72635bb5f.png

修改應用程式的程式碼

為了讓應用程式正常運作,我們需要更新模型載入邏輯的路徑,使其指向新增的新模型。

開啟 ModelDataHandler.swift (Xcode 導覽工具路徑:ImageClassification -> ModelDataHandler -> ModelDataHandler.swift),並將 line 36 變更為

// before
static let modelInfo: FileInfo = (name: "mobilenet_quant_v1_224", extension: "tflite")

// after
static let modelInfo: FileInfo = (name: "model", extension: "tflite")

請務必儲存所有變更。

8. 執行自訂應用程式

按 Cmd+B 鍵,或按下 Xcode 中的「Play」f96cf117245c0fa6.png按鈕,即可在裝置上建構應用程式。應用程式啟動後應如下所示:

c45ecd122998622e.png

同時按住電源鍵和調高音量按鈕,即可擷取螢幕截圖。

現在試著在網路上搜尋花朵,將相機對準電腦螢幕,看看這些相片是否正確分類。

或者請朋友幫你拍照,看看你是哪種 TensorFlower \uf339 \uf33b \uf337

9. 運作方式

現在您已開始執行應用程式,讓我們來看看 TensorFlow Lite 特定程式碼。

TensorFlowLiteSwift

這個應用程式會透過 CocoaPods 使用 TensorFlowLite Swift 程式庫。Swift 程式庫是 TFLite C API 上的精簡包裝函式,本身是 TFLite C++ 程式庫的包裝函式。

模組 Podfile 檔案中的下列程式碼會將 Pod 全域 CocoaPods 規格存放區的最新版本提取至專案。

Podfile

target 'ImageClassification' do
  use_frameworks!

  # Pods for ImageClassification
   pod 'TensorFlowLiteSwift'
end

使用 TensorFlow Lite Swift API

與 TensorFlow Lite 互動的程式碼全都位於 ModelDataHandler.swift 中。

設定

第一個相關區塊是 ModelDataHandler 的初始化器:

ModelDataHandler.swift

/// A failable initializer for `ModelDataHandler`. A new instance is created if the model and
/// labels files are successfully loaded from the app's main bundle. Default `threadCount` is 1.
init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int = 1) {
  let modelFilename = modelFileInfo.name

  // Construct the path to the model file.
  guard let modelPath = Bundle.main.path(
    forResource: modelFilename,
    ofType: modelFileInfo.extension
  ) else {
    print("Failed to load the model file with name: \(modelFilename).")
    return nil
  }

  // Specify the options for the `Interpreter`.
  self.threadCount = threadCount
  var options = InterpreterOptions()
  options.threadCount = threadCount
  do {
    // Create the `Interpreter`.
    interpreter = try Interpreter(modelPath: modelPath, options: options)
    // Allocate memory for the model's input `Tensor`s.
    try interpreter.allocateTensors()
  } catch let error {
    print("Failed to create the interpreter with error: \(error.localizedDescription)")
    return nil
  }
  // Load the classes listed in the labels file.
  loadLabels(fileInfo: labelsFileInfo)
}

您有幾行內容需要進一步討論。

以下程式碼會建立 TFLite 解譯器:

ModelDataHandler.swift

interpreter = try Interpreter(modelPath: modelPath, options: options)

解譯器負責透過 TensorFlow 圖形傳遞原始資料輸入內容。我們會將解譯器路徑傳遞至磁碟上的模型,然後翻譯器將它載入為 FlatBufferModel

最後一行會載入標籤清單:

loadLabels(fileInfo: labelsFileInfo)

這項操作都是將文字檔中的字串載入記憶體

執行模型

第二個相關區塊是 runModel 方法。該函式會將 CVPixelBuffer 做為輸入內容,執行解譯器,並傳回要在應用程式中顯示的文字。

ModelDataHandler.swift

try interpreter.copy(rgbData, toInputAt: 0)
// ...
try interpreter.invoke()
// ...
outputTensor = try interpreter.output(at: 0)

10. 後續步驟

請參閱下列連結瞭解詳情: