1. 准备工作
在此 Codelab 中,您将更新在之前的“移动文本分类入门”Codelab 中构建的应用。
前提条件
- 此 Codelab 专为刚接触机器学习的经验丰富的开发者而设计。
- 此 Codelab 是有序开发者在线课程的一部分。如果您尚未完成“构建基本消息样式应用”或“构建垃圾评论机器学习模型”,请停止并立即完成。
您将 [构建或学习]内容
- 您将学习如何在前面的步骤中构建自定义模型,并将其集成到应用中。
所需条件
- Android Studio 或 CocoaPods(适用于 iOS)
2. 打开现有 Android 应用
您可按照 Codelab 1 操作来获取相关代码,也可以克隆此代码库并从 TextClassificationStep1
加载该应用。
git clone https://github.com/googlecodelabs/odml-pathways
您可以在 TextClassificationOnMobile->Android
路径中找到该编号。
您也可以以 TextClassificationStep2
的形式获取 finished 代码。
打开该文件后,您就可以继续执行第 2 步了。
3. 导入模型文件和元数据
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个 .TFLITE 模型。
您应该已下载模型文件。如果您没有该模型,可以从此 Codelab 的代码库中获取,也可以在此处获取。
通过创建资源目录将其添加到您的项目。
- 使用项目导航器,确保选择顶部的 Android。
- 右键点击 app 文件夹。选择新建 >目录 -
- 在 New Directory 对话框中,选择 src/main/assets。
您会看到应用中现在出现了新的 assets 文件夹。
- 右键点击 assets。
- 在打开的菜单中,您会看到(在 Mac 上)在“访达”中显示。选择它。(Windows 上显示在资源管理器中显示,Ubuntu 则显示在文件中显示)。
Finder 将启动,以显示文件位置(Windows 上的文件资源管理器,Linux 中的文件)。
- 将
labels.txt
、model.tflite
和vocab
文件复制到此目录。
- 返回 Android Studio,您将在 assets 文件夹中看到这些素材资源。
4. 更新 build.gradle 以使用 TensorFlow Lite
如需使用 TensorFlow Lite 以及支持 TensorFlow Lite 的 TensorFlow Lite 任务库,您需要更新 build.gradle
文件。
Android 项目通常有多个,因此请务必找到应用级别 1。在 Android 视图的 Project Explorer 中,您可以在 Gradle Scripts 部分中找到它。正确的应用将使用 .app 标签,如下所示:
您需要对此文件进行两项更改。第一个位于底部的 dependencies 部分。为 TensorFlow Lite 任务库添加一个文本 implementation
,如下所示:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
自编写以来,版本号可能已发生变化,因此请务必访问 https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier 获取最新版本。
此外,任务库的最低 SDK 版本要求为 21。在 android
中找到此设置 >default config
,并将其更改为 21:
现在,您已经有了所有依赖项,是时候开始编码了!
5. 添加辅助类
如需将应用使用模型的推理逻辑与界面分离开,请再创建一个类来处理模型推理。称之为“助手”类。
- 右键点击
MainActivity
代码所在的软件包名称。 - 选择新建 >软件包。
- 屏幕中央会出现一个对话框,要求您输入软件包名称。将其添加到当前软件包名称的末尾。(此处称为帮助程序。)
- 完成此操作后,右键点击 Project Explorer 中的 helpers 文件夹。
- 选择新建 >Java Class,并将其命名为
TextClassificationClient
。您将在下一步中修改此文件。
您的 TextClassificationClient
辅助类将如下所示(但软件包名称可能会有所不同)。
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- 使用以下代码更新文件:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
此类将为 TensorFlow Lite 解释器提供一个封装容器,用于加载模型并抽象化管理应用与模型之间数据交换的复杂性。
在 load()
方法中,它会从模型路径实例化一个新的 NLClassifier
类型。模型路径就是模型的名称 model.tflite
。NLClassifier
类型是文本任务库的一部分,可帮助您将字符串转换为词元、使用正确的序列长度、将其传递给模型以及解析结果。
(如需详细了解这些内容,请再次参阅“构建垃圾评论机器学习模型”。)
分类在分类方法中执行,您需要向其传递一个字符串,然后该方法会返回 List
。如果您想判断某个字符串是否为垃圾内容,在使用机器学习模型对内容进行分类时,系统通常会返回所有回答以及指定的概率。例如,如果您向它传递了一条看似垃圾内容的消息,就会收到包含 2 个答案的列表;一组的概率是垃圾模型,另一组则预测这不是垃圾信息的概率。“垃圾内容”/“非垃圾内容”是类别,因此返回的 List
将包含这些概率。稍后您将对其进行解析。
现在您已经有了辅助类,请返回 MainActivity
并对其进行更新,以使用此类对文本进行分类。您将在下一步中看到它!
6. 对文本进行分类
在 MainActivity
中,您首先需要导入刚刚创建的帮助程序!
- 在
MainActivity.kt
的顶部以及其他导入内容的顶部,添加以下代码:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- 接下来,您需要加载帮助程序。在
onCreate
中,紧跟在setContentView
行之后,添加以下几行代码以实例化并加载帮助程序类:
val client = TextClassificationClient(applicationContext)
client.load()
此时,按钮的 onClickListener
应如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- 更新后如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
这会将功能从仅输出用户的输入更改为首先对其进行分类。
- 在此行中,您将获取用户输入的字符串并将其传递给模型,从而获得结果:
var results:List<Category> = client.classify(toSend)
只有 2 个类别:False
和 True
(TensorFlow 按字母顺序对其进行排序,因此 False 将是第 0 项,True 将是第 1 项)。
- 如需获取值为
True
的概率的分数,可以按如下方式查看 results[1].score:
val score = results[1].score
- 选择了阈值(本例中为 0.8),也就是如果 True 类别的得分高于阈值 (0.8),那么邮件就是垃圾邮件。否则,此邮件就不是垃圾邮件,邮件中可以安全发送:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- 在此处查看模型的实际运用。“访问我的博客买东西!”消息已被标记为很可能是垃圾邮件:
反过来说“嘿,有趣的教程,谢谢!”被认定为垃圾信息的可能性极低:
7. 更新 iOS 应用以使用 TensorFlow Lite 模型
您可按照 Codelab 1 操作来获取相应代码,也可以克隆此代码库并从 TextClassificationStep1
加载该应用。您可以在 TextClassificationOnMobile->iOS
路径中找到该编号。
您也可以以 TextClassificationStep2
的形式获取 finished 代码。
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个非常简单的应用,允许用户在 UITextView
中输入消息,然后将消息传递到输出,无需任何过滤操作。
现在,您将更新该应用,以便在发送之前使用 TensorFlow Lite 模型检测文本中的垃圾评论。只需在输出标签中呈现文本,即可模拟此应用中的发送操作(但真实的应用可能会提供公告板、聊天工具或类似内容)。
首先,您需要第 1 步中的应用,您可以从代码库中克隆该应用。
要集成 TensorFlow Lite,您将使用 CocoaPods。如果您尚未安装这些软件,可以按照 https://cocoapods.org/ 中的说明进行安装。
- 安装 CocoaPods 后,在 TextClassification 应用的
.xcproject
所在的同一目录中创建一个名为 Podfile 的文件。此文件的内容应如下所示:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
第一行中应该包含应用的名称,而不是“TextClassificationStep2”。
使用终端导航到该目录并运行 pod install
。如果运行成功,您会看到一个名为 Pods 的新目录,以及一个为您创建的新 .xcworkspace
文件。您将来会用到它,而不是 .xcproject
。
如果上传失败,请确保 Podfile 位于 .xcproject
之前所在的同一目录中。错误目录或错误目标名称的 Podfile 通常是主要问题!
8. 添加模型和词汇表文件
使用 TensorFlow Lite Model Maker 创建模型时,您可以输出模型(以 model.tflite
格式)和词汇表(以 vocab.txt
格式)。
- 在“访达”中将其拖放到项目窗口中,即可将其添加到项目中。确保选中添加到定位条件:
完成后,您应该会在项目中看到它们:
- 选择您的项目(在上面的屏幕截图中,蓝色图标 TextClassificationStep2)并查看 Build Phases 标签页,仔细检查是否已将它们添加到 bundle(以便将其部署到设备):
9. 加载词汇
在进行 NLP 分类时,系统会使用编码为向量的字词来训练模型。该模型使用在模型训练过程中学习的一组特定名称和值对字词进行编码。请注意,大多数模型具有不同的词汇表,因此请务必使用训练时生成的模型词汇表。这是您刚刚添加到应用中的 vocab.txt
文件。
您可以在 Xcode 中打开该文件以查看编码。类似“歌曲”的字词编码为 6 且“love”。顺序实际上是频率顺序,因此“I”是数据集中最常见的单词,其次是“check”。
当用户输入字词时,您需要使用该词汇表对其进行编码,然后再将其发送到模型进行分类。
我们来研究一下该代码。首先加载词汇。
- 定义一个类级别变量来存储字典:
var words_dictionary = [String : Int]()
- 然后,在类中创建一个
func
,将词汇表加载到此字典中:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- 您可以通过在
viewDidLoad
中调用来运行它:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. 将字符串转换为一个词元序列
您的用户将输入的字词作为句子,该句子将成为字符串。句子中的每个字词(如果存在于字典中)都将编码到词汇表中定义的字词的键值中。
NLP 模型通常接受固定的序列长度。使用 ragged tensors
构建的模型存在例外情况,但在大多数情况下,您会发现问题已得到解决。您在创建模型时指定了此长度。确保在 iOS 应用中使用相同的长度。
您之前使用的 TensorFlow Lite Model Maker 的 Colab 默认值为 20,因此也请在此处设置:
let SEQUENCE_LENGTH = 20
添加以下 func
,它会接受字符串并将其转换为小写形式,并删除所有标点符号:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
请注意,该顺序将是 Int32 序列。这是我们特意选择的,因为在向 TensorFlow Lite 传递值时,您要处理的是低级内存,而 TensorFlow Lite 会将字符串序列中的整数视为 32 位整数。这会让您在将字符串传递给模型时,会稍微容易一些。
11. 分类
若要对句子进行分类,必须先根据句子中的字词将其转换为词元序列。这项操作在第 9 步完成。
现在,您需要将句子传递给模型,让模型对句子进行推理,然后解析结果。
这将使用 TensorFlow Lite 解释器,您需要将其导入:
import TensorFlowLite
从接受您的序列(Int32 类型的数组)的 func
开始:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
这将从 bundle 加载模型文件,并使用该文件调用解释器。
下一步是将存储在序列中的底层内存复制到名为 myData,
的缓冲区中,以便将其传递给张量。在实现 TensorFlow Lite Pod 和解释器时,您可以访问张量类型。
以如下方式开始编写代码(仍在分类 func
中):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
如果您在 copyingBufferOf
上遇到错误,请不要担心。稍后,这将作为扩展程序来实现。
现在,可以在解释器上分配张量,将您刚刚创建的数据缓冲区复制到输入张量,然后调用解释器进行推理了:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
调用完成后,您可以查看解释器的输出以了解结果。
这些是原始值(每个神经元 4 个字节),您必须对其进行读取和转换。由于此特定模型有 2 个输出神经元,因此您需要读取 8 个字节,这些字节将转换为 Float32 进行解析。您要处理的是低级别内存,因此需要 unsafeData
。
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
现在,解析数据以确定垃圾邮件质量是相对容易的。该模型有 2 个输出,第一个输出的概率是相应邮件不是垃圾邮件,第二个输出是垃圾邮件的概率。因此,您可以查看 results[1]
查找垃圾内容值:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
为方便起见,以下是完整的方法:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. 添加 Swift 扩展程序
上述代码使用了数据类型的扩展,以便您将 Int32 数组的原始位复制到 Data
中。该扩展程序的代码如下:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
在处理低级别内存时,会使用“不安全”数据,上述代码需要您初始化不安全的数据数组。此扩展程序可实现以下目的:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. 运行 iOS 应用
运行并测试应用。
如果一切顺利,您应该会在设备上看到该应用,如下所示:
在发送后,应用会发回检测到垃圾内容的提醒,概率为 0 .99%!
14. 恭喜!
您现在已经创建了一个非常简单的应用,该应用使用根据垃圾博客数据进行训练的模型来过滤垃圾评论文本。
典型开发者生命周期中的下一步是探索如何根据自己社区中的数据来自定义模型。在下一在线课程活动中,您将了解到如何操作。