1. 准备工作
在此 Codelab 中,您将更新在上一个“移动文本分类入门”Codelab 中构建的应用。
前提条件
- 此 Codelab 专为刚接触机器学习的经验丰富的开发者而设计。
- 此 Codelab 是按顺序编排的在线课程的一部分。如果您尚未完成“构建基本的消息式应用”或“构建评论垃圾内容机器学习模型”课程,请先完成这两门课程,然后再继续。
您将 [构建或学习] 的内容
- 您将学习如何在前面的步骤中构建自定义模型,并将其集成到应用中。
所需条件
- Android Studio,或适用于 iOS 的 CocoaPods
2. 打开现有 Android 应用
您可以按照 Codelab 1 中的说明获取此代码,也可以克隆此代码库并从 TextClassificationStep1
加载该应用。
git clone https://github.com/googlecodelabs/odml-pathways
您可以在 TextClassificationOnMobile->Android
路径中找到它。
您还可以以 TextClassificationStep2
形式查看 finished 代码。
打开该应用后,您就可以继续执行第 2 步了。
3. 导入模型文件和元数据
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个 .TFLITE 模型。
您应该已下载模型文件。如果您没有该模型,可以从此 Codelab 的代码库中获取,也可以在此处获取。
创建 assets 目录将其添加到项目中。
- 使用项目导航器,确保在顶部选择了 Android。
- 右键点击 app 文件夹。依次选择 New > Directory。
- 在 New Directory 对话框中,选择 src/main/assets。
您会看到应用中现在有一个新的 assets 文件夹。
- 点击右键素材资源。
- 在随即打开的菜单中,您会看到(在 Mac 上)在 Finder 中显示。选择它。(在 Windows 上,此项操作为 Show in Explorer,在 Ubuntu 上为 Show in Files。)
系统会启动 Finder 来显示文件位置(在 Windows 上为 File Explorer,在 Linux 上为 Files)。
- 将
labels.txt
、model.tflite
和vocab
文件复制到此目录。
- 返回 Android Studio,您将在 assets 文件夹中看到这些素材资源。
4. 更新 build.gradle 以使用 TensorFlow Lite
如需使用 TensorFlow Lite 及其支持的 TensorFlow Lite 任务库,您需要更新 build.gradle
文件。
Android 项目通常有多个级别,因此请务必找到 app 级别。在 Android 视图的 Project Explorer 中,在 Gradle Scripts 部分中找到该文件。正确的文件将带有 .app 标签,如下所示:
您需要对此文件进行两项更改。第一种方法是在底部的依赖项部分中进行。为 TensorFlow Lite 任务库添加文本 implementation
,如下所示:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
自编写以来,版本号可能已发生变化,因此请务必查看 https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier 了解最新信息。
任务库的最低 SDK 版本也要求为 21。在 android
> default config
中找到此设置,然后将其更改为 21:
现在,您已经有了所有依赖项,是时候开始编码了!
5. 添加辅助类
为了将推理逻辑(即您的应用使用模型的位置)与界面分离开,请创建另一个类来处理模型推理。将其称为“辅助”类。
- 右键点击
MainActivity
代码所在的软件包名称。 - 依次选择 New > Package。
- 您会在屏幕中央看到一个对话框,要求您输入软件包名称。将其添加到当前软件包名称的末尾。(此处称为帮助程序。)
- 完成此操作后,右键点击 Project Explorer 中的 helpers 文件夹。
- 依次选择 New > Java Class,并将其命名为
TextClassificationClient
。您将在下一步中修改该文件。
您的 TextClassificationClient
辅助类将如下所示(但您的软件包名称可能有所不同)。
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- 使用以下代码更新文件:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
此类将为 TensorFlow Lite 解释器提供封装容器,用于加载模型并抽象化管理应用与模型之间数据交换的复杂性。
在 load()
方法中,它会从模型路径实例化一个新的 NLClassifier
类型。模型路径只是模型名称 model.tflite
。NLClassifier
类型属于文本任务库,可帮助您使用正确的序列长度将字符串转换为令牌,将其传递给模型并解析结果。
(如需详细了解这些参数,请参阅“构建垃圾评论机器学习模型”)。
分类在分类方法中执行,您需要向其传递一个字符串,然后该方法将返回 List
。如果您想判断某个字符串是否为垃圾内容,在使用机器学习模型对内容时进行分类时,系统通常会返回所有回答以及指定的概率。例如,如果您向它传递了一条看似垃圾消息的消息,您会返回 2 个答案列表;一个表示垃圾消息的概率,另一个表示垃圾消息的概率。“垃圾内容/非垃圾内容”是类别,因此返回的 List
将包含这些概率。您稍后会解析该内容。
现在,您已经有了辅助类,请返回 MainActivity
并对其进行更新,以便使用该类对文本进行分类。您将在下一步中看到!
6. 对文本进行分类
在 MainActivity
中,您首先需要导入刚刚创建的帮助程序!
- 在
MainActivity.kt
顶部,除了其他导入内容以外,添加以下内容:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- 接下来,您需要加载帮助程序。在
onCreate
中,紧跟在setContentView
行之后,添加以下几行代码以实例化并加载辅助类:
val client = TextClassificationClient(applicationContext)
client.load()
目前,按钮的 onClickListener
应如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- 更新后如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
这会将功能从仅输出用户输入更改为先对其进行分类。
- 通过这行代码,您将获取用户输入的字符串并将其传递给模型,以获取结果:
var results:List<Category> = client.classify(toSend)
只有 2 个类别:False
和 True
(TensorFlow 会按字母顺序对其进行排序,因此 False 将是项 0,True 将是项 1。)
- 如需获取值为
True
的概率的分数,可以按如下方式查看 results[1].score:
val score = results[1].score
- 选择了阈值(本例中为 0.8),也就是如果 True 类别的得分高于阈值 (0.8),那么邮件就是垃圾邮件。否则,该邮件不是垃圾邮件,可以放心地发送:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- 点击此处可查看该模型的实际运作情况。系统已将消息“Visit my blog to buy stuff!” 标记为极有可能是垃圾内容:
反之,“Hey, fun tutorial, thanks!” 被认为不太可能是垃圾内容:
7. 更新您的 iOS 应用以使用 TensorFlow Lite 模型
您可以按照 Codelab 1 中的说明获取此代码,也可以克隆此代码库并从 TextClassificationStep1
加载该应用。您可以在 TextClassificationOnMobile->iOS
路径中找到它。
您还可以将完成代码作为 TextClassificationStep2
使用。
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个非常简单的应用,让用户能够在 UITextView
中输入消息,并将其传递到输出,而无需进行任何过滤。
现在,您将更新该应用,以便在发送之前使用 TensorFlow Lite 模型检测文本中的垃圾评论。只需在输出标签中呈现文本,即可在该应用中模拟发送操作(但真实应用可能具有公告板、聊天或类似功能)。
首先,您需要第 1 步中的应用,您可以从代码库中克隆该应用。
要集成 TensorFlow Lite,您将使用 CocoaPods。如果您尚未安装这些工具,可以按照 https://cocoapods.org/ 上的说明进行安装。
- 安装 CocoaPods 后,在 TextClassification 应用的
.xcproject
所在的同一目录中创建一个名为 Podfile 的文件。此文件的内容应如下所示:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
第一行中应该包含应用的名称,而不是“TextClassificationStep2”。
使用终端,导航到该目录并运行 pod install
。如果成功,系统会为您创建一个名为 Pods 的新目录,以及一个新的 .xcworkspace
文件。以后,您将使用它,而不是 .xcproject
。
如果失败,请确保您的 Podfile 位于 .xcproject
所在的目录中。导致此问题的主要原因通常是 podfile 位于错误的目录中或目标名称有误!
8. 添加模型和词汇表文件
使用 TensorFlow Lite Model Maker 创建模型时,您可以输出模型(作为 model.tflite
)和词汇表(作为 vocab.txt
)。
- 将它们从 Finder 拖放到项目窗口中,以将其添加到项目中。确保选中添加到目标:
完成后,您应该会在项目中看到它们:
- 请选中您的项目(在上面的屏幕截图中,它是蓝色图标 TextClassificationStep2),然后查看 Build Phases(构建阶段)标签页,仔细检查它们是否已添加到软件包中(以便部署到设备):
9. 加载词汇
在进行 NLP 分类时,系统会使用编码为向量的字词来训练模型。该模型使用一组特定的名称和值对字词进行编码,这些名称和值是在模型训练过程中学习到的。请注意,大多数模型具有不同的词汇表,因此请务必使用训练时生成的模型词汇表。这是您刚刚添加到应用中的 vocab.txt
文件。
您可以在 Xcode 中打开该文件,查看编码。“song”等字词的编码为 6,而“love”的编码为 12。该顺序实际上是频率顺序,因此“I”是数据集中最常见的字词,其次是“check”。
当用户输入字词时,您需要先使用此词汇对其进行编码,然后再将其发送给模型进行分类。
我们来研究一下该代码。首先加载词汇。
- 定义一个类级别变量来存储字典:
var words_dictionary = [String : Int]()
- 然后,在该类中创建
func
以将词汇加载到此字典中:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- 您可以通过从
viewDidLoad
内调用它来运行此脚本:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. 将字符串转换为一系列令牌
用户会输入字词作为句子,该句子将转换为字符串。句子中的每个字词(如果在字典中)都将编码为词汇表中定义的该字词的键值对。
NLP 模型通常接受固定的序列长度。使用 ragged tensors
构建的模型存在例外情况,但在大多数情况下,您会发现问题已得到解决。您在创建模型时指定了此长度。请务必在 iOS 应用中使用相同的长度。
您之前在 Colab 中使用的 TensorFlow Lite Model Maker 的默认值为 20,因此请在此处也进行相应设置:
let SEQUENCE_LENGTH = 20
添加以下 func
,它会接受字符串,将其转换为小写,并移除所有标点符号:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
请注意,序列将是 Int32。之所以选择这种方式,是因为在将值传递给 TensorFlow Lite 时,您需要处理低级内存,而 TensorFlow Lite 会将字符串序列中的整数视为 32 位整数。这样,在将字符串传递给模型时,您将会轻松一些。
11. 进行分类
如需对句子进行分类,必须先根据句子中的字词将其转换为一系列令牌。这已在第 9 步中完成。
现在,您将获取句子并将其传递给模型,让模型对句子进行推理,然后解析结果。
这将使用 TensorFlow Lite 解释器,您需要导入该解释器:
import TensorFlowLite
首先,创建一个接受序列(即 Int32 类型的数组)的 func
:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
这将从 bundle 加载模型文件,并使用该文件调用解释器。
下一步是将存储在序列中的底层内存复制到名为 myData,
的缓冲区中,以便将其传递给张量。实现 TensorFlow Lite pod 和解释器后,您可以访问 Tensor 类型。
以如下方式开始编写代码(仍在分类 func
中):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
如果您在 copyingBufferOf
上收到错误消息,请不要担心。我们将稍后以扩展程序的形式实现此功能。
现在,您需要在解释器上分配张量,将刚刚创建的数据缓冲区复制到输入张量,然后调用解释器进行推理:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
调用完成后,您可以查看解释器的输出以了解结果。
这些将是原始值(每个神经元 4 字节),您必须读取并转换这些值。由于此特定模型有 2 个输出神经元,因此您需要读取 8 个字节,这些字节将转换为 Float32 以进行解析。您在处理低级内存,因此使用了 unsafeData
。
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
现在,您可以相对轻松地解析数据以确定垃圾内容质量。该模型有 2 个输出,第一个输出表示相应邮件不是垃圾邮件的概率,第二个输出是垃圾邮件的概率。因此,您可以查看 results[1]
以查找垃圾邮件的值:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
为方便起见,以下是完整的方法:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. 添加 Swift 扩展程序
上面的代码使用了 Data 类型的扩展,以便您将 Int32 数组的原始位复制到 Data
。该扩展程序的代码如下:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
处理低级内存时,您需要使用“不安全”数据,并且上述代码需要您初始化一组不安全数据。借助此扩展程序,您可以:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. 运行 iOS 应用
运行并测试应用。
如果一切顺利,您应该会在设备上看到如下所示的应用:
如果发送了“Buy my book to learn online trading!” 消息,应用会回发“检测到垃圾内容”提醒,概率为 0 .99%!
14. 恭喜!
您现在已经创建了一个非常简单的应用,该应用使用根据垃圾博客数据进行训练的模型来过滤垃圾评论文本。
典型开发者生命周期中的下一步是探索如何根据在自己社区中找到的数据来自定义模型。您可以在下一个课程活动中了解具体操作方法。