Release v0.3.0
Breaking: vmap over apply_tesseract now requires an explicit vmap_method argument
If you use jax.vmap, jax.jacfwd, or jax.jacrev with apply_tesseract, calls will now raise NotImplementedError unless you specify a vmap_method. This was changed to prevent silent, incorrect or inefficient batching behavior.
To migrate, add vmap_method to your apply_tesseract calls:
# Before (no longer works)
jax.vmap(lambda x: apply_tesseract(tess, x))(batch)
# After — pick one:
apply_tesseract(tess, inputs, vmap_method="sequential") # safe default, one call per element
apply_tesseract(tess, inputs, vmap_method="expand_dims") # fast, single batched call (if Tesseract supports it)
apply_tesseract(tess, inputs, vmap_method="broadcast_all") # single call, broadcasts unbatched args
apply_tesseract(tess, inputs, vmap_method="auto_experimental") # tries to auto-detect best strategy from schema
Start with "sequential" if unsure — it works with any Tesseract. Use "expand_dims" or "auto_experimental" for performance once you’ve verified your Tesseract handles batched inputs. See the full guide for more help.
New: Debug pipelines with sow / save_intermediates
You can now tag and capture intermediate values (and their derivatives) inside multi-Tesseract pipelines:
from tesseract_jax import apply_tesseract, sow, save_intermediates
def pipeline(inputs):
res = apply_tesseract(tess1, inputs)
res = sow(res, "after_tess1") # tag this intermediate
res = apply_tesseract(tess2, res)
return res["output"].sum()
# Capture intermediates alongside gradients
grads, intermediates = save_intermediates(jax.grad(pipeline))(inputs)
print(intermediates["after_tess1"]["primal"]) # forward value
print(intermediates["after_tess1"]["cotangent"]) # gradient at that point
Works with jax.grad, jax.jvp, jax.vmap, jax.jit, and combinations thereof. See Debugging pipelines for more information.
Also in this release
- Python scalars and array-likes now accepted as inputs — no more wrapping every
floatinjnp.array()before callingapply_tesseract(#155) - Several bug fixes: dictionary inputs (#127), partial derivatives with JIT (#99), VJP index bug with static args (#159), and better error messages (#136, #144)
What’s Changed
Features
- Tests for fori and scan (#135)
- Implement
sow/save_intermediatesfor easier debugging of Tesseract pipelines (#150) - Support Python scalars and generic arrays as inputs to
apply_tesseract(#155) - [breaking] Support “expand_dims”
,“broadcast_all"and"auto”`, vmap methods. (#162)
Bug Fixes
- Tesseracts with dictionary inputs fail (#127)
- Better error messages for missing endpoints (#136)
- Error message in fem example notebook (#144)
- Partial derivatives + JIT causes error (#99)
- VJP index bug with static args (#159)
Refactor
- Move helpers to tree_util (#142)
- Reduce code duplication in test suite (#161)
Full diff: Comparing v0.2.3...v0.3.0 · pasteurlabs/tesseract-jax · GitHub