אימון TensorFlow.js ב-Codelab של Node.js

1. מבוא

בשיעור הזה ב-Codelab, נלמד איך לפתח שרת אינטרנט מסוג Node.js כדי לאמן ולסווג את סוגי הגשות הבייסבול בצד השרת באמצעות TensorFlow.js, ספרייה עוצמתית וגמישה ללמידת מכונה בשביל JavaScript. אתם תפתחו אפליקציית אינטרנט כדי לאמן מודל לחזות את סוג גובה הצליל מנתוני החיישן של גובה הצליל ולהפעיל חיזוי מלקוח אינטרנט. גרסה תקינה של ה-Codelab הזה נמצאת במאגר tfjs-examples ב-GitHub.

מה תלמדו

  • איך להתקין ולהגדיר את חבילת ה-npm של tensorflow.js לשימוש עם Node.js.
  • איך לגשת לנתוני האימון והבדיקה בסביבת Node.js.
  • איך לאמן מודל באמצעות TensorFlow.js בשרת Node.js.
  • איך לפרוס את המודל שאומן לצורך הסקת מסקנות באפליקציית לקוח/שרת.

אז בואו נתחיל!

2. דרישות

כדי להשלים את הקורס Codelab, צריך:

  1. גרסה עדכנית של Chrome או דפדפן מתקדם אחר.
  2. כלי לעריכת טקסט ומסוף פקודות שפועלים באופן מקומי במחשב שלכם.
  3. ידע ב-HTML, ב-CSS, ב-JavaScript וב-כלי הפיתוח ל-Chrome (או בכלי הפיתוח המועדפים עליכם בדפדפנים).
  4. הבנה של מושגים ברמה גבוהה של רשתות נוירונים. אם אתם צריכים מבוא או רענון, תוכלו לצפות בסרטון הזה ב-3blue1brown או בסרטון הזה בנושא למידה עמוקה (Deep Learning) ב-JavaScript של אשי קרישנן.

3. הגדרה של אפליקציית Node.js

מתקינים את Node.js ו-npm. מידע על פלטפורמות ויחסי תלות נתמכים זמין במדריך ההתקנה של צומת tfjs-node.

יוצרים ספרייה בשם ./baseball עבור אפליקציית Node.js. מעתיקים את package.json ו-webpack.config.js המקושרים לספרייה הזו כדי להגדיר את יחסי התלות של חבילת ה-npm (כולל חבילת ה-npm @tensorflow/tfjs-node). לאחר מכן מריצים התקנת NPM כדי להתקין את יחסי התלות.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

עכשיו אתם מוכנים לכתוב קוד ולאמן מודל!

4. הגדרת נתוני האימון והבדיקה

נתוני האימון והבדיקה ישמשו אותך כקובצי CSV מהקישורים הבאים. מורידים את הנתונים בקבצים הבאים ובודקים אותם:

pitch_type_training_data.csv

pitch_type_test_data.csv

בואו נבחן כמה נתוני אימון לדוגמה:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

יש שמונה תכונות קלט שמתארות את הנתונים של חיישן הגובה-רוחב:

  • מהירות הכדור (vx0, vy0, vz0)
  • האצת כדור (ax, ay, az)
  • מהירות ההתחלה של גובה הצליל
  • אם מבצע ההגשות הוא יד שמאל או לא

ותווית פלט אחת:

  • Pitch_code שמציין אחד משבעה סוגים של הצעות לשירים: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

המטרה היא לפתח מודל שיכול לחזות את סוג גובה הצליל על סמך נתונים מחיישן של גובה הצליל.

לפני יצירת המודל, צריך להכין את נתוני האימון והבדיקה. יוצרים את הקובץ pull_type.js ב-בייסבול/ dir ומעתיקים אליו את הקוד הבא. הקוד הזה טוען נתוני אימון ובדיקה באמצעות ממשק API של tf.data.csv. הוא גם מנרמל את הנתונים (שמומלץ תמיד) באמצעות סולם נירמול מינימלי-מקסימלי.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. יצירת מודל לסיווג סוגי ההצעות לשיר

עכשיו אתם מוכנים לבנות את המודל. משתמשים ב-API של tf.layers כדי לחבר את ערכי הקלט (בצורה של [8] ערכים של חיישן הגובה) ל-3 שכבות מוסתרות שמחוברות באופן מלא, שמורכבות מיחידות הפעלת ReLU, ואחריה שכבת פלט softmax אחת שמורכבת מ-7 יחידות, שכל אחת מייצגת את אחד מסוגי גובה הצליל.

אפשר לאמן את המודל באמצעות הכלי לאופטימיזציה של adam והפונקציה sparseCategoricalCrossentropy Losentropy. מידע נוסף על האפשרויות האלה זמין במדריך למודלים של אימון.

מוסיפים את הקוד הבא לסוף של pull_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

מפעילים את האימון מקוד השרת הראשי שתכתוב מאוחר יותר.

כדי להשלים את המודול get_type.js, נכתוב פונקציה שמעריכה את קבוצת הנתונים לאימות ובודקת את מערך הנתונים, חיזוי של סוג מצגת המכירות לדוגמה אחת וחישוב מדדי הדיוק. יש להוסיף את הקוד הבא בסוף השדה Pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. מודל אימון בשרת

לכתוב את קוד השרת כדי לבצע אימון והערכה של מודלים בקובץ חדש שנקרא server.js. קודם כול, יוצרים שרת HTTP ופותחים חיבור שקע דו-כיווני באמצעות ה-API socket.io. לאחר מכן תצטרכו לבצע אימון מודלים באמצעות ה-API model.fitDataset, ולהעריך את מידת הדיוק של המודל באמצעות השיטה pitch_type.evaluate() שכתבתם קודם. אימון והערכה של 10 חזרות, והדפסת מדדים למסוף.

מעתיקים את הקוד הבא ל-server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

בשלב הזה אתם מוכנים להפעיל את השרת ולבדוק אותו! אתם אמורים לראות משהו כזה, כאשר אימון השרת אימון אחד בכל איטרציה (אפשר גם להשתמש ב-API model.fitDataset כדי לאמן מספר תקופות של זמן מערכת באמצעות קריאה אחת). אם תיתקלו בשגיאות בשלב זה, כדאי לבדוק את ההתקנה של הצומת וה-npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

כדי להפסיק את פעולת השרת, מקלידים Ctrl-C. נפעיל אותה שוב בשלב הבא.

7. יצירת דף לקוח וקוד לתצוגה

עכשיו השרת מוכן, השלב הבא הוא לכתוב את קוד הלקוח שיפעל בדפדפן. יוצרים דף פשוט כדי להפעיל חיזוי של מודל בשרת ולהציג את התוצאה. נעשה שימוש ב-socket.io לתקשורת עם לקוח/שרת.

קודם כל, יוצרים index.html בתיקיית הבייסבול/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

לאחר מכן יוצרים קובץ client.js חדש בתיקיית הבייסבול/ בתיקייה עם הקוד הבא:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

הלקוח מטפל בהודעה בשקע trainingComplete כדי להציג לחצן חיזוי. כשמשתמש לוחץ על הלחצן הזה, הלקוח שולח הודעת socket עם נתוני חיישנים לדוגמה. כשמקבלים הודעה מסוג predictResult, החיזוי מוצג בדף.

8. הפעלת האפליקציה

מריצים גם את השרת וגם את הלקוח כדי לראות את כל האפליקציה בפעולה:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

פותחים את דף הלקוח בדפדפן ( http://localhost:8080). כאשר אימון המודל מסתיים, לוחצים על הלחצן חיזוי דגימה. תוצאת חיזוי אמורה להופיע בדפדפן. אתם יכולים לשנות את נתוני החיישנים לדוגמה בעזרת כמה דוגמאות מקובץ ה-CSV לבדיקה, ולראות עד כמה המודל חזוי.

9. מה למדתם

ב-Codelab הזה יישמת אפליקציית אינטרנט פשוטה של למידת מכונה באמצעות TensorFlow.js. אימנת מודל מותאם אישית לסיווג סוגים של מגרש בייסבול על סמך נתוני חיישנים. כתבתם קוד של Node.js כדי לבצע אימון בשרת ולהסיק מסקנות מהמודל המאומן באמצעות נתונים שנשלחו מהלקוח.

בכתובת tensorflow.org/js תוכלו למצוא דוגמאות נוספות והדגמות עם קוד, שיעזרו לכם להבין איך להשתמש ב-TensorFlow.js באפליקציות.