التدريب على TensorFlow.js في الدرس التطبيقي حول ترميز Node.js

1. مقدمة

ستتعلّم في هذا الدرس التطبيقي كيفية إنشاء خادم ويب Node.js لتدريب أنواع ملاعب البيسبول وتصنيفها على جهة الخادم باستخدام TensorFlow.js، وهي مكتبة فعّالة ومرنة لتعلُّم الآلة في JavaScript. ستقوم بإنشاء تطبيق ويب لتدريب نموذج للتنبؤ بنوع درجة الصوت من بيانات جهاز استشعار درجة الصوت، ولاستدعاء التنبؤ من عميل ويب. تتوفّر نسخة صالحة بالكامل من هذا الدرس التطبيقي في مستودع GitHub tfjs-examples.

ما ستتعرَّف عليه

  • كيفية تثبيت وإعداد حزمة tensorflow.js npm لاستخدامها مع Node.js.
  • كيفية الوصول إلى بيانات التدريب والاختبار في بيئة Node.js.
  • كيفية تدريب نموذج باستخدام TensorFlow.js في خادم Node.js.
  • كيفية نشر النموذج المدرَّب للاستنتاج في تطبيق العميل/الخادم.

هيا بنا نبدأ!

2. المتطلبات

لإكمال هذا الدرس التطبيقي حول الترميز، ستحتاج إلى:

  1. إصدار حديث من Chrome أو أي متصفّح حديث آخر
  2. محرِّر نصوص ومحطة أوامر يتم تشغيلهما محليًا على جهازك
  3. معرفة HTML وCSS وJavaScript وChrome DevTools (أو أدوات مطوري البرامج للمتصفحات المفضلة لديك).
  4. فهم نظري عالي المستوى للشبكات العصبونية. إذا كنت بحاجة إلى مقدمة أو تنشيط للذاكرة، يمكنك مشاهدة هذا الفيديو من قناة 3blue1brown أو هذا الفيديو حول التعلم المتعمق بلغة JavaScript من إعداد "آشي كريشنان".

3- إعداد تطبيق Node.js

ثبّت Node.js وnpm. بالنسبة إلى الأنظمة الأساسية والتبعيات المتوافقة، يُرجى الاطّلاع على دليل تثبيت tfjs-node.

أنشئ دليلاً باسم ./baseball لتطبيق Node.js لدينا. انسخ ملفي package.json وwebpack.config.js المرتبطَين في هذا الدليل لإعداد تبعيات حزمة npm (بما في ذلك حزمة @tensorflow/tfjs-node npm). ثم شغّل npm install لتثبيت التبعيات.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

والآن أنت على استعداد لكتابة بعض التعليمات البرمجية وتدريب نموذج!

4. إعداد بيانات التدريب والاختبار

ستستخدم بيانات التدريب والاختبار كملفات CSV من الروابط أدناه. تنزيل البيانات واستكشافها في هذه الملفات:

pitch_type_training_data.csv

pitch_type_test_data.csv

لنلقِ نظرة على بعض نماذج بيانات التدريب:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

هناك ثماني ميزات إدخال تصف بيانات أداة استشعار حدّة الصوت:

  • السرعة الكرة (vx0، vy0، vz0)
  • تسريع الكرات (ax، ay، az)
  • سرعة بدء درجة الصوت
  • ما إذا كان رامي الأسد عُسرًا أم لا

وتصنيف ناتج واحد:

  • "رمز_الاقتراح" الذي يشير إلى أحد أنواع الاقتراحات السبعة: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

الهدف هو إنشاء نموذج قادر على التنبؤ بنوع درجة الصوت استنادًا إلى بيانات جهاز استشعار درجة الصوت.

قبل إنشاء النموذج، تحتاج إلى إعداد بيانات التدريب والاختبار. أنشئ الملف لِth_type.js في الحقل "basball/ dir"، وانسخ الكود التالي فيه. يُحمِّل هذا الرمز بيانات التدريب والاختبار باستخدام واجهة برمجة التطبيقات tf.data.csv. كما أنها تعمل على تسوية البيانات (والتي يوصى بها دائمًا) باستخدام مقياس تسوية من الحد الأدنى والأقصى.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5- إنشاء نموذج لتصنيف أنواع الأغاني المقترَحة

أنت الآن جاهز لبناء النموذج. استخدِم واجهة برمجة التطبيقات tf.layers لتوصيل مصادر الإدخال (شكل [8] قيم أداة استشعار درجة الصوت) بثلاث طبقات مخفية متصلة بالكامل تتألف من وحدات تفعيل ReLU، متبوعة بطبقة إخراج softmax واحدة تتكون من 7 وحدات، تمثل كل منها أحد أنواع درجات الصوت الناتجة.

درِّب النموذج باستخدام أداة تحسين الأداء "آدم" ودالة الفقدان المتفرقة CategoricalCrossentropy. لمزيد من المعلومات عن هذه الخيارات، يُرجى الرجوع إلى دليل نماذج التدريب.

أضِف الرمز التالي في نهاية Pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

شغِّل التدريب من رمز الخادم الرئيسي الذي ستكتبه لاحقًا.

لإكمال وحدة اللغة Pitch_type.js، لنقم بكتابة دالة لتقييم مجموعة بيانات الاختبار والتحقق من الصحة، وتوقع نوع الاقتراح لعينة واحدة، وحساب مقاييس الدقة. إلحاق هذا الرمز بنهاية Pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6- نموذج القطار على الخادم

اكتب الرمز البرمجي للخادم لتنفيذ تدريب النموذج وتقييمه في ملف جديد باسم server.js. أولاً، أنشئ خادم HTTP وافتح اتصال مقبس ثنائي الاتجاه باستخدام واجهة برمجة تطبيقات Socket.io. بعد ذلك، يمكنك تنفيذ تدريب النموذج باستخدام واجهة برمجة التطبيقات model.fitDataset، وتقييم دقة النموذج باستخدام طريقة pitch_type.evaluate() التي كتبتها سابقًا. التدريب وتقييم 10 تكرارات، وطباعة المقاييس إلى وحدة التحكم.

انسخ الرمز أدناه إلى server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

في هذه المرحلة، تكون جاهزًا لتشغيل الخادم واختباره. من المفترض أن تظهر لك على النحو التالي، مع تدريب الخادم فترة واحدة في كل تكرار (يمكنك أيضًا استخدام واجهة برمجة التطبيقات model.fitDataset لتدريب فترات متعددة من خلال استدعاء واحد). في حال مواجهة أي أخطاء في هذه المرحلة، يُرجى التحقّق من العُقدة وتثبيت npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

اكتب Ctrl-C لإيقاف الخادم قيد التشغيل. سنُعيد تشغيله مرة أخرى في الخطوة التالية.

7. إنشاء صفحة العميل وعرض الرمز

والآن بعد أن أصبح الخادم جاهزًا، تتمثل الخطوة التالية في كتابة كود العميل ويتم تشغيله في المتصفح. إنشاء صفحة بسيطة لاستدعاء تنبؤ النموذج على الخادم وعرض النتيجة. يستخدم هذا Socket.io لاتصال العميل/الخادم.

أولاً، قم بإنشاء index.html في المجلد "baseball/":

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

ثم قم بإنشاء ملف جديد client.js في مجلد baseball/ بالتعليمة البرمجية أدناه:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

يتعامل العميل مع رسالة المقبس trainingComplete لعرض زر التوقّع. عند النقر على هذا الزر، يرسل العميل رسالة مقبس تحتوي على عينة من بيانات أداة الاستشعار. عند تلقّي رسالة predictResult، يتم عرض توقّع البحث في الصفحة.

8. تشغيل التطبيق

عليك تشغيل كلٍ من الخادم والعميل للاطلاع على التطبيق الكامل قيد التشغيل:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

افتح صفحة العميل في متصفّحك ( http://localhost:8080). عند انتهاء تدريب النموذج، انقر على الزر توقُّع العينة. من المفترض أن تظهر لك نتيجة توقّعات في المتصفح. لا تتردد في تعديل بيانات أداة استشعار العينة من خلال بعض الأمثلة من ملف CSV التجريبي ومعرفة مدى دقة تنبؤ النموذج.

9. ما تعلمته

في هذا الدرس التطبيقي حول الترميز، نفّذت تطبيقًا بسيطًا لتعلُّم الآلة باستخدام TensorFlow.js. لقد درّبت نموذجًا مخصّصًا لتصنيف أنواع ملاعب البيسبول من بيانات المستشعر. لقد كتبت رمز Node.js لتنفيذ تدريب على الخادم واستدعِ النموذج على النموذج المدرَّب باستخدام البيانات المرسلة من العميل.

احرص على الانتقال إلى tensorflow.org/js للاطّلاع على المزيد من الأمثلة والعروض التوضيحية التي تتضمّن رموزًا لمعرفة كيفية استخدام TensorFlow.js في تطبيقاتك.