Tesseract-JAX v0.3.0 released

Release v0.3.0

Breaking: vmap over apply_tesseract now requires an explicit vmap_method argument

If you use jax.vmap, jax.jacfwd, or jax.jacrev with apply_tesseract, calls will now raise NotImplementedError unless you specify a vmap_method. This was changed to prevent silent, incorrect or inefficient batching behavior.

To migrate, add vmap_method to your apply_tesseract calls:

# Before (no longer works)
jax.vmap(lambda x: apply_tesseract(tess, x))(batch)

# After — pick one:
apply_tesseract(tess, inputs, vmap_method="sequential")       # safe default, one call per element
apply_tesseract(tess, inputs, vmap_method="expand_dims")       # fast, single batched call (if Tesseract supports it)
apply_tesseract(tess, inputs, vmap_method="broadcast_all")     # single call, broadcasts unbatched args
apply_tesseract(tess, inputs, vmap_method="auto_experimental") # tries to auto-detect best strategy from schema

Start with "sequential" if unsure — it works with any Tesseract. Use "expand_dims" or "auto_experimental" for performance once you’ve verified your Tesseract handles batched inputs. See the full guide for more help.

New: Debug pipelines with sow / save_intermediates

You can now tag and capture intermediate values (and their derivatives) inside multi-Tesseract pipelines:

from tesseract_jax import apply_tesseract, sow, save_intermediates

def pipeline(inputs):
res = apply_tesseract(tess1, inputs)
res = sow(res, "after_tess1")       # tag this intermediate
res = apply_tesseract(tess2, res)
return res["output"].sum()

# Capture intermediates alongside gradients
grads, intermediates = save_intermediates(jax.grad(pipeline))(inputs)
print(intermediates["after_tess1"]["primal"])     # forward value
print(intermediates["after_tess1"]["cotangent"])  # gradient at that point

Works with jax.grad, jax.jvp, jax.vmap, jax.jit, and combinations thereof. See Debugging pipelines for more information.

Also in this release

  • Python scalars and array-likes now accepted as inputs — no more wrapping every float in jnp.array() before calling apply_tesseract (#155)
  • Several bug fixes: dictionary inputs (#127), partial derivatives with JIT (#99), VJP index bug with static args (#159), and better error messages (#136, #144)

What’s Changed

Features

  • Tests for fori and scan (#135)
  • Implement sow / save_intermediates for easier debugging of Tesseract pipelines (#150)
  • Support Python scalars and generic arrays as inputs to apply_tesseract (#155)
  • [breaking] Support “expand_dims”, “broadcast_all"and"auto”`, vmap methods. (#162)

Bug Fixes

  • Tesseracts with dictionary inputs fail (#127)
  • Better error messages for missing endpoints (#136)
  • Error message in fem example notebook (#144)
  • Partial derivatives + JIT causes error (#99)
  • VJP index bug with static args (#159)

Refactor

  • Move helpers to tree_util (#142)
  • Reduce code duplication in test suite (#161)

Full diff: Comparing v0.2.3...v0.3.0 · pasteurlabs/tesseract-jax · GitHub