Adding a new language#

In this example we will go through the process needed to add a new language to gotranx. We will do this by adding support for jax.

Note

If you end up adding support for your own language, we would be really glad if you could submit a PR so that we can include it as part of gotranx

Adding some tests#

This first thing you might want to do is to add a test to check that your implementation is working correctly. For example you could try to simply load and solve the Lorentz attraction using the new Jax code generator. That code could look as follows

from pathlib import Path
import time
import numpy as np
import matplotlib.pyplot as plt
import jax
import gotranx

from codegen import JaxCodeGenerator

# Define an ODE
ode_fname = Path("lorentz.ode")
ode_fname.write_text("""
parameters(sigma=12.0, rho=21.0, beta=2.4)
states(x=1.0, y=2.0,z=3.05)
dx_dt = sigma * (y - x)
dy_dt = x * (rho - z) - y  # m/s
dz_dt = x * y - beta * z
""")

# Load ode
ode = gotranx.load_ode(ode_fname)

# Generate code (note that we add the jax=True flag)
codegen = JaxCodeGenerator(ode)

comp = [
    codegen.imports(),
    codegen.parameter_index(),
    codegen.state_index(),
    codegen.monitor_index(),
    codegen.missing_index(),
    codegen.initial_parameter_values(),
    codegen.initial_state_values(),
    codegen.rhs(),
    codegen.monitor_values(),
    codegen.scheme(gotranx.schemes.get_scheme("forward_explicit_euler"))
]

code = codegen._format("\n".join(comp))

# Execute code
model = {}
exec(code, model)

# Initial values
y = model["init_state_values"](x=1.0, y=2.0, z=3.05)
p = model["init_parameter_values"](sigma=12.0, rho=21.0, beta=2.4)
dt = 0.005
T = 200
t = np.arange(0, T, dt)

# We will integrate the system using the jax.lax.scan function
# so we need to write a function that will be called at each time step
@jax.jit
def fgrl(carry, y):
    t, y = carry
    y = model["forward_explicit_euler"](y, t, dt, p)
    return (t + dt, y), y


# Integrate the system and measure the elapsed time
t0 = time.perf_counter()
_, state = jax.lax.scan(fgrl, (0.0, y), t)
t1 = time.perf_counter()
print(f"Elapsed time: {t1 - t0} s")

# Finally, plot the results
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.plot(state[:, 0], state[:, 1], state[:, 2], lw=0.5)
ax.set_xlabel("X Axis")
ax.set_ylabel("Y Axis")
ax.set_zlabel("Z Axis")
ax.set_title("Lorenz Attractor")
plt.show()
2024-11-05 20:05:34 [info     ] Load ode lorentz.ode          
2024-11-05 20:05:35 [info     ] Num states 3                  
2024-11-05 20:05:35 [info     ] Num parameters 3              
Elapsed time: 0.05233125300000552 s
../../_images/a4eb48f640a500a8f278ec2f46dbb05f7911e8d88acf59a233e4ff7de12e9908.png

Here we use gotranx to load the code, and then we have implemented our own code-generator for jax in the module called codegen.

Note

You could also fork the repository and add your code generator to the gotranx project. Have a look at the project now to see how the jax code generator has been added.

Adding the code generator#

The file codegen.py contains the following lines

import sympy

from gotranx.codegen.python import PythonCodeGenerator, GotranPythonCodePrinter

import template

class JaxPrinter(GotranPythonCodePrinter):
    def _print_Assignment(self, expr):
        sym, value = expr.lhs, expr.rhs
        if isinstance(sym, sympy.tensor.indexed.Indexed):
            if sym.base.name == "values":
                index = self._print(sym.indices[0])
                return f"_{sym.base.name}_{index} = {self._print(value)}"

        return super()._print_Assignment(expr)


class JaxCodeGenerator(PythonCodeGenerator):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self._printer = JaxPrinter()

    def imports(self) -> str:
        return "\n".join(
            [
                "import jax",
                "import jax.numpy as numpy",
                'jax.config.update("jax_enable_x64", True)',
            ]
        )

    @property
    def template(self):
        return template

There are two classes in this module; the codeprinter JaxPrinter and the code generator JaxCodeGenerator

What you will see is that the code we want to generate is very similar to the code that is already generated by the python code generator. Therefore we will subclass the existing classes from gotranx.

Code printers#

The class JaxPrinter inherits from the GotranPythonCodePrinter which in turn inherits from sympy.printing.pycode.PythonCodePrinter. This is essentially where all the magic happens. Our ODE is just a collection of sympy expressions, which we can print with the code printers provided by SymPy. Please check out the official documentation to see the existing code printers and how these can be modified to your need.

The only modification we will do to the code printer is to make sure to modify the assignment of values. In the original python code we first allocate a numpy array with a given shape and fill that array with values as we go through the function, e.g

def rhs(t, states, parameters):
    values = numpy.zeros_like(states)
    values[0] = 42
    values[1] = 43
    ...
    return values

However, this is not possible in jax since jax arrays are immutable. What we would like to do instead (note that there are many ways to handle this) is to do as follows

import jax
import jax.numpy as numpy


@jax.jit
def rhs(t, states, parameters):
    values = numpy.zeros_like(states)
    _values_0 = 42
    _values_1 = 43
    ...
    return numpy.array([
        _values_0,
        _values_1,
        ...
    ])

For now, do not care about the imports, decorator nor the return statement. The method _print_Assignment is called whenever sympy is trying to print an assignment

class JaxPrinter(GotranPythonCodePrinter):
    def _print_Assignment(self, expr):
        sym, value = expr.lhs, expr.rhs
        if isinstance(sym, sympy.tensor.indexed.Indexed):
            if sym.base.name == "values":
                index = self._print(sym.indices[0])
                return f"_{sym.base.name}_{index} = {self._print(value)}"

        return super()._print_Assignment(expr)

Here we check if the left hand side of the assignment is indexed and has the name values. If it does, we change it from the default printing which is

values[index] = rhs

to the current way

_values_index = rhs

Code generator#

The class JaxCodeGenerator is a subclass of the abstract base class CodeGenerator which implements functionality to generate all the methods we need. For examples we need one method to generate code for the right hand side (i.e def rhs(t, states, parameters)), and another method to generate code for initialing the default parameter values. For both of these methods we can use the same printer to convert the sympy expressions into code, but we need other types of logic to tell what should be in the different methods.

We want the class JaxCodeGenerator to change three things compared to PythonCodeGenerator. First we want to set the codeprinter to be the JaxCodePrinter that we just discussed. Next we want to add the imports that we need at the top (we also enable float64 since arrays are by default float32 in Jax)

Finally we set the template to be our custom template which we will implement in a second module called template.py (in the next section). This module implements the template for the different method that we want to generate.

Adding a new template#

Now we will implement the template in template.py, which need to follow the Template protocol. In other words we need to implement all the methods are are part of the Template class. Our template will be pretty close to the python template

For state_index, parameter_index, monitor_index and missing_index we will just use the same methods as been implemented in the python template so we will just import those directly. We will also import the acc function that is used in the init_state_values

from __future__ import annotations
from textwrap import dedent, indent
import functools
from structlog import get_logger

from gotranx.templates.python import acc, state_index, parameter_index, monitor_index, missing_index


__all__ = [
    "init_state_values",
    "init_parameter_values",
    "method",
    "parameter_index",
    "state_index",
    "monitor_index",
    "missing_index",
]


logger = get_logger()

Next we will implement a new version of the init_state_values. We will just slightly modify the function from the python template.

def init_state_values(name, state_names, state_values, code):
    logger.debug(f"Generating init_state_values with {len(state_values)} values")
    values_comment = indent(
        "#" + functools.reduce(acc, [f"{n}={v}" for n, v in zip(state_names, state_values)]),
        "    ",
    )

    values = ", ".join(map(str, state_values))
    return dedent(
        f'''
@jax.jit
def init_state_values(**values):
    """Initialize state values
    """
{values_comment}

    {name} = numpy.array([{values}], dtype=numpy.float64)

    for key, value in values.items():
        {name} = {name}.at[state_index(key)].set(value)

    return {name}
''',
    )

Here we have added the @jax.jit decorator to the function and we have used the jax notation for assigning values at a specific index.

For init_parameter_values we do the same

def init_parameter_values(name, parameter_names, parameter_values, code):
    logger.debug(f"Generating init_parameter_values with {len(parameter_values)} values")
    values_comment = indent(
        "#"
        + functools.reduce(acc, [f"{n}={v}" for n, v in zip(parameter_names, parameter_values)]),
        "    ",
    )

    values = ", ".join(map(str, parameter_values))
    return dedent(
        f'''
@jax.jit
def init_parameter_values(**values):
    """Initialize parameter values
    """
{values_comment}

    {name} = numpy.array([{values}], dtype=numpy.float64)

    for key, value in values.items():
        {name} = {name}.at[parameter_index(key)].set(value)

    return {name}
''',
    )

Finally for the method function we just need to make sure to create the correct return array, which now should be an array containing the values _value_0, _value_1, and so on.

def method(
    name,
    args,
    states,
    parameters,
    values,
    num_return_values: int,
    missing_variables: str = "",
    **kwargs,
):
    logger.debug(f"Generating method '{name}', with {num_return_values} return values.")
    if len(kwargs) > 0:
        logger.debug(f"Unused kwargs: {kwargs}")

    return_name_lst = (
        ["numpy.array(["] + [f"_values_{i}, " for i in range(num_return_values)] + ["])"]
    )
    indent_return = indent(f"return {''.join(return_name_lst)}", "    ")
    indent_missing_variables = indent(missing_variables, "    ")
    indent_states = indent(states, "    ")
    indent_parameters = indent(parameters, "    ")
    indent_values = indent(values, "    ")

    return dedent(
        f"""
@jax.jit
def {name}({args}):

    # Assign states
{indent_states}

    # Assign parameters
{indent_parameters}
{indent_missing_variables}
    # Assign expressions
{indent_values}

{indent_return}
""",
    )

Summary#

# codegen.py
import sympy

from gotranx.codegen.python import PythonCodeGenerator, GotranPythonCodePrinter

import template

class JaxPrinter(GotranPythonCodePrinter):
    def _print_Assignment(self, expr):
        sym, value = expr.lhs, expr.rhs
        if isinstance(sym, sympy.tensor.indexed.Indexed):
            if sym.base.name == "values":
                index = self._print(sym.indices[0])
                return f"_{sym.base.name}_{index} = {self._print(value)}"

        return super()._print_Assignment(expr)


class JaxCodeGenerator(PythonCodeGenerator):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self._printer = JaxPrinter()

    def imports(self) -> str:
        return "\n".join(
            [
                "import jax",
                "import jax.numpy as numpy",
                'jax.config.update("jax_enable_x64", True)',
            ]
        )

    @property
    def template(self):
        return template

# template.py


from __future__ import annotations
from textwrap import dedent, indent
import functools
from structlog import get_logger

from gotranx.templates.python import acc, state_index, parameter_index, monitor_index, missing_index


__all__ = [
    "init_state_values",
    "init_parameter_values",
    "method",
    "parameter_index",
    "state_index",
    "monitor_index",
    "missing_index",
]


logger = get_logger()


def init_state_values(name, state_names, state_values, code):
    logger.debug(f"Generating init_state_values with {len(state_values)} values")
    values_comment = indent(
        "#" + functools.reduce(acc, [f"{n}={v}" for n, v in zip(state_names, state_values)]),
        "    ",
    )

    values = ", ".join(map(str, state_values))
    return dedent(
        f'''
@jax.jit
def init_state_values(**values):
    """Initialize state values
    """
{values_comment}

    {name} = numpy.array([{values}], dtype=numpy.float64)

    for key, value in values.items():
        {name} = {name}.at[state_index(key)].set(value)

    return {name}
''',
    )


def init_parameter_values(name, parameter_names, parameter_values, code):
    logger.debug(f"Generating init_parameter_values with {len(parameter_values)} values")
    values_comment = indent(
        "#"
        + functools.reduce(acc, [f"{n}={v}" for n, v in zip(parameter_names, parameter_values)]),
        "    ",
    )

    values = ", ".join(map(str, parameter_values))
    return dedent(
        f'''
@jax.jit
def init_parameter_values(**values):
    """Initialize parameter values
    """
{values_comment}

    {name} = numpy.array([{values}], dtype=numpy.float64)

    for key, value in values.items():
        {name} = {name}.at[parameter_index(key)].set(value)

    return {name}
''',
    )


def method(
    name,
    args,
    states,
    parameters,
    values,
    num_return_values: int,
    missing_variables: str = "",
    **kwargs,
):
    logger.debug(f"Generating method '{name}', with {num_return_values} return values.")
    if len(kwargs) > 0:
        logger.debug(f"Unused kwargs: {kwargs}")

    return_name_lst = (
        ["numpy.array(["] + [f"_values_{i}, " for i in range(num_return_values)] + ["])"]
    )
    indent_return = indent(f"return {''.join(return_name_lst)}", "    ")
    indent_missing_variables = indent(missing_variables, "    ")
    indent_states = indent(states, "    ")
    indent_parameters = indent(parameters, "    ")
    indent_values = indent(values, "    ")

    return dedent(
        f"""
@jax.jit
def {name}({args}):

    # Assign states
{indent_states}

    # Assign parameters
{indent_parameters}
{indent_missing_variables}
    # Assign expressions
{indent_values}

{indent_return}
""",
    )

Finally let us print out the generated code

print(code)
import jax
import jax.numpy as numpy

jax.config.update("jax_enable_x64", True)
parameter = {"beta": 0, "rho": 1, "sigma": 2}


def parameter_index(name: str) -> int:
    """Return the index of the parameter with the given name

    Arguments
    ---------
    name : str
        The name of the parameter

    Returns
    -------
    int
        The index of the parameter

    Raises
    ------
    KeyError
        If the name is not a valid parameter
    """

    return parameter[name]


state = {"x": 0, "y": 1, "z": 2}


def state_index(name: str) -> int:
    """Return the index of the state with the given name

    Arguments
    ---------
    name : str
        The name of the state

    Returns
    -------
    int
        The index of the state

    Raises
    ------
    KeyError
        If the name is not a valid state
    """

    return state[name]


monitor = {"dx_dt": 0, "dy_dt": 1, "dz_dt": 2}


def monitor_index(name: str) -> int:
    """Return the index of the monitor with the given name

    Arguments
    ---------
    name : str
        The name of the monitor

    Returns
    -------
    int
        The index of the monitor

    Raises
    ------
    KeyError
        If the name is not a valid monitor
    """

    return monitor[name]


@jax.jit
def init_parameter_values(**values):
    """Initialize parameter values"""
    # beta=2.4, rho=21.0, sigma=12.0

    parameters = numpy.array([2.4, 21.0, 12.0], dtype=numpy.float64)

    for key, value in values.items():
        parameters = parameters.at[parameter_index(key)].set(value)

    return parameters


@jax.jit
def init_state_values(**values):
    """Initialize state values"""
    # x=1.0, y=2.0, z=3.05

    states = numpy.array([1.0, 2.0, 3.05], dtype=numpy.float64)

    for key, value in values.items():
        states = states.at[state_index(key)].set(value)

    return states


@jax.jit
def rhs(t, states, parameters):

    # Assign states
    x = states[0]
    y = states[1]
    z = states[2]

    # Assign parameters
    beta = parameters[0]
    rho = parameters[1]
    sigma = parameters[2]

    # Assign expressions
    dx_dt = sigma * (-x + y)
    _values_0 = dx_dt
    dy_dt = x * (rho - z) - y
    _values_1 = dy_dt
    dz_dt = -beta * z + x * y
    _values_2 = dz_dt

    return numpy.array(
        [
            _values_0,
            _values_1,
            _values_2,
        ]
    )


@jax.jit
def monitor_values(t, states, parameters):

    # Assign states
    x = states[0]
    y = states[1]
    z = states[2]

    # Assign parameters
    beta = parameters[0]
    rho = parameters[1]
    sigma = parameters[2]

    # Assign expressions
    dx_dt = sigma * (-x + y)
    _values_0 = dx_dt
    dy_dt = x * (rho - z) - y
    _values_1 = dy_dt
    dz_dt = -beta * z + x * y
    _values_2 = dz_dt

    return numpy.array(
        [
            _values_0,
            _values_1,
            _values_2,
        ]
    )


@jax.jit
def forward_explicit_euler(states, t, dt, parameters):

    # Assign states
    x = states[0]
    y = states[1]
    z = states[2]

    # Assign parameters
    beta = parameters[0]
    rho = parameters[1]
    sigma = parameters[2]

    # Assign expressions
    dx_dt = sigma * (-x + y)
    _values_0 = dt * dx_dt + x
    dy_dt = x * (rho - z) - y
    _values_1 = dt * dy_dt + y
    dz_dt = -beta * z + x * y
    _values_2 = dt * dz_dt + z

    return numpy.array(
        [
            _values_0,
            _values_1,
            _values_2,
        ]
    )