TensorFlow.js - 使用迁移学习进行音频识别

1. 简介

在此 Codelab 中,您将构建一个音频识别网络,并使用该网络通过发出声音来控制浏览器中的滑块。您将使用 TensorFlow.js,后者是采用 JavaScript 的一种功能强大且灵活的机器学习库。

首先,您将加载并运行一个 预训练模型,该模型可以识别 20 个语音命令。然后,您将使用麦克风构建并训练一个简单的神经网络,该网络可以识别您的声音并使滑块向左或向右移动。

此 Codelab 不会 介绍音频识别模型背后的理论。如果您对此感兴趣,请查看 本教程

我们还创建了一个 机器学习术语表,其中包含您在此 Codelab 中会遇到的术语。

学习内容

  • 如何加载经过预先训练的语音命令识别模型
  • 如何使用麦克风进行实时预测
  • 如何使用浏览器麦克风训练和使用自定义音频识别模型

下面我们开始步入正题

2. 要求

要完成本 Codelab,您需要:

  1. 最新版本的 Chrome 或其他现代浏览器。
  2. 文本编辑器,可在本地计算机上运行,也可通过 CodepenGlitch 等工具在网络上运行。
  3. 了解 HTML、CSS、JavaScript 和 Chrome 开发者工具(或您的首选浏览器开发者工具)。
  4. 大致了解神经网络的概念。如果您需要了解简介或回顾内容,请考虑观看 这部由 3blue1brown 制作的视频,或 Ashi Krishnan 这部有关使用 JavaScript 构建深度学习应用的 视频

3. 加载 TensorFlow.js 和音频模型

在编辑器中打开 index.html 并添加以下内容:

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

第一个 <script> 标记用于导入 TensorFlow.js 库,第二个 <script> 标记用于导入经过预先训练的 语音命令模型<div id="console"> 标记将用于显示模型的输出。

4. 实时预测

接下来,在代码编辑器中打开/创建文件 index.js ,并添加以下代码:

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. 测试预测

确保您的设备有麦克风。值得注意的是,此功能也适用于手机!如需运行网页,请在浏览器中打开 index.html 。如果您使用的是本地文件,则必须启动网络服务器并使用 http://localhost:port/ 才能访问麦克风。

如需在端口 8000 上启动简单的网络服务器,请执行以下操作:

python -m SimpleHTTPServer

下载模型可能需要一些时间,请耐心等待。模型加载完毕后,您应该会在页面顶部看到一个字词。该模型经过训练,可以识别数字 0 到 9 以及一些其他命令,例如“left”“right”“yes”“no”等。

说出其中一个字词。它是否能正确识别您的字词?您可以尝试使用 probabilityThreshold,该变量用于控制模型触发的频率 - 0.75 表示当模型确信听到给定字词的概率超过 75% 时,模型会触发。

如需详细了解语音命令模型及其 API,请参阅 Github 上的 README.md

6. 收集数据

为了增加趣味性,我们使用短声音而不是完整的字词来控制滑块!

您将训练一个模型来识别 3 个不同的命令:“Left”“Right”和“Noise”,这些命令将使滑块向左或向右移动。识别“Noise”(无需执行任何操作)在语音检测中至关重要,因为我们希望滑块仅在我们发出正确的声音时做出反应,而不是在我们说话和四处走动时做出反应。

  1. 首先,我们需要收集数据。在 <div id="console"> 之前的 <body> 标记内添加以下内容,为应用添加一个简单的界面:
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. 将以下内容添加到 index.js
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. app() 中移除 predictWord()
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

细分

这段代码乍一看可能让人不知所措,下面我们来细分一下。

我们在界面中添加了三个按钮,分别标记为“Left”“Right”和“Noise”,对应于我们希望模型识别的三个命令。按下这些按钮会调用我们新添加的 collect() 函数,该函数会为我们的模型创建训练示例。

collect()labelrecognizer.listen() 的输出相关联。由于 includeSpectrogram 为 true, recognizer.listen() 会提供 1 秒音频的原始频谱图(频率数据),分为 43 帧,因此每帧音频约为 23 毫秒:

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

由于我们希望使用短声音而不是字词来控制滑块,因此我们仅考虑最后 3 帧(约 70 毫秒):

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

为了避免数值问题,我们将数据归一化,使其平均值为 0,标准差为 1。在这种情况下,频谱图值通常是大约 -100 的较大负数,标准差为 10:

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

最后,每个训练示例将包含 2 个字段:

  • label****: 0、1 和 2 分别对应“Left”“Right”和“Noise”。
  • vals****: 696 个数字,用于保存频率信息(频谱图)

我们将所有数据存储在 examples 变量中:

examples.push({vals, label});

7. 测试数据收集

在浏览器中打开 index.html ,您应该会看到 3 个按钮,分别对应于 3 个命令。如果您使用的是本地文件,则必须启动网络服务器并使用 http://localhost:port/ 才能访问麦克风。

如需在端口 8000 上启动简单的网络服务器,请执行以下操作:

python -m SimpleHTTPServer

如需为每个命令收集示例,请在按住 每个按钮 3-4 秒的同时,重复(或持续)发出一致的声音。您应该为每个标签收集约 150 个示例。例如,我们可以用弹指声表示“Left”,用口哨声表示“Right”,用静音和说话交替表示“Noise”。

随着您收集更多示例,页面上显示的计数器应该会增加。您还可以通过在控制台中对 examples 变量调用 console.log() 来检查数据。此时的目标是测试数据收集过程。稍后,您将在测试整个应用时重新收集数据。

8. 训练模型

  1. index.html 的正文中,紧挨着 Noise 按钮添加一个 Train 按钮:
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. 将以下内容添加到 index.js 中的现有代码:
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. 在应用加载时调用 buildModel()
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

此时,如果您刷新应用,您会看到一个新的“Train”按钮。您可以重新收集数据并点击“Train”来测试训练,也可以等到第 10 步再测试训练和预测。

细分

从宏观层面来看,我们做了两件事:buildModel() 定义了模型架构,train() 使用收集的数据训练模型。

模型架构

该模型有 4 层:一个处理音频数据(以频谱图表示)的卷积层、一个最大池化层、一个展平层和一个映射到 3 个操作的密集层:

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

模型的输入形状为 [NUM_FRAMES, 232, 1],其中每帧音频为 23 毫秒,包含 232 个数字,这些数字对应于不同的频率(之所以选择 232,是因为这是捕获人声所需的频率桶数量)。在此 Codelab 中,我们使用的是 3 帧长的样本(约 70 毫秒的样本),因为我们发出的是声音而不是说出完整的字词来控制滑块。

我们编译模型以使其做好训练准备:

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

我们使用 Adam 优化器(深度学习中常用的优化器)和 categoricalCrossEntropy 作为损失函数(用于分类的标准损失函数)。简而言之,它衡量的是预测概率(每个类别一个概率)与真实类别中 100% 的概率以及所有其他类别中 0% 的概率之间的差距。我们还提供 accuracy 作为要监控的指标,该指标将给出模型在每个训练周期后正确识别的示例百分比。

训练

训练使用 16 的批次大小(一次处理 16 个示例)对数据进行 10 次(周期),并在界面中显示当前准确率:

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. 实时更新滑块

现在我们可以训练模型了,接下来添加代码以进行实时预测并移动滑块。在 index.html 中,紧挨着 "Train" 按钮添加以下内容:

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

并在 index.js 中添加以下内容:

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

细分

实时预测

listen() 会监听麦克风并进行实时预测。该代码与 collect() 方法非常相似,后者会对原始频谱图进行归一化,并舍弃除最后 NUM_FRAMES 帧之外的所有帧。唯一的区别是我们还会调用经过训练的模型来获取预测:

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

model.predict(input) 的输出是一个形状为 [1, numClasses] 的张量,表示类别数量的概率分布。更简单地说,这只是一组置信度,分别对应于每个可能的输出类别,总和为 1。该张量的外部维度为 1,因为这是批次的大小(单个示例)。

如需将概率分布转换为表示最可能类别的单个整数,我们调用 probs.argMax(1),该函数会返回概率最高的类别的索引。我们将“1”作为轴参数传递,因为我们希望计算最后一个维度 numClassesargMax

更新滑块

如果标签为 0(“Left”),moveSlider() 会减小滑块的值;如果标签为 1(“Right”),则会增大滑块的值;如果标签为 2(“Noise”),则会忽略滑块的值。

处置张量

为了清理 GPU 内存,我们需要手动对输出张量调用 tf.dispose()。手动 tf.dispose() 的替代方法是将函数调用封装在 tf.tidy() 中,但此方法不能与异步函数一起使用。

   tf.dispose([input, probs, predLabel]);

10. 测试最终应用

在浏览器中打开 index.html ,并使用与 3 个命令对应的 3 个按钮收集数据,就像您在上一部分中所做的那样。请务必在收集数据时按住 每个按钮 3-4 秒。

收集示例后,按 “Train” 按钮。这将开始训练模型,您应该会看到模型的准确率超过 90%。如果您没有获得良好的模型性能,请尝试收集更多数据。

训练完成后,按 “Listen” 按钮,通过麦克风进行预测并控制滑块!

如需查看更多教程,请访问 http://js.tensorflow.org/。