Addestramento di TensorFlow.js nel codelab Node.js

1. Introduzione

In questo codelab, imparerai a creare un server web Node.js per addestrare e classificare i tipi di lancio del baseball sul lato server utilizzando TensorFlow.js, una libreria di machine learning potente e flessibile per JavaScript. Creerai un'applicazione web per addestrare un modello a prevedere il tipo di pitch dai dati dei sensori di inclinazione e per richiamare la previsione da un client web. Una versione completamente funzionante di questo codelab è presente nel repository GitHub tfjs-examples.

Obiettivi didattici

  • Installare e configurare il pacchetto npm tensorflow.js per l'utilizzo con Node.js.
  • Come accedere ai dati di addestramento e test nell'ambiente Node.js.
  • Come addestrare un modello con TensorFlow.js in un server Node.js.
  • Eseguire il deployment del modello addestrato per l'inferenza in un'applicazione client/server.

E ora iniziamo!

2. Requisiti

Per completare questo codelab, ti serviranno:

  1. Una versione recente di Chrome o di un altro browser moderno.
  2. Un editor di testo e un terminale di comando in esecuzione localmente sulla tua macchina.
  3. Conoscenza di HTML, CSS, JavaScript e Chrome DevTools (o gli strumenti di sviluppo del browser che preferisci).
  4. Una comprensione concettuale di alto livello delle reti neurali. Se hai bisogno di una presentazione o di un ripasso, guarda questo video di 3blue1brown o questo video sul deep learning in JavaScript di Ashi Krishnan.

3. Configura un'app Node.js

Installare Node.js e npm. Per le piattaforme e le dipendenze supportate, consulta la guida all'installazione di tfjs-node.

Crea una directory chiamata ./baseball per la nostra app Node.js. Copia i file package.json e webpack.config.js collegati in questa directory per configurare le dipendenze del pacchetto npm (incluso il pacchetto npm @tensorflow/tfjs-node). Quindi esegui npm install per installare le dipendenze.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Ora è tutto pronto per scrivere del codice e addestrare un modello.

4. Configura i dati di addestramento e test

Utilizzerai i dati di addestramento e di test come file CSV dei link riportati di seguito. Scarica ed esplora i dati contenuti in questi file:

pitch_type_training_data.csv

pitch_type_test_data.csv

Diamo un'occhiata ad alcuni dati di addestramento di esempio:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Esistono otto funzionalità di input, che descrivono i dati del sensore del passo:

  • velocità della sfera (vx0, vy0, vz0)
  • accelerazione della palla (ax, ay, az)
  • velocità iniziale del tono
  • se il lanciatore è mancino o meno

e un'etichetta di output:

  • pitch_code che indica uno dei sette tipi di proposta musicale: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

L'obiettivo è creare un modello che sia in grado di prevedere il tipo di beccheggio in base ai dati del sensore della pendenza.

Prima di creare il modello, devi preparare i dati di addestramento e test. Crea il file pitch_type.js nella directory baseball/ dir e copia al suo interno il codice seguente. Questo codice carica i dati di addestramento e test utilizzando l'API tf.data.csv. Inoltre, normalizza i dati (sempre consigliata) utilizzando una scala di normalizzazione min-max.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Crea un modello per classificare i tipi di proposta musicale

Ora è tutto pronto per creare il modello. Utilizza l'API tf.layers per collegare gli input (forma dei valori dei sensori del passo [8]) a 3 strati nascosti completamente connessi costituiti da unità di attivazione ReLU, seguiti da uno strato di output softmax composto da 7 unità, ciascuna delle quali rappresenta uno dei tipi di passo di output.

Addestra il modello con l'ottimizzatore di Adam e la funzione di perdita sparseCategoricalCrossentropy. Per saperne di più su queste scelte, consulta la guida ai modelli di addestramento.

Aggiungi il seguente codice alla fine di pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Attiva l'addestramento dal codice del server principale che scriverai in seguito.

Per completare il modulo pitch_type.js, scriviamo una funzione per valutare il set di dati di convalida e test, prevedere il tipo di proposta per un singolo campione e calcolare le metriche di accuratezza. Aggiungi questo codice alla fine di pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Addestra il modello sul server

Scrivi il codice server per eseguire l'addestramento e la valutazione del modello in un nuovo file chiamato server.js. Per prima cosa, crea un server HTTP e apri una connessione socket bidirezionale utilizzando l'API socket.io. Quindi, esegui l'addestramento del modello utilizzando l'API model.fitDataset e valuta l'accuratezza del modello utilizzando il metodo pitch_type.evaluate() che hai scritto in precedenza. Addestra e valuta 10 iterazioni, stampando le metriche sulla console.

Copia il seguente codice su server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

A questo punto, sei pronto per eseguire e testare il server. Dovresti vedere qualcosa di simile, con il server che esegue l'addestramento di un'epoca in ogni iterazione (puoi anche utilizzare l'API model.fitDataset per addestrare più epoche con una sola chiamata). Se si verificano errori a questo punto, controlla l'installazione del nodo e di npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Digita Ctrl+C per arrestare il server in esecuzione. Eseguiremo di nuovo la richiesta nel passaggio successivo.

7. Crea la pagina client e visualizza il codice

Ora che il server è pronto, il passaggio successivo è scrivere il codice client che viene eseguito nel browser. Creare una semplice pagina per richiamare la previsione del modello sul server e visualizzare il risultato. Viene utilizzato socket.io per la comunicazione client/server.

Per prima cosa, crea index.html nella cartella baseball/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Quindi, crea un nuovo file client.js nella cartella baseball/ con il codice riportato di seguito:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Il client gestisce il messaggio socket trainingComplete per visualizzare un pulsante di previsione. Quando viene fatto clic su questo pulsante, il client invia un messaggio socket con i dati dei sensori di esempio. Dopo aver ricevuto un messaggio predictResult, viene visualizzata la previsione sulla pagina.

8. Esegui l'app

Esegui sia il server che il client per vedere come funziona l'app completa:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Apri la pagina del client nel browser ( http://localhost:8080). Al termine dell'addestramento del modello, fai clic sul pulsante Prevedi esempio. Dovresti vedere nel browser il risultato della previsione. Non esitare a modificare i dati dei sensori campione con alcuni esempi del file CSV di prova e vedere quanto è precisa la previsione del modello.

9. Che cosa hai imparato

In questo codelab, hai implementato una semplice applicazione web di machine learning utilizzando TensorFlow.js. Hai addestrato un modello personalizzato per classificare i tipi di campo da baseball a partire dai dati dei sensori. Hai scritto il codice Node.js per eseguire l'addestramento sul server e chiamare l'inferenza sul modello addestrato utilizzando i dati inviati dal client.

Assicurati di visitare tensorflow.org/js per altri esempi e demo con codice che illustrano come utilizzare TensorFlow.js nelle tue applicazioni.