Source code for gotran.solver.scipysolver

import warnings
from copy import deepcopy
import numpy as np
try:
    import scipy.integrate as spi
    has_scipy = True
except:
    has_scipy = False

from gotran.common import warning
from .odesolver import Solver, ODESolverError

__all__ = ["ScipySolver", "has_scipy"]


[docs]class ScipySolver(Solver): def __init__(self, ode, **options): msg = "Chosen backend is scipy, but scipy is not installed" assert has_scipy, msg Solver.__init__(self, ode, **options) self._options = ScipySolver.list_solver_options() self._options.update((k,v) for k,v in options.items() \ if k in list(self._options.keys()))
[docs] def get_options(self): return self._options
# @staticmethod # def list_solver_options(): # return {'atol': 1e-6, # 'max_step': np.inf, # 'rtol': 1e-6} # These are the old ones
[docs] @staticmethod def list_solver_options(): return {'atol': None, 'col_deriv': 0, 'full_output': 0, 'h0': 0.0, 'hmax': 0.0, 'hmin': 0.0, 'ixpr': 0, 'ml': None, 'mu': None, 'mxhnil': 0, 'mxordn': 12, 'mxords': 5, 'mxstep': 0, 'printmessg': 0, 'rtol': None, 'tcrit': None}
def _solve(self, tsteps, attempts=3): """ Solve ode using scipy.integrade.odeint Arguments --------- tsteps : array The time steps attempts : int If integration fails, i.e the solver does not converge, we could reduce the step size and try again. This varible controls how many time we should try to solve the problem. Default: 3 """ # Some flags it = 0 converged = False # Get solver options options = deepcopy(self._options) while it < attempts and not converged: # Somehow scipy only display a warning if the ODE itegrator fails. # We can record these warnings using the warning module with warnings.catch_warnings(record=True) as caught_warnings: # Allways catch warnings (not only the first) warnings.simplefilter("always") fun=lambda y, t: self._rhs(t, y, self._model_params) # Solve ode results = spi.odeint(fun, self._y0, tsteps, Dfun=self._jac, **options) t, y = tsteps, results # fun=lambda t, y: self._rhs(t, y, self._model_params) # results = spi.solve_ivp(fun=fun, # y0=self._y0, # t_span=[tsteps[0], tsteps[-1]], # method='BDF', #'BDF', 'LSODA' 'Radau' # t_eval=tsteps, # jac=self._jac, # **options) # t, y = results.t, results.y # Check if we caught any warnings converged = len(caught_warnings) == 0 it += 1 # If we did, reduce maximum step size options["hmax"] /= 2.0 # If we still caught some warnings raise exception if len(caught_warnings) > 0: for w in caught_warnings: msg="Catched warning {}\n{}".format(w.category, w.message) warning(msg) if w.category == spi.odepack.ODEintWarning: raise ODESolverError(msg) return t, y