npm package discovery and stats viewer.

Discover Tips

  • General search

    [free text search, go nuts!]

  • Package details

    pkg:[package-name]

  • User packages

    @[username]

Sponsor

Optimize Toolset

I’ve always been into building performant and accessible sites, but lately I’ve been taking it extremely seriously. So much so that I’ve been building a tool to help me optimize and monitor the sites that I build to make sure that I’m making an attempt to offer the best experience to those who visit them. If you’re into performant, accessible and SEO friendly sites, you might like it too! You can check it out at Optimize Toolset.

About

Hi, 👋, I’m Ryan Hefner  and I built this site for me, and you! The goal of this site was to provide an easy way for me to check the stats on my npm packages, both for prioritizing issues and updates, and to give me a little kick in the pants to keep up on stuff.

As I was building it, I realized that I was actually using the tool to build the tool, and figured I might as well put this out there and hopefully others will find it to be a fast and useful way to search and browse npm packages as I have.

If you’re interested in other things I’m working on, follow me on Twitter or check out the open source projects I’ve been publishing on GitHub.

I am also working on a Twitter bot for this site to tweet the most popular, newest, random packages from npm. Please follow that account now and it will start sending out packages soon–ish.

Open Software & Tools

This site wouldn’t be possible without the immense generosity and tireless efforts from the people who make contributions to the world and share their work via open source initiatives. Thank you 🙏

© 2026 – Pkg Stats / Ryan Hefner

@jax-js/jax

v0.1.9

Published

Numerical computing and ML in the browser

Readme

jax-js is a machine learning framework for the browser. It aims to bring JAX-style, high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.

npm i @jax-js/jax

Under the hood, it translates array operations into a compiler representation, then synthesizes kernels in WebAssembly and WebGPU.

The library is written from scratch, with zero external dependencies. It maintains close API compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable GPU ML framework, since it runs anywhere a browser can run.

Quickstart

import { numpy as np } from "@jax-js/jax";

// Array operations, compatible with JAX/NumPy.
const x = np.array([1, 2, 3]);
const y = x.mul(4); // [4, 8, 12]

Web usage (CDN)

In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest way to get started on a blank HTML page.

<script type="module">
  import { numpy as np } from "https://esm.sh/@jax-js/jax";
</script>

Examples

Cool things that the community has made with jax-js:

And some more demos on the official website.

Feature comparison

Here's a quick, high-level comparison with other popular web ML runtimes:

| Feature | jax-js | TensorFlow.js | onnxruntime-web | | ------------------------------- | ---------- | --------------- | ------------------ | | Overview | | | | | API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs | | Latest release | 2026 | ⚠️ 2024 | 2026 | | Speed | Fastest | Fast | Fastest | | Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm | | Autodiff & JIT | | | | | Gradients | ✅ | ✅ | ❌ | | Jacobian and Hessian | ✅ | ❌ | ❌ | | jvp() forward differentiation | ✅ | ❌ | ❌ | | jit() kernel fusion | ✅ | ❌ | ❌ | | vmap() auto-vectorization | ✅ | ❌ | ❌ | | Graph capture | ✅ | ❌ | ✅ | | Backends & Data | | | | | WebGPU backend | ✅ | 🟡 Preview | ✅ | | WebGL backend | ✅ | ✅ | ✅ | | Wasm (CPU) backend | ✅ | ✅ | ✅ | | Eager array API | ✅ | ✅ | ❌ | | Run ONNX models | 🟡 Partial | ❌ | ✅ | | Read safetensors | ✅ | ❌ | ❌ | | Float64 | ✅ | ❌ | ❌ | | Float32 | ✅ | ✅ | ✅ | | Float16 | ✅ | ❌ | ✅ | | BFloat16 | ❌ | ❌ | ❌ | | Packed Uint8 | ❌ | ❌ | 🟡 Partial | | Mixed precision | ✅ | ❌ | ✅ | | Mixed devices | ✅ | ❌ | ❌ | | Ops & Numerics | | | | | Arithmetic functions | ✅ | ✅ | ✅ | | Matrix multiplication | ✅ | ✅ | ✅ | | General einsum | ✅ | 🟡 Partial | 🟡 Partial | | Sorting | ✅ | ❌ | ❌ | | Activation functions | ✅ | ✅ | ✅ | | NaN/Inf numerics | ✅ | ✅ | ✅ | | Basic convolutions | ✅ | ✅ | ✅ | | n-d convolutions | ✅ | ❌ | ✅ | | Strided/dilated convolution | ✅ | ✅ | ✅ | | Cholesky, Lstsq | ✅ | ❌ | ❌ | | LU, Solve, Determinant | ✅ | ❌ | ❌ | | SVD | ❌ | ❌ | ❌ | | FFT | ✅ | ✅ | ✅ | | Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ | | Advanced RNG | ✅ | ❌ | ❌ |

Tutorial

Programming in jax-js looks very similar to JAX, just in JavaScript.

Arrays

Create an array with np.array():

import { numpy as np } from "@jax-js/jax";

const ar = np.array([1, 2, 3]);

By default, this is a float32 array, but you can specify a different dtype:

const ar = np.array([1, 2, 3], { dtype: np.int32 });

For more efficient construction, create an array from a JS TypedArray buffer:

const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
const ar = np.array(buf).reshape([2, 3]);

Once you're done with it, you can unwrap a jax.Array back into JavaScript. This will also apply any pending operations or lazy updates:

// 1) Returns a possibly nested JavaScript array.
ar.js();
await ar.jsAsync(); // Faster, non-blocking

// 2) Returns a flat TypedArray data buffer.
ar.dataSync();
await ar.data(); // Fastest, non-blocking

Arrays can have mathematical operations applied to them. For example:

import { numpy as np, scipySpecial as special } from "@jax-js/jax";

const x = np.arange(100).astype(np.float32); // array of integers [0..99]

const y1 = x.ref.add(x.ref); // x + x
const y2 = np.sin(x.ref); // sin(x)
const y3 = np.tanh(x.ref).mul(5); // 5 * tanh(x)
const y4 = special.erfc(x.ref); // erfc(x)

Notice that in the above code, we used x.ref. This is because of the memory model, jax-js uses reference-counted ownership to track when the memory of an Array can be freed. More on this below.

Reference counting

Big Arrays take up a lot of memory. Python ML libraries override the __del__() method to free memory, but JavaScript has no such API for running object destructors (cf.). This means that you have to track references manually. jax-js tries to make this as ergonomic as possible, so you don't accidentally leak memory in a loop.

Every jax.Array has a reference count. This satisfies the following rules:

  • Whenever you create an Array, its reference count starts at 1.
  • When an Array's reference count reaches 0, it is freed and can no longer be used.
  • Given an Array a:
    • Accessing a.ref returns a and changes its reference count by +1.
    • Passing a into any function as argument changes its reference count by -1.
    • Calling a.dispose() also changes its reference count by -1.

What this means is that all functions in jax-js must take ownership of their arguments as references. Whenever you would like to pass an Array as argument, you can pass it directly to dispose of it, or use .ref if you'd like to use it again later.

You must follow these rules on your own functions as well! All combinators like jvp, grad, jit assume that you are following these conventions on how arguments are passed, and they will respect them as well.

// Bad: Uses `x` twice, decrementing its reference count twice.
function foo_bad(x: np.Array, y: np.Array) {
  return x.add(x.mul(y));
}

// Good: The first usage of `x` is `x.ref`, adding +1 to refcount.
function foo_good(x: np.Array, y: np.Array) {
  return x.ref.add(x.mul(y));
}

Here's another example:

// Bad: Doesn't consume `x` in the `if`-branch.
function bar_bad(x: np.Array, skip: boolean) {
  if (skip) return np.zeros(x.shape);
  return x;
}

// Good: Consumes `x` the one time in each branch.
function bar_good(x: np.Array, skip: boolean) {
  if (skip) {
    const ret = np.zeros(x.shape);
    x.dispose();
    return ret;
  }
  return x;
}

You can assume that every function in jax-js takes ownership properly, except with a couple of very rare exceptions that are documented.

grad(), vmap() and jit()

JAX's signature composable transformations are also supported in jax-js. Here is a simple example of using grad and vmap to compute the derivaive of a function:

import { numpy as np, grad, vmap } from "@jax-js/jax";

const x = np.linspace(-10, 10, 1000);

const y1 = vmap(grad(np.sin))(x.ref); // d/dx sin(x) = cos(x)
const y2 = np.cos(x);

np.allclose(y1, y2); // => true

The jit function is especially useful when doing long sequences of primitives on GPU, since it fuses operations together into a single kernel dispatch. This improves memory bandwidth usage on hardware accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:

export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
  return np.sqrt(np.square(x1).add(np.square(x2)));
});

Without JIT, the hypot() function would require four kernel dispatches: two multiplies, one add, and one sqrt. JIT fuses these together into a single kernel that does it all at once.

All functional transformations can take typed JsTree of inputs and outputs. These are similar to JAX's pytrees, and it's basically just a structure of nested JavaScript objects and arrays. For instance:

import { grad, numpy as np } from "@jax-js/jax";

type Params = {
  foo: np.Array;
  bar: np.Array[];
};

function getSums(p: Params) {
  const fooSum = p.foo.sum();
  const barSum = p.bar.map((x) => x.sum()).reduce(np.add);
  return fooSum.add(barSum);
}

grad(getSums)({
  foo: np.array([1, 2, 3]),
  bar: [np.array([10]), np.array([11, 12])],
});
// => { foo: [1, 1, 1], bar: [[1], [1, 1]] }

Note that you need to use type alias syntax rather than interface to define fine-grained JsTree types.

Devices

Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory and determines how to execute compiled operations on them.

There are currently 4 devices in jax-js:

  • cpu: Slow, interpreted JS, only meant for debugging.
  • wasm: WebAssembly, currently single-threaded and blocking.
  • webgpu: WebGPU, available on supported browsers (Chrome, Firefox, Safari, iOS).
  • webgl: WebGL2, via fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much slower than WebGPU. It's offered on a best-effort basis and not as well-supported.

We recommend webgpu for best performance, especially when running neural networks. The default device is wasm, but you can change this at startup time:

import { defaultDevice, init } from "@jax-js/jax";

const devices = await init(); // Starts all available backends.

if (devices.includes("webgpu")) {
  defaultDevice("webgpu");
} else {
  console.warn("WebGPU is not supported, falling back to Wasm.");
}

You can also place individual arrays on specific devices:

import { devicePut, numpy as np } from "@jax-js/jax";

const ar = np.array([1, 2, 3]); // Starts with device="wasm"
await devicePut(ar, "webgpu"); // Now device="webgpu"

Helper libraries

There are other libraries in the @jax-js namespace that can work with jax-js, or be used in a self-contained way in other projects.

  • @jax-js/loaders can load tensors from various formats like Safetensors, includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets like model weights in OPFS.
  • @jax-js/onnx is a model loader from the ONNX format into native jax-js functions.
  • @jax-js/optax provides implementations of optimizers like Adam and SGD.

Performance

The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual browsers. Also, this library uniquely has the jit() feature that fuses operations together and records an execution graph. jax-js achieves over 7000 GFLOP/s for matrix multiplication on an Apple M4 Max chip (try it).

For that example, it's significantly faster than both TensorFlow.js and ONNX Runtime Web, which both use handwritten libraries of custom kernels.

It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as well as unique optimizations such as FlashAttention variants.

API Reference

That's all for this short tutorial. Please see the generated API reference for detailed documentation.

Development

The following technical details are for contributing to jax-js and modifying its internals.

This repository is managed by pnpm. You can compile and build all packages in watch mode with:

pnpm install
pnpm run build:watch

The pnpm install command automatically sets up Git hooks via Husky. Pre-commit hooks will run ESLint and Prettier on staged files to ensure code quality.

You can also run linting and formatting manually:

pnpm lint          # Run ESLint
pnpm format        # Format all files with Prettier
pnpm format:check  # Check formatting without writing
pnpm check         # Run TypeScript type checking

Then you can run tests in a headless browser using Vitest.

pnpm exec playwright install
pnpm test

We are currently on an older version of Playwright that supports using WebGPU in headless mode; newer versions skip the WebGPU tests.

To start a Vite dev server running the website, demos and REPL:

pnpm -C website dev

Future work / help wanted

Contributions are welcomed! Some fruitful areas to look into:

  • Adding support for more JAX functions and operations, see compatibility table.
  • Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD and multithreading. (Even single-threaded Wasm could be ~20x faster.)
  • Adding support for jax.profiling, in particular the start and end trace functions. We should be able to generate traceEvents from backends (especially on GPU, with precise timestamp queries) to help with model performance debugging.
  • Helping the JIT compiler to fuse operations in more cases, like tanh branches.
  • Making a fast transformer inference engine, comparing against onnxruntime-web.

You may join our Discord server and chat with the community.