Adding a new language#
In this example we will go through the process needed to add a new language to gotranx
. We will do this by adding support for jax
.
Note
If you end up adding support for your own language, we would be really glad if you could submit a PR so that we can include it as part of gotranx
Adding some tests#
This first thing you might want to do is to add a test to check that your implementation is working correctly. For example you could try to simply load and solve the Lorentz attraction using the new Jax code generator. That code could look as follows
from pathlib import Path
import time
import numpy as np
import matplotlib.pyplot as plt
import jax
import gotranx
from codegen import JaxCodeGenerator
# Define an ODE
ode_fname = Path("lorentz.ode")
ode_fname.write_text("""
parameters(sigma=12.0, rho=21.0, beta=2.4)
states(x=1.0, y=2.0,z=3.05)
dx_dt = sigma * (y - x)
dy_dt = x * (rho - z) - y # m/s
dz_dt = x * y - beta * z
""")
# Load ode
ode = gotranx.load_ode(ode_fname)
# Generate code (note that we add the jax=True flag)
codegen = JaxCodeGenerator(ode)
comp = [
codegen.imports(),
codegen.parameter_index(),
codegen.state_index(),
codegen.monitor_index(),
codegen.missing_index(),
codegen.initial_parameter_values(),
codegen.initial_state_values(),
codegen.rhs(),
codegen.monitor_values(),
codegen.scheme(gotranx.schemes.get_scheme("forward_explicit_euler"))
]
code = codegen._format("\n".join(comp))
# Execute code
model = {}
exec(code, model)
# Initial values
y = model["init_state_values"](x=1.0, y=2.0, z=3.05)
p = model["init_parameter_values"](sigma=12.0, rho=21.0, beta=2.4)
dt = 0.005
T = 200
t = np.arange(0, T, dt)
# We will integrate the system using the jax.lax.scan function
# so we need to write a function that will be called at each time step
@jax.jit
def fgrl(carry, y):
t, y = carry
y = model["forward_explicit_euler"](y, t, dt, p)
return (t + dt, y), y
# Integrate the system and measure the elapsed time
t0 = time.perf_counter()
_, state = jax.lax.scan(fgrl, (0.0, y), t)
t1 = time.perf_counter()
print(f"Elapsed time: {t1 - t0} s")
# Finally, plot the results
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.plot(state[:, 0], state[:, 1], state[:, 2], lw=0.5)
ax.set_xlabel("X Axis")
ax.set_ylabel("Y Axis")
ax.set_zlabel("Z Axis")
ax.set_title("Lorenz Attractor")
plt.show()
2024-11-05 20:05:34 [info ] Load ode lorentz.ode
2024-11-05 20:05:35 [info ] Num states 3
2024-11-05 20:05:35 [info ] Num parameters 3
Elapsed time: 0.05233125300000552 s
Here we use gotranx
to load the code, and then we have implemented our own code-generator for jax
in the module called codegen
.
Note
You could also fork the repository and add your code generator to the gotranx
project. Have a look at the project now to see how the jax code generator has been added.
Adding the code generator#
The file codegen.py
contains the following lines
import sympy
from gotranx.codegen.python import PythonCodeGenerator, GotranPythonCodePrinter
import template
class JaxPrinter(GotranPythonCodePrinter):
def _print_Assignment(self, expr):
sym, value = expr.lhs, expr.rhs
if isinstance(sym, sympy.tensor.indexed.Indexed):
if sym.base.name == "values":
index = self._print(sym.indices[0])
return f"_{sym.base.name}_{index} = {self._print(value)}"
return super()._print_Assignment(expr)
class JaxCodeGenerator(PythonCodeGenerator):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._printer = JaxPrinter()
def imports(self) -> str:
return "\n".join(
[
"import jax",
"import jax.numpy as numpy",
'jax.config.update("jax_enable_x64", True)',
]
)
@property
def template(self):
return template
There are two classes in this module; the codeprinter JaxPrinter
and the code generator JaxCodeGenerator
What you will see is that the code we want to generate is very similar to the code that is already generated by the python code generator. Therefore we will subclass the existing classes from gotranx
.
Code printers#
The class JaxPrinter
inherits from the GotranPythonCodePrinter
which in turn inherits from sympy.printing.pycode.PythonCodePrinter
. This is essentially where all the magic happens. Our ODE is just a collection of sympy
expressions, which we can print with the code printers provided by SymPy
. Please check out the official documentation to see the existing code printers and how these can be modified to your need.
The only modification we will do to the code printer is to make sure to modify the assignment of values. In the original python code we first allocate a numpy array with a given shape and fill that array with values as we go through the function, e.g
def rhs(t, states, parameters):
values = numpy.zeros_like(states)
values[0] = 42
values[1] = 43
...
return values
However, this is not possible in jax since jax arrays are immutable. What we would like to do instead (note that there are many ways to handle this) is to do as follows
import jax
import jax.numpy as numpy
@jax.jit
def rhs(t, states, parameters):
values = numpy.zeros_like(states)
_values_0 = 42
_values_1 = 43
...
return numpy.array([
_values_0,
_values_1,
...
])
For now, do not care about the imports, decorator nor the return statement. The method _print_Assignment
is called whenever sympy is trying to print an assignment
class JaxPrinter(GotranPythonCodePrinter):
def _print_Assignment(self, expr):
sym, value = expr.lhs, expr.rhs
if isinstance(sym, sympy.tensor.indexed.Indexed):
if sym.base.name == "values":
index = self._print(sym.indices[0])
return f"_{sym.base.name}_{index} = {self._print(value)}"
return super()._print_Assignment(expr)
Here we check if the left hand side of the assignment is indexed and has the name values
. If it does, we change it from the default printing which is
values[index] = rhs
to the current way
_values_index = rhs
Code generator#
The class JaxCodeGenerator
is a subclass of the abstract base class CodeGenerator
which implements functionality to generate all the methods we need. For examples we need one method to generate code for the right hand side (i.e def rhs(t, states, parameters)
), and another method to generate code for initialing the default parameter values. For both of these methods we can use the same printer to convert the sympy expressions into code, but we need other types of logic to tell what should be in the different methods.
We want the class JaxCodeGenerator
to change three things compared to PythonCodeGenerator
. First we want to set the codeprinter to be the JaxCodePrinter
that we just discussed. Next we want to add the imports that we need at the top (we also enable float64 since arrays are by default float32 in Jax)
Finally we set the template to be our custom template which we will implement in a second module called template.py
(in the next section). This module implements the template for the different method that we want to generate.
Adding a new template#
Now we will implement the template in template.py
, which need to follow the Template protocol
. In other words we need to implement all the methods are are part of the Template
class. Our template will be pretty close to the python template
For state_index
, parameter_index
, monitor_index
and missing_index
we will just use the same methods as been implemented in the python template so we will just import those directly. We will also import the acc
function that is used in the init_state_values
from __future__ import annotations
from textwrap import dedent, indent
import functools
from structlog import get_logger
from gotranx.templates.python import acc, state_index, parameter_index, monitor_index, missing_index
__all__ = [
"init_state_values",
"init_parameter_values",
"method",
"parameter_index",
"state_index",
"monitor_index",
"missing_index",
]
logger = get_logger()
Next we will implement a new version of the init_state_values
. We will just slightly modify the function from the python template
.
def init_state_values(name, state_names, state_values, code):
logger.debug(f"Generating init_state_values with {len(state_values)} values")
values_comment = indent(
"#" + functools.reduce(acc, [f"{n}={v}" for n, v in zip(state_names, state_values)]),
" ",
)
values = ", ".join(map(str, state_values))
return dedent(
f'''
@jax.jit
def init_state_values(**values):
"""Initialize state values
"""
{values_comment}
{name} = numpy.array([{values}], dtype=numpy.float64)
for key, value in values.items():
{name} = {name}.at[state_index(key)].set(value)
return {name}
''',
)
Here we have added the @jax.jit
decorator to the function and we have used the jax notation for assigning values at a specific index.
For init_parameter_values
we do the same
def init_parameter_values(name, parameter_names, parameter_values, code):
logger.debug(f"Generating init_parameter_values with {len(parameter_values)} values")
values_comment = indent(
"#"
+ functools.reduce(acc, [f"{n}={v}" for n, v in zip(parameter_names, parameter_values)]),
" ",
)
values = ", ".join(map(str, parameter_values))
return dedent(
f'''
@jax.jit
def init_parameter_values(**values):
"""Initialize parameter values
"""
{values_comment}
{name} = numpy.array([{values}], dtype=numpy.float64)
for key, value in values.items():
{name} = {name}.at[parameter_index(key)].set(value)
return {name}
''',
)
Finally for the method
function we just need to make sure to create the correct return array, which now should be an array containing the values _value_0
, _value_1
, and so on.
def method(
name,
args,
states,
parameters,
values,
num_return_values: int,
missing_variables: str = "",
**kwargs,
):
logger.debug(f"Generating method '{name}', with {num_return_values} return values.")
if len(kwargs) > 0:
logger.debug(f"Unused kwargs: {kwargs}")
return_name_lst = (
["numpy.array(["] + [f"_values_{i}, " for i in range(num_return_values)] + ["])"]
)
indent_return = indent(f"return {''.join(return_name_lst)}", " ")
indent_missing_variables = indent(missing_variables, " ")
indent_states = indent(states, " ")
indent_parameters = indent(parameters, " ")
indent_values = indent(values, " ")
return dedent(
f"""
@jax.jit
def {name}({args}):
# Assign states
{indent_states}
# Assign parameters
{indent_parameters}
{indent_missing_variables}
# Assign expressions
{indent_values}
{indent_return}
""",
)
Summary#
# codegen.py
import sympy
from gotranx.codegen.python import PythonCodeGenerator, GotranPythonCodePrinter
import template
class JaxPrinter(GotranPythonCodePrinter):
def _print_Assignment(self, expr):
sym, value = expr.lhs, expr.rhs
if isinstance(sym, sympy.tensor.indexed.Indexed):
if sym.base.name == "values":
index = self._print(sym.indices[0])
return f"_{sym.base.name}_{index} = {self._print(value)}"
return super()._print_Assignment(expr)
class JaxCodeGenerator(PythonCodeGenerator):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._printer = JaxPrinter()
def imports(self) -> str:
return "\n".join(
[
"import jax",
"import jax.numpy as numpy",
'jax.config.update("jax_enable_x64", True)',
]
)
@property
def template(self):
return template
# template.py
from __future__ import annotations
from textwrap import dedent, indent
import functools
from structlog import get_logger
from gotranx.templates.python import acc, state_index, parameter_index, monitor_index, missing_index
__all__ = [
"init_state_values",
"init_parameter_values",
"method",
"parameter_index",
"state_index",
"monitor_index",
"missing_index",
]
logger = get_logger()
def init_state_values(name, state_names, state_values, code):
logger.debug(f"Generating init_state_values with {len(state_values)} values")
values_comment = indent(
"#" + functools.reduce(acc, [f"{n}={v}" for n, v in zip(state_names, state_values)]),
" ",
)
values = ", ".join(map(str, state_values))
return dedent(
f'''
@jax.jit
def init_state_values(**values):
"""Initialize state values
"""
{values_comment}
{name} = numpy.array([{values}], dtype=numpy.float64)
for key, value in values.items():
{name} = {name}.at[state_index(key)].set(value)
return {name}
''',
)
def init_parameter_values(name, parameter_names, parameter_values, code):
logger.debug(f"Generating init_parameter_values with {len(parameter_values)} values")
values_comment = indent(
"#"
+ functools.reduce(acc, [f"{n}={v}" for n, v in zip(parameter_names, parameter_values)]),
" ",
)
values = ", ".join(map(str, parameter_values))
return dedent(
f'''
@jax.jit
def init_parameter_values(**values):
"""Initialize parameter values
"""
{values_comment}
{name} = numpy.array([{values}], dtype=numpy.float64)
for key, value in values.items():
{name} = {name}.at[parameter_index(key)].set(value)
return {name}
''',
)
def method(
name,
args,
states,
parameters,
values,
num_return_values: int,
missing_variables: str = "",
**kwargs,
):
logger.debug(f"Generating method '{name}', with {num_return_values} return values.")
if len(kwargs) > 0:
logger.debug(f"Unused kwargs: {kwargs}")
return_name_lst = (
["numpy.array(["] + [f"_values_{i}, " for i in range(num_return_values)] + ["])"]
)
indent_return = indent(f"return {''.join(return_name_lst)}", " ")
indent_missing_variables = indent(missing_variables, " ")
indent_states = indent(states, " ")
indent_parameters = indent(parameters, " ")
indent_values = indent(values, " ")
return dedent(
f"""
@jax.jit
def {name}({args}):
# Assign states
{indent_states}
# Assign parameters
{indent_parameters}
{indent_missing_variables}
# Assign expressions
{indent_values}
{indent_return}
""",
)
Finally let us print out the generated code
print(code)
import jax
import jax.numpy as numpy
jax.config.update("jax_enable_x64", True)
parameter = {"beta": 0, "rho": 1, "sigma": 2}
def parameter_index(name: str) -> int:
"""Return the index of the parameter with the given name
Arguments
---------
name : str
The name of the parameter
Returns
-------
int
The index of the parameter
Raises
------
KeyError
If the name is not a valid parameter
"""
return parameter[name]
state = {"x": 0, "y": 1, "z": 2}
def state_index(name: str) -> int:
"""Return the index of the state with the given name
Arguments
---------
name : str
The name of the state
Returns
-------
int
The index of the state
Raises
------
KeyError
If the name is not a valid state
"""
return state[name]
monitor = {"dx_dt": 0, "dy_dt": 1, "dz_dt": 2}
def monitor_index(name: str) -> int:
"""Return the index of the monitor with the given name
Arguments
---------
name : str
The name of the monitor
Returns
-------
int
The index of the monitor
Raises
------
KeyError
If the name is not a valid monitor
"""
return monitor[name]
@jax.jit
def init_parameter_values(**values):
"""Initialize parameter values"""
# beta=2.4, rho=21.0, sigma=12.0
parameters = numpy.array([2.4, 21.0, 12.0], dtype=numpy.float64)
for key, value in values.items():
parameters = parameters.at[parameter_index(key)].set(value)
return parameters
@jax.jit
def init_state_values(**values):
"""Initialize state values"""
# x=1.0, y=2.0, z=3.05
states = numpy.array([1.0, 2.0, 3.05], dtype=numpy.float64)
for key, value in values.items():
states = states.at[state_index(key)].set(value)
return states
@jax.jit
def rhs(t, states, parameters):
# Assign states
x = states[0]
y = states[1]
z = states[2]
# Assign parameters
beta = parameters[0]
rho = parameters[1]
sigma = parameters[2]
# Assign expressions
dx_dt = sigma * (-x + y)
_values_0 = dx_dt
dy_dt = x * (rho - z) - y
_values_1 = dy_dt
dz_dt = -beta * z + x * y
_values_2 = dz_dt
return numpy.array(
[
_values_0,
_values_1,
_values_2,
]
)
@jax.jit
def monitor_values(t, states, parameters):
# Assign states
x = states[0]
y = states[1]
z = states[2]
# Assign parameters
beta = parameters[0]
rho = parameters[1]
sigma = parameters[2]
# Assign expressions
dx_dt = sigma * (-x + y)
_values_0 = dx_dt
dy_dt = x * (rho - z) - y
_values_1 = dy_dt
dz_dt = -beta * z + x * y
_values_2 = dz_dt
return numpy.array(
[
_values_0,
_values_1,
_values_2,
]
)
@jax.jit
def forward_explicit_euler(states, t, dt, parameters):
# Assign states
x = states[0]
y = states[1]
z = states[2]
# Assign parameters
beta = parameters[0]
rho = parameters[1]
sigma = parameters[2]
# Assign expressions
dx_dt = sigma * (-x + y)
_values_0 = dt * dx_dt + x
dy_dt = x * (rho - z) - y
_values_1 = dt * dy_dt + y
dz_dt = -beta * z + x * y
_values_2 = dt * dz_dt + z
return numpy.array(
[
_values_0,
_values_1,
_values_2,
]
)