ブラウザ上での機械学習は、サーバーコストの削減、ユーザープライバシーの保護、オフライン対応のインテリジェントアプリケーションを可能にします。TensorFlow.jsは、WebGLおよびWebGPUバックエンドによるGPUアクセラレーションを活用し、JavaScript開発者にMLへの扉を開きます。本記事では、事前学習モデルの読み込み、転移学習、リアルタイム姿勢検出、本番運用の考慮点を解説します。
なぜブラウザでMLか?
クライアントサイドMLの利点は4つあります。サーバーコストが不要(ユーザーのデバイスで推論を実行)、完全なプライバシー(データがマシン外に出ない)、オフライン対応(モデル読み込み後はネットワーク不要)、低レイテンシ(予測のラウンドトリップ不要)です。一方で、計算能力とメモリの制約、モバイルでのバッテリー消費、5〜200MBのモデルダウンロードサイズが課題です。
TensorFlow.js概要
TensorFlow.jsは3つのAPIを提供します。Core APIは低レベルのテンソル操作を、Layers APIはKerasに類似した高レベルモデル構築を、ConverterはPythonで訓練されたモデルをTensorFlow.js形式に変換します:
tensorflowjs_converter --input_format=keras \
path/to/model.h5 path/to/tfjs-model/
GPUアクセラレーションはWebGL(成熟、float32のみ)とWebGPU(高性能、float16対応、Chrome 113+)で利用可能です。WASMバックエンドがCPUフォールバックを提供します。
モデル読み込みと推論
事前学習モデルの読み込みと推論は数行のコードで実現できます:
const model = await tf.loadGraphModel(
'https://tfhub.dev/google/tfjs-model/mobilenet_v2/1/default/1/model.json'
);
const img = document.getElementById('image');
const tensor = tf.browser.fromPixels(img).expandDims(0);
const predictions = await model.predict(tensor);
const topClass = predictions.argMax(1).dataSync()[0];
メモリ管理にはtf.tidy()を使用し、中間テンソルを自動的に破棄します:
const result = tf.tidy(() => {
const processed = imgTensor.div(255).expandDims(0);
return model.predict(processed);
});
転移学習とカスタムモデル
転移学習により、少ないデータセットでカスタム分類器を訓練できます。MobileNetなどの事前学習済み特徴抽出器を読み込み、新しい分類層を追加し、ベース層を凍結します:
const baseModel = await tf.loadGraphModel(MOBILETEN_URL);
const classifier = tf.sequential();
classifier.add(tf.layers.dense({ units: 128, activation: 'relu', inputShape: [1024] }));
classifier.add(tf.layers.dense({ units: NUM_CLASSES, activation: 'softmax' }));
baseModel.trainable = false;
const combined = tf.sequential();
combined.add(baseModel);
combined.add(classifier);
combined.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy' });
await combined.fit(trainingData, trainingLabels, { epochs: 10 });
このアプローチは、クラスあたり100サンプル程度の少量データでも効果的です。
リアルタイム姿勢検出
TensorFlow.jsのモデルゾーンには、姿勢推定や物体検出の事前構築済み検出器が含まれています:
const detector = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet
);
const video = document.getElementById('webcam');
const poses = await detector.estimatePoses(video);
poses.forEach(pose => {
pose.keypoints.forEach(kp => {
ctx.beginPath();
ctx.arc(kp.x, kp.y, 5, 0, 2 * Math.PI);
ctx.fillStyle = 'red';
ctx.fill();
});
});
フレームスキップや入力サイズの縮小により、パフォーマンスを最適化できます。
バックエンド比較
| バックエンド | GPU必須 | 精度 | 成熟度 | MobileNet推論速度 |
|---|---|---|---|---|
| WebGL | はい | float32 | 成熟 | ~15ms |
| WebGPU | はい | float16/32 | 実験的 | ~8ms |
| WASM | いいえ | float32 | 安定 | ~40ms |
| CPU | いいえ | float32 | 安定 | ~120ms |
WebGPUは一般的なモデルでWebGLの約2倍の速度を提供しますが、対応ブラウザは現在Chromiumベースに限定されます。
プロダクション考慮点
モデルサイズは5〜200MBになります。量子化とプルーニングでサイズを削減し、IndexedDB(tf.io.IndexedDB)にキャッシュして再ダウンロードを防ぎます。モバイルではバッテリー消費についてユーザーに警告し、tf.tidy()でメモリリークを防止します。ハードウェアアクセラレーションが利用できない場合のグレースフルなフォールバックも実装しましょう。
クライアントサイド推論により、機密データがユーザーのデバイスから出ることはありません。医療診断支援、個人財務分析、機密文書の分類など、プライバシーが重要なユースケースで特に価値を発揮します。
TensorFlow.jsはWeb開発者にMLを身近なものにします。ゼロからの訓練よりも、事前学習モデルと転移学習から始め、特定のモデルに最適なバックエンドをプロファイリングし、tfjs-modelsコレクションの姿勢検出や顔ランドマーク、手の追跡を活用しましょう。
