TensorFlow.js – rozpoznawanie dźwięku przy użyciu uczenia się transferu

1. Wprowadzenie

W ramach tego ćwiczenia w programie utworzysz sieć rozpoznawania dźwięku i użyjesz jej do sterowania suwakiem w przeglądarce i wydawaniem dźwięków. Będziesz używać TensorFlow.js – wydajnej i elastycznej biblioteki systemów uczących się dla JavaScriptu.

Najpierw wczytasz i uruchomisz wytrenowany model, który rozpoznaje 20 poleceń głosowych. Następnie z pomocą mikrofonu zbudujesz i wytrenujesz prostą sieć neuronowa, która rozpozna Twoje dźwięki i przesunie suwak w lewo lub w prawo.

To ćwiczenie w Codelabs nie będzie opisywać teorii stojącej za modelami rozpoznawania dźwięku. Więcej informacji znajdziesz w tym samouczku.

Opracowaliśmy też glosariusz z pojęciami związanymi z systemami uczącymi się, które znajdziesz w tym ćwiczeniu z programowania.

Czego się nauczysz

  • Jak wczytać wytrenowany model rozpoznawania poleceń głosowych
  • Jak generować prognozy w czasie rzeczywistym za pomocą mikrofonu
  • Jak wytrenować własny model rozpoznawania dźwięku i używać go za pomocą mikrofonu w przeglądarce

Zaczynamy.

2. Wymagania

Do ukończenia tego ćwiczenia w programowaniu będziesz potrzebować:

  1. Mieć najnowszą wersję Chrome lub innej nowoczesnej przeglądarki.
  2. Edytor tekstu działający lokalnie na komputerze lub w internecie za pomocą programu takiego jak Codepen lub Glitch.
  3. znajomość języków HTML, CSS, JavaScript i Narzędzi deweloperskich w Chrome (lub narzędzi deweloperskich preferowanych przez Ciebie w przeglądarce).
  4. Ogólne pojęcie teoretycznego na temat sieci neuronowych. Jeśli potrzebujesz wprowadzenia lub przypomnienia, możesz obejrzeć ten film 3blue1brown lub film o Deep Learning in JavaScript autorstwa Ashiego Krishnana.

3. Wczytaj TensorFlow.js i model audio

Otwórz plik index.html w edytorze i dodaj tę treść:

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

Pierwszy tag <script> importuje bibliotekę TensorFlow.js, a drugi <script> – wytrenowany model Speech Commands. Tag <div id="console"> będzie używany do wyświetlania danych wyjściowych modelu.

4. Prognozowanie w czasie rzeczywistym

Następnie otwórz lub utwórz plik index.js w edytorze kodu i dodaj do niego ten kod:

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. Testowanie prognozy

Upewnij się, że urządzenie ma mikrofon. Funkcja działa również na telefonach komórkowych. Aby uruchomić stronę internetową, otwórz index.html w przeglądarce. Jeśli pracujesz z pliku lokalnego, aby uzyskać dostęp do mikrofonu, musisz uruchomić serwer WWW i użyć usługi http://localhost:port/.

Aby uruchomić prosty serwer WWW na porcie 8000:

python -m SimpleHTTPServer

Pobranie modelu może trochę potrwać, dlatego prosimy o cierpliwość. Po wczytaniu modelu na górze strony powinno pojawić się słowo. Model został wytrenowany do rozpoznawania liczb od 0 do 9 oraz kilku dodatkowych poleceń, takich jak „left”, „right”, „yes”, „no” itp.

Powiedz jedno z tych słów. Czy poprawnie coś mówi? Użyj funkcji probabilityThreshold, która określa, jak często model jest uruchamiany – wartość 0,75 oznacza, że model uruchomi się, gdy będzie ponad 75% pewności, że słyszy dane słowo.

Więcej informacji o modelu Speech Commands i jego interfejsie API znajdziesz w pliku README.md w GitHubie.

6. Zbieranie danych

Aby ułatwić korzystanie z suwaka, użyjmy krótkich dźwięków, a nie całych słów.

Wytrenujesz model do rozpoznawania 3 różnych poleceń: „Left” i „Right”. i „Szum” co spowoduje przesunięcie suwaka w lewo lub w prawo. Rozpoznawanie szumu (nie trzeba nic robić) ma kluczowe znaczenie w przypadku wykrywania mowy, ponieważ chcemy, aby suwak reagował tylko wtedy, gdy generujemy właściwy dźwięk, a nie wtedy, gdy mówimy i się poruszamy.

  1. Najpierw musimy zebrać dane. Dodaj prosty interfejs użytkownika do aplikacji, dodając ten element w tagu <body> przed <div id="console">:
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. Dodaj do index.js:
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. Usuń predictWord() z app():
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

Jak to działa

Początkowo ten kod może być przytłaczający, dlatego warto go omówić.

Do interfejsu dodaliśmy 3 przyciski: „Lewo”, „W prawo” i „Szum”, odpowiadające 3 poleceniom, które nasz model ma rozpoznawać. Naciśnięcie tych przycisków wywołuje naszą nowo dodaną funkcję collect(), która tworzy przykłady treningowe dla naszego modelu.

Funkcja collect() wiąże element label z danymi wyjściowymi recognizer.listen(). Ponieważ includeSpectrogram ma wartość prawda, recognizer.listen() podaje nieprzetworzony spektrogram (dane o częstotliwości) dla 1 sekundy dźwięku podzielonego na 43 klatki, co daje ok. 23 ms dźwięku:

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

Ponieważ do sterowania suwakiem chcemy używać krótkich dźwięków zamiast słów, bierzemy pod uwagę tylko ostatnie 3 klatki (ok. 70 ms):

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

Aby uniknąć problemów liczbowych, normalizujemy dane, aby ich średnia wynosiła 0, a odchylenie standardowe 1. W tym przypadku wartości spektrogramu są zwykle dużymi liczbami ujemnymi w zakresie -100 i odchyleniem wynoszącym 10:

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

Każdy przykład treningowy będzie miał 2 pola:

  • label**** 0, 1 i 2 – w przypadku opcji „Left” i „Right”. i „Szum” .
  • vals****: 696 cyfr z informacjami o częstotliwości (spektrogram).

a wszystkie dane przechowujemy w zmiennej examples:

examples.push({vals, label});

7. Testowe zbieranie danych

Otwórz plik index.html w przeglądarce. Zobaczysz 3 przyciski odpowiadające 3 poleceniom. Jeśli pracujesz z pliku lokalnego, to aby uzyskać dostęp do mikrofonu, musisz uruchomić serwer internetowy i użyć narzędzia http://localhost:port/.

Aby uruchomić prosty serwer WWW na porcie 8000:

python -m SimpleHTTPServer

Aby zebrać przykłady dla każdego polecenia, powtarzaj konsekwentnie (lub stale) dźwięk, naciskając i przytrzymując każdy przycisk przez 3–4 sekundy. Dla każdej etykiety powinno być zbierz ok. 150 przykładów. Możemy na przykład strzelać palcami, aby pokazać „w lewo”, gwizdować na „prawo” i na zmianę przełączać się między trybem ciszy i wypowiedzią „Szum”.

W miarę zbierania kolejnych przykładów licznik widoczny na stronie powinien rosnąć. Dane możesz też sprawdzać przez wywołanie pola Console.log() w zmiennej examples w konsoli. Na tym etapie celem jest przetestowanie procesu gromadzenia danych. Później możesz ponownie zebrać dane, gdy będziesz testować całą aplikację.

8. Trenuj model

  1. Dodaj „Pociąg” za przyciskiem „Szum” w sekcji index.html:.
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. Dodaj do istniejącego kodu w pliku index.js:
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. Wywołaj funkcję buildModel() po załadowaniu aplikacji:
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

Gdy odświeżysz aplikację, zobaczysz nowy „Pociąg”. Przycisk Możesz przetestować trenowanie, ponownie zbierając dane i klikając „Trenuj”. Możesz też poczekać do kroku 10, aby przetestować trenowanie wraz z prognozą.

Jak działa

Ogólnie rzecz biorąc, wykonujemy 2 czynności: buildModel() definiuje architekturę modelu, a train() trenuje model na podstawie zebranych danych.

Architektura modelu

Model ma 4 warstwy: splotową warstwę, która przetwarza dane dźwiękowe (reprezentowana jako spektrogram), maksymalną warstwę puli, warstwę spłaszczoną i gęstą warstwę, która odpowiada 3 czynnościom:

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

Kształt danych wejściowych modelu to [NUM_FRAMES, 232, 1], gdzie każda klatka to 23 ms ścieżki dźwiękowej z 232 cyframi, które odpowiadają różnym częstotliwościom (wybrano 232, ponieważ jest to liczba grup częstotliwości wymaganych do przechwycenia ludzkiego głosu). W tym ćwiczeniu w programowaniu używamy próbek o długości 3 klatek (około 70 ms), ponieważ w celu sterowania suwakiem wydajemy dźwięki, a nie wypowiadamy całe słowa.

Kompilujemy nasz model, aby przygotować go do trenowania:

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

Korzystamy z optymalizatora Adama, który jest często używany w technologii deep learning, i categoricalCrossEntropy do straty – standardowej funkcji utraty wykorzystywanej do klasyfikacji. Krótko mówiąc, mierzy, jak bardzo prawdopodobne jest, że prawdopodobieństwo (jedno prawdopodobieństwo na klasę) będzie większe od prawdopodobieństwa 100% w klasie rzeczywistej i 0% w przypadku wszystkich pozostałych klas. Udostępniamy też accuracy jako wskaźnik do monitorowania, który da nam procent przykładów, które model jest poprawny po każdej epoce trenowania.

Szkolenia

Trenowanie trwa 10 razy (epoki) na danych przy użyciu wsadu 16 (przetwarzanie 16 przykładów naraz) i pokazuje aktualną dokładność w interfejsie:

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. Aktualizowanie suwaka w czasie rzeczywistym

Możemy już wytrenować model, dodajmy więc kod, aby formułować prognozy w czasie rzeczywistym, i przesuń suwak. Dodaj ten fragment bezpośrednio za tekstem „Pociąg”. w pliku index.html:

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

A w pliku index.js te elementy:

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

Jak działa

Prognozowanie w czasie rzeczywistym

listen() słucha mikrofonu i przewiduje w czasie rzeczywistym. Kod jest bardzo podobny do metody collect(), która normalizuje nieprzetworzony spektrogram i usuwa wszystkie klatki oprócz ostatnich NUM_FRAMES. Jedyną różnicą jest to, że w celu uzyskania prognozy wywołujemy też wytrenowany model:

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

Dane wyjściowe funkcji model.predict(input) to tensor o kształcie [1, numClasses], który reprezentuje rozkład prawdopodobieństwa na liczbę klas. Mówiąc prościej, jest to zestaw wartości ufności każdej z możliwych klas wyjściowych, który sumuje się do 1. Tensor ma wymiar zewnętrzny wynoszący 1, ponieważ właśnie on odpowiada rozmiarowi wsadu (pojedynczy przykład).

Aby przekonwertować rozkład prawdopodobieństwa na pojedynczą liczbę całkowitą reprezentującą najbardziej prawdopodobną klasę, wywołujemy funkcję probs.argMax(1), która zwraca indeks klas o najwyższym prawdopodobieństwie. Podajemy wynik „1” jako parametr osi, ponieważ chcemy obliczyć argMax dla ostatniego wymiaru, numClasses.

Aktualizowanie suwaka

moveSlider() zmniejsza wartość suwaka, jeśli etykieta ma wartość 0 („Po lewej”), zwiększa ją, jeśli etykieta ma wartość 1 („Po prawej”), i ignoruje, jeśli etykieta ma wartość 2 („Szum”).

Rozpraszanie tensorów

Aby wyczyścić pamięć GPU, musimy ręcznie wywołać funkcję tf.dispose() na wyjściowych procesorach Tensor. Alternatywą dla ręcznego wstawiania funkcji tf.dispose() jest pakowanie wywołań funkcji w obiekcie tf.tidy(), ale nie można go używać z funkcjami asynchronicznymi.

   tf.dispose([input, probs, predLabel]);

10. Testowanie ostatecznej wersji aplikacji

Otwórz w przeglądarce stronę index.html i zbieraj dane tak jak w poprzedniej sekcji, używając 3 przycisków odpowiadających 3 poleceniom. Pamiętaj, by nacisnąć i przytrzymać każdy przycisk przez 3–4 sekundy podczas zbierania danych.

Po zebraniu przykładów kliknij przycisk „Trenuj”. Spowoduje to rozpoczęcie trenowania modelu, a jego dokładność powinna przekroczyć 90%. Jeśli nie osiągniesz dobrej skuteczności modelu, spróbuj zebrać więcej danych.

Po zakończeniu trenowania naciśnij przycisk "Listen", aby wygenerować przepowiedź do mikrofonu i sterować suwakiem.

Więcej samouczków znajdziesz na http://js.tensorflow.org/.