Đào tạo về TensorFlow.js trong Lớp học lập trình Node.js

1. Giới thiệu

Trong lớp học lập trình này, bạn sẽ tìm hiểu cách tạo máy chủ web Node.js để huấn luyện và phân loại các loại sân bóng chày ở phía máy chủ bằng TensorFlow.js, một thư viện máy học mạnh mẽ và linh hoạt dành cho JavaScript. Bạn sẽ xây dựng một ứng dụng web để huấn luyện một mô hình giúp dự đoán loại đề cử từ dữ liệu cảm biến đề cử cũng như để gọi ra thông tin dự đoán của một ứng dụng web. Phiên bản hoạt động đầy đủ của lớp học lập trình này có trong kho lưu trữ GitHub tfjs-examples.

Kiến thức bạn sẽ học được

  • Cách cài đặt và thiết lập gói npm tensorflow.js để sử dụng với Node.js.
  • Cách truy cập vào dữ liệu huấn luyện và thử nghiệm trong môi trường Node.js.
  • Cách huấn luyện mô hình bằng TensorFlow.js trong máy chủ Node.js.
  • Cách triển khai mô hình đã huấn luyện để suy luận trong ứng dụng máy khách/máy chủ.

Hãy cùng bắt đầu nào!

2. Yêu cầu

Để hoàn tất lớp học lập trình này, bạn sẽ cần:

  1. Phiên bản Chrome mới nhất hoặc một trình duyệt hiện đại khác.
  2. Một trình chỉnh sửa văn bản và cửa sổ dòng lệnh chạy trên máy của bạn.
  3. Có kiến thức về HTML, CSS, JavaScript và Công cụ của Chrome cho nhà phát triển (hoặc công cụ cho nhà phát triển của trình duyệt mà bạn ưu tiên).
  4. Hiểu biết khái niệm ở mức độ cao về mạng nơron. Nếu bạn cần giới thiệu hoặc ôn tập lại, hãy cân nhắc xem video này của 3blue1brown hoặc video về Học sâu bằng JavaScript của Ashi Krishnan.

3. Thiết lập ứng dụng Node.js

Cài đặt Node.js và npm. Đối với các nền tảng và phần phụ thuộc được hỗ trợ, vui lòng xem hướng dẫn cài đặt nút tfjs.

Tạo một thư mục có tên là ./xx cho ứng dụng Node.js của chúng tôi. Sao chép package.jsonwebpack.config.js đã liên kết vào thư mục này để định cấu hình các phần phụ thuộc của gói npm (bao gồm cả gói npm @tensorflow/tfjs-node). Sau đó, hãy chạy lệnh cài đặt npm để cài đặt các phần phụ thuộc.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Bây giờ, bạn đã sẵn sàng viết một số mã và huấn luyện mô hình!

4. Thiết lập dữ liệu huấn luyện và kiểm thử

Bạn sẽ sử dụng dữ liệu huấn luyện và kiểm tra dưới dạng tệp CSV từ các liên kết bên dưới. Tải xuống và khám phá dữ liệu trong các tệp sau:

pitch_type_training_data.csv

pitch_type_test_data.csv

Hãy xem một số dữ liệu huấn luyện mẫu:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Có 8 tính năng đầu vào – mô tả dữ liệu cảm biến cao độ:

  • vận tốc của bóng (vx0, vy0, vz0)
  • gia tốc của bóng (ax, ay, az)
  • tốc độ cao độ ban đầu
  • người ném bóng có thuận tay trái hay không

và một nhãn đầu ra:

  • Pitch_code biểu thị một trong 7 loại đề cử: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

Mục tiêu là xây dựng một mô hình có thể dự đoán loại cao độ dựa trên dữ liệu cảm biến cao độ.

Trước khi tạo mô hình, bạn cần chuẩn bị dữ liệu huấn luyện và kiểm thử. Tạo tệp sân_type.js trong bóng chày/ thư mục và sao chép đoạn mã sau vào tệp đó. Mã này tải dữ liệu huấn luyện và thử nghiệm bằng cách sử dụng API tf.data.csv. Việc này cũng chuẩn hoá dữ liệu (luôn được khuyến nghị) bằng cách sử dụng thang đo chuẩn hoá tối thiểu.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Tạo mô hình để phân loại các loại bài thuyết trình bán hàng

Giờ thì bạn đã sẵn sàng xây dựng mô hình. Sử dụng API tf.layers để kết nối các đầu vào (hình dạng của [8] giá trị cảm biến độ cao) với 3 lớp ẩn được kết nối đầy đủ bao gồm các đơn vị kích hoạt ReLU, theo sau là một lớp đầu ra mềm max bao gồm 7 đơn vị, mỗi đơn vị đại diện cho một trong các kiểu cao độ đầu ra.

Huấn luyện mô hình bằng trình tối ưu hoá adam và hàm mất sparseCategoricalCrossentropy. Để biết thêm thông tin về những lựa chọn này, hãy tham khảo hướng dẫn về mô hình huấn luyện.

Thêm đoạn mã sau vào cuối tập_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Kích hoạt quá trình huấn luyện qua mã máy chủ chính mà bạn sẽ viết sau này.

Để hoàn thành mô-đun sân_type.js, hãy viết một hàm để đánh giá tập dữ liệu xác thực và kiểm tra, dự đoán loại đề cử cho một mẫu duy nhất và tính toán các chỉ số về độ chính xác. Thêm mã này vào cuối cuốiitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Mô hình tàu hoả trên máy chủ

Viết mã máy chủ để thực hiện huấn luyện và đánh giá mô hình trong một tệp mới có tên là server.js. Trước tiên, hãy tạo một máy chủ HTTP rồi mở kết nối ổ cắm hai chiều bằng API socket.io. Sau đó, hãy thực thi quy trình huấn luyện mô hình bằng API model.fitDataset và đánh giá độ chính xác của mô hình bằng phương thức pitch_type.evaluate() mà bạn đã viết trước đó. Đào tạo và đánh giá 10 lần lặp lại, in chỉ số ra bảng điều khiển.

Sao chép mã bên dưới vào server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Đến đây, bạn đã sẵn sàng chạy và kiểm tra máy chủ! Bạn sẽ thấy như thế này, với máy chủ huấn luyện một thời gian bắt đầu của hệ thống trong mỗi lần lặp lại (bạn cũng có thể sử dụng API model.fitDataset để huấn luyện nhiều thời gian bắt đầu của hệ thống bằng một lệnh gọi). Nếu bạn gặp bất kỳ lỗi nào vào lúc này, vui lòng kiểm tra cài đặt nút và npm của bạn.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Gõ Ctrl-C để dừng máy chủ đang chạy. Chúng ta sẽ chạy lại mã này trong bước tiếp theo.

7. Tạo trang khách hàng và hiển thị mã

Giờ máy chủ đã sẵn sàng, bước tiếp theo là viết mã ứng dụng khách và chạy mã đó trong trình duyệt. Tạo một trang đơn giản để gọi thông tin dự đoán của mô hình trên máy chủ và hiển thị kết quả. Tuỳ chọn này sử dụng socket.io để giao tiếp máy khách/máy chủ.

Trước tiên, tạo chỉ mục.html trong thư mục bóng chày/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Sau đó, tạo một tệp client.js mới trong bóng chày/ thư mục với mã dưới đây:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Ứng dụng xử lý thông báo về ổ cắm trainingComplete để cho thấy nút dự đoán. Khi người dùng nhấp vào nút này, máy khách sẽ gửi một thông báo về ổ cắm kèm theo dữ liệu cảm biến mẫu. Sau khi nhận được tin nhắn predictResult, thông tin gợi ý sẽ hiển thị trên trang.

8. Chạy ứng dụng

Chạy cả máy chủ và máy khách để xem toàn bộ ứng dụng hoạt động:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Mở trang ứng dụng trong trình duyệt của bạn ( http://localhost:8080). Khi quá trình huấn luyện mô hình kết thúc, hãy nhấp vào nút Forecast Sample (Dự đoán mẫu). Bạn sẽ thấy kết quả gợi ý hiển thị trong trình duyệt. Bạn có thể chỉnh sửa dữ liệu cảm biến mẫu bằng một số ví dụ trong tệp CSV thử nghiệm để xem mô hình dự đoán chính xác đến mức nào.

9. Kiến thức bạn học được

Trong lớp học lập trình này, bạn đã triển khai một ứng dụng web đơn giản trong lĩnh vực máy học bằng cách sử dụng TensorFlow.js. Bạn đã huấn luyện một mô hình tuỳ chỉnh để phân loại các loại sân bóng chày từ dữ liệu cảm biến. Bạn đã viết mã Node.js để thực thi quá trình huấn luyện trên máy chủ và gọi dự đoán trên mô hình đã huấn luyện bằng cách sử dụng dữ liệu gửi từ máy khách.

Hãy nhớ truy cập vào tensorflow.org/js để xem thêm các ví dụ và bản minh hoạ có đoạn mã để xem cách bạn có thể sử dụng TensorFlow.js trong các ứng dụng của mình.