Featured image of post Machine Learning in the Browser with TensorFlow.js Featured image of post Machine Learning in the Browser with TensorFlow.js

Machine Learning in the Browser with TensorFlow.js

Machine learning in the browser with TensorFlow.js. Covers model training, transfer learning, WebGL/WebGPU backends, pose detection, and privacy-preserving ML.

Machine learning in the browser eliminates server costs, preserves user privacy, and enables offline-capable intelligent applications. TensorFlow.js brings ML to JavaScript developers with GPU-accelerated inference and training, powered by WebGL and WebGPU backends. This article covers loading pre-trained models, transfer learning, real-time pose detection, and production deployment considerations.

Why ML in the Browser?

Running ML models client-side offers four key advantages: zero server costs (inference runs on the user’s device), complete privacy (data never leaves the machine), offline capability (no network required after model load), and low latency (no round-trip for predictions). The trade-offs include limited compute power, memory constraints, battery drain on mobile devices, and large model download sizes (5-200 MB).


TensorFlow.js Overview

TensorFlow.js provides three APIs. The Core API offers low-level tensor operations. The Layers API enables high-level model building similar to Keras. The Converter transforms Python-trained Keras or TF models into the TensorFlow.js format:

tensorflowjs_converter --input_format=keras \
  path/to/model.h5 path/to/tfjs-model/

GPU acceleration is available via WebGL (mature, widely supported, float32 only) and WebGPU (better performance, float16 support, Chrome 113+). The WASM backend serves as a CPU-based fallback using XNNPACK optimization.


Model Loading and Inference

Loading a pre-trained model and running inference requires only a few lines of code:

const model = await tf.loadGraphModel(
  'https://tfhub.dev/google/tfjs-model/mobilenet_v2/1/default/1/model.json'
);
const img = document.getElementById('image');
const tensor = tf.browser.fromPixels(img).expandDims(0);
const predictions = await model.predict(tensor);
const topClass = predictions.argMax(1).dataSync()[0];

Memory management is critical. Use tf.tidy() to automatically dispose intermediate tensors:

const result = tf.tidy(() => {
  const processed = imgTensor.div(255).expandDims(0);
  return model.predict(processed);
});

Transfer Learning and Custom Models

Transfer learning enables training custom classifiers with small datasets. The process loads a pre-trained feature extractor (like MobileNet), removes its top classification layer, adds new trainable layers, and freezes the base:

const baseModel = await tf.loadGraphModel(MOBILENET_URL);
const classifier = tf.sequential();
classifier.add(tf.layers.dense({ units: 128, activation: 'relu', inputShape: [1024] }));
classifier.add(tf.layers.dense({ units: NUM_CLASSES, activation: 'softmax' }));

// Freeze base model and train only the classifier
baseModel.trainable = false;
const combined = tf.sequential();
combined.add(baseModel);
combined.add(classifier);
combined.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy' });
await combined.fit(trainingData, trainingLabels, { epochs: 10 });

This approach works well for custom image classification tasks with as few as 100 samples per class.


Real-Time Pose Detection

TensorFlow.js includes a model zoo with pre-built detectors for pose estimation, object detection, and face landmarks:

const detector = await poseDetection.createDetector(
  poseDetection.SupportedModels.MoveNet
);
const video = document.getElementById('webcam');
const poses = await detector.estimatePoses(video);

// Draw keypoints on canvas overlay
poses.forEach(pose => {
  pose.keypoints.forEach(kp => {
    ctx.beginPath();
    ctx.arc(kp.x, kp.y, 5, 0, 2 * Math.PI);
    ctx.fillStyle = 'red';
    ctx.fill();
  });
});

Performance optimization techniques include skipping frames, resizing input to smaller dimensions, and selecting lighter model variants.


Backend Performance Comparison

BackendGPU RequiredPrecisionMaturityMobileNet Speed
WebGLYesfloat32Mature~15ms
WebGPUYesfloat16/32Experimental~8ms
WASMNofloat32Stable~40ms
CPUNofloat32Stable~120ms

WebGPU offers approximately 2x speedup over WebGL for common models, but browser support is currently limited to Chromium-based browsers.


Production Considerations

Models can be 5-200 MB. Apply quantization and pruning to reduce size. Cache models in IndexedDB via tf.io.IndexedDB to avoid re-downloading. Show loading progress with tf.io progress callbacks. On mobile, warn users about battery impact. Use tf.tidy() diligently to prevent memory leaks. Always provide graceful fallback when hardware acceleration is unavailable.


Privacy-Preserving ML

Client-side inference ensures sensitive data never leaves the user’s device. This is particularly valuable for medical diagnosis support, personal finance analysis, and document classification with confidential content. Federated learning and differential privacy concepts can be applied in browser-based training scenarios.

TensorFlow.js makes ML accessible to web developers. Start with pre-trained models and transfer learning before attempting training from scratch. Profile the WebGL versus WebGPU backend for your specific model, and explore the tfjs-models collection for pose detection, face landmarks, hand tracking, and text toxicity classification.