Overview
This example explains how to implement a PyTorch-based Tesseract for the Rosenbrock function and use it in an optimization loop to find the global minima starting from a random initial guess.
This Tesseract provides functionalities for computing both function’s value and its Jacobian matrix.
The whole Tesseract can be downloaded here: pytorch_optimization.zip (11.6 KB)
Required Dependencies
The following additional dependencies are required to build and run this Tesseract:
- NumPy (
numpy==1.26.4
) - PyTorch (
torch==2.6.0
) - Matplotlib (
matplotlib==3.10.0
)
These can be stored in a tesseract_requirements.txt
file.
tesseract_config.yaml
This is how this file is supposed to be structured:
name: "pytorch_optimization"
version: "1.0.0"
build_config:
# Platform to build the container for. In general, images can only be executed
# on the platform they were built for. For apple silicon use "linux/arm64"
# target_platform: "native"
Code Breakdown
Importing Dependencies
import torch
from pydantic import BaseModel, Field
from tesseract_core.runtime import Differentiable, Float32
torch
: Provides tensor operations and automatic differentiation.pydantic
: Used for input/output validation schema.tesseract_core.runtime
: Defines custom types likeDifferentiable
andFloat32
for schema validation.
Implementing the Log-Rosenbrock Function
def log_rosenbrock(x: float, y: float, a: float = 1.0, b: float = 100.0):
"""The log-Rosenbrock function."""
rosenbrock = (a - x) ** 2 + b * (y - x**2) ** 2
rosenbrock = torch.as_tensor(rosenbrock)
return torch.log(rosenbrock + 1e-5)
- Computes the classic Rosenbrock function, which has a global minimum at
(x, y) = (a, a^2)
. - Converts the computed value into a PyTorch tensor.
- Applies a logarithm to stabilize numerical calculations, adding
1e-5
to preventlog(0)
errors.
Defining Input and Output Schemas
class InputSchema(BaseModel):
x: Differentiable[Float32] = Field(description="X-value of inputs")
y: Differentiable[Float32] = Field(description="Y-value of inputs")
a: Float32
b: Float32
- Defines an
InputSchema
that validates user inputs. x
andy
are differentiable floating-point values.a
andb
are parameters that control the function shape.
class OutputSchema(BaseModel):
loss: Differentiable[Float32] = Field(description="Rosenbrock loss function value")
- Defines an
OutputSchema
to wrap the computed loss value.
Computing the Function Value
def apply(inputs: InputSchema) -> OutputSchema:
loss = log_rosenbrock(
inputs.x,
inputs.y,
inputs.a,
inputs.b,
)
return OutputSchema(loss=loss)
- Takes
InputSchema
as input and computes the loss function via Rosenbrock function. - Returns the result wrapped in
OutputSchema
.
Computing the Jacobian Matrix
def jacobian(
inputs: InputSchema,
jac_inputs: set[str],
jac_outputs: set[str],
):
inputs.x = torch.as_tensor(inputs.x)
inputs.y = torch.as_tensor(inputs.y)
inputs.a = torch.as_tensor(inputs.a)
inputs.b = torch.as_tensor(inputs.b)
for key in jac_inputs:
setattr(
inputs,
key,
torch.nn.Parameter(getattr(inputs, key)),
)
jac_result = {dy: {} for dy in jac_outputs}
with torch.enable_grad():
output = log_rosenbrock(inputs.x, inputs.y, inputs.a, inputs.b)
for dx in jac_inputs:
grads = torch.autograd.grad(output, getattr(inputs, dx), retain_graph=True)[0]
for dy in jac_outputs:
jac_result[dy][dx] = grads
return jac_result
- Converts input values to PyTorch tensors.
- Marks the specified input variables as differentiable.
- Computes the Jacobian matrix using
torch.autograd.grad
. - Returns a dictionary mapping output derivatives to input variables.
Building the tesseract
Go in the folder where the Tesseract has been coded and run tesseract build .
Usage Examples
Evaluating the Function
inputs = InputSchema(x=1.0, y=2.0, a=1.0, b=100.0)
result = apply(inputs)
print(result.loss)
Computing the Jacobian
jac_result = jacobian(inputs, jac_inputs={"x", "y"}, jac_outputs={"y"})
print(jac_result)
Optimization Pipeline with Tesseract
The following code snippet demonstrates how to optimize the Rosenbrock function using PyTorch and Tesseract:
import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch._dynamo
torch._dynamo.config.disable = True
from tesseract_api import log_rosenbrock
from tesseract_core import Tesseract
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--niter", help="The number of iterations", default=300)
parser.add_argument("-l", "--lr", help="The learning rate", default=1e-2)
parser.add_argument("-s", "--seed", help="The random seed", default=100)
parser.add_argument("--jac_inputs", nargs="+", help="The variables to optimize, e.g. 'x y'", default=["x", "y"])
parser.add_argument("--opt", help="The optimizer to use, ['SGD', 'Adam']", default="Adam")
args = vars(parser.parse_args())
torch.manual_seed(int(args["seed"]))
x0 = torch.nn.Parameter(torch.randn(1) * 2)
y0 = torch.nn.Parameter(torch.randn(1) * 2)
jac_inputs = args["jac_inputs"]
a = 1.0
b = 100.0
inputs = {"x": x0.detach().item(), "y": y0.detach().item(), "a": a, "b": b}
assert args["opt"] in ["SGD", "Adam"], "Only Adam or SGD optimizer supported."
opt = getattr(torch.optim, args["opt"])((x0, y0), lr=float(args["lr"]))
losses = []
x = []
y = []
with Tesseract.from_image(image="pytorch_optimization") as pytorch_optimization:
for _ in range(int(args["niter"])):
opt.zero_grad()
output = pytorch_optimization.apply(inputs)
loss = output["loss"]
losses.append(loss)
x.append(inputs["x"])
y.append(inputs["y"])
jacobian_response = pytorch_optimization.jacobian(inputs=inputs, jac_inputs=jac_inputs, jac_outputs=["loss"])
if "x" in jac_inputs:
x0.grad = torch.as_tensor(jacobian_response["loss"]["x"]).float().unsqueeze(0)
if "y" in jac_inputs:
y0.grad = torch.as_tensor(jacobian_response["loss"]["y"]).float().unsqueeze(0)
opt.step()
inputs["x"] = x0.detach().item()
inputs["y"] = y0.detach().item()
X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100), indexing="xy")
plt.imshow(log_rosenbrock(X, Y, a, b), extent=[-3, 3, 3, -3], cmap="Blues_r")
plt.colorbar()
plt.scatter(x, y, c=np.arange(len(x)))
plt.scatter([a], [a**2], marker="*", s=500, color="r")
plt.xlabel("X [-]")
plt.ylabel("Y [-]")
plt.show()
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].plot(losses)
axes[0].set_title("loss")
axes[0].set_xlabel("iteration")
axes[1].plot(x)
axes[1].set_title("X")
axes[1].set_xlabel("iteration")
axes[1].axhline(a, c="r", lw=3, alpha=0.5)
axes[2].plot(y)
axes[2].set_title("Y")
axes[2].set_xlabel("iteration")
axes[2].axhline(a**2, c="r", lw=3, alpha=0.5)
plt.show()
This is the outcome of the optimization process, where the right minimum (red star) is properly reached: