import logging
from dataclasses import dataclass, field
from typing import Any, NamedTuple
from petsc4py import PETSc
import dolfinx
import numpy as np
import ufl
logger = logging.getLogger(__name__)
from scipy.signal import find_peaks
def detect_r_peaks(ecg_signal: np.ndarray, min_distance: float = 20) -> np.ndarray:
"""
Detects R-peaks in the ECG signal.
Parameters
----------
ecg_signal : numpy.ndarray
The ECG signal to be processed. Should be filtered.
min_distance : float, optional
Minimum distance between R-peaks in milliseconds. Defaults to 20 ms.
Returns
-------
numpy.ndarray
Indices of the detected R-peaks in the ECG signal.
"""
# Add a height threshold relative to the signal's max value to avoid noisy peaks
height_threshold = 0.5 * np.max(ecg_signal) if np.max(ecg_signal) > 0 else None
peaks, _ = find_peaks(ecg_signal, distance=min_distance, height=height_threshold)
return peaks
def detect_t_end(
averaged_rr: np.ndarray,
r_peak_index: int,
window_start_offset: int = 50,
window_end_offset: int = 400,
) -> int:
"""
Detects the end of the T-wave in the averaged RR interval using a derivative-based method.
Parameters
----------
averaged_rr : numpy.ndarray
The averaged RR interval. Should not be None.
r_peak_index : int
The index of the R-peak in the averaged RR interval.
window_start_offset : int, optional
Start of the search window relative to R-peak (milliseconds). Defaults to 50.
window_end_offset : int, optional
End of the search window relative to R-peak (milliseconds). Defaults to 400.
Returns
-------
int
Index of the T-wave end relative to the start of averaged_rr.
"""
if averaged_rr is None or len(averaged_rr) == 0:
raise RuntimeError("Error: Cannot detect T-end on empty or None averaged RR interval.")
# Define the search window for the T-wave end based on offsets from R-peak
search_start = r_peak_index + window_start_offset # * sampling_rate)
search_end = r_peak_index + window_end_offset # * sampling_rate)
# Ensure indices are within the bounds of the averaged_rr array
search_start = max(0, search_start) # Ensure start is not negative
search_end = min(len(averaged_rr), search_end) # Ensure end does not exceed array length
# Check if the search window is valid
if (
search_start >= search_end or search_end - search_start < 2
): # Need at least 2 points for diff
logger.warning("Invalid or too short search window for T-end detection.")
# return None
# Extract the segment for T-wave end detection
signal_segment = averaged_rr[search_start:search_end]
# Calculate the first derivative (velocity) of the signal segment
derivative = np.diff(signal_segment)
if len(derivative) == 0:
logger.warning("Could not compute derivative for T-end detection (segment too short?).")
# return None
# Find T-peak index within the segment (relative to segment start)
t_peak_index_relative = np.argmax(
np.abs(signal_segment),
) # Find peak of T-wave (can be positive or negative)
# Search for T-end after the T-peak
search_start_tend = t_peak_index_relative # Start search slightly after T-peak
search_start_tend = max(0, search_start_tend) # Ensure start is not negative
if search_start_tend >= len(derivative):
logger.warning("T-peak is too close to the end of the search window.")
# return None
# Find the point where the derivative returns close to zero after the T-peak
# This is a simplified approach; tangent methods are more common in literature
# Find the index of the minimum derivative *after* the T-peak
min_derivative_after_peak_index = np.argmin(derivative[search_start_tend:])
# Calculate T-end relative index
t_end_index_relative = search_start_tend + min_derivative_after_peak_index
# Convert relative index back to the original averaged_rr index
t_end_index_absolute = search_start + t_end_index_relative
# Basic validation: T-end should be after R-peak
if t_end_index_absolute <= r_peak_index:
logger.warning("Detected T-end is before or at the R-peak index.")
# return None
return int(t_end_index_absolute)
# def correct_qt_interval(
# qt_interval_ms: float,
# rr_interval_duration_s: float,
# method: Literal["bazett", "fridericia"] = "bazett",
# ):
# """
# Corrects the QT interval for heart rate using Bazett's or Fridericia's formula.
# Parameters:
# ----------
# qt_interval_ms : float
# The QT interval in milliseconds. Can be None.
# rr_interval_duration_s : float
# The RR interval duration in seconds.
# method : str, optional
# The correction method ('bazett' or 'fridericia'). Defaults to 'bazett'.
# Returns:
# -------
# float
# The corrected QT interval (QTc) in milliseconds. Returns None if input is invalid.
# """
# qt_interval_s = qt_interval_ms / 1000.0 # Convert QT to seconds for formula
# if method.lower() == "bazett":
# # QTc = QT / sqrt(RR)
# qtc_s = qt_interval_s / np.sqrt(rr_interval_duration_s)
# elif method.lower() == "fridericia":
# # QTc = QT / cubic-sqrt(RR)
# qtc_s = qt_interval_s / (rr_interval_duration_s ** (1 / 3))
# else:
# raise ValueError(
# f"Invalid QTc correction method '{method}'. Use 'bazett' or 'fridericia'.",
# )
# qtc_ms = qtc_s * 1000.0 # Convert back to milliseconds
# return qtc_ms
class QTIntervalResult(NamedTuple):
qt_interval: float
start_index: int
end_index: int
def qt_interval(
t: np.ndarray,
ecg_signal: np.ndarray,
min_distance: float = 20.0,
window_start_offset: int = 50,
window_end_offset: int = 400,
) -> QTIntervalResult:
"""
Processes the ECG signal to compute the corrected QT interval (QTc).
Parameters:
----------
t : np.ndarray
Time vector corresponding to the ECG signal in seconds.
ecg_signal : np.ndarray
The ECG signal to be processed. Should be filtered.
min_distance : float, optional
Minimum distance between R-peaks in seconds. Defaults to 20 ms.
window_start_offset : int, optional
Start of the search window for T-wave end relative to R-peak (milliseconds). Defaults to 50.
window_end_offset : int, optional
End of the search window for T-wave end relative to R-peak (milliseconds). Defaults to 400.
Returns:
-------
QTIntervalResult
A named tuple containing the start index, end index, and the QT interval in seconds.
Returns None if no R-peaks are detected or if the T-end cannot be determined.
"""
r_peaks = detect_r_peaks(ecg_signal=ecg_signal, min_distance=min_distance)
assert len(r_peaks) > 0, "No R-peaks detected. Check signal quality and detection parameters."
r_peak_index = r_peaks[0]
t_end_index = detect_t_end(
ecg_signal,
r_peak_index,
window_start_offset=window_start_offset,
window_end_offset=window_end_offset,
)
qt_interval = t[t_end_index] - t[r_peak_index]
return QTIntervalResult(
start_index=r_peak_index,
end_index=t_end_index,
qt_interval=qt_interval,
)
[docs]
@dataclass
class ECGRecovery:
v: dolfinx.fem.Function
sigma_b: float | dolfinx.fem.Constant = 1.0
C_m: float | dolfinx.fem.Constant = 1.0
dx: ufl.Measure | None = None
M: float = 1.0
petsc_options: dict[str, Any] = field(
default_factory=lambda: {
"ksp_type": "cg",
"pc_type": "sor",
# "ksp_monitor": None,
"ksp_rtol": 1.0e-8,
"ksp_atol": 1.0e-8,
# "ksp_error_if_not_converged": True,
},
)
def __post_init__(self):
if self.dx is None:
self.dx = ufl.dx(domain=self.mesh, metadata={"quadrature_degree": 4})
self.sol = dolfinx.fem.Function(self.V)
w = ufl.TestFunction(self.V)
Im = ufl.TrialFunction(self.V)
self.sol = dolfinx.fem.Function(self.V)
self._lhs = -self.C_m * Im * w * self.dx
self._rhs = ufl.inner(self.M * ufl.grad(self.v), ufl.grad(w)) * self.dx
self.solver = dolfinx.fem.petsc.LinearProblem(
self._lhs,
self._rhs,
u=self.sol,
petsc_options=self.petsc_options,
)
dolfinx.fem.petsc.assemble_matrix(self.solver.A, self.solver.a)
self.solver.A.assemble()
@property
def V(self) -> dolfinx.fem.FunctionSpace:
return self.v.function_space
@property
def mesh(self) -> dolfinx.mesh.Mesh:
return self.v.function_space.mesh
def solve(self):
logger.debug("Solving ECG recovery")
with self.solver.b.localForm() as b_loc:
b_loc.set(0)
dolfinx.fem.petsc.assemble_vector(self.solver.b, self.solver.L)
self.solver.b.ghostUpdate(
addv=PETSc.InsertMode.ADD,
mode=PETSc.ScatterMode.REVERSE,
)
self.solver.solver.solve(self.solver.b, self.sol.x.petsc_vec)
self.sol.x.scatter_forward()
def eval(self, point) -> dolfinx.fem.forms.Form:
r = ufl.SpatialCoordinate(self.mesh) - dolfinx.fem.Constant(self.mesh, point)
dist = ufl.sqrt((r**2))
return dolfinx.fem.form((1 / (4 * ufl.pi * self.sigma_b)) * (self.sol / dist) * self.dx)
def _check_attr(attr: np.ndarray | None):
if attr is None:
raise AttributeError(f"Missing attribute {attr}")
# Taken from https://en.wikipedia.org/wiki/Electrocardiography
class Leads12(NamedTuple):
RA: np.ndarray
LA: np.ndarray
LL: np.ndarray
RL: np.ndarray | None = None # Do we really need this?
V1: np.ndarray | None = None
V2: np.ndarray | None = None
V3: np.ndarray | None = None
V4: np.ndarray | None = None
V5: np.ndarray | None = None
V6: np.ndarray | None = None
@property
def I(self) -> np.ndarray:
"""Voltage between the (positive) left arm (LA)
electrode and right arm (RA) electrode"""
return self.LA - self.RA
@property
def II(self) -> np.ndarray:
"""Voltage between the (positive) left leg (LL)
electrode and the right arm (RA) electrode
"""
return self.LL - self.RA
@property
def III(self) -> np.ndarray:
"""Voltage between the (positive) left leg (LL)
electrode and the left arm (LA) electrode
"""
return self.LL - self.LA
@property
def Vw(self) -> np.ndarray:
"""Wilson's central terminal"""
return (1 / 3) * (self.RA + self.LA + self.LL)
@property
def aVR(self) -> np.ndarray:
"""Lead augmented vector right (aVR) has the positive
electrode on the right arm. The negative pole is a
combination of the left arm electrode and the left leg electrode
"""
return (3 / 2) * (self.RA - self.Vw)
@property
def aVL(self) -> np.ndarray:
"""Lead augmented vector left (aVL) has the positive electrode
on the left arm. The negative pole is a combination of the right
arm electrode and the left leg electrode
"""
return (3 / 2) * (self.LA - self.Vw)
@property
def aVF(self) -> np.ndarray:
"""Lead augmented vector foot (aVF) has the positive electrode on the
left leg. The negative pole is a combination of the right arm
electrode and the left arm electrode
"""
return (3 / 2) * (self.LL - self.Vw)
@property
def V1_(self) -> np.ndarray:
_check_attr(self.V1)
return self.V1 - self.Vw
@property
def V2_(self) -> np.ndarray:
_check_attr(self.V2)
return self.V2 - self.Vw
@property
def V3_(self) -> np.ndarray:
_check_attr(self.V3)
return self.V3 - self.Vw
@property
def V4_(self) -> np.ndarray:
_check_attr(self.V4)
return self.V4 - self.Vw
@property
def V5_(self) -> np.ndarray:
_check_attr(self.V5)
return self.V5 - self.Vw
@property
def V6_(self) -> np.ndarray:
_check_attr(self.V6)
return self.V6 - self.Vw
def example(
sampling_rate_hz: int = 1000,
duration_s: float = 10,
heart_rate_bpm: float = 60,
q_offset_ms: float = 40,
s_offset_ms: float = 40,
t_peak_offset_ms: float = 200,
r_width_ms: float = 20,
q_width_ms: float = 20,
s_width_ms: float = 30,
t_width_ms: float = 60,
qrs_peak_time: float = 200,
noise_amplitude: float = 0.0,
wander_freq_hz: float = 0.2,
wander_amplitude: float = 0.1,
):
"""
Generate a synthetic ECG signal.
Parameters
----------
sampling_rate_hz : int
Sampling rate in Hz.
duration_s : float
Duration of the signal in seconds.
heart_rate_bpm : float
Heart rate in beats per minute.
q_offset_ms : float
Offset for the Q wave in milliseconds.
s_offset_ms : float
Offset for the S wave in milliseconds.
t_peak_offset_ms : float
Offset for the T peak in milliseconds.
r_width_ms : float
Width of the R wave in milliseconds.
q_width_ms : float
Width of the Q wave in milliseconds.
s_width_ms : float
Width of the S wave in milliseconds.
t_width_ms : float
Width of the T wave in milliseconds.
qrs_peak_time : float
Start time for the qrs peak time in milliseconds.
noise_amplitude : float
Amplitude of the noise to be added to the signal.
wander_freq_hz : float
Frequency of the baseline wander in Hz.
wander_amplitude : float
Amplitude of the baseline wander.
Returns
-------
t_ms : np.ndarray
Time vector in milliseconds.
ecg_signal : np.ndarray
Generated ECG signal.
"""
# Convert time parameters to milliseconds
duration_ms = duration_s * 1000
rr_interval_s = 60.0 / heart_rate_bpm
rr_interval_ms = rr_interval_s * 1000
num_beats = int(duration_s / rr_interval_s)
# Time vector in milliseconds
num_samples = int(duration_s * sampling_rate_hz)
t_ms = np.linspace(0, duration_ms, num_samples, endpoint=False)
ecg_signal = np.zeros_like(t_ms)
# Create multiple beats
for i in range(num_beats):
# R-peak time for the current beat, in milliseconds
r_peak_time_ms = (i + qrs_peak_time / 1000) * rr_interval_ms
# Calculate absolute times for other wave components for this beat
q_time_ms = r_peak_time_ms - q_offset_ms
s_time_ms = r_peak_time_ms + s_offset_ms
t_peak_time_ms = r_peak_time_ms + t_peak_offset_ms
# Add waves for the current beat
# R peak
ecg_signal += 1.0 * np.exp(-(((t_ms - r_peak_time_ms) / r_width_ms) ** 2))
# Q wave
ecg_signal -= 0.2 * np.exp(-(((t_ms - q_time_ms) / q_width_ms) ** 2))
# S wave
ecg_signal -= 0.3 * np.exp(-(((t_ms - s_time_ms) / s_width_ms) ** 2))
# T wave
ecg_signal += 0.4 * np.exp(-(((t_ms - t_peak_time_ms) / t_width_ms) ** 2))
# Add some baseline noise
if noise_amplitude > 0:
ecg_signal += noise_amplitude * np.random.randn(len(t_ms))
# Add some baseline wander (low frequency noise)
wander_freq_per_ms = wander_freq_hz / 1000.0
ecg_signal += wander_amplitude * np.sin(2 * np.pi * wander_freq_per_ms * t_ms)
return t_ms, ecg_signal