Atelier de programmation sur l'entraînement TensorFlow.js dans Node.js

1. Introduction

Dans cet atelier de programmation, vous allez apprendre à créer un serveur Web Node.js pour entraîner et classer des types de lancers de baseball côté serveur à l'aide de TensorFlow.js, une bibliothèque de machine learning puissante et flexible pour JavaScript. Vous allez créer une application Web pour entraîner un modèle à prédire le type d'argumentaire à partir des données de capteurs de hauteur et pour appeler une prédiction à partir d'un client Web. Le dépôt GitHub tfjs-examples contient une version entièrement fonctionnelle de cet atelier de programmation.

Points abordés

  • Installer et configurer le package npm tensorflow.js à utiliser avec Node.js
  • Accéder aux données d'entraînement et de test dans l'environnement Node.js
  • Entraîner un modèle avec TensorFlow.js sur un serveur Node.js
  • Déployer le modèle entraîné pour l'inférence dans une application client/serveur

C'est parti !

2. Conditions requises

Pour suivre cet atelier de programmation, vous devez disposer des éléments suivants:

  1. Une version récente de Chrome ou un autre navigateur récent.
  2. Un éditeur de texte et un terminal de commande s'exécutant localement sur votre machine
  3. Connaissance des langages HTML, CSS et JavaScript, ainsi que des outils pour les développeurs Chrome (ou des outils de développement de votre navigateur préféré)
  4. Une compréhension conceptuelle de haut niveau des réseaux de neurones Si vous avez besoin d'une introduction ou d'un rappel, regardez cette vidéo de 3blue1brown ou cette vidéo sur le deep learning en JavaScript d'Ashi Krishnan.

3. Configurer une application Node.js

Installez Node.js et npm. Pour connaître les plates-formes et dépendances compatibles, consultez le guide d'installation de tfjs-node.

Créez un répertoire appelé ./baseball pour notre application Node.js. Copiez les fichiers package.json et webpack.config.js associés dans ce répertoire pour configurer les dépendances du package npm (y compris le package npm @tensorflow/tfjs-node). Ensuite, exécutez npm install pour installer les dépendances.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Vous êtes maintenant prêt à écrire du code et à entraîner un modèle.

4. Configurer les données d'entraînement et de test

Vous utiliserez les données d'entraînement et de test sous forme de fichiers CSV à partir des liens ci-dessous. Téléchargez et explorez les données contenues dans ces fichiers:

pitch_type_training_data.csv

pitch_type_test_data.csv

Examinons quelques exemples de données d'entraînement:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Huit caractéristiques d'entrée décrivent les données du capteur de hauteur:

  • vitesse de la balle (Vx0, vy0, vz0)
  • accélération de la balle (ax, ay, az)
  • vitesse de démarrage de l'argumentaire
  • si le lanceur est gaucher ou non

et une étiquette de sortie:

  • pitch_code correspondant à l'un des sept types de titres: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

L'objectif est de créer un modèle capable de prédire le type de tonalité d'après les données des capteurs de hauteur.

Avant de créer le modèle, vous devez préparer les données d'entraînement et de test. Créez le fichier pitch_type.js dans le répertoire baseball/ dir et copiez-y le code suivant. Ce code charge les données d'entraînement et de test à l'aide de l'API tf.data.csv. Il normalise également les données (ce qui est toujours recommandé) à l'aide d'une échelle de normalisation min-max.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Créer un modèle pour classer les types de titres

Vous êtes maintenant prêt à créer le modèle. Utilisez l'API tf.layers pour connecter les entrées (forme de [8] valeurs du capteur de hauteur) à trois couches cachées entièrement connectées composées d'unités d'activation ReLU, suivies d'une couche de sortie softmax composée de sept unités, chacune représentant l'un des types de hauteur de la de sortie.

Entraîner le modèle avec l'optimiseur adam et la fonction de perte d'entropie sparseCategoricalCrossentropy Pour en savoir plus sur ces options, consultez le guide d'entraînement des modèles.

Ajoutez le code suivant à la fin de pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Déclenchez l'entraînement à partir du code du serveur principal que vous écrirez ultérieurement.

Pour terminer le module pitch_type.js, nous allons écrire une fonction permettant d'évaluer l'ensemble de données de validation et de test, de prédire un type d'argumentaire pour un seul échantillon et de calculer des métriques de précision. Ajoutez ce code à la fin de pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Entraîner le modèle sur le serveur

Écrivez le code du serveur permettant d'entraîner et d'évaluer le modèle dans un nouveau fichier nommé server.js. Commencez par créer un serveur HTTP et ouvrez une connexion de socket bidirectionnelle à l'aide de l'API socket.io. Exécutez ensuite l'entraînement du modèle avec l'API model.fitDataset, puis évaluez la justesse du modèle à l'aide de la méthode pitch_type.evaluate() que vous avez écrite précédemment. Entraînez et évaluez les applications pour 10 itérations, en imprimant les métriques dans la console.

Copiez le code ci-dessous dans le fichier server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

À ce stade, vous êtes prêt à exécuter et à tester le serveur. Le résultat devrait ressembler à ceci, avec l'entraînement du serveur à une époque à chaque itération (vous pouvez également utiliser l'API model.fitDataset pour entraîner plusieurs époques avec un seul appel). Si vous rencontrez des erreurs à ce stade, veuillez vérifier l'installation de votre nœud et de npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Appuyez sur Ctrl+C pour arrêter le serveur en cours d'exécution. Nous l'exécuterons à nouveau à l'étape suivante.

7. Créer une page client et afficher le code

Maintenant que le serveur est prêt, l'étape suivante consiste à écrire le code client, qui s'exécute dans le navigateur. Créez une page simple pour appeler la prédiction du modèle sur le serveur et afficher le résultat. socket.io est utilisé pour la communication client/serveur.

Commencez par créer le fichier index.html dans le dossier baseball/ :

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Créez ensuite un nouveau fichier client.js dans le dossier baseball/ avec le code ci-dessous:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Le client gère le message de socket trainingComplete pour afficher un bouton de prédiction. Lorsque l'utilisateur clique sur ce bouton, le client envoie un message de socket contenant des échantillons de données de capteur. Lorsque vous recevez un message predictResult, la prédiction s'affiche sur la page.

8. Exécuter l'application

Exécutez à la fois le serveur et le client pour voir l'intégralité de l'application en action:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Ouvrez la page du client dans votre navigateur ( http://localhost:8080). Une fois l'entraînement du modèle terminé, cliquez sur le bouton Predict Sample (Échantillon de prédiction). Un résultat de prédiction doit s'afficher dans le navigateur. N'hésitez pas à modifier l'échantillon de données du capteur avec quelques exemples du fichier CSV de test pour voir avec quelle précision le modèle prédit.

9. Ce que vous avez appris

Dans cet atelier de programmation, vous avez implémenté une application Web de machine learning simple à l'aide de TensorFlow.js. Vous avez entraîné un modèle personnalisé pour classer des types de lancers de baseball à partir de données de capteurs. Vous avez écrit du code Node.js pour exécuter l'entraînement sur le serveur, puis appelé l'inférence sur le modèle entraîné à l'aide des données envoyées par le client.

N'oubliez pas de consulter le site tensorflow.org/js pour obtenir d'autres exemples et démonstrations de code, et apprendre à utiliser TensorFlow.js dans vos applications.