Featured image of post TensorFlow.jsでブラウザで機械学習 Featured image of post TensorFlow.jsでブラウザで機械学習

TensorFlow.jsでブラウザで機械学習

TensorFlow.jsを使ったブラウザ上の機械学習を解説。モデル読み込みと推論、転移学習、WebGL/WebGPUバックエンド、姿勢検出、プライバシー保護MLまでカバー。

ブラウザ上での機械学習は、サーバーコストの削減、ユーザープライバシーの保護、オフライン対応のインテリジェントアプリケーションを可能にします。TensorFlow.jsは、WebGLおよびWebGPUバックエンドによるGPUアクセラレーションを活用し、JavaScript開発者にMLへの扉を開きます。本記事では、事前学習モデルの読み込み、転移学習、リアルタイム姿勢検出、本番運用の考慮点を解説します。

なぜブラウザでMLか?

クライアントサイドMLの利点は4つあります。サーバーコストが不要(ユーザーのデバイスで推論を実行)、完全なプライバシー(データがマシン外に出ない)、オフライン対応(モデル読み込み後はネットワーク不要)、低レイテンシ(予測のラウンドトリップ不要)です。一方で、計算能力とメモリの制約、モバイルでのバッテリー消費、5〜200MBのモデルダウンロードサイズが課題です。


TensorFlow.js概要

TensorFlow.jsは3つのAPIを提供します。Core APIは低レベルのテンソル操作を、Layers APIはKerasに類似した高レベルモデル構築を、ConverterはPythonで訓練されたモデルをTensorFlow.js形式に変換します:

tensorflowjs_converter --input_format=keras \
  path/to/model.h5 path/to/tfjs-model/

GPUアクセラレーションはWebGL(成熟、float32のみ)とWebGPU(高性能、float16対応、Chrome 113+)で利用可能です。WASMバックエンドがCPUフォールバックを提供します。


モデル読み込みと推論

事前学習モデルの読み込みと推論は数行のコードで実現できます:

const model = await tf.loadGraphModel(
  'https://tfhub.dev/google/tfjs-model/mobilenet_v2/1/default/1/model.json'
);
const img = document.getElementById('image');
const tensor = tf.browser.fromPixels(img).expandDims(0);
const predictions = await model.predict(tensor);
const topClass = predictions.argMax(1).dataSync()[0];

メモリ管理にはtf.tidy()を使用し、中間テンソルを自動的に破棄します:

const result = tf.tidy(() => {
  const processed = imgTensor.div(255).expandDims(0);
  return model.predict(processed);
});

転移学習とカスタムモデル

転移学習により、少ないデータセットでカスタム分類器を訓練できます。MobileNetなどの事前学習済み特徴抽出器を読み込み、新しい分類層を追加し、ベース層を凍結します:

const baseModel = await tf.loadGraphModel(MOBILETEN_URL);
const classifier = tf.sequential();
classifier.add(tf.layers.dense({ units: 128, activation: 'relu', inputShape: [1024] }));
classifier.add(tf.layers.dense({ units: NUM_CLASSES, activation: 'softmax' }));

baseModel.trainable = false;
const combined = tf.sequential();
combined.add(baseModel);
combined.add(classifier);
combined.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy' });
await combined.fit(trainingData, trainingLabels, { epochs: 10 });

このアプローチは、クラスあたり100サンプル程度の少量データでも効果的です。


リアルタイム姿勢検出

TensorFlow.jsのモデルゾーンには、姿勢推定や物体検出の事前構築済み検出器が含まれています:

const detector = await poseDetection.createDetector(
  poseDetection.SupportedModels.MoveNet
);
const video = document.getElementById('webcam');
const poses = await detector.estimatePoses(video);

poses.forEach(pose => {
  pose.keypoints.forEach(kp => {
    ctx.beginPath();
    ctx.arc(kp.x, kp.y, 5, 0, 2 * Math.PI);
    ctx.fillStyle = 'red';
    ctx.fill();
  });
});

フレームスキップや入力サイズの縮小により、パフォーマンスを最適化できます。


バックエンド比較

バックエンドGPU必須精度成熟度MobileNet推論速度
WebGLはいfloat32成熟~15ms
WebGPUはいfloat16/32実験的~8ms
WASMいいえfloat32安定~40ms
CPUいいえfloat32安定~120ms

WebGPUは一般的なモデルでWebGLの約2倍の速度を提供しますが、対応ブラウザは現在Chromiumベースに限定されます。


プロダクション考慮点

モデルサイズは5〜200MBになります。量子化とプルーニングでサイズを削減し、IndexedDB(tf.io.IndexedDB)にキャッシュして再ダウンロードを防ぎます。モバイルではバッテリー消費についてユーザーに警告し、tf.tidy()でメモリリークを防止します。ハードウェアアクセラレーションが利用できない場合のグレースフルなフォールバックも実装しましょう。

クライアントサイド推論により、機密データがユーザーのデバイスから出ることはありません。医療診断支援、個人財務分析、機密文書の分類など、プライバシーが重要なユースケースで特に価値を発揮します。

TensorFlow.jsはWeb開発者にMLを身近なものにします。ゼロからの訓練よりも、事前学習モデルと転移学習から始め、特定のモデルに最適なバックエンドをプロファイリングし、tfjs-modelsコレクションの姿勢検出や顔ランドマーク、手の追跡を活用しましょう。