TensorFlow.js – זיהוי אודיו באמצעות למידת העברה

1. מבוא

ב-Codelab הזה תבנה רשת לזיהוי אודיו ותשתמש בה כדי לשלוט בפס הזזה בדפדפן באמצעות השמעת צלילים. אתם תשתמשו ב-TensorFlow.js, ספרייה חזקה וגמישה ללמידת מכונה ב-JavaScript.

קודם כול, טוענים ומפעילים מודל שעבר אימון מראש שיכול לזהות 20 פקודות קוליות. לאחר מכן תשתמשו במיקרופון, תצרו ותאמנים רשת נוירונים פשוטה שמזהה את הצלילים שלכם וגורמת לפס ההזזה לעבור שמאלה או ימינה.

ה-Codelab הזה לא יסקור את התיאוריה שמאחורי מודלים של זיהוי אודיו. אם אתם סקרנים לגבי הנושא הזה, כדאי לעיין במדריך הזה.

יצרנו גם מילון מונחים של מונחים של למידת מכונה שמופיעים ב-Codelab הזה.

מה תלמדו

  • איך לטעון מודל לזיהוי פקודות דיבור שעברו אימון מקדים
  • איך לבצע חיזויים בזמן אמת באמצעות המיקרופון
  • איך לאמן מודל לזיהוי אודיו מותאם אישית ולהשתמש בו עם המיקרופון של הדפדפן

אז בואו נתחיל.

2. דרישות

כדי להשלים את ה-Codelab הזה צריך:

  1. גרסה עדכנית של Chrome או דפדפן מתקדם אחר.
  2. כלי לעריכת טקסט, שפועל באופן מקומי במחשב שלכם או באינטרנט באמצעות משהו כמו Codepen או Glitch.
  3. ידע ב-HTML, ב-CSS, ב-JavaScript וב-כלי הפיתוח ל-Chrome (או בכלי הפיתוח המועדפים עליכם בדפדפנים).
  4. הבנה של מושגים ברמה גבוהה של רשתות נוירונים. אם אתם צריכים מבוא או רענון, תוכלו לצפות בסרטון הזה ב-3blue1brown או בסרטון הזה בנושא למידה עמוקה (Deep Learning) ב-JavaScript של אשי קרישנן.

3. טעינת TensorFlow.js ומודל האודיו

צריך לפתוח את index.html בכלי עריכה ולהוסיף את התוכן הבא:

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

תג <script> הראשון מיובא את ספריית TensorFlow.js, וה<script> השני מייבאת את מודל פקודות הדיבור שעבר אימון מראש. התג <div id="console"> ישמש להצגת הפלט של המודל.

4. חיזוי בזמן אמת

לאחר מכן, פותחים/יוצרים את הקובץ index.js בעורך קוד וכוללים את הקוד הבא:

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. בדיקת החיזוי

מוודאים שלמכשיר יש מיקרופון. חשוב לציין שהתכונה הזו תפעל גם בטלפון נייד! כדי להפעיל את דף האינטרנט, פותחים את index.html בדפדפן. אם בחרת לעבוד מקובץ מקומי, צריך להפעיל שרת אינטרנט כדי לגשת למיקרופון ולהשתמש ב-http://localhost:port/.

כדי להפעיל שרת אינטרנט פשוט ביציאה 8000:

python -m SimpleHTTPServer

הורדת המודל עשויה להימשך זמן מה, אז יש להתאזר בסבלנות. ברגע שהמודל נטען, אתה אמור לראות מילה בחלק העליון של הדף. המודל אומן לזהות את המספרים מ-0 עד 9 וכמה פקודות נוספות כמו "left", "right", "yes", "no" וכו'.

אומרים אחת מהמילים האלה. האם הוא נאמר בצורה נכונה? משחקים עם probabilityThreshold, שקובעים את תדירות ההפעלה של המודל. המשמעות של 0.75 היא שהמודל יופעל כשיש סבירות של יותר מ-75% שהוא שומע מילה מסוימת.

מידע נוסף על המודל 'פקודות קוליות' ועל ה-API שלו זמין ב-README.md ב-GitHub.

6. איסוף נתונים

כדי שיהיה כיף, נשתמש בצלילים קצרים במקום במילים שלמות כדי לשלוט בפס ההזזה.

אתם עומדים לאמן מודל לזהות 3 פקודות שונות: 'Left', 'Right' ו'רעש' שיגרמו לפס ההזזה לנוע שמאלה או ימינה. זיהוי 'רעש' (אין צורך בפעולה נוספת) היא קריטית בזיהוי דיבור, כי אנחנו רוצים שפס ההזזה יגיב רק כאשר אנחנו מפיקים את הצליל הנכון, ולא כאשר אנו מדברים או זזים.

  1. קודם כל, אנחנו צריכים לאסוף נתונים. כדי להוסיף ממשק משתמש פשוט לאפליקציה, צריך להוסיף את הקוד הבא בתוך התג <body> לפני <div id="console">:
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. הוסף את זה ל-index.js:
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. הסרת predictWord() מהחשבון app():
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

להסבר מפורט

הקוד הזה עלול להיות מפחיד בהתחלה, אז נסביר אותו.

הוספנו לממשק המשתמש שלנו שלושה לחצנים שמסומנים בתווית 'שמאל', 'ימינה' ו'רעש', בהתאם לשלוש הפקודות שאנחנו רוצים שהמודל שלנו יזהו. לחיצה על הלחצנים האלה מפעילה את הפונקציה collect() החדשה שנוספה, שיוצרת דוגמאות לאימון למודל שלנו.

collect() משייך label לפלט של recognizer.listen(). מכיוון שהפרמטר includeSpectrogram נכון, recognizer.listen() מספק את הספקטרוגרמה הגולמית (נתוני תדירות) לשנייה אחת של אודיו, בחלוקה ל-43 פריימים, כך שכל פריים שווה כ-23 אלפיות שנייה של אודיו:

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

מאחר שאנחנו רוצים להשתמש בצלילים קצרים במקום במילים כדי לשלוט בפס ההזזה, אנחנו מביאים בחשבון רק את 3 הפריימים האחרונים (כ-70 אלפיות שנייה):

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

וכדי להימנע מבעיות מספריות, אנחנו מנרמלים את הנתונים כך שממוצע הנתונים יהיה 0 וסטיית תקן של 1. במקרה כזה, ערכי הספקטרוגרמה הם בדרך כלל מספרים שליליים גדולים בערך של -100 ובסטייה של 10:

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

לבסוף, כל דוגמה לאימון כוללת 2 שדות:

  • label****: 0, 1 ו-2 עבור 'Left', 'Right' וגם 'רעש' בהתאמה.
  • vals****: 696 מספרים שמחזיקים את פרטי התדירות (ספקטרוגרמה)

ואנחנו מאחסנים את כל הנתונים במשתנה examples:

examples.push({vals, label});

7. איסוף נתוני בדיקה

פותחים את index.html בדפדפן, אמורים להופיע 3 לחצנים שתואמים ל-3 הפקודות. אם את/ה עובד/ת מקובץ מקומי, כדי לגשת למיקרופון צריך להפעיל שרת אינטרנט ולהשתמש ב-http://localhost:port/.

כדי להפעיל שרת אינטרנט פשוט ביציאה 8000:

python -m SimpleHTTPServer

כדי לאסוף דוגמאות לכל פקודה, יש להשמיע צליל עקבי באופן חוזר (או מתמשך) תוך לחיצה ארוכה על כל לחצן במשך 3-4 שניות. על כל תווית לאסוף כ-150 דוגמאות. לדוגמה, אנחנו יכולים להקיש באצבעות על 'שמאל', לשרוק ל'ימין' ולהחליף בין שתיקה ולדבר לאמירת 'רעש'.

ככל שתאספו דוגמאות נוספות, המונה שמוצג בדף אמור לעלות. אפשר גם לבדוק את הנתונים באמצעות קריאה ל-console.log() במשתנה examples במסוף. בשלב הזה המטרה היא לבדוק את תהליך איסוף הנתונים. בהמשך, תאספו מחדש את הנתונים כשתבדקו את כל האפליקציה.

8. אימון מודל

  1. מוסיפים רכבת. הלחצן מיד אחרי "רעש" הלחצן בגוף הטקסט ב-index.html:
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. מוסיפים את הקוד הבא לקוד הקיים ב-index.js:
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. קוראים לפונקציה buildModel() כשהאפליקציה נטענת:
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

בשלב הזה אם תרעננו את האפליקציה, יופיע הכיתוב רכבת חדש. לחצן. אפשר להתנסות באימון על ידי איסוף מחדש של הנתונים ולחיצה על 'אימון', או להמתין עד שלב 10 כדי לבדוק את האימון יחד עם החיזוי.

הסבר על התהליך

ברמה הכללית אנחנו עושים שני דברים: buildModel() מגדיר את ארכיטקטורת המודל ו-train() מאמן את המודל באמצעות הנתונים שנאספו.

ארכיטקטורת מודלים

במודל יש 4 שכבות: שכבה קונבולוציה שמעבדת את נתוני האודיו (המיוצגת כספקטרוגרמה), שכבת מאגר מקסימלית, שכבה שטוחה ושכבה צפופה שממפה אל 3 הפעולות:

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

צורת הקלט של המודל היא [NUM_FRAMES, 232, 1], כאשר כל פריים הוא באורך 23 אלפיות שנייה של אודיו שמכיל 232 מספרים שתואמים לתדרים שונים (232 נבחרו כי זו כמות קטגוריות התדירות שנדרשות כדי לתעד את הקול האנושי). ב-Codelab הזה אנחנו משתמשים בדוגמאות שהן באורך של 3 פריימים (דגימות של כ-70 אלפיות השנייה), כי אנחנו מפיקים צלילים במקום לומר מילים שלמות כדי לשלוט בפס ההזזה.

אנחנו יוצרים את המודל שלנו כדי להכין אותו לאימון:

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

אנחנו משתמשים בכלי האופטימיזציה של אדם, כלי אופטימיזציה נפוץ בלמידה עמוקה (Deep Learning), וב-categoricalCrossEntropy בפונקציית אובדן, שמשמשת לסיווג. בקיצור, הוא מודד את המרחק בין ההסתברויות החזויות (הסתברות אחת לכל כיתה) מהסתברות של 100% במחלקה האמיתית, והסתברות של 0% לכל שאר המחלקות. אנחנו גם מספקים את accuracy כמדד למעקב, שניתן לראות בו את אחוז הדוגמאות שבהן המודל מקבל נכון אחרי כל תקופת אימון.

הדרכה

האימון עובר 10 פעמים (תקופות) על הנתונים, תוך שימוש בכמות גדולה של 16 פעמים (עיבוד 16 דוגמאות בכל פעם) ומראה את הדיוק הנוכחי בממשק המשתמש:

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. עדכון פס ההזזה בזמן אמת

עכשיו, אחרי שאנחנו יכולים לאמן את המודל שלנו, נוסיף קוד שיבצע תחזיות בזמן אמת ונזיז את פס ההזזה. צריך להוסיף את הקטע הזה מיד אחרי ה'רכבת'. הלחצן ב-index.html:

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

וגם הקוד הבא ב-index.js:

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

הסבר על התהליך

חיזוי בזמן אמת

listen() מאזין למיקרופון ומפיק חיזויים בזמן אמת. הקוד דומה מאוד לשיטה collect(), שמנרמלת את הספקטרוגרמה הגולמית ומבטלת את כל הפריימים מלבד NUM_FRAMES האחרונים. ההבדל היחיד הוא שאנחנו קוראים גם למודל המאומן כדי לקבל חיזוי:

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

הפלט של model.predict(input) הוא Tensor בצורה [1, numClasses] שמייצגת התפלגות של הסתברות לפי מספר המחלקות. במילים פשוטות יותר, זו רק קבוצת סמך של כל אחד מסיווגי הפלט האפשריים, שמסתכמת ב-1. ל-Tensor יש המימד החיצוני של 1 כי זה גודל הקבוצה (דוגמה אחת).

כדי להמיר את התפלגות ההסתברות למספר שלם יחיד שמייצג את המחלקה הסבירה ביותר, אנחנו מפעילים את הפונקציה probs.argMax(1), שמחזירה את אינדקס המחלקה עם ההסתברות הגבוהה ביותר. אנחנו מעבירים את הערך '1' בתור פרמטר הציר כי אנחנו רוצים לחשב את argMax לפי המאפיין האחרון, numClasses.

עדכון פס ההזזה

moveSlider() מקטין את ערך פס ההזזה אם התווית היא 0 ("שמאל") , מגדילה אותו אם התווית היא 1 ("ימין") ומתעלמת אם התווית היא 2 ("רעש").

השמטת טינזורים

כדי לפנות מקום בזיכרון ה-GPU חשוב לקרוא ידנית ל-tf.dispose() בפלט Tensors. החלופה ל-tf.dispose() הידנית היא הפעלות של פונקציות בתוך tf.tidy(), אבל לא ניתן להשתמש בה עם פונקציות אסינכרוניות.

   tf.dispose([input, probs, predLabel]);

10. בדיקת האפליקציה הסופית

פותחים את index.html בדפדפן ואוספים נתונים כפי שעשיתם בקטע הקודם, באמצעות 3 הלחצנים שתואמים ל-3 הפקודות. חשוב לזכור ללחוץ לחיצה ארוכה על כל לחצן למשך 3-4 שניות בזמן איסוף הנתונים.

אחרי שאוספים דוגמאות, לוחצים על הלחצן רכבת. כך תתחיל אימון המודל ותוכלו לראות שרמת הדיוק של המודל גבוהה מ-90%. אם הביצועים של המודל לא טובים, נסו לאסוף עוד נתונים.

בסיום האימון, לוחצים על הלחצן Listen (האזנה) כדי לקבל חיזויים מהמיקרופון ולשלוט בפס ההזזה.

תוכלו למצוא מדריכים נוספים בכתובת http://js.tensorflow.org/.