npm package discovery and stats viewer.

Discover Tips

  • General search

    [free text search, go nuts!]

  • Package details

    pkg:[package-name]

  • User packages

    @[username]

Sponsor

Optimize Toolset

I’ve always been into building performant and accessible sites, but lately I’ve been taking it extremely seriously. So much so that I’ve been building a tool to help me optimize and monitor the sites that I build to make sure that I’m making an attempt to offer the best experience to those who visit them. If you’re into performant, accessible and SEO friendly sites, you might like it too! You can check it out at Optimize Toolset.

About

Hi, 👋, I’m Ryan Hefner  and I built this site for me, and you! The goal of this site was to provide an easy way for me to check the stats on my npm packages, both for prioritizing issues and updates, and to give me a little kick in the pants to keep up on stuff.

As I was building it, I realized that I was actually using the tool to build the tool, and figured I might as well put this out there and hopefully others will find it to be a fast and useful way to search and browse npm packages as I have.

If you’re interested in other things I’m working on, follow me on Twitter or check out the open source projects I’ve been publishing on GitHub.

I am also working on a Twitter bot for this site to tweet the most popular, newest, random packages from npm. Please follow that account now and it will start sending out packages soon–ish.

Open Software & Tools

This site wouldn’t be possible without the immense generosity and tireless efforts from the people who make contributions to the world and share their work via open source initiatives. Thank you 🙏

© 2026 – Pkg Stats / Ryan Hefner

@wlearn/mitra

v0.1.0

Published

Mitra Tab2D in-context learning model as a wlearn Estimator (ONNX Runtime)

Readme

@wlearn/mitra

Mitra Tab2D in-context learning model as a wlearn Estimator. Wraps the 72M-parameter tabular foundation model from Amazon/AutoGluon in the standard wlearn fit() / predict() / save() / load() API.

Mitra uses in-context learning: instead of gradient-based training, fit() selects a support set from your data, and predict() passes that support set alongside your query data through the ONNX model in a single forward pass.

Data storage warning

fit() stores a subset of your training data (up to maxSupport rows, default 512) inside the model instance. This support set is:

  • Held in memory as Float32Array for the lifetime of the instance
  • Serialized into the .wlrn bundle when you call save()
  • Required for every predict() call (it is the model's "context")

If your training data is sensitive, be aware that saved bundles contain real data points. The maxSupport parameter controls how many rows are stored. Call dispose() to release the support set from memory.

Install

npm install @wlearn/mitra onnxruntime-node   # Node.js
npm install @wlearn/mitra onnxruntime-web    # Browser

You also need the ONNX model files (see "ONNX conversion" below).

Usage

import { MitraClassifier } from '@wlearn/mitra'
import * as ort from 'onnxruntime-node'
import { readFileSync } from 'fs'

// Load ONNX model (290 MB)
const onnxBytes = readFileSync('mitra-classifier.onnx')
const session = await ort.InferenceSession.create(onnxBytes.buffer)

// Create estimator
const model = await MitraClassifier.create(session, {
  maxSupport: 512,  // max support set size (default)
  seed: 42          // RNG seed for support set sampling
}, { ort })

// Fit: selects and stores a support set from training data
model.fit(X_train, y_train)

// Predict: runs ONNX inference with stored support set as context
const predictions = await model.predict(X_test)
const probabilities = await model.predictProba(X_test)
const acc = await model.score(X_test, y_test)

// Save/load (requires the same ONNX model to load)
const bundle = model.save()  // Uint8Array (.wlrn format)
const loaded = await MitraClassifier.load(bundle, session, { ort })

model.dispose()

Regressor

import { MitraRegressor } from '@wlearn/mitra'

const session = await ort.InferenceSession.create(regressorOnnxBytes.buffer)
const model = await MitraRegressor.create(session, { maxSupport: 512 }, { ort })

model.fit(X_train, y_train)
const predictions = await model.predict(X_test)
const r2 = await model.score(X_test, y_test)

Registry integration

import { registerLoaders } from '@wlearn/mitra'
import { load } from '@wlearn/core'

// Register loaders so core.load() can dispatch .wlrn bundles
registerLoaders(classifierSession, regressorSession, { ort })
const model = await load(bundleBytes)

API

MitraClassifier

| Method | Returns | Description | |--------|---------|-------------| | static create(onnxSource, params?, opts?) | Promise<MitraClassifier> | Factory. onnxSource is an InferenceSession or Uint8Array of ONNX bytes | | fit(X, y) | this | Select support set from training data (sync) | | predict(X) | Promise<Float64Array> | Class label predictions | | predictProba(X) | Promise<Float64Array> | Class probabilities (rows * nClasses) | | score(X, y) | Promise<number> | Accuracy | | save() | Uint8Array | Serialize to .wlrn bundle | | static load(bytes, onnxSource, opts?) | Promise<MitraClassifier> | Deserialize | | dispose() | void | Release session and support set | | getParams() | object | { maxSupport, seed } | | setParams(p) | this | Update params | | capabilities | object | { classifier: true, predictProba: true, ... } | | classes | Int32Array | Sorted unique class labels | | nrClass | number | Number of classes | | nrFeature | number | Number of features |

MitraRegressor

Same API minus predictProba, classes, nrClass. score() returns R2.

Parameters

| Param | Default | Description | |-------|---------|-------------| | maxSupport | 512 | Maximum support set size. If training data exceeds this, a subset is sampled | | seed | 42 | RNG seed for deterministic support set sampling |

Bundle format

The .wlrn bundle stores the support set, not the ONNX model (which is 290 MB). Loading requires the ONNX model to be provided separately.

| Artifact | Format | Contents | |----------|--------|----------| | meta | JSON | { nFeatures, nSupport, classes, onnxSha256, seed } | | support_x | raw float32 | Support set features, row-major | | support_y | raw int32/float32 | Support set labels (int32 for classifier, float32 for regressor) |

TypeIds: wlearn.mitra_onnx.classifier@1, wlearn.mitra_onnx.regressor@1

ONNX model variants

| Model | HuggingFace | Output | |-------|-------------|--------| | Classifier | autogluon/mitra-classifier | (B, N_query, 10) logits | | Regressor | autogluon/mitra-regressor | (B, N_query) values |

ONNX conversion

pip install -r requirements.txt

Convert

python convert.py              # both variants
python convert.py --variant classifier
python convert.py --variant regressor

Produces mitra-classifier.onnx and/or mitra-regressor.onnx.

Verify

Compare PyTorch and ONNX Runtime outputs:

python verify.py
python verify.py --atol 1e-4

ONNX model inputs

All dimensions are dynamic (variable batch, support/query/feature counts).

| Input | Shape | Type | |-------|-------|------| | x_support | (B, N_support, N_features) | float32 | | y_support | (B, N_support) | int64 (classifier) / float32 (regressor) | | x_query | (B, N_query, N_features) | float32 | | padding_obs_support | (B, N_support) | bool |

Note: padding_features and padding_obs_query are accepted by the PyTorch model for API compatibility but are unused in the CPU code path (they only matter for flash attention). The ONNX tracer correctly eliminates them from the graph.

Testing

Requires ONNX models in the repo root (run python convert.py first).

npm install
npm test

26 tests covering: create, fit, predict, predictProba, score, save/load round-trip, dispose, error handling, support set selection, determinism.

What was changed for ONNX export

The upstream AutoGluon implementation (_internal/models/tab2d.py, _internal/models/embedding.py) uses several ops that the ONNX exporter cannot trace or has no opset mapping for. Each replacement below preserves numerical equivalence while using only standard ONNX ops.

Quantile computation (Tab2DQuantileEmbeddingX)

The upstream computes 999 quantiles of x_support along the observation axis with torch.quantile, which has no ONNX opset mapping.

Replacement: torch.sort along dim 1, then torch.gather at fractional index positions with linear interpolation between the floor and ceil indices. Produces identical quantile boundaries.

Bucketize / searchsorted (Tab2DQuantileEmbeddingX)

The upstream maps each value to its quantile bin using torch.vmap(torch.bucketize, in_dims=(0,0)). Both vmap (not traceable) and bucketize / searchsorted (no ONNX op) are unavailable.

Replacement: broadcasting comparison. For values (b, f, s) and boundaries (b, f, 999), compute (values.unsqueeze(-1) >= boundaries.unsqueeze(-2)).sum(-1). This counts how many boundaries each value exceeds, which is exactly the bucket index. O(n * 999) instead of O(n * log 999), but 999 is small and the operation is pure element-wise ONNX ops (GreaterOrEqual, Cast, ReduceSum).

In-place masked assignment (Tab2DQuantileEmbeddingX, Tab2DEmbeddingY*)

The upstream uses x_support[padding_mask] = 9999 and y_support[padding_obs] = 0 to set padded positions. In-place mutation through boolean indexing is not traceable.

Replacement: torch.where(mask, fill_value, x). Functionally identical, produces a new tensor instead of mutating.

einops / einx (Tab2D, Layer, MultiheadAttention, embeddings)

The upstream uses einops.rearrange, einops.pack/unpack, einx.rearrange, and einx.sum throughout. These are external libraries the ONNX tracer cannot see through.

Replacements:

  • einx.rearrange("b s f -> b s f 1", x) -- x.unsqueeze(-1)
  • einops.rearrange("b n -> b n 1", y) -- y.unsqueeze(-1)
  • einops.rearrange("b s f d -> (b f) s d", x) -- x.permute(0,2,1,3).reshape(b*f, s, d)
  • einops.rearrange("(b f) s d -> b s f d", x, b=b) -- x.reshape(b, f, s, d).permute(0,2,1,3)
  • einops.rearrange("b s f d -> (b s) f d", x) -- x.reshape(b*s, f, d)
  • einops.rearrange("b t (h d) -> b h t d", q, h=h) -- q.reshape(b, t, h, d).permute(0,2,1,3)
  • einops.pack((y, x), "b s * d") -- torch.cat([y, x], dim=2)
  • einops.unpack(q, pack_info, "b s * c") -- q[:, :, 0, :] (index the y slot)
  • einx.sum("b [s] f", x) -- x.sum(dim=1, keepdim=True)

Gradient checkpointing (Tab2D)

The upstream wraps each layer call in torch.utils.checkpoint.checkpoint(layer, ...) which is a training-only optimization not compatible with tracing.

Replacement: direct call layer(support, query). The model is exported in eval mode so checkpointing has no effect anyway.

Flash attention path (Tab2D, Layer, Padder)

The upstream has two code paths: a flash attention path (CUDA with flash_attn library) and a CPU path using F.scaled_dot_product_attention. The flash attention path uses flash_attn_varlen_func, unpad_input/pad_input, and a Padder class -- none of which are ONNX-exportable.

Replacement: only the CPU path is reimplemented. F.scaled_dot_product_attention maps cleanly to standard ONNX attention ops. The Padder class and all flash attention imports are removed entirely.

Summary table

| Original | Replacement | Location | |----------|-------------|----------| | torch.quantile | sort + gather + lerp | Tab2DQuantileEmbeddingX | | torch.vmap(torch.bucketize) | broadcast compare + sum | Tab2DQuantileEmbeddingX | | x[mask] = val (in-place) | torch.where | embeddings | | einops.rearrange / einx.rearrange | reshape / permute / unsqueeze | everywhere | | einops.pack / unpack | torch.cat / indexing | Tab2D.forward | | einx.sum | torch.sum | Tab2DQuantileEmbeddingX | | checkpoint(layer, ...) | layer(...) | Tab2D.forward | | Flash attention path + Padder | Removed (CPU path only) | Tab2D, Layer |

State dict keys match upstream exactly -- safetensors load without any key renaming.

License

The Mitra model weights are Apache-2.0 licensed by Amazon/AutoGluon. This conversion code is Apache-2.0 licensed.