← Back to Lessons Lesson 41 of 48
Advanced api

Wasm + Machine Learning

Why Run ML in the Browser with Wasm?

Machine learning inference in the browser has unique advantages:

  Traditional ML Pipeline           Wasm ML Pipeline

  ┌────────┐     ┌────────┐        ┌────────────────────┐
  │ Client │────>│ Server │        │ Browser             │
  │        │     │ (GPU)  │        │                     │
  │ Send   │     │ Run    │        │ ┌────────────────┐  │
  │ data   │     │ model  │        │ │ Wasm ML Runtime│  │
  │ over   │     │ return │        │ │                │  │
  │ network│<────│ result │        │ │ Load model     │  │
  └────────┘     └────────┘        │ │ Run inference  │  │
                                   │ │ Return result  │  │
  ● Network latency                │ └────────────────┘  │
  ● Privacy: data leaves device    │                     │
  ● Server cost (GPU $$$)          │ ● Zero latency     │
                                   │ ● Data stays local  │
                                   │ ● No server cost    │
                                   └────────────────────┘
Factor Server ML Wasm ML (Browser) WebGPU ML
Latency Network RTT Instant (local) Instant (local)
Privacy Data sent to server Data stays on device Data stays on device
Hardware Server GPU CPU (all threads) Client GPU
Cost Per-request Free (client CPU) Free (client GPU)
Model size limit Unlimited ~100 MB practical ~1 GB practical
Accuracy Full precision Full precision May use f16
Offline capable No Yes (Service Worker) Yes
Browser support N/A (server) Universal Chrome, Edge, Firefox

ML Runtimes for Wasm

tract (ONNX Runtime in Rust)

tract is the most mature Rust ML runtime that compiles to Wasm:

use tract_onnx::prelude::*;

// Load an ONNX model
let model = tract_onnx::onnx()
    .model_for_path("model.onnx")?
    .with_input_fact(0, f32::fact([1, 3, 224, 224]).into())?
    .into_optimized()?
    .into_runnable()?;

// Run inference
let input: Tensor = tract_ndarray::Array4::from_shape_fn(
    (1, 3, 224, 224),
    |(_, c, y, x)| pixel_value(image, c, y, x)
).into();

let result = model.run(tvec!(input.into()))?;
let output = result[0].to_array_view::<f32>()?;

candle (Hugging Face)

use candle_core::{Tensor, Device};

let device = Device::Cpu;  // Wasm always uses CPU
let weights = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device)?;
let input = Tensor::from_vec(vec![0.5f32, 0.3, 0.2], (3,), &device)?;
let output = (weights * input)?;

Comparison of Wasm ML runtimes

  ┌─────────────┬──────────────┬───────────────┬────────────────┐
  │ Runtime     │ Model Format │ Wasm Support  │ Best For       │
  ├─────────────┼──────────────┼───────────────┼────────────────┤
  │ tract       │ ONNX, TF     │ Excellent     │ Classic models │
  │ candle      │ safetensors  │ Good          │ LLMs, HF models│
  │ burn        │ burn format  │ Experimental  │ Custom training│
  │ ort (onnx)  │ ONNX         │ Via wasm-pack │ ONNX ecosystem │
  └─────────────┴──────────────┴───────────────┴────────────────┘

Tensor Operations in Detail

The core of any ML runtime is tensor math. Here are the key operations:

Matrix Multiplication (GEMM)

The most performance-critical operation in neural networks:

  Matrix A (2x3)     Matrix B (3x2)     Result (2x2)

  ┌─────────────┐   ┌─────────┐       ┌───────────┐
  │ 1   2   3   │   │ 7   8   │       │ 58    64  │
  │             │ × │ 9  10   │  =    │           │
  │ 4   5   6   │   │11  12   │       │139   154  │
  └─────────────┘   └─────────┘       └───────────┘

  result[0][0] = 1*7 + 2*9  + 3*11  = 58
  result[0][1] = 1*8 + 2*10 + 3*12  = 64
  result[1][0] = 4*7 + 5*9  + 6*11  = 139
  result[1][1] = 4*8 + 5*10 + 6*12  = 154

Activation Functions

  ReLU                    Sigmoid                  Softmax
  f(x) = max(0, x)       f(x) = 1/(1+e^-x)       f(xi) = e^xi / sum(e^xj)

  output │     ╱          output │    ──────       output │   ╱
         │    ╱                  │  ╱              (prob) │  ╱
         │   ╱              0.5 │╱                       │╱────
  ───────┼──╱            ───────┼──────            ──────┼──────
         │               ╱     │                        │
         │             ╱       │                        │
  ───────┘                                              │
Activation Formula Range Used For
ReLU max(0, x) [0, inf) Hidden layers (default)
Sigmoid 1 / (1 + e^-x) (0, 1) Binary classification
Softmax e^xi / sum(e^xj) (0, 1) Multi-class output
Tanh (e^x - e^-x)/(e^x + e^-x) (-1, 1) RNNs, hidden layers

Loading Models in the Browser

  ┌─────────────────────────────────────────────────────┐
  │  Model Loading Pipeline                             │
  │                                                     │
  │  1. fetch("model.onnx")                             │
  │     │                                               │
  │     ▼                                               │
  │  2. ArrayBuffer (raw bytes)                         │
  │     │                                               │
  │     ▼                                               │
  │  3. Pass to Wasm: load_model(&bytes)                │
  │     │                                               │
  │     ├── Parse model graph (protobuf/flatbuf)        │
  │     ├── Allocate weight tensors in Wasm memory      │
  │     ├── Optimize graph (fuse ops, constant fold)    │
  │     └── Return opaque handle                        │
  │     │                                               │
  │     ▼                                               │
  │  4. Ready for inference: predict(handle, input)     │
  └─────────────────────────────────────────────────────┘
// JavaScript side
import init, { WasmModel } from './pkg/ml_wasm';

async function loadAndPredict() {
    await init();

    // Fetch model bytes
    const response = await fetch('/models/mobilenet_v2.onnx');
    const modelBytes = new Uint8Array(await response.arrayBuffer());

    // Load into Wasm
    const model = WasmModel.load(modelBytes);

    // Prepare input (224x224 RGB image as flat Float32Array)
    const input = preprocessImage(imageElement);

    // Run inference
    const output = model.predict(input);
    const topClass = argmax(output);
    console.log(`Prediction: ${IMAGENET_CLASSES[topClass]}`);
}

Performance: Wasm vs TensorFlow.js

Benchmarks on a typical laptop (MobileNet V2, single image):

  Inference Time (ms) — lower is better

  TF.js (WebGL)  ████████████  12ms
  Wasm (tract)   ██████████████████  18ms
  TF.js (Wasm)   ████████████████████  20ms
  TF.js (CPU/JS) █████████████████████████████████████████  42ms

  ← faster                                     slower →
Runtime MobileNet V2 ResNet50 BERT (base)
TF.js WebGL ~12ms ~45ms ~120ms
Wasm (tract) ~18ms ~80ms ~300ms
TF.js Wasm backend ~20ms ~90ms ~350ms
TF.js CPU (pure JS) ~42ms ~200ms ~800ms
Native (PyTorch CPU) ~8ms ~30ms ~80ms

Key takeaways:

  • Wasm is 2-3x faster than pure JavaScript
  • WebGL/WebGPU beats Wasm for large models (GPU parallelism)
  • Wasm wins for small models where GPU setup overhead dominates
  • Wasm is more predictable (no GPU driver variability)

When to Use Wasm ML vs WebGPU

  Decision Matrix:

                     Small model           Large model
                     (< 10M params)        (> 100M params)
  ┌─────────────────┬────────────────────┬────────────────────┐
  │ Need offline    │  Wasm ✓            │  Wasm (if fits)    │
  │ support?        │                    │  or WebGPU         │
  ├─────────────────┼────────────────────┼────────────────────┤
  │ Consistent      │  Wasm ✓            │  Wasm              │
  │ performance?    │  (no GPU variance) │  (slower but steady)│
  ├─────────────────┼────────────────────┼────────────────────┤
  │ Maximum speed?  │  Wasm ✓            │  WebGPU ✓          │
  │                 │  (GPU overhead     │  (GPU parallelism  │
  │                 │   not worth it)    │   dominates)       │
  ├─────────────────┼────────────────────┼────────────────────┤
  │ Browser compat? │  Wasm ✓            │  Wasm ✓            │
  │                 │  (universal)       │  WebGPU still new  │
  └─────────────────┴────────────────────┴────────────────────┘

Practical Use Cases

Application Model Type Why Wasm?
Spam detection Logistic reg. Tiny model, instant prediction
Image classification MobileNet Runs on any device, offline
Text autocomplete Small RNN Low latency, privacy
Pose estimation (camera) PoseNet Real-time, no server round-trip
Document OCR CRNN Sensitive documents stay local
Audio keyword detection Small CNN Always-on, low power
Anomaly detection (IoT) Autoencoder Edge device, no connectivity

Summary

Wasm brings ML inference to the browser with full precision, offline capability, and zero server cost. Use tract for ONNX models, candle for Hugging Face models, or build custom tensor operations in Rust. Wasm ML is 2-3x faster than pure JavaScript and is the best choice for small-to-medium models where GPU overhead is not worth it. For large models (LLMs, diffusion), combine Wasm with WebGPU for GPU acceleration.

Try It