Tesseract “recipes” provide ready-to-go templates that simplify the writing of API files and automate the generation of auto-diff endpoints using powerful third party frameworks such as JAX. This example demonstrates how to use the Tesseract JAX recipe for fitting radial basis functions.
The problem
We shall attempt to fit a univariate function f(x)\to y using a weighted sum of radial basis function using array-like training data, i.e.
We shall use Gaussian radial basis functions:
Initialisation
Simply activate a virtual environment in which tesseract-core
is installed and navigate to a new empty directory, then run
tesseract init --recipe jax --name rbf_fitting
You should now see the following pre-populated files in your working directory:
- tesseract_api.py
- tesseract_config.yaml
- tesseract_requirements.txt
Input and Output Schema
If you are happy to use CPU jax on your local machine, all that is required is editing tesseract_api.py. We want to train the weights (and perhaps length-scale) of our radial basis functions for a given set of centers, targets and values.
from tesseract_core.runtime import Array
class InputSchema(BaseModel):
x_centers: Array[(None,), Float32]
weights: Differentiable[Array[(None,), Float32]]
scale: Differentiable[Float32]
x_target: Array[(None,), Float32]]
y_target: Array[(None,), Float32]
Note that the shape annotation (None,)
here simply specifies one-dimensional arrays of arbitrary length. Descriptions and model validators can easily be added to the schema (as in the tesseract_api.py file within the zip file attached below).
For output we can provide both the predicted y values and a loss function (mean squared error):
class OutputSchema(BaseModel):
y_pred: Differentiable[Array(None,), Float32]]
mse: Differentiable[Float32]
Core functionality
We need to implement the RBF approximation for a single target:
import jax.numpy as jnp
from jax.scipy.stats.norm import pdf as rbf
def rbf_approx(x, c, w, s):
return jnp.sum(w * rbf(x - c, scale=s))
And then extend this to work for an array of targets using a vectorizing map:
def rbf_approx_vmapped(x, c, w, s):
return jax.vmap(rbf_approx, in_axes=(0, None, None, None))(x, c, w, s)
The loss function is simply defined as:
def mse(y_target, y_pred):
return jnp.mean((y_target - y_pred) ** 2)
Putting it all together—The apply_jit
function
In the JAX recipe, instead of implementing the apply
function directly we prescribe it in terms of a “jittable” function that takes a single Python dict
conforming to InputSchema
as an input and returns a single Python dict
conforming to OutputSchema
(the reason for this is because Pydantic models complain when JAX uses tracer objects). A placeholder for apply_jit
is already in the template, this needs to be replaced!
@jit
def apply_jit(inputs: dict) -> dict:
ordered_keys = ["x_target", "x_centers", "weights", "scale"]
y_pred = rbf_approx_vmapped(*(inputs[key] for key in ordered_keys))
return {"y_pred": y_pred, "mse": mse(inputs["y_target"], y_pred)}
Auto-diff Endpoints
The JAX template already provides these for you, your Tesseract is ready to build!
However, just to be sure the Tesseract was implemented correctly before building you can copy across the example*inputs.json
files and test the endpoints we’ll use below locally with the tesseract-runtime
command:
$ pip install -r tesseract_requirements.txt
$ tesseract-runtime apply @example_inputs.json
{"y_pred":{"object_type":"array","shape":[50],"dtype":"float32","data":{"buffer":[4.1248459815979,2.457914352416992,-0.5994710922241211,-4.728244304656982,-9.438132286071777,-14.06363296508789,-17.672792434692383,-19.21015739440918,-18.01862907409668,-14.362250328063965,-9.364424705505371,-4.313163757324219,-0.008022308349609375,3.224073886871338,5.0479960441589355,4.952749729156494,2.7478413581848145,-0.8284816741943359,-4.185606956481934,-5.731773376464844,-4.820869445800781,-2.1094820499420166,0.8307108879089355,2.3777546882629395,1.7485218048095703,-0.592017650604248,-3.2679872512817383,-4.829357147216797,-4.522092342376709,-2.5778591632843018,0.04971001297235489,2.3230526447296143,3.759436845779419,4.6376423835754395,5.628350734710693,7.171075820922852,9.073172569274902,10.524554252624512,10.403829574584961,7.7067742347717285,2.034390926361084,-5.930616855621338,-14.228880882263184,-20.137781143188477,-21.48729705810547,-18.024473190307617,-11.644132614135742,-5.149086952209473,-0.5783791542053223,1.5724661350250244],"encoding":"json"}},"mse":{"object_type":"array","shape":[],"dtype":"float32","data":{"buffer":143.03921508789062,"encoding":"json"}}}
$ tesseract-runtime vector-jacobian-product @example_vjp_inputs.json
{"weights":{"object_type":"array","shape":[20],"dtype":"float32","data":{"buffer":[-9.413665771484375,-26.18354034423828,-42.3206787109375,-43.22483444213867,-28.82162857055664,-15.092391967773438,-14.821725845336914,-21.93976593017578,-22.008630752563477,-15.109295845031738,-8.894996643066406,-2.4718704223632812,4.410447120666504,8.857855796813965,10.989492416381836,6.1818318367004395,-7.348637580871582,-14.533407211303711,-6.561798095703125,1.6526353359222412],"encoding":"json"}}}
$ tesseract-runtime jacobian @example_jac_inputs.json
{"mse":{"weights":{"object_type":"array","shape":[20],"dtype":"float32","data":{"buffer":[-9.413665771484375,-26.18354034423828,-42.3206787109375,-43.22483444213867,-28.82162857055664,-15.092391967773438,-14.821725845336914,-21.93976593017578,-22.008630752563477,-15.109295845031738,-8.894996643066406,-2.4718704223632812,4.410447120666504,8.857855796813965,10.989492416381836,6.1818318367004395,-7.348637580871582,-14.533407211303711,-6.561798095703125,1.6526353359222412],"encoding":"json"}}}}
Interacting through the Python API
It is now time to build your Tesseract (tesseract build .
), if you copy the pre-written [example files]((#downloadfile) optimization_routine.py
, plotting.py
and their requirements all_requirements.txt
to your working directory, they can be run by:
$ pip install -r opt_requirements.txt
$ python3 optimization_routine.py
Launching Tesseract
Starting ADAM optimization process for 100 iterations.
Completed in 0.78 seconds.
Starting L-BFGS-B optimization process for 100 iterations.
Completed in 0.55 seconds.
Starting Least Squares optimization process with Levenberg-Marquadt algorithm.
Completed in 0.13 seconds, with 5 calls to apply and 2 calls to jacobian.
All optimizations completed!
Plotting the results...
The resulting plots of the initial/final approximation and the ADAM and L-BFGS-B loss histories should appear in three separate windows:
Note that L-BFGS-B converges faster than the stochastic gradient descent method ADAM and runs in less time (for the same number of iterations). This is because we have a relatively small amount of training data (50 points) that we can train on simultaneously, SGD shines when one cannot do this due to the size of training data. Furthermore, it is not all surprising that using a least squares approach with the Levenberg-Marquadt algorithm converges essentially instantaneously (2 jacobian calls) as the RBF is a linear least squares problem. Note however, that the runtime per iteration is a lot larger due to the requirement to evaluate the entire Jacobian of residuals with respect to weights.
We invite you to take a look through the example file(s) to follow through how this is achieved. A summary of the key steps is:
- Set up initial conditions for model and training data in a
dict
. - Specify the inputs we want to optimize (just
weights
in the example). - Launch the Tesseract using the Python SDK context manager.
- Perform optimization routines
- Plot results.
We shall summarise step 5 here for convenient reference,
import optax
# Import the Python SDK
from tesseract import Tesseract
max_iterations = 100
# Launch Tesseract
with Tesseract.from_image(image="rbf_fitting") as tess:
# Compute gradients using VJP
def grad(inputs):
vjp = tess.vector_jacobian_product(inputs, diff_inputs, ["mse"], {"mse": 1.0})
return vjp[diff_inputs]
# ADAM optimization
optimizer = optax.adam(learning_rate=0.5)
opt_state = optimizer.init(jnp.array(inputs["weights"]))
for n_iteration in range(max_iterations):
updates, opt_state = optimizer.update(grad(inputs), opt_state, inputs["weights"])
inputs["weights"] = optax.apply_updates(inputs["weights"], updates)
# L-BFGS-B optimization
def apply_and_grad_wrapped(weights):
inputs["weights"] = weights
return tess.apply(inputs)["mse"], grad(inputs)
lbfgsb_weights = minimize(
apply_and_grad,
weights_0,
method="L-BFGS-B",
jac=True,
options={"maxiter": max_iterations},
).x
# scipy uses half sum of square residuals, convert to MSE
normalize = jnp.sqrt(0.5 * len(y_target))
def y_pred_wrapped(weights):
inputs["weights"] = weights
return (tess.apply(inputs)["y_pred"] - y_target) / normalize
def jac_wrapped(weights):
inputs["weights"] = weights
return (
tess.jacobian(inputs, ["weights"], ["y_pred"])["y_pred"]["weights"]
/ normalize
)
ls_weights = least_squares(
y_pred_wrapped, weights_0, method="lm", jac=jac_wrapped, x_scale="jac"
).x
Note that as
mse
is a scalar variable we could just as easily access its gradient with respect to weights from the jacobian (i.e. tess.jacobian(inputs, ["weights"], ["mse"])["mse"]["weights"]
). However, we demonstrate using the VJP above as it is more efficient when calculating gradients of arbitrary scalar loss functions directly from arrays of residuals output by the Tesseract.
Extensibility
The workflow outlined above is fully customisable and individual “components” can easily be swapped in and out such as:
- The radial basis function
rbf
- The model itself (e.g. we could use a JAX-implemented neural network)
- The loss function
- The “ground truth” function we are trying to approximate
- The training data
- The optimizer
Example Files
rbf_fitting.zip (162.5 KB)