Node.js Codelab의 TensorFlow.js 학습

1. 소개

이 Codelab에서는 JavaScript용 강력하고 유연한 머신러닝 라이브러리인 TensorFlow.js를 사용하여 서버 측에서 야구 투구 유형을 학습시키고 분류하는 Node.js 웹 서버를 빌드하는 방법을 알아봅니다. 피치 센서 데이터에서 피치 유형을 예측하도록 모델을 학습시키고 웹 클라이언트에서 예측을 호출하는 웹 애플리케이션을 빌드합니다. 이 Codelab의 완전하게 작동하는 버전은 tfjs-examples GitHub 저장소에 있습니다.

학습할 내용

  • Node.js와 함께 사용할 tensorflow.js npm 패키지를 설치하고 설정하는 방법
  • Node.js 환경에서 학습 및 테스트 데이터에 액세스하는 방법
  • Node.js 서버에서 TensorFlow.js를 사용하여 모델을 학습시키는 방법
  • 클라이언트/서버 애플리케이션에서 추론을 위해 학습된 모델을 배포하는 방법

그럼 시작해 보겠습니다.

2. 요구사항

이 Codelab을 완료하려면 다음이 필요합니다.

  1. 최신 버전의 Chrome 또는 다른 최신 브라우저
  2. 머신에서 로컬로 실행되는 텍스트 편집기 및 명령어 터미널
  3. HTML, CSS, JavaScript, Chrome DevTools (또는 선호하는 브라우저의 개발 도구)에 관한 지식
  4. 신경망에 관한 대략적인 개념 이해 소개나 복습이 필요한 경우 3blue1brown의 동영상 또는 아시 크리슈난의 JavaScript 딥 러닝에 관한 동영상을 시청해 보세요.

3. Node.js 앱 설정

Node.js 및 npm을 설치합니다. 지원되는 플랫폼 및 종속 항목은 tfjs-node 설치 가이드를 참조하세요.

Node.js 앱에 사용할 ./baseball이라는 디렉터리를 만듭니다. 연결된 package.jsonwebpack.config.js를 이 디렉터리에 복사하여 npm 패키지 종속 항목 (@tensorflow/tfjs-node npm 패키지 포함)을 구성합니다. 그런 다음 npm install을 실행하여 종속 항목을 설치합니다.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

이제 코드를 작성하고 모델을 학습시킬 준비가 되었습니다.

4. 학습 및 테스트 데이터 설정

학습 및 테스트 데이터는 아래 링크의 CSV 파일로 사용합니다. 다음 파일의 데이터를 다운로드하고 탐색하세요.

pitch_type_training_data.csv

pitch_type_test_data.csv

샘플 학습 데이터를 살펴보겠습니다.

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

피치 센서 데이터를 설명하는 8개의 입력 특성이 있습니다.

  • 볼 속도 (vx0, vy0, vz0)
  • 볼 가속도 (ax, ay, az)
  • 피치 시작 속도
  • 투수의 왼손잡이용인지 여부

출력 라벨 1개:

  • 7가지 추천곡 유형 중 하나를 나타내는pel_code: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

목표는 피치 센서 데이터에 따라 피치 유형을 예측할 수 있는 모델을 빌드하는 것입니다.

모델을 만들기 전에 학습 및 테스트 데이터를 준비해야 합니다. baseball/ dir에 Pitch_type.js 파일을 만들고 이 파일에 다음 코드를 복사합니다. 이 코드는 tf.data.csv API를 사용하여 학습 및 테스트 데이터를 로드합니다. 또한 최소-최대 정규화 척도를 사용하여 데이터를 정규화합니다 (항상 권장됨).

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. 피치 유형 분류를 위한 모델 만들기

이제 모델을 빌드할 준비가 되었습니다. tf.layers API를 사용하여 입력([8] 개의 피치 센서 값의 형태)을 ReLU 활성화 단위로 이루어진 3개의 완전 연결형 히든 레이어에 연결하고, 이어서 7개의 단위로 구성된 소프트맥스 출력 레이어 1개를 연결합니다(각각 출력 피치 유형 중 하나를 나타냄).

adam 옵티마이저와 sparseCategoricalCrossentropy 손실 함수를 사용하여 모델을 학습시킵니다. 이러한 선택사항에 대한 자세한 내용은 학습 모델 가이드를 참조하세요.

다음 코드를 train_type.js의 끝에 추가합니다.

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

나중에 작성할 기본 서버 코드에서 학습을 트리거합니다.

deck_type.js 모듈을 완료하기 위해 검증 및 테스트 데이터 세트를 평가하고, 단일 샘플의 피치 유형을 예측하고, 정확성 측정항목을 계산하는 함수를 작성해 보겠습니다. 다음 코드를 train_type.js의 끝에 추가하세요.

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. 서버에서 모델 학습

server.js라는 새 파일에서 모델 학습 및 평가를 수행하는 서버 코드를 작성합니다. 먼저 HTTP 서버를 만들고 socket.io API를 사용하여 양방향 소켓 연결을 엽니다. 그런 다음 model.fitDataset API를 사용하여 모델 학습을 실행하고 앞서 작성한 pitch_type.evaluate() 메서드를 사용하여 모델 정확성을 평가합니다. 10회 반복을 학습 및 평가하고 측정항목을 콘솔에 출력합니다.

아래 코드를 server.js에 복사합니다.

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

이제 서버를 실행하고 테스트할 준비가 되었습니다. 서버가 반복될 때마다 한 에포크를 학습하는 모습이 표시됩니다. model.fitDataset API를 사용하여 한 번의 호출로 여러 에포크를 학습시킬 수도 있습니다. 이 시점에 오류가 발생하면 노드 및 npm 설치를 확인하세요.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ctrl-C를 입력하여 실행 중인 서버를 중지합니다. 다음 단계에서 다시 실행합니다.

7. 고객 페이지 만들기 및 코드 표시

이제 서버가 준비되었으므로 다음 단계는 브라우저에서 실행되는 클라이언트 코드를 작성하는 것입니다. 서버에서 모델 예측을 호출하고 결과를 표시하는 간단한 페이지를 만듭니다. 클라이언트/서버 통신에 socket.io를 사용합니다.

먼저 baseball/ 폴더에 index.html을 만듭니다.

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

그런 다음 아래 코드를 사용하여 baseball/ 폴더에 client.js 파일을 새로 만듭니다.

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

클라이언트는 trainingComplete 소켓 메시지를 처리하여 예측 버튼을 표시합니다. 이 버튼을 클릭하면 클라이언트가 샘플 센서 데이터가 포함된 소켓 메시지를 보냅니다. predictResult 메시지를 수신하면 페이지에 예상 검색어가 표시됩니다.

8. 앱 실행

서버와 클라이언트를 모두 실행하여 작동 중인 전체 앱을 확인합니다.

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

브라우저에서 클라이언트 페이지를 엽니다 ( http://localhost:8080). 모델 학습이 끝나면 샘플 예측 버튼을 클릭합니다. 브라우저에 예상 검색어 결과가 표시됩니다. 테스트 CSV 파일의 몇 가지 예를 사용하여 샘플 센서 데이터를 자유롭게 수정하고 모델이 얼마나 정확하게 예측하는지 확인해 보세요.

9. 학습한 내용

이 Codelab에서는 TensorFlow.js를 사용하여 간단한 머신러닝 웹 애플리케이션을 구현했습니다. 센서 데이터에서 야구 투구 유형을 분류하는 커스텀 모델을 학습시켰습니다. 서버에서 학습을 실행하고 클라이언트에서 전송한 데이터를 사용하여 학습된 모델에서 추론을 호출하도록 Node.js 코드를 작성했습니다.

애플리케이션에서 TensorFlow.js를 사용하는 방법을 알아보려면 코드가 포함된 더 많은 예제와 데모를 보려면 tensorflow.org/js를 방문하세요.