TensorFlow.js-Training in Node.js-Codelab

1. Einführung

In diesem Codelab erfahren Sie, wie Sie einen Node.js-Webserver erstellen, um Baseball-Pitch-Typen serverseitig mit TensorFlow.js, einer leistungsstarken und flexiblen Bibliothek für maschinelles Lernen, für JavaScript zu trainieren und zu klassifizieren. Sie erstellen eine Webanwendung, um ein Modell zu trainieren, um die Art der Tonhöhe anhand von Tonhöhensensordaten vorherzusagen und die Vorhersage aus einem Webclient aufzurufen. Eine voll funktionsfähige Version dieses Codelabs befindet sich im GitHub-Repository tfjs-examples.

Lerninhalte

  • Hier erfahren Sie, wie Sie das npm-Paket für tensorflow.js zur Verwendung mit Node.js installieren und einrichten.
  • Wie Sie in der Node.js-Umgebung auf Trainings- und Testdaten zugreifen.
  • Hier erfahren Sie, wie Sie ein Modell mit TensorFlow.js auf einem Node.js-Server trainieren.
  • Trainiertes Modell für Inferenz in einer Client-/Serveranwendung bereitstellen

Legen wir los!

2. Voraussetzungen

Für dieses Codelab benötigen Sie Folgendes:

  1. Eine aktuelle Version von Chrome oder einem anderen aktuellen Browser
  2. Ein Texteditor und ein Befehlsterminal, die lokal auf Ihrem Computer ausgeführt werden
  3. Kenntnisse in HTML, CSS, JavaScript und Chrome-Entwicklertools bzw. den Entwicklertools Ihrer bevorzugten Browser
  4. Ein allgemeines Verständnis von neuronalen Netzwerken. Wenn Sie eine Einführung oder Auffrischung benötigen, können Sie sich dieses Video von 3blue1brown oder dieses Video zu Deep Learning in JavaScript von Ashi Krishnan ansehen.

3. Node.js-Anwendung einrichten

Installieren Sie Node.js und npm. Informationen zu unterstützten Plattformen und Abhängigkeiten finden Sie in der Installationsanleitung für tfjs-node.

Erstellen Sie für die Node.js-App ein Verzeichnis mit dem Namen ./Baseball. Kopieren Sie die verknüpfte Datei package.json und webpack.config.js in dieses Verzeichnis, um die Abhängigkeiten des npm-Pakets zu konfigurieren (einschließlich des npm-Pakets @tensorflow/tfjs-node). Führen Sie dann npm install aus, um die Abhängigkeiten zu installieren.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Jetzt können Sie Code schreiben und ein Modell trainieren.

4. Trainings- und Testdaten einrichten

Sie verwenden die Trainings- und Testdaten als CSV-Dateien aus den folgenden Links. Laden Sie die Daten in diesen Dateien herunter und untersuchen Sie sie:

pitch_type_training_data.csv

pitch_type_test_data.csv

Sehen wir uns einige Beispieltrainingsdaten an:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Es gibt acht Eingabemerkmale, die die Tonhöhensensordaten beschreiben:

  • Ballgeschwindigkeit (vx0, vy0, vz0)
  • Ballbeschleunigung (ax, ay, az)
  • Anfangsgeschwindigkeit der Tonhöhe
  • unabhängig davon, ob der Pitcher Linkshänder ist oder nicht

und ein Ausgabelabel:

  • Pitch_Code, der einen von sieben Vorschlagstypen angibt: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

Das Ziel besteht darin, ein Modell zu erstellen, das den Tonhöhentyp anhand von Tonhöhensensordaten vorhersagen kann.

Bevor Sie das Modell erstellen, müssen Sie die Trainings- und Testdaten vorbereiten. Erstellen Sie die Datei „court_type.js“ im „Baseball-/Verzeichnis“ und kopieren Sie den folgenden Code in diese Datei. Dieser Code lädt Trainings- und Testdaten mithilfe der API tf.data.csv. Außerdem normalisiert er die Daten (was immer empfohlen wird) mit einer Min-Max-Normalisierungsskala.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Modell zum Klassifizieren der Arten von Pitches erstellen

Jetzt können Sie das Modell erstellen. Verwenden Sie die tf.layers API, um die Eingaben (Form der Tonhöhensensorwerte [8]) mit drei verborgenen, vollständig verbundenen Ebenen zu verbinden, die aus ReLU-Aktivierungseinheiten bestehen, gefolgt von einer Softmax-Ausgabeschicht aus sieben Einheiten, die jeweils einen der Ausgabetypen für die Tonhöhe darstellen.

Trainieren Sie das Modell mit dem Adam-Optimierer und der sparseCategoricalCrossentropy-Verlustfunktion. Weitere Informationen zu diesen Optionen finden Sie im Leitfaden zum Trainieren von Modellen.

Fügen Sie am Ende von Pitch_type.js den folgenden Code ein:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Lösen Sie das Training über den Hauptservercode aus, den Sie später schreiben.

Um das Modul Pitch_type.js abzuschließen, schreiben wir eine Funktion, um das Validierungs- und Test-Dataset auszuwerten, einen Pitch-Typ für eine einzelne Stichprobe vorherzusagen und die Genauigkeitsmesswerte zu berechnen. Hängen Sie diesen Code am Ende von Pitch_type.js an:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Modell auf dem Server trainieren

Schreiben Sie den Servercode zum Trainieren und Auswerten des Modells in eine neue Datei mit dem Namen „server.js“. Erstellen Sie zuerst einen HTTP-Server und öffnen Sie mithilfe der Socket.io API eine bidirektionale Socket-Verbindung. Führen Sie dann das Modelltraining mit der model.fitDataset API aus und bewerten Sie die Modellgenauigkeit mit der zuvor geschriebenen pitch_type.evaluate()-Methode. Trainieren und für 10 Iterationen bewerten, wobei Messwerte an die Konsole ausgegeben werden

Kopieren Sie den folgenden Code in server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Jetzt können Sie den Server ausführen und testen. Sie sollten in etwa Folgendes sehen, wobei der Server in jeder Iteration eine Epoche trainiert. Sie können auch die model.fitDataset API verwenden, um mehrere Epochen mit einem Aufruf zu trainieren. Wenn jetzt Fehler auftreten, überprüfen Sie die Knoten- und npm-Installation.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Drücken Sie Strg + C, um den aktiven Server zu beenden. Wir führen ihn im nächsten Schritt noch einmal aus.

7. Clientseite erstellen und Code anzeigen

Da der Server nun bereit ist, muss der Clientcode geschrieben werden. Der Code wird dann im Browser ausgeführt. Erstellen Sie eine einfache Seite, um die Modellvorhersage auf dem Server aufzurufen und das Ergebnis anzuzeigen. Dabei wird socket.io für die Client/Server-Kommunikation verwendet.

Erstellen Sie zunächst im Ordner „Baseball/“ die Datei „index.html“:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Erstellen Sie dann mit dem folgenden Code eine neue Datei „client.js“ im Ordner „Baseball/“:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Der Client verarbeitet die Socket-Nachricht trainingComplete, um eine Vorhersageschaltfläche anzuzeigen. Wenn auf diese Schaltfläche geklickt wird, sendet der Client eine Socket-Nachricht mit Beispielsensordaten. Beim Empfang einer predictResult-Nachricht wird die Vervollständigung auf der Seite angezeigt.

8. Anwendung ausführen

Führen Sie sowohl den Server als auch den Client aus, um die vollständige Anwendung in Aktion zu sehen:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Öffnen Sie die Clientseite in Ihrem Browser ( http://localhost:8080). Klicken Sie nach Abschluss des Modelltrainings auf die Schaltfläche Vorhersagebeispiel. Im Browser sollte eine Vervollständigung angezeigt werden. Sie können die Beispielsensordaten mit einigen Beispielen aus der CSV-Testdatei ändern und sehen, wie genau das Modell vorhersagt.

9. Das haben Sie gelernt

In diesem Codelab haben Sie mithilfe von TensorFlow.js eine einfache Webanwendung für maschinelles Lernen implementiert. Sie haben ein benutzerdefiniertes Modell zur Klassifizierung von Baseballfeldtypen anhand von Sensordaten trainiert. Sie haben Node.js-Code geschrieben, um das Training auf dem Server auszuführen und mithilfe der vom Client gesendeten Daten die Inferenz für das trainierte Modell aufzurufen.

Unter tensorflow.org/js finden Sie weitere Beispiele und Demos mit Code, um zu erfahren, wie Sie TensorFlow.js in Ihren Anwendungen verwenden können.