Sampling the Universe: HMC through JAX and Fortran with Enzyme-differentiated Tesseracts

Hi everyone,

following up on my previous Tesseract + Disco-DJ optimization example, I put together a new demo showing field-level Bayesian inference through a coupled JAX + Fortran simulation pipeline using Tesseract and Enzyme.

The demo combines two Tesseracts:

  1. a Disco-DJ forward-model Tesseract, which evolves Fourier-space initial conditions into a nonlinear 3D matter density field, and
  2. a Fortran FGPA Tesseract, which turns that density field into synthetic Lyman-alpha forest skewers using the Fluctuating Gunn-Peterson Approximation (FGPA).

The FGPA post-processing code is written in Fortran, and its derivatives are obtained automatically using Enzyme. Through tesseract-jax, these derivatives can then be combined with a JAX-based workflow for gradient-based inference.

The notebook considers a Bayesian reconstruction problem: starting only from noisy 1D Lyman-alpha skewer observations, it samples the posterior over

  • the high-dimensional initial Fourier white-noise field \epsilon, and
  • a global FGPA amplitude parameter, \log A.

The inference is done with Hamiltonian Monte Carlo (HMC) using BlackJAX. Before HMC, the notebook runs a staged L-BFGS MAP optimization to find a good initial point.

In this run at resolution 64^3, HMC samples the posterior over the spherically UV-filtered Fourier modes of the initial white-noise field. There are 68,529 active complex representatives, so the parameter space consists of 137,058 real scalar field coordinates, plus the scalar log_A.
The chain recovers the true FGPA parameter within the posterior uncertainty. The posterior samples reproduce the observed skewers well, while the density away from the skewer locations remains much less constrained, as expected.

combined_posterior_samples

Notebook on Github: DiscoDJ-Tesseract-Demo/discodj_tesseract_example/enzyme_fgpa_loga/field_level_fgpa_loga_demo.ipynb at main ยท FloList/DiscoDJ-Tesseract-Demo ยท GitHub

The example is meant to demonstrate

  • Automatic differentiation through Fortran code using Enzyme
  • Field-level Hamiltonian Monte Carlo over a high-dimensional latent field together with a global model parameter
  • Coupling multiple simulation components into a differentiable inference pipeline using multiple Tesseracts

Happy to hear any thoughts or suggestions :slight_smile:

Author: Florian List

1 Like