TensorFlow.js – Nhận dạng âm thanh bằng công nghệ học chuyển

1. Giới thiệu

Trong lớp học lập trình này, bạn sẽ xây dựng một mạng nhận dạng âm thanh và dùng mạng này để điều khiển thanh trượt trong trình duyệt bằng cách phát ra âm thanh. Bạn sẽ sử dụng TensorFlow.js, một thư viện máy học linh hoạt và mạnh mẽ để viết JavaScript.

Trước tiên, bạn sẽ tải và chạy một mô hình huấn luyện trước có thể nhận dạng 20 lệnh thoại. Sau đó, bằng cách sử dụng micrô, bạn sẽ xây dựng và huấn luyện một mạng nơron đơn giản giúp nhận dạng âm thanh của bạn và di chuyển thanh trượt sang trái hoặc phải.

Lớp học lập trình này sẽ không trình bày lý thuyết đằng sau các mô hình nhận dạng âm thanh. Nếu bạn muốn biết về điều đó, hãy xem hướng dẫn này.

Chúng tôi cũng đã tạo một bảng thuật ngữ về các thuật ngữ trong công nghệ học máy mà bạn có thể tìm thấy trong lớp học lập trình này.

Kiến thức bạn sẽ học được

  • Cách tải mô hình nhận dạng lệnh thoại đã được huấn luyện trước
  • Cách dùng micrô để đưa ra dự đoán theo thời gian thực
  • Cách huấn luyện và sử dụng mô hình nhận dạng âm thanh tuỳ chỉnh bằng micrô của trình duyệt

Vì vậy, hãy bắt đầu.

2. Yêu cầu

Để hoàn tất lớp học lập trình này, bạn sẽ cần:

  1. Phiên bản Chrome mới nhất hoặc một trình duyệt hiện đại khác.
  2. Một trình chỉnh sửa văn bản, chạy cục bộ trên máy của bạn hoặc trên web thông qua Codepen hoặc Glitch.
  3. Có kiến thức về HTML, CSS, JavaScript và Công cụ của Chrome cho nhà phát triển (hoặc công cụ cho nhà phát triển của trình duyệt mà bạn ưu tiên).
  4. Hiểu biết khái niệm ở mức cao về Mạng nơron. Nếu bạn cần giới thiệu hoặc ôn tập lại, hãy cân nhắc xem video này của 3blue1brown hoặc video về Học sâu bằng JavaScript của Ashi Krishnan.

3. Tải TensorFlow.js và Mô hình âm thanh

Mở index.html trong trình chỉnh sửa rồi thêm nội dung sau:

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

Thẻ <script> đầu tiên nhập thư viện TensorFlow.js và <script> thứ hai nhập mô hình Lệnh thoại đã huấn luyện trước. Thẻ <div id="console"> sẽ được dùng để hiện kết quả của mô hình.

4. Dự đoán theo thời gian thực

Tiếp theo, hãy mở/tạo tệp index.js trong một trình soạn thảo mã và bao gồm mã sau:

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. Kiểm tra thông tin dự đoán

Đảm bảo thiết bị của bạn có micrô. Lưu ý rằng tính năng này cũng sẽ hoạt động trên điện thoại di động! Để chạy trang web, hãy mở index.html trong trình duyệt. Nếu đang làm việc từ một tệp cục bộ, bạn phải khởi động một máy chủ web và sử dụng http://localhost:port/ để truy cập vào micrô.

Để khởi động một máy chủ web đơn giản trên cổng 8000:

python -m SimpleHTTPServer

Có thể mất chút thời gian để tải mô hình xuống, vì vậy hãy kiên nhẫn. Ngay khi mô hình này tải, bạn sẽ thấy một từ ở đầu trang. Mô hình này đã được huấn luyện để nhận ra các số từ 0 đến 9 và một vài lệnh bổ sung như "trái", "phải", "có", "không", v.v.

Hãy nói một trong những từ đó. Công cụ đó có hiểu đúng từ của bạn không? Dùng probabilityThreshold để kiểm soát tần suất mô hình kích hoạt – 0,75 có nghĩa là mô hình sẽ kích hoạt khi có hơn 75% tự tin là nghe được một từ nhất định.

Để tìm hiểu thêm về mô hình Lệnh chuyển lời nói và API của mô hình này, hãy xem tệp README.md trên GitHub.

6. Thu thập dữ liệu

Để tạo trải nghiệm thú vị, hãy sử dụng các âm thanh ngắn thay vì toàn bộ từ ngữ để điều khiển thanh trượt!

Bạn sẽ huấn luyện một mô hình để nhận dạng 3 lệnh khác nhau: "Trái", "Phải" và "Noise" (Tiếng ồn) Thao tác này sẽ làm cho thanh trượt di chuyển sang trái hoặc sang phải. Nhận dạng "Tiếng ồn" (không cần hành động) rất quan trọng trong việc phát hiện giọng nói vì chúng ta muốn thanh trượt phản ứng chỉ khi tạo ra âm thanh phù hợp chứ không phải khi chúng ta nói chung và di chuyển xung quanh.

  1. Trước tiên, chúng tôi cần thu thập dữ liệu. Thêm một giao diện người dùng đơn giản vào ứng dụng bằng cách thêm đoạn mã này vào bên trong thẻ <body> trước <div id="console">:
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. Thêm đoạn mã này vào index.js:
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. Xoá predictWord() khỏi app():
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

Chia nhỏ

Ban đầu, mã này có thể gây choáng ngợp, vì vậy hãy xem chi tiết.

Chúng ta đã thêm 3 nút có nhãn "Trái", "Phải" và "Noise" vào giao diện người dùng vào giao diện người dùng, tương ứng với 3 lệnh mà chúng ta muốn mô hình nhận dạng. Khi nhấn các nút này, chúng ta sẽ gọi hàm collect() mới được thêm vào. Hàm này sẽ tạo các ví dụ huấn luyện cho mô hình của chúng ta.

collect() liên kết label với dữ liệu đầu ra của recognizer.listen(). Vì includeSpectrogram là true, recognizer.listen() cung cấp phổ thô (dữ liệu tần số) trong 1 giây âm thanh, được chia thành 43 khung, do đó, mỗi khung hình dài khoảng 23 mili giây âm thanh:

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

Vì muốn sử dụng âm thanh ngắn thay vì từ ngữ để điều khiển thanh trượt, nên chúng tôi chỉ xem xét 3 khung hình cuối (khoảng 70 mili giây):

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

Và để tránh các vấn đề về số, chúng tôi chuẩn hoá dữ liệu để có giá trị trung bình bằng 0 và độ lệch chuẩn là 1. Trong trường hợp này, các giá trị phổ thường là các số âm lớn khoảng -100 và độ lệch 10:

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

Cuối cùng, mỗi ví dụ huấn luyện sẽ có 2 trường:

  • label****: 0, 1 và 2 cho "Trái", "Phải" và "Noise" (Tiếng ồn) .
  • vals****: 696 số chứa thông tin về tần số (quang phổ)

và chúng ta lưu trữ tất cả dữ liệu trong biến examples:

examples.push({vals, label});

7. Kiểm thử tính năng thu thập dữ liệu

Mở index.html trong trình duyệt và bạn sẽ thấy 3 nút tương ứng với 3 lệnh. Nếu đang làm việc từ một tệp cục bộ, bạn phải khởi động máy chủ web và sử dụng http://localhost:port/ để truy cập micrô.

Để khởi động một máy chủ web đơn giản trên cổng 8000:

python -m SimpleHTTPServer

Để thu thập ví dụ cho từng lệnh, hãy phát ra một âm thanh nhất quán nhiều lần (hoặc liên tục) trong khi nhấn và giữ mỗi nút trong 3 đến 4 giây. Bạn nên thu thập khoảng 150 ví dụ cho mỗi nhãn. Ví dụ: chúng ta có thể búng tay khi chỉ báo "Trái", huýt sáo khi nghe "Phải" và thay đổi giữa khoảng lặng và nói để chọn "Tiếng ồn".

Khi bạn thu thập thêm ví dụ, bộ đếm hiển thị trên trang sẽ tăng lên. Bạn cũng có thể kiểm tra dữ liệu bằng cách gọi console.log() trên biến examples trong bảng điều khiển. Tại thời điểm này, mục tiêu là thử nghiệm quy trình thu thập dữ liệu. Sau này, bạn sẽ thu thập lại dữ liệu khi kiểm thử toàn bộ ứng dụng.

8. Huấn luyện người mẫu

  1. Thêm "Chuyến tàu" ngay sau nút "Tiếng ồn" nút ở phần nội dung trong index.html:
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. Thêm đoạn mã sau vào mã hiện có trong index.js:
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. Gọi buildModel() khi ứng dụng tải:
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

Tại thời điểm này nếu làm mới ứng dụng, bạn sẽ thấy một "Chuyến tàu" mới . Bạn có thể kiểm tra hoạt động huấn luyện bằng cách thu thập lại dữ liệu rồi nhấp vào "Đào tạo", hoặc bạn có thể đợi đến bước 10 để kiểm tra hoạt động huấn luyện cùng với thông tin dự đoán.

Chia nhỏ

Ở cấp độ cao, chúng ta đang làm 2 việc: buildModel() xác định cấu trúc mô hình và train() huấn luyện mô hình bằng cách sử dụng dữ liệu đã thu thập.

Cấu trúc mô hình

Mô hình này có 4 lớp: một lớp tích chập xử lý dữ liệu âm thanh (biểu thị dưới dạng ảnh quang phổ), một lớp hồ bơi tối đa, một lớp làm phẳng và một lớp dày đặc ánh xạ đến 3 thao tác:

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

Hình dạng đầu vào của mô hình là [NUM_FRAMES, 232, 1], trong đó mỗi khung là 23 mili giây của âm thanh chứa 232 số tương ứng với các tần số khác nhau (232 được chọn vì đây là số nhóm tần số cần thiết để thu được giọng nói của người). Trong lớp học lập trình này, chúng ta sẽ sử dụng các mẫu có độ dài 3 khung hình (mẫu khoảng 70 mili giây) vì chúng ta sẽ tạo âm thanh thay vì nói toàn bộ từ để điều khiển thanh trượt.

Chúng ta biên soạn mô hình để chuẩn bị sẵn sàng cho việc huấn luyện:

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

Chúng tôi sử dụng trình tối ưu hoá Adam (một trình tối ưu hoá phổ biến dùng trong công nghệ học sâu) và categoricalCrossEntropy để mất dữ liệu (hàm mất tiêu chuẩn) dùng để phân loại. Tóm lại, phương pháp này đo lường mức độ chênh lệch giữa xác suất dự đoán (một xác suất cho mỗi lớp) so với việc có xác suất 100% trong lớp đúng và xác suất 0% cho tất cả các lớp khác. Chúng tôi cũng cung cấp accuracy làm chỉ số để giám sát. Chỉ số này sẽ cho biết tỷ lệ phần trăm số ví dụ mà mô hình đạt được sau mỗi khoảng thời gian huấn luyện.

Đào tạo

Quá trình đào tạo diễn ra 10 lần (khoảng thời gian bắt đầu của hệ thống) trên dữ liệu bằng cách sử dụng kích thước lô là 16 (xử lý 16 ví dụ cùng một lúc) và cho thấy độ chính xác hiện tại trong giao diện người dùng:

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. Cập nhật thanh trượt theo thời gian thực

Bây giờ, chúng ta có thể huấn luyện mô hình của mình, hãy thêm mã để đưa ra dự đoán theo thời gian thực và di chuyển thanh trượt. Thêm mục này ngay sau "Chuyến tàu" nút trong index.html:

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

Và đoạn mã sau trong index.js:

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

Chia nhỏ

Dự đoán theo thời gian thực

listen() lắng nghe micrô và đưa ra dự đoán theo thời gian thực. Mã này rất giống với phương thức collect(), phương thức này chuẩn hoá phổ thô và loại bỏ tất cả trừ khung NUM_FRAMES cuối cùng. Điểm khác biệt duy nhất là chúng tôi cũng gọi mô hình đã huấn luyện để nhận thông tin dự đoán:

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

Đầu ra của model.predict(input) là một Tensor có hình dạng [1, numClasses] đại diện cho sự phân bố xác suất trên số lượng lớp. Nói một cách đơn giản hơn, đây chỉ là một tập hợp các giá trị tin cậy cho từng lớp đầu ra có thể có, với tổng bằng 1. Tensor có kích thước bên ngoài là 1 vì đó là kích thước của lô (một ví dụ).

Để chuyển đổi mức phân phối xác suất thành một số nguyên duy nhất đại diện cho lớp có khả năng cao nhất, chúng ta gọi probs.argMax(1) để trả về chỉ mục lớp có xác suất cao nhất. Chúng ta truyền "1" làm tham số trục vì chúng ta muốn tính toán argMax trên kích thước cuối cùng là numClasses.

Cập nhật thanh trượt

moveSlider() giảm giá trị của thanh trượt nếu nhãn là 0 ("Trái") , sẽ tăng giá trị của thanh trượt nếu nhãn là 1 ("Phải") và bỏ qua nếu nhãn là 2 ("Tiếng ồn").

Xử lý tensor

Để dọn dẹp bộ nhớ GPU, chúng ta cần gọi tf.dispose() theo cách thủ công trên Tensors đầu ra. Lựa chọn thay thế cho tf.dispose() thủ công là gói các lệnh gọi hàm trong tf.tidy(), nhưng không thể sử dụng phương pháp này với các hàm không đồng bộ.

   tf.dispose([input, probs, predLabel]);

10. Kiểm thử ứng dụng hoàn thiện

Mở index.html trong trình duyệt của bạn và thu thập dữ liệu như đã làm trong phần trước bằng 3 nút tương ứng với 3 lệnh. Hãy nhớ nhấn và giữ mỗi nút từ 3 đến 4 giây trong khi thu thập dữ liệu.

Sau khi bạn đã thu thập được ví dụ, hãy nhấn nút "Train". Thao tác này sẽ bắt đầu huấn luyện mô hình và bạn sẽ thấy độ chính xác của mô hình tăng trên 90%. Nếu mô hình không đạt được hiệu suất tốt, hãy thử thu thập thêm dữ liệu.

Sau khi khoá đào tạo hoàn tất, hãy nhấn nút "Listen" (Nghe) để đưa ra dự đoán từ micrô và điều khiển thanh trượt!

Xem thêm hướng dẫn tại http://js.tensorflow.org/.