آموزش TensorFlow.js در Node.js Codelab

1. مقدمه

در این Codelab، نحوه ساخت یک وب سرور Node.js را برای آموزش و طبقه بندی انواع زمین بیسبال در سمت سرور با استفاده از TensorFlow.js ، یک کتابخانه یادگیری ماشینی قدرتمند و انعطاف پذیر برای جاوا اسکریپت، یاد خواهید گرفت. شما یک برنامه وب می سازید تا مدلی را برای پیش بینی نوع گام از روی داده های حسگر pitch، و برای فراخوانی پیش بینی از یک سرویس گیرنده وب، آموزش دهید. یک نسخه کاملاً کارآمد از این Codelab در مخزن tfjs-examples GitHub موجود است.

چیزی که یاد خواهید گرفت

  • نحوه نصب و راه اندازی بسته npm tensorflow.js برای استفاده با Node.js.
  • نحوه دسترسی به داده های آموزشی و آزمایشی در محیط Node.js.
  • نحوه آموزش یک مدل با TensorFlow.js در سرور Node.js.
  • نحوه استقرار مدل آموزش دیده برای استنتاج در یک برنامه مشتری/سرور.

پس بیایید شروع کنیم!

2. الزامات

برای تکمیل این Codelab به موارد زیر نیاز دارید:

  1. نسخه اخیر کروم یا مرورگر مدرن دیگری.
  2. یک ویرایشگر متن و ترمینال فرمان که به صورت محلی روی دستگاه شما اجرا می شود.
  3. دانش HTML، CSS، جاوا اسکریپت و ابزارهای توسعه دهنده کروم (یا ابزارهای توسعه دهنده مرورگرهای دلخواه شما).
  4. درک مفهومی سطح بالا از شبکه های عصبی . اگر به یک مقدمه یا تجدید نظر نیاز دارید، این ویدیو را توسط 3blue1brown یا این ویدیو را در آموزش عمیق در جاوا اسکریپت توسط Ashi Krishnan تماشا کنید.

3. یک برنامه Node.js راه اندازی کنید

Node.js و npm را نصب کنید. برای پلتفرم‌ها و وابستگی‌های پشتیبانی‌شده، لطفاً راهنمای نصب tfjs-node را ببینید.

یک دایرکتوری به نام ./baseball برای برنامه Node.js ما ایجاد کنید. بسته.json و webpack.config.js پیوند داده شده را در این دایرکتوری کپی کنید تا وابستگی های بسته npm (از جمله بسته @tensorflow/tfjs-node npm) را پیکربندی کنید. سپس npm install را برای نصب وابستگی ها اجرا کنید.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

حالا شما آماده نوشتن کد و آموزش یک مدل هستید!

4. راه اندازی داده های آموزش و آزمون

شما از داده های آموزش و آزمون به عنوان فایل CSV از لینک های زیر استفاده خواهید کرد. داده های این فایل ها را دانلود و کاوش کنید:

pitch_type_training_data.csv

pitch_type_test_data.csv

بیایید به نمونه داده های آموزشی نگاه کنیم:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

هشت ویژگی ورودی وجود دارد - داده های سنسور پیچ را توصیف می کند:

  • سرعت توپ (vx0، vy0، vz0)
  • شتاب توپ (ax، ay، az)
  • سرعت شروع زمین
  • پارچ چپ دست باشد یا نه

و یک برچسب خروجی:

  • pitch_code که یکی از هفت نوع زمین را نشان می دهد: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

هدف، ساخت مدلی است که قادر به پیش‌بینی نوع گام داده‌های سنسور گام باشد.

قبل از ایجاد مدل، باید داده های آموزشی و آزمایشی را آماده کنید. فایل pitch_type.js را در baseball/dir ایجاد کنید و کد زیر را در آن کپی کنید. این کد داده های آموزشی و آزمایشی را با استفاده از tf.data.csv API بارگیری می کند. همچنین داده ها را (که همیشه توصیه می شود) با استفاده از مقیاس نرمال سازی حداقل حداکثر، عادی می کند.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. ایجاد مدل برای طبقه بندی انواع زمین

اکنون شما آماده ساخت مدل هستید. از tf.layers API برای اتصال ورودی‌ها (شکل [8] مقادیر سنسور pitch) به 3 لایه پنهان کاملاً متصل متشکل از واحدهای فعال‌سازی ReLU، و به دنبال آن یک لایه خروجی softmax متشکل از 7 واحد استفاده کنید که هر یک نماینده یکی از خروجی‌ها هستند. انواع زمین

مدل را با بهینه‌ساز adam و تابع از دست دادن Crossentropy sparseCategoricalCrossentropy آموزش دهید. برای اطلاعات بیشتر در مورد این انتخاب ها، به راهنمای مدل های آموزشی مراجعه کنید.

کد زیر را به انتهای pitch_type.js اضافه کنید:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

آموزش را از کد سرور اصلی که بعداً می نویسید فعال کنید.

برای تکمیل ماژول pitch_type.js، بیایید تابعی بنویسیم تا اعتبارسنجی و مجموعه داده‌های آزمایشی را ارزیابی کند، یک نوع پیچ را برای یک نمونه پیش‌بینی کند و معیارهای دقت را محاسبه کند. این کد را به انتهای pitch_type.js اضافه کنید:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. مدل قطار روی سرور

کد سرور را برای انجام آموزش و ارزیابی مدل در فایل جدیدی به نام server.js بنویسید. ابتدا یک سرور HTTP ایجاد کنید و یک اتصال سوکت دو طرفه را با استفاده از API socket.io باز کنید. سپس آموزش مدل را با استفاده از model.fitDataset API اجرا کنید و دقت مدل را با استفاده از روش pitch_type.evaluate() که قبلا نوشتید ارزیابی کنید. آموزش و ارزیابی برای 10 تکرار، چاپ معیارها در کنسول.

کد زیر را در server.js کپی کنید:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

در این مرحله شما آماده اجرا و تست سرور هستید! شما باید چیزی شبیه به این را ببینید، با آموزش سرور یک دوره در هر تکرار (همچنین می توانید از model.fitDataset API برای آموزش چندین دوره با یک تماس استفاده کنید). اگر در این مرحله با خطا مواجه شدید، لطفاً نصب node و npm خود را بررسی کنید.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ctrl-C را تایپ کنید تا سرور در حال اجرا متوقف شود. در مرحله بعد دوباره آن را اجرا می کنیم.

7. صفحه مشتری و کد نمایش ایجاد کنید

اکنون که سرور آماده است، مرحله بعدی نوشتن کد مشتری است که در مرورگر اجرا می شود. یک صفحه ساده برای فراخوانی پیش بینی مدل در سرور ایجاد کنید و نتیجه را نمایش دهید. این از socket.io برای ارتباط مشتری/سرور استفاده می کند.

ابتدا index.html را در پوشه baseball/ ایجاد کنید:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

سپس یک فایل client.js جدید در پوشه baseball/ با کد زیر ایجاد کنید:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

کلاینت پیام سوکت trainingComplete را برای نمایش یک دکمه پیش بینی مدیریت می کند. هنگامی که این دکمه کلیک می شود، مشتری یک پیام سوکت با داده های نمونه سنسور ارسال می کند. با دریافت پیام predictResult ، پیش بینی را در صفحه نمایش می دهد.

8. برنامه را اجرا کنید

سرور و کلاینت را اجرا کنید تا برنامه کامل را در عمل ببینید:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

صفحه مشتری را در مرورگر خود باز کنید ( http://localhost:8080 ). وقتی آموزش مدل به پایان رسید، روی دکمه Predict Sample کلیک کنید. شما باید یک نتیجه پیش بینی نمایش داده شده در مرورگر را ببینید. به راحتی می توانید داده های نمونه سنسور را با چند نمونه از فایل CSV آزمایشی اصلاح کنید و ببینید که مدل چقدر دقیق پیش بینی می کند.

9. آنچه یاد گرفتید

در این Codelab، شما یک برنامه وب یادگیری ماشین ساده را با استفاده از TensorFlow.js پیاده سازی کردید. شما یک مدل سفارشی برای طبقه بندی انواع زمین بیسبال از روی داده های حسگر آموزش دادید. شما کد Node.js را برای اجرای آموزش روی سرور نوشتید و با استفاده از داده های ارسال شده از مشتری، استنتاج را روی مدل آموزش دیده فراخوانی کرد.

حتماً به tensorflow.org/js برای مثال‌ها و نسخه‌های نمایشی بیشتر با کد مراجعه کنید تا ببینید چگونه می‌توانید از TensorFlow.js در برنامه‌های خود استفاده کنید.