from __future__ import annotations
import abc
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple
import dolfinx
import numpy as np
import numpy.typing as npt
from .utils import local_project
EPS = 1e-12
[docs]
class ODEResults(NamedTuple):
y: npt.NDArray[np.float64]
t: npt.NDArray[np.float64]
def solve(
fun: np.NDArray,
t_bound: float,
states: np.NDArray,
V: np.NDArray,
V_index: int,
dt: float,
parameters: np.NDArray,
t0: float = 0.0,
extra: dict[str, float | npt.NDArray] | None = None,
):
if extra is None:
extra = {}
i = 0
t = t0
while t + dt < t_bound:
fun(states=states, t=t, parameters=parameters, dt=dt, **extra)
V[i, :] = states[V_index, :]
i += 1
t += dt
[docs]
@dataclass
class ODESystemSolver:
fun: Callable
states: npt.NDArray
parameters: npt.NDArray
@property
def num_points(self) -> int:
return self.states.shape[1]
@property
def num_states(self) -> int:
return self.states.shape[0]
def step(self, t0: float, dt: float) -> None:
self.states[:] = self.fun(states=self.states, t=t0, parameters=self.parameters, dt=dt)
[docs]
class BaseDolfinODESolver(abc.ABC):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
_metadata: dict[str, Any] | None = None
def _initialize_metadata(self):
if self.v_ode.ufl_element().family_name == "Quadrature":
self._metadata = {"quadrature_degree": self.v_ode.ufl_element().degree()}
else:
self._metadata = None
@abc.abstractmethod
def to_dolfin(self) -> None:
pass
@abc.abstractmethod
def from_dolfin(self) -> None:
pass
[docs]
def ode_to_pde(self) -> None:
"""Projects v_ode (DG0, quadrature space, ...) into v_pde (CG1)"""
local_project(
self.v_ode,
self.v_pde.function_space,
self.v_pde,
)
[docs]
def pde_to_ode(self) -> None:
"""Projects v_pde (CG1) into v_ode (DG0, quadrature space, ...)"""
local_project(
self.v_pde,
self.v_ode.function_space,
self.v_ode,
)
@abc.abstractmethod
def step(self, t0: float, dt: float) -> None:
pass
@property
@abc.abstractmethod
def full_values(self) -> npt.NDArray:
pass
@abc.abstractmethod
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
pass
@abc.abstractmethod
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
pass
[docs]
@dataclass
class DolfinODESolver(BaseDolfinODESolver):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
init_states: npt.NDArray
parameters: npt.NDArray
fun: Callable
num_states: int
v_index: int = 0
def __post_init__(self):
if np.shape(self.init_states) == self.shape:
self._values = np.copy(self.init_states)
else:
self._values = np.zeros(self.shape)
self._values.T[:] = self.init_states
self._ode = ODESystemSolver(
fun=self.fun,
states=self._values,
parameters=self.parameters,
)
self._initialize_metadata()
[docs]
def to_dolfin(self) -> None:
"""Assign values from numpy array to dolfin function"""
self.v_ode.x.array[:] = self._values[self.v_index, :]
[docs]
def from_dolfin(self) -> None:
"""Assign values from dolfin function to numpy array"""
self._values[self.v_index, :] = self.v_ode.x.array
@property
def values(self):
return self._values
@property
def num_parameters(self) -> int:
return len(self.parameters)
@property
def shape(self) -> tuple[int, int]:
return (self.num_states, self.num_points)
@property
def num_points(self) -> int:
return self.v_ode.x.array.size
def step(self, t0: float, dt: float):
self._ode.step(t0=t0, dt=dt)
@property
def full_values(self):
return self._values
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
num_states = self._values.shape[0]
assert len(functions) == num_states, "Number of functions must match number of states"
for index, f in enumerate(functions):
f.x.array[:] = self._values[index, :]
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
V = self.v_ode.function_space
functions = []
num_states = self._values.shape[0]
if names is not None:
msg = (
"Number of names must match number of states, got "
f"{len(names)} names, but number of states is {num_states}"
)
assert len(names) == num_states, msg
else:
names = [f"state_{i}" for i in range(num_states)]
for name in names:
f = dolfinx.fem.Function(V, name=name)
functions.append(f)
self.assign_all_states(functions)
return functions
[docs]
@dataclass
class DolfinMultiODESolver(BaseDolfinODESolver):
v_ode: dolfinx.fem.Function
v_pde: dolfinx.fem.Function
markers: dolfinx.fem.Function
init_states: dict[int, npt.NDArray]
parameters: dict[int, npt.NDArray]
fun: dict[int, Callable]
num_states: dict[int, int]
v_index: dict[int, int]
def __post_init__(self):
if self.v_ode.x.array.size != self.markers.x.array.size:
raise RuntimeError("Marker and voltage need to be in the same function space")
self._marker_values = tuple(self.init_states.keys())
self._num_points = {}
self._odes = {}
self._values = {}
self._inds = {}
self._initialize_full_values()
for marker in self._marker_values:
where = self.markers.x.array == marker
self._num_points[marker] = where.sum()
self._inds[marker] = where
if np.shape(self.init_states[marker]) == self.shape(marker):
self._values[marker] = np.copy(self.init_states[marker])
else:
self._values[marker] = np.zeros(self.shape(marker))
self._values[marker].T[:] = self.init_states[marker]
self._odes[marker] = ODESystemSolver(
fun=self.fun[marker],
states=self._values[marker],
parameters=self.parameters[marker],
)
self._initialize_metadata()
def _initialize_full_values(self):
self._all_states_equal_size = (
np.array(tuple(self.num_states.values())) == tuple(self.num_states.values())[0]
).all()
if self._all_states_equal_size:
self._full_values = np.zeros(
(next(iter(self.num_states.values())), self.markers.x.array.size),
)
[docs]
def to_dolfin(self) -> None:
"""Assign values from numpy array to dolfinx function"""
arr = self.v_ode.x.array.copy()
for marker in self._marker_values:
arr[self._inds[marker]] = self._values[marker][self.v_index[marker], :]
self.v_ode.x.array[:] = arr
[docs]
def from_dolfin(self) -> None:
"""Assign values from dolfinx function to numpy array"""
arr = self.v_ode.x.array
for marker in self._marker_values:
self._values[marker][self.v_index[marker], :] = arr[self._inds[marker]]
def values(self, marker: int) -> npt.NDArray:
return self._values[marker]
def num_parameters(self, marker: int) -> int:
return len(self.parameters[marker])
def shape(self, marker: int) -> tuple[int, int]:
return (self.num_states[marker], self._num_points[marker])
def num_points(self, marker: int) -> int:
return self._num_points[marker]
def step(self, t0: float, dt: float):
for ode in self._odes.values():
ode.step(t0=t0, dt=dt)
def assign_all_states(self, functions: list[dolfinx.fem.Function]) -> None:
num_states = self._values[self._marker_values[0]].shape[0]
assert len(functions) == num_states, "Number of functions must match number of states"
for index, f in enumerate(functions):
for marker in self._marker_values:
f.x.array[self._inds[marker]] = self._values[marker][index, :]
def states_to_dolfin(self, names: list[str] | None = None) -> list[dolfinx.fem.Function]:
V = self.v_ode.function_space
functions = []
num_states = self._values[self._marker_values[0]].shape[0]
if names is not None:
msg = (
"Number of names must match number of states, got "
f"{len(names)} names, but number of states is {num_states}"
)
assert len(names) == num_states, msg
else:
names = [f"state_{i}" for i in range(num_states)]
for name in names:
f = dolfinx.fem.Function(V, name=name)
functions.append(f)
self.assign_all_states(functions)
return functions
@property
def full_values(self):
if not self._all_states_equal_size:
msg = (
"Cannot get full values size states are not of equal size. "
f"Have {self.num_states=}, use .values(marker) instead"
)
raise RuntimeError(msg)
for marker in self._marker_values:
where = self.markers.x.array == marker
self._full_values[:, where] = self._values[marker]
return self._full_values