JAX Auto-diff templates: Gaussian radial basis function fitting

Tesseract “recipes” provide ready-to-go templates that simplify the writing of API files and automate the generation of auto-diff endpoints using powerful third party frameworks such as JAX. This example demonstrates how to use the Tesseract JAX recipe for fitting radial basis functions.

The problem

We shall attempt to fit a univariate function f(x)\to y using a weighted sum of radial basis function using array-like training data, i.e.

f(x)\approx \sum_{i=0}^N w_i \cdot\phi(x-c_i).

We shall use Gaussian radial basis functions:

\phi(x) = \frac{\exp(-x^2/2\sigma^2)}{\sqrt{2\pi\sigma^2}}

Initialisation

Simply activate a virtual environment in which tesseract-core is installed and navigate to a new empty directory, then run

tesseract init --recipe jax --name rbf_fitting

You should now see the following pre-populated files in your working directory:

  • tesseract_api.py
  • tesseract_config.yaml
  • tesseract_requirements.txt

Input and Output Schema

If you are happy to use CPU jax on your local machine, all that is required is editing tesseract_api.py. We want to train the weights (and perhaps length-scale) of our radial basis functions for a given set of centers, targets and values.

from tesseract_core.runtime import Array

class InputSchema(BaseModel):
    x_centers: Array[(None,), Float32] 
    weights: Differentiable[Array[(None,), Float32]]
    scale: Differentiable[Float32] 
    x_target: Array[(None,), Float32]]
    y_target: Array[(None,), Float32]

Note that the shape annotation (None,) here simply specifies one-dimensional arrays of arbitrary length. Descriptions and model validators can easily be added to the schema (as in the tesseract_api.py file within the zip file attached below).

For output we can provide both the predicted y values and a loss function (mean squared error):

class OutputSchema(BaseModel):
    y_pred: Differentiable[Array(None,),  Float32]]
    mse: Differentiable[Float32]

Core functionality

We need to implement the RBF approximation for a single target:

import jax.numpy as jnp
from jax.scipy.stats.norm import pdf as rbf

def rbf_approx(x, c, w, s):
    return jnp.sum(w * rbf(x - c, scale=s))

And then extend this to work for an array of targets using a vectorizing map:

def rbf_approx_vmapped(x, c, w, s):
    return jax.vmap(rbf_approx, in_axes=(0, None, None, None))(x, c, w, s)

The loss function is simply defined as:

def mse(y_target, y_pred):
    return jnp.mean((y_target - y_pred) ** 2)

Putting it all together—The apply_jit function

In the JAX recipe, instead of implementing the apply function directly we prescribe it in terms of a “jittable” function that takes a single Python dict conforming to InputSchema as an input and returns a single Python dict conforming to OutputSchema (the reason for this is because Pydantic models complain when JAX uses tracer objects). A placeholder for apply_jit is already in the template, this needs to be replaced!

@jit
def apply_jit(inputs: dict) -> dict:
    ordered_keys = ["x_target", "x_centers", "weights", "scale"]
    y_pred = rbf_approx_vmapped(*(inputs[key] for key in ordered_keys))
    return {"y_pred": y_pred, "mse": mse(inputs["y_target"], y_pred)}

Auto-diff Endpoints

The JAX template already provides these for you, your Tesseract is ready to build! :brick:

However, just to be sure the Tesseract was implemented correctly before building you can copy across the example*inputs.json files and test the endpoints we’ll use below locally with the tesseract-runtime command:

$ pip install -r tesseract_requirements.txt
$ tesseract-runtime apply @example_inputs.json
{"y_pred":{"object_type":"array","shape":[50],"dtype":"float32","data":{"buffer":[4.1248459815979,2.457914352416992,-0.5994710922241211,-4.728244304656982,-9.438132286071777,-14.06363296508789,-17.672792434692383,-19.21015739440918,-18.01862907409668,-14.362250328063965,-9.364424705505371,-4.313163757324219,-0.008022308349609375,3.224073886871338,5.0479960441589355,4.952749729156494,2.7478413581848145,-0.8284816741943359,-4.185606956481934,-5.731773376464844,-4.820869445800781,-2.1094820499420166,0.8307108879089355,2.3777546882629395,1.7485218048095703,-0.592017650604248,-3.2679872512817383,-4.829357147216797,-4.522092342376709,-2.5778591632843018,0.04971001297235489,2.3230526447296143,3.759436845779419,4.6376423835754395,5.628350734710693,7.171075820922852,9.073172569274902,10.524554252624512,10.403829574584961,7.7067742347717285,2.034390926361084,-5.930616855621338,-14.228880882263184,-20.137781143188477,-21.48729705810547,-18.024473190307617,-11.644132614135742,-5.149086952209473,-0.5783791542053223,1.5724661350250244],"encoding":"json"}},"mse":{"object_type":"array","shape":[],"dtype":"float32","data":{"buffer":143.03921508789062,"encoding":"json"}}}
$ tesseract-runtime vector-jacobian-product @example_vjp_inputs.json
{"weights":{"object_type":"array","shape":[20],"dtype":"float32","data":{"buffer":[-9.413665771484375,-26.18354034423828,-42.3206787109375,-43.22483444213867,-28.82162857055664,-15.092391967773438,-14.821725845336914,-21.93976593017578,-22.008630752563477,-15.109295845031738,-8.894996643066406,-2.4718704223632812,4.410447120666504,8.857855796813965,10.989492416381836,6.1818318367004395,-7.348637580871582,-14.533407211303711,-6.561798095703125,1.6526353359222412],"encoding":"json"}}}
$ tesseract-runtime jacobian @example_jac_inputs.json
{"mse":{"weights":{"object_type":"array","shape":[20],"dtype":"float32","data":{"buffer":[-9.413665771484375,-26.18354034423828,-42.3206787109375,-43.22483444213867,-28.82162857055664,-15.092391967773438,-14.821725845336914,-21.93976593017578,-22.008630752563477,-15.109295845031738,-8.894996643066406,-2.4718704223632812,4.410447120666504,8.857855796813965,10.989492416381836,6.1818318367004395,-7.348637580871582,-14.533407211303711,-6.561798095703125,1.6526353359222412],"encoding":"json"}}}}

Interacting through the Python API

It is now time to build your Tesseract (tesseract build .), if you copy the pre-written [example files]((#downloadfile) optimization_routine.py, plotting.py and their requirements all_requirements.txt to your working directory, they can be run by:

$ pip install -r opt_requirements.txt
$ python3 optimization_routine.py
Launching Tesseract
Starting ADAM optimization process for 100 iterations.
Completed in 0.78 seconds.
Starting L-BFGS-B optimization process for 100 iterations.
Completed in 0.55 seconds.
Starting Least Squares optimization process with Levenberg-Marquadt algorithm.
Completed in 0.13 seconds, with 5 calls to apply and 2 calls to jacobian.
All optimizations completed!
Plotting the results...

The resulting plots of the initial/final approximation and the ADAM and L-BFGS-B loss histories should appear in three separate windows:




Note that L-BFGS-B converges faster than the stochastic gradient descent method ADAM and runs in less time (for the same number of iterations). This is because we have a relatively small amount of training data (50 points) that we can train on simultaneously, SGD shines when one cannot do this due to the size of training data. Furthermore, it is not all surprising that using a least squares approach with the Levenberg-Marquadt algorithm converges essentially instantaneously (2 jacobian calls) as the RBF is a linear least squares problem. Note however, that the runtime per iteration is a lot larger due to the requirement to evaluate the entire Jacobian of residuals with respect to weights.

We invite you to take a look through the example file(s) to follow through how this is achieved. A summary of the key steps is:

  1. Set up initial conditions for model and training data in a dict.
  2. Specify the inputs we want to optimize (just weights in the example).
  3. Launch the Tesseract using the Python SDK context manager.
  4. Perform optimization routines
  5. Plot results.

We shall summarise step 5 here for convenient reference,

import optax

# Import the Python SDK
from tesseract import Tesseract

max_iterations = 100

# Launch Tesseract
with Tesseract.from_image(image="rbf_fitting") as tess:
    # Compute gradients using VJP
    def grad(inputs):
        vjp = tess.vector_jacobian_product(inputs, diff_inputs, ["mse"], {"mse": 1.0})
        return vjp[diff_inputs]

    # ADAM optimization
    optimizer = optax.adam(learning_rate=0.5)
    opt_state = optimizer.init(jnp.array(inputs["weights"]))

    for n_iteration in range(max_iterations):
        updates, opt_state = optimizer.update(grad(inputs), opt_state, inputs["weights"])
        inputs["weights"] = optax.apply_updates(inputs["weights"], updates)

    # L-BFGS-B optimization
    def apply_and_grad_wrapped(weights):
        inputs["weights"] = weights
        return tess.apply(inputs)["mse"], grad(inputs)

    lbfgsb_weights = minimize(
        apply_and_grad,
        weights_0,
        method="L-BFGS-B",
        jac=True,
        options={"maxiter": max_iterations},
    ).x

    # scipy uses half sum of square residuals, convert to MSE
    normalize = jnp.sqrt(0.5 * len(y_target))

    def y_pred_wrapped(weights):
        inputs["weights"] = weights
        return (tess.apply(inputs)["y_pred"] - y_target) / normalize

    def jac_wrapped(weights):
        inputs["weights"] = weights
        return (
            tess.jacobian(inputs, ["weights"], ["y_pred"])["y_pred"]["weights"]
            / normalize
        )

    ls_weights = least_squares(
        y_pred_wrapped, weights_0, method="lm", jac=jac_wrapped, x_scale="jac"
    ).x

:information_source: Note that as mse is a scalar variable we could just as easily access its gradient with respect to weights from the jacobian (i.e. tess.jacobian(inputs, ["weights"], ["mse"])["mse"]["weights"]). However, we demonstrate using the VJP above as it is more efficient when calculating gradients of arbitrary scalar loss functions directly from arrays of residuals output by the Tesseract.

Extensibility

The workflow outlined above is fully customisable and individual “components” can easily be swapped in and out such as:

  • The radial basis function rbf
  • The model itself (e.g. we could use a JAX-implemented neural network)
  • The loss function
  • The “ground truth” function we are trying to approximate
  • The training data
  • The optimizer

Example Files

rbf_fitting.zip (162.5 KB)

1 Like