Pelatihan TensorFlow.js di Codelab Node.js

1. Pengantar

Dalam Codelab ini, Anda akan mempelajari cara membangun server web Node.js untuk melatih dan mengklasifikasikan jenis lapangan bisbol di sisi server menggunakan TensorFlow.js, library machine learning yang canggih dan fleksibel untuk JavaScript. Anda akan membangun aplikasi web untuk melatih model guna memprediksi jenis pitch dari data sensor pitch, dan untuk memanggil prediksi dari klien web. Versi Codelab ini yang berfungsi sepenuhnya ada di repo GitHub tfjs-examples.

Yang akan Anda pelajari

  • Cara menginstal dan menyiapkan paket npm tensorflow.js untuk digunakan dengan Node.js.
  • Cara mengakses data pelatihan dan pengujian di lingkungan Node.js.
  • Cara melatih model dengan TensorFlow.js di server Node.js.
  • Cara men-deploy model yang telah dilatih untuk inferensi dalam aplikasi klien/server.

Mari kita mulai.

2. Persyaratan

Untuk menyelesaikan Codelab ini, Anda akan memerlukan:

  1. Chrome versi terbaru atau browser modern lainnya.
  2. Editor teks dan terminal perintah yang berjalan secara lokal di komputer Anda.
  3. Pengetahuan terkait HTML, CSS, JavaScript, dan Chrome DevTools (atau browser pilihan Anda).
  4. Pemahaman konseptual tingkat tinggi tentang jaringan neural. Jika Anda memerlukan pengantar atau penyegaran terkait materi yang akan dipelajari, pertimbangkan untuk menonton video karya 3blue1brown ini atau video tentang Deep Learning in JavaScript dari Ashi Krishnan.

3. Menyiapkan aplikasi Node.js

Instal Node.js dan npm. Untuk platform dan dependensi yang didukung, lihat panduan penginstalan node tfjs.

Buat direktori bernama ./baseball untuk aplikasi Node.js. Salin package.json dan webpack.config.js yang tertaut ke direktori ini untuk mengonfigurasi dependensi paket npm (termasuk paket npm @tensorflow/tfjs-node). Kemudian, jalankan penginstalan npm untuk menginstal dependensi.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Sekarang Anda siap untuk menulis beberapa kode dan melatih model.

4. Menyiapkan data pelatihan dan pengujian

Anda akan menggunakan data pelatihan dan pengujian sebagai {i>file<i} CSV dari tautan di bawah ini. Download dan jelajahi data dalam file ini:

pitch_type_training_data.csv

pitch_type_test_data.csv

Mari kita lihat beberapa contoh data pelatihan:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Ada delapan fitur input - yang menjelaskan data sensor pitch:

  • kecepatan bola (vx0, vy0, vz0)
  • akselerasi bola (kapak, ay, az)
  • kecepatan awal pitch
  • apakah pitcher tangan kiri atau tidak

dan satu label output:

  • pitch_code yang menandakan salah satu dari tujuh jenis saran lagu: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

Tujuannya adalah membangun model yang dapat memprediksi jenis pitch berdasarkan data sensor pitch.

Sebelum membuat model, Anda perlu menyiapkan data pelatihan dan pengujian. Buat file pitch_type.js di baseball/ dir, lalu salin kode berikut ke dalamnya. Kode ini memuat data pelatihan dan pengujian menggunakan API tf.data.csv. Tindakan ini juga menormalisasi data (yang selalu direkomendasikan) menggunakan skala normalisasi min-maks.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Membuat model untuk mengklasifikasikan jenis saran lagu

Sekarang Anda siap membangun model. Gunakan tf.layers API untuk menghubungkan input (bentuk dari [8] nilai sensor pitch) ke 3 lapisan tersembunyi yang terhubung sepenuhnya yang terdiri dari unit aktivasi ULT, diikuti dengan satu lapisan output softmax yang terdiri dari 7 unit, yang masing-masing mewakili salah satu jenis pitch output.

Latih model dengan pengoptimal adam dan fungsi kerugian sparseCategoricalCrossentropy. Untuk mengetahui info selengkapnya tentang pilihan ini, lihat panduan model pelatihan.

Tambahkan kode berikut ke akhir pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Picu pelatihan dari kode server utama yang akan Anda tulis nanti.

Untuk menyelesaikan modul pitch_type.js, mari kita tulis fungsi untuk mengevaluasi set data validasi dan pengujian, memprediksi jenis pitch untuk satu sampel, dan menghitung metrik akurasi. Tambahkan kode ini ke akhir pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Melatih model di server

Tulis kode server untuk melakukan pelatihan dan evaluasi model dalam file baru yang disebut server.js. Pertama, buat server HTTP dan buka koneksi soket dua arah menggunakan socket.io API. Kemudian, jalankan pelatihan model menggunakan API model.fitDataset, dan evaluasi akurasi model menggunakan metode pitch_type.evaluate() yang telah Anda tulis sebelumnya. Latih dan evaluasi untuk 10 iterasi, yang akan mencetak metrik ke konsol.

Salin kode di bawah ke server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Pada tahap ini, Anda siap untuk menjalankan dan menguji server. Anda akan melihat sesuatu seperti ini, dengan server yang melatih satu epoch di setiap iterasi (Anda juga dapat menggunakan model.fitDataset API untuk melatih beberapa epoch dengan satu panggilan). Jika Anda menemukan error pada tahap ini, periksa penginstalan node dan npm Anda.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ketik Ctrl-C untuk menghentikan server yang berjalan. Kami akan menjalankannya lagi di langkah berikutnya.

7. Membuat halaman klien dan menampilkan kode

Setelah server siap, langkah berikutnya adalah menulis kode klien yang berjalan di browser. Buat halaman sederhana untuk memanggil prediksi model di server dan menampilkan hasilnya. ini menggunakan socket.io untuk komunikasi klien/server.

Pertama, buat index.html di folder baseball/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Kemudian buat file baru client.js di folder bisbol/ dengan kode di bawah ini:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Klien menangani pesan soket trainingComplete untuk menampilkan tombol prediksi. Saat tombol ini diklik, klien akan mengirimkan pesan soket dengan contoh data sensor. Setelah menerima pesan predictResult, prediksi akan ditampilkan di halaman.

8. Menjalankan aplikasi

Jalankan server dan klien untuk melihat cara kerja aplikasi lengkap:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Buka halaman klien di browser Anda ( http://localhost:8080). Saat pelatihan model selesai, klik tombol Predict Sample. Anda akan melihat hasil prediksi yang ditampilkan di browser. Jangan ragu untuk mengubah sampel data sensor dengan beberapa contoh dari file CSV pengujian dan lihat seberapa akurat model tersebut dapat memprediksi.

9. Yang telah Anda pelajari

Dalam Codelab ini, Anda akan mengimplementasikan aplikasi web machine learning sederhana menggunakan TensorFlow.js. Anda telah melatih model kustom untuk mengklasifikasikan jenis lapangan bisbol dari data sensor. Anda telah menulis kode Node.js untuk menjalankan pelatihan di server, dan memanggil inferensi pada model yang dilatih menggunakan data yang dikirim dari klien.

Pastikan Anda mengunjungi tensorflow.org/js untuk melihat contoh dan demo lainnya beserta kode untuk mengetahui cara menggunakan TensorFlow.js di aplikasi Anda.