Monodomain convergence test#
In this example, we will demonstrate how to perform a convergence test for the monodomain model using the forward Euler method for the ODE solver. We will use the same test case as in the tests/test_monodomain.py file. We will compare the error in the solution for different spatial and temporal resolutions. We will use the L2 norm of the error as the error measure.
def simple_ode_forward_euler(states, t, dt, parameters):
v, s = states
values = np.zeros_like(states)
values[0] = v - s * dt
values[1] = s + v * dt
return values
def main():
M = 1.0
T = 1.0
t0 = 0.0
comm = MPI.COMM_WORLD
odespaces = ["P_1", "P_2", "DG_1"]
Ns = [2**level for level in range(3, 8)]
dts = [2 ** (-i) for i in range(3, 9)]
fig, ax = plt.subplots(
2,
len(odespaces),
figsize=(10, 8),
sharey="row",
sharex="row",
)
for k, odespace in enumerate(odespaces):
errors = defaultdict(list)
error_fname = Path(f"convergence_{odespace}.json")
if not error_fname.is_file():
for dt in dts:
print(f"Running for dt={dt}")
for N in Ns:
print(f"Running for N={N}")
mesh = dolfinx.mesh.create_unit_square(
comm,
N,
N,
dolfinx.cpp.mesh.CellType.triangle,
)
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0))
x = ufl.SpatialCoordinate(mesh)
I_s = ac_func(x, time)
pde = beat.MonodomainModel(time=time, mesh=mesh, M=M, I_s=I_s)
V_ode = beat.utils.space_from_string(odespace, mesh, dim=1)
v_ode = dolfinx.fem.Function(V_ode)
s = dolfinx.fem.Function(V_ode)
s.interpolate(
dolfinx.fem.Expression(
s_exact_func(x, time), V_ode.element.interpolation_points(),
),
)
s_arr = s.x.array
init_states = np.zeros((2, s_arr.size))
init_states[1, :] = s_arr
ode = beat.odesolver.DolfinODESolver(
v_ode=v_ode,
v_pde=pde.state,
fun=simple_ode_forward_euler,
init_states=init_states,
parameters=None,
num_states=2,
v_index=0,
)
solver = beat.MonodomainSplittingSolver(pde=pde, ode=ode, theta=1.0)
solver.solve((t0, T), dt=dt)
v_exact = v_exact_func(x, T)
error = dolfinx.fem.form((pde.state - v_exact) ** 2 * ufl.dx)
E = np.sqrt(
comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM),
)
errors[str(dt)].append(E)
error_fname.write_text(json.dumps(errors))
errors = json.loads(error_fname.read_text())
errors_N = {}
for i, N in enumerate(Ns):
errors_N[N] = [errors[str(dt)][i] for dt in dts]
lines_dt = []
labels_dt = []
for dt, errs in errors.items():
(l,) = ax[0, k].loglog([1 / N for N in Ns], errs, "-o")
lines_dt.append(l)
labels_dt.append(f"dt={dt}")
(l,) = ax[0, k].loglog(
[1 / N for N in Ns],
[5 / N**2 for N in Ns],
"--",
color="gray",
)
lines_dt.append(l)
labels_dt.append("$O(h^2)$")
(l,) = ax[0, k].loglog(
[1 / N for N in Ns],
[0.5 / N for N in Ns],
":",
color="gray",
)
lines_dt.append(l)
labels_dt.append("$O(h)$")
ax[0, k].set_xlabel("N")
lines_N = []
labels_N = []
for N, errs in errors_N.items():
(l,) = ax[1, k].loglog(dts, errs, "-o")
lines_N.append(l)
labels_N.append(f"N={N}")
(l,) = ax[1, k].loglog(
dts,
[dt**2 for dt in dts],
"--",
color="gray",
)
lines_N.append(l)
labels_N.append(r"$O(\Delta t^2)$")
(l,) = ax[1, k].loglog(
dts,
[0.08 * dt for dt in dts],
":",
color="gray",
)
lines_N.append(l)
labels_N.append(r"$O(\Delta t)$")
ax[1, k].set_xlabel("dt")
ax[0, k].set_title(" ".join(odespace.split("_")))
if k == 0:
ax[0, k].set_ylabel("Error vs N")
ax[1, k].set_ylabel("Error vs dt")
for axi in ax.flatten():
axi.grid()
axi.set_ylim([1e-3, 0.1])
fig.subplots_adjust(right=0.8)
fig.legend(
lines_dt,
labels_dt,
loc="upper center",
bbox_to_anchor=(0.9, 0.85),
)
fig.legend(
lines_N,
labels_N,
loc="upper center",
bbox_to_anchor=(0.87, 0.4),
)
fig.savefig("convergence.png")
# rates = [np.log(e1 / e2) / np.log(2) for e1, e2 in zip(errors[:-1], errors[1:])]
# cvg_rate = sum(rates) / len(rates)
# print(rates)
# breakpoint()
# assert np.isclose(cvg_rate, 2, rtol=0.15)
if __name__ == "__main__":
main()
Running for dt=0.125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.03125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.015625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0078125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.00390625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.03125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.015625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0078125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.00390625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.03125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.015625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.0078125
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
Running for dt=0.00390625
Running for N=8
Running for N=16
Running for N=32
Running for N=64
Running for N=128
