from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, NamedTuple
import dolfinx
import numpy as np
import numpy.typing as npt
from .telemetry import BaseMonitor, NullMonitor
from .utils import local_project
EPS = 1e-12
logger = logging.getLogger(__name__)
[docs]
class ODEResults(NamedTuple):
y: npt.NDArray[np.float64]
t: npt.NDArray[np.float64]
def solve(
fun: np.NDArray,
t_bound: float,
states: np.NDArray,
V: np.NDArray,
V_index: int,
dt: float,
parameters: np.NDArray,
t0: float = 0.0,
extra: dict[str, float | npt.NDArray] | None = None,
):
if extra is None:
extra = {}
i = 0
t = t0
while t + dt < t_bound:
fun(states=states, t=t, parameters=parameters, dt=dt, **extra)
V[i, :] = states[V_index, :]
i += 1
t += dt
[docs]
@dataclass
class ODESystemSolver:
fun: Callable
states: npt.NDArray
parameters: npt.NDArray
monitor: BaseMonitor = field(default_factory=NullMonitor)
@property
def num_points(self) -> int:
return self.states.shape[1]
@property
def num_states(self) -> int:
return self.states.shape[0]
def step(self, t0: float, dt: float) -> None:
with self.monitor.track_time("ode_total_step"):
with self.monitor.track_time("ode_function_call"):
updated_states = self.fun(
states=self.states,
t=t0,
parameters=self.parameters,
dt=dt,
)
with self.monitor.track_time("ode_state_update"):
self.states[:] = updated_states
[docs]
class BaseDolfinODESolver(abc.ABC):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
_metadata: dict[str, Any] | None = None
def _initialize_metadata(self):
if self.v_ode.ufl_element().family_name == "Quadrature":
self._metadata = {"quadrature_degree": self.v_ode.ufl_element().degree()}
else:
self._metadata = None
@abc.abstractmethod
def to_dolfin(self) -> None:
pass
@abc.abstractmethod
def from_dolfin(self) -> None:
pass
[docs]
def ode_to_pde(self) -> None:
"""Projects v_ode (DG0, quadrature space, ...) into v_pde (CG1)"""
local_project(
self.v_ode,
self.v_pde.function_space,
self.v_pde,
)
[docs]
def pde_to_ode(self) -> None:
"""Projects v_pde (CG1) into v_ode (DG0, quadrature space, ...)"""
local_project(
self.v_pde,
self.v_ode.function_space,
self.v_ode,
)
@abc.abstractmethod
def step(self, t0: float, dt: float) -> None:
pass
@property
@abc.abstractmethod
def full_values(self) -> npt.NDArray:
pass
@abc.abstractmethod
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
pass
@abc.abstractmethod
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
pass
[docs]
@dataclass
class DolfinODESolver(BaseDolfinODESolver):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
init_states: npt.NDArray
parameters: npt.NDArray
fun: Callable
num_states: int
v_index: int = 0
monitor: BaseMonitor = field(default_factory=NullMonitor)
def __post_init__(self):
if np.shape(self.init_states) == self.shape:
self._values = np.copy(self.init_states)
else:
self._values = np.zeros(self.shape)
self._values.T[:] = self.init_states
self._ode = ODESystemSolver(
fun=self.fun,
states=self._values,
parameters=self.parameters,
monitor=self.monitor,
)
self._initialize_metadata()
[docs]
def to_dolfin(self) -> None:
"""Assign values from numpy array to dolfin function"""
self.v_ode.x.array[:] = self._values[self.v_index, :]
[docs]
def from_dolfin(self) -> None:
"""Assign values from dolfin function to numpy array"""
self._values[self.v_index, :] = self.v_ode.x.array
@property
def values(self):
return self._values
@property
def num_parameters(self) -> int:
return len(self.parameters)
@property
def shape(self) -> tuple[int, int]:
return (self.num_states, self.num_points)
@property
def num_points(self) -> int:
return self.v_ode.x.array.size
def step(self, t0: float, dt: float):
self._ode.step(t0=t0, dt=dt)
@property
def full_values(self):
return self._values
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
num_states = self._values.shape[0]
assert len(functions) == num_states, "Number of functions must match number of states"
for index, f in enumerate(functions):
f.x.array[:] = self._values[index, :]
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
V = self.v_ode.function_space
functions = []
num_states = self._values.shape[0]
if names is not None:
msg = (
"Number of names must match number of states, got "
f"{len(names)} names, but number of states is {num_states}"
)
assert len(names) == num_states, msg
else:
names = [f"state_{i}" for i in range(num_states)]
for name in names:
f = dolfinx.fem.Function(V, name=name)
functions.append(f)
self.assign_all_states(functions)
return functions
[docs]
@dataclass
class DolfinMultiODESolver(BaseDolfinODESolver):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
markers: dolfinx.fem.Function
init_states: dict[int, npt.NDArray]
parameters: dict[int, npt.NDArray]
fun: dict[int, Callable]
num_states: dict[int, int]
v_index: dict[int, int]
monitor: BaseMonitor = field(default_factory=NullMonitor)
def __post_init__(self):
if self.v_ode.x.array.size != self.markers.x.array.size:
raise RuntimeError("Marker and voltage need to be in the same function space")
self._marker_values = tuple(self.init_states.keys())
self._num_points = {}
self._odes = {}
self._values = {}
self._inds = {}
self._initialize_full_values()
for marker in self._marker_values:
where = self.markers.x.array == marker
self._num_points[marker] = where.sum()
self._inds[marker] = where
if np.shape(self.init_states[marker]) == self.shape(marker):
self._values[marker] = np.copy(self.init_states[marker])
else:
self._values[marker] = np.zeros(self.shape(marker))
self._values[marker].T[:] = self.init_states[marker]
self._odes[marker] = ODESystemSolver(
fun=self.fun[marker],
states=self._values[marker],
parameters=self.parameters[marker],
monitor=self.monitor,
)
self._initialize_metadata()
def _initialize_full_values(self):
self._all_states_equal_size = (
np.array(tuple(self.num_states.values())) == tuple(self.num_states.values())[0]
).all()
if self._all_states_equal_size:
self._full_values = np.zeros(
(next(iter(self.num_states.values())), self.markers.x.array.size),
)
[docs]
def to_dolfin(self) -> None:
"""Assign values from numpy array to dolfinx function"""
arr = self.v_ode.x.array.copy()
for marker in self._marker_values:
arr[self._inds[marker]] = self._values[marker][self.v_index[marker], :]
self.v_ode.x.array[:] = arr
[docs]
def from_dolfin(self) -> None:
"""Assign values from dolfinx function to numpy array"""
arr = self.v_ode.x.array
for marker in self._marker_values:
self._values[marker][self.v_index[marker], :] = arr[self._inds[marker]]
def values(self, marker: int) -> npt.NDArray:
return self._values[marker]
def num_parameters(self, marker: int) -> int:
return len(self.parameters[marker])
def shape(self, marker: int) -> tuple[int, int]:
return (self.num_states[marker], self._num_points[marker])
def num_points(self, marker: int) -> int:
return self._num_points[marker]
def step(self, t0: float, dt: float):
with self.monitor.track_time("total_ode_step"):
for marker, ode in self._odes.items():
with self.monitor.track_time(f"marker_{marker}_ode_step"):
ode.step(t0=t0, dt=dt)
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
num_states = self._values[self._marker_values[0]].shape[0]
assert len(functions) == num_states, "Number of functions must match number of states"
for index, f in enumerate(functions):
for marker in self._marker_values:
f.x.array[self._inds[marker]] = self._values[marker][index, :]
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
V = self.v_ode.function_space
functions = []
num_states = self._values[self._marker_values[0]].shape[0]
if names is not None:
msg = (
"Number of names must match number of states, got "
f"{len(names)} names, but number of states is {num_states}"
)
assert len(names) == num_states, msg
else:
names = [f"state_{i}" for i in range(num_states)]
for name in names:
f = dolfinx.fem.Function(V, name=name)
functions.append(f)
self.assign_all_states(functions)
return functions
@property
def full_values(self):
if not self._all_states_equal_size:
msg = (
"Cannot get full values size states are not of equal size. "
f"Have {self.num_states=}, use .values(marker) instead"
)
raise RuntimeError(msg)
for marker in self._marker_values:
where = self.markers.x.array == marker
self._full_values[:, where] = self._values[marker]
return self._full_values