Source code for fenicsx_plotly._fenicsx_plotly

import os
import typing
from pathlib import Path

import dolfinx
import numpy as np
import numpy.typing as npt
import plotly
import plotly.graph_objects as go
import plotly.io as pio
import ufl
from plotly.basedatatypes import BaseTraceType as _BaseTraceType

try:
    _SHOW_PLOT = bool(int(os.getenv("FENICS_PLOTLY_SHOW", 1)))
except ValueError:
    _SHOW_PLOT = True

try:
    _RENDERER = os.getenv("FENICS_PLOTLY_RENDERER", "notebook")
except ValueError:
    _RENDERER = "notebook"


def set_renderer(renderer: str) -> None:
    pio.renderers.default = renderer


set_renderer(_RENDERER)


def savefig(
    fig: go.FigureWidget,
    filename: str,
    save_config: typing.Optional[typing.Dict[str, typing.Any]] = None,
):
    """Save figure to file

    Parameters
    ----------
    fig : `plotly.graph_objects.Figure`
        This figure that you want to save
    filename : Path or str
        Path to the destination where you want to
        save the figure
    save_config : dict, optional
        Additional configurations to be passed
        to `plotly.offline.plot`, by default None
    """

    fname = Path(filename)
    outdir = fname.parent
    assert outdir.exists(), f"Folder {outdir} does not exist"

    config = {
        "toImageButtonOptions": {
            "filename": fname.stem,
            "width": 1500,
            "height": 1200,
        },
    }
    if save_config is not None:
        config.update(save_config)

    plotly.offline.plot(fig, filename=fname.as_posix(), auto_open=False, config=config)


def _get_triangles(mesh: dolfinx.mesh.Mesh) -> npt.NDArray[np.int32]:
    faces = dolfinx.mesh.locate_entities(
        mesh,
        2,
        lambda x: np.full(x.shape[1], True, dtype=bool),
    )

    mesh.topology.create_connectivity(2, 0)
    conn = mesh.topology.connectivity(2, 0)
    triangle = np.zeros((3, faces.size), dtype=int)

    for face in faces:
        # FIXME: Should be possible to do this vectorized!
        triangle[:, face] = conn.links(face)

    return triangle


def _surface_plot_mesh(
    mesh: dolfinx.mesh.Mesh,
    color: str = "gray",
    opacity: float = 1.0,
    **kwargs,
):
    coord = mesh.geometry.x
    triangle = _get_triangles(mesh)
    if len(coord[0, :]) == 2:
        coord = np.c_[coord, np.zeros(len(coord[:, 0]))]

    surface = go.Mesh3d(
        x=coord[:, 0],
        y=coord[:, 1],
        z=coord[:, 2],
        i=triangle[0, :],
        j=triangle[1, :],
        k=triangle[2, :],
        flatshading=True,
        color=color,
        opacity=opacity,
        lighting=dict(ambient=1),
    )

    return surface


def _get_cells(mesh: dolfinx.mesh.Mesh) -> np.ndarray:
    dm = mesh.geometry.dofmap
    return dm.T


def _wireframe_plot_mesh(mesh: dolfinx.mesh.Mesh, **kwargs) -> go.Scatter3d:
    coord = mesh.geometry.x

    if len(coord[0, :]) == 2:
        coord = np.c_[coord, np.zeros(len(coord[:, 0]))]

    cells = _get_cells(mesh)

    X = []
    Y = []
    Z = []
    for c in cells:
        X.extend(coord[c, :][:, 0].tolist() + [None])
        Y.extend(coord[c, :][:, 1].tolist() + [None])
        Z.extend(coord[c, :][:, 2].tolist() + [None])

    # define the trace for triangle sides
    lines = go.Scatter3d(
        x=X,
        y=Y,
        z=Z,
        mode="lines",
        name="",
        line=dict(color="rgb(70,70,70)", width=2),
        hoverinfo="none",
    )

    return lines


def _plot_dofs(
    functionspace: dolfinx.fem.FunctionSpace,
    size: int,
    **kwargs,
) -> go.Scatter3d:
    dofs_coord = functionspace.tabulate_dof_coordinates()
    if len(dofs_coord[0, :]) == 2:
        dofs_coord = np.c_[dofs_coord, np.zeros(len(dofs_coord[:, 0]))]

    points = go.Scatter3d(
        x=dofs_coord[:, 0],
        y=dofs_coord[:, 1],
        z=dofs_coord[:, 2],
        mode="markers",
        name=kwargs.get("name", None),
        marker=dict(size=size),
    )

    return points


def _get_vertex_values(function: dolfinx.fem.Function) -> np.ndarray:
    fs = function.function_space
    mesh = fs.mesh
    shape = function.ufl_shape

    if len(shape) == 0:  # FiniteElement
        el = fs.ufl_element()
        # TODO: Ask Jørgen if there is a better way
        # Where is the `.compute_vertex_values()` method?
        if (el.family(), el.degree()) != ("P", 1):
            # Interpolate into a linear lagrange space
            V = dolfinx.fem.FunctionSpace(mesh, ("P", 1))
            u = dolfinx.fem.Function(V)
            u.interpolate(function)
            res = u.x.array
        else:
            res = function.x.array
    elif len(shape) == 1:  # Vector Element
        res = np.zeros((mesh.geometry.x.shape[0], shape[0]))
        for i in range(shape[0]):
            res[:, i] = _get_vertex_values(function.sub(i).collapse())

    else:  # Tensor Element
        res = np.zeros((mesh.geometry.x.shape[0], shape[0], shape[1]))
        count = 0
        for i in range(shape[0]):
            for j in range(shape[0]):
                res[:, i, j] = _get_vertex_values(function.sub(count + j).collapse())
            count += shape[0]
    return res


def _surface_plot_function(
    function: dolfinx.fem.Function,
    colorscale: str = "inferno",
    showscale: bool = True,
    intensitymode: str = "vertex",
    **kwargs,
) -> go.Mesh3d:
    fs = function.function_space
    mesh = fs.mesh

    val = _get_vertex_values(function=function)

    triangle = _get_triangles(mesh)

    coord = mesh.geometry.x

    hoverinfo = ["val:" + "%.5f" % item for item in val]

    if len(coord[0, :]) == 2:
        coord = np.c_[coord, np.zeros(len(coord[:, 0]))]

    surface = go.Mesh3d(
        x=coord[:, 0],
        y=coord[:, 1],
        z=coord[:, 2],
        i=triangle[0, :],
        j=triangle[1, :],
        k=triangle[2, :],
        flatshading=True,
        intensitymode=intensitymode,
        intensity=val,
        colorscale=colorscale,
        lighting=dict(ambient=1),
        name="",
        hoverinfo="all",
        text=hoverinfo,
        showscale=showscale,
    )

    return surface


def _scatter_plot_function(
    function: dolfinx.fem.Function,
    colorscale,
    showscale=True,
    size=10,
    **kwargs,
) -> go.Scatter3d:
    dofs_coord = function.function_space.tabulate_dof_coordinates()
    if len(dofs_coord[0, :]) == 2:
        dofs_coord = np.c_[dofs_coord, np.zeros(len(dofs_coord[:, 0]))]

    mesh = function.function_space.mesh
    val = function.x.array
    coord = mesh.geometry.x
    hoverinfo = ["val:" + "%.5f" % item for item in val]

    if len(coord[0, :]) == 2:
        coord = np.c_[coord, np.zeros(len(coord[:, 0]))]

    points = go.Scatter3d(
        x=dofs_coord[:, 0],
        y=dofs_coord[:, 1],
        z=dofs_coord[:, 2],
        mode="markers",
        marker=dict(size=size, color=val, colorscale=colorscale),
        hoverinfo="all",
        text=hoverinfo,
    )

    return points


def _cone_plot(
    function: dolfinx.fem.Function,
    size: int = 10,
    showscale: bool = True,
    normalize: bool = False,
    **kwargs,
) -> go.Cone:
    mesh = function.function_space.mesh
    points = mesh.geometry.x
    vectors = _get_vertex_values(function)

    if len(points[0, :]) == 2:
        points = np.c_[points, np.zeros(len(points[:, 0]))]

    if vectors.shape[1] == 2:
        vectors = np.c_[vectors, np.zeros(len(vectors[:, 0]))]

    if normalize:
        vectors = np.divide(vectors.T, np.linalg.norm(vectors, axis=1)).T

    cones = go.Cone(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        u=vectors[:, 0],
        v=vectors[:, 1],
        w=vectors[:, 2],
        sizemode="absolute",
        sizeref=size,
        showscale=showscale,
    )

    return cones


def _handle_mesh(obj: dolfinx.mesh.Mesh, **kwargs) -> list[_BaseTraceType]:
    data = []
    wireframe = bool(kwargs.get("wireframe", False))
    if not wireframe:
        surf = _surface_plot_mesh(obj, **kwargs)
        data.append(surf)

    data.append(_wireframe_plot_mesh(obj))

    return data


def _handle_function_space(
    obj: dolfinx.fem.FunctionSpace,
    **kwargs,
) -> list[_BaseTraceType]:
    data = []
    points = _plot_dofs(obj, **kwargs)
    data.append(points)

    if kwargs.get("wireframe", True):
        lines = _wireframe_plot_mesh(obj.mesh, **kwargs)
        data.append(lines)
    return data


def _handle_scalar_function(
    obj: dolfinx.fem.Function,
    scatter: bool = False,
    **kwargs,
) -> _BaseTraceType:
    if scatter:
        surface = _scatter_plot_function(obj, **kwargs)
    else:
        surface = _surface_plot_function(obj, **kwargs)
    return surface


def _handle_vector_function(
    obj: dolfinx.fem.Function,
    component: typing.Optional[str] = None,
    **kwargs,
) -> _BaseTraceType:
    fs = obj.function_space

    if component is None:
        return _cone_plot(obj, **kwargs)

    elif component == "magnitude":
        V, _ = obj.function_space.sub(0).collapse()
        magnitude = dolfinx.fem.Function(V)
        magnitude.interpolate(
            dolfinx.fem.Expression(
                ufl.sqrt(ufl.inner(obj, obj)),
                V.element.interpolation_points(),
            ),
        )
        return _surface_plot_function(magnitude, **kwargs)

    else:
        # Extract x, y or z
        i = {"x": 0, "y": 1, "z": 2}[component.lower()]
        if i >= fs.num_sub_spaces:
            raise RuntimeError(
                f"Cannot extract component from subspace {i} for"
                f" function space with {fs.num_sub_spaces}"
                " number of subspaces.",
            )
        return _surface_plot_function(obj.sub(i).collapse(), **kwargs)


def _handle_function(
    obj: dolfinx.fem.Function,
    **kwargs,
) -> list[_BaseTraceType]:
    data = []

    if len(obj.ufl_shape) == 0:  # Scalar Function
        data.append(_handle_scalar_function(obj, **kwargs))

    elif len(obj.ufl_shape) == 1:  # Vector Function
        data.append(_handle_vector_function(obj, **kwargs))

    if kwargs.get("wireframe", True):
        lines = _wireframe_plot_mesh(obj.function_space.mesh)
        data.append(lines)

    return data


def _handle_meshtags(
    obj: dolfinx.mesh.MeshTags,
    colorscale: str = "inferno",
    **kwargs,
) -> list[_BaseTraceType]:
    data = []
    if obj.dim != 2:
        raise NotImplementedError("Plotting of MeshTags is only supported for facets")
    mesh = kwargs.get("mesh")
    if mesh is None:
        raise RuntimeError(
            "Please provide mesh as a keyword argument when plotting MeshTags",
        )

    # array = meshfunc.array()
    coord = mesh.geometry.x
    if len(coord[0, :]) == 2:
        coord = np.c_[coord, np.zeros(len(coord[:, 0]))]

    triangle = _get_triangles(mesh)
    array = np.zeros(triangle.shape[1])
    array[obj.indices] = obj.values

    hoverinfo = ["val:" + "%d" % item for item in array]

    data.append(
        go.Mesh3d(
            x=coord[:, 0],
            y=coord[:, 1],
            z=coord[:, 2],
            i=triangle[0, :],
            j=triangle[1, :],
            k=triangle[2, :],
            flatshading=True,
            intensity=array,
            colorscale=colorscale,
            lighting=dict(ambient=1),
            name="",
            hoverinfo="all",
            text=hoverinfo,
            intensitymode="cell",
        ),
    )

    if kwargs.get("wireframe", True):
        lines = _wireframe_plot_mesh(mesh)
        data.append(lines)

    return data


def _plot_dirichlet_bc(
    obj: dolfinx.fem.bcs.DirichletBC,
    size: int = 10,
    colorscale: str = "inferno",
    **kwargs,
) -> list[_BaseTraceType]:
    return _scatter_plot_function(
        function=obj.g,
        size=size,
        colorscale=colorscale,
        **kwargs,
    )


def _handle_dirichlet_bc(
    obj: dolfinx.fem.bcs.DirichletBC,
    **kwargs,
) -> list[_BaseTraceType]:
    data = []
    points = _plot_dirichlet_bc(obj, **kwargs)
    data.append(points)

    lines = _wireframe_plot_mesh(obj.function_space.mesh, **kwargs)
    data.append(lines)
    return data


class FEniCSPlotFig:
    def __init__(self, fig: go.FigureWidget) -> None:
        self.figure = fig

    def add_plot(self, fig: go.FigureWidget) -> None:
        data = list(self.figure.data) + list(fig.figure.data)
        self.figure = go.FigureWidget(data=data, layout=self.figure.layout)

    def show(self) -> None:
        if _SHOW_PLOT:
            self.figure.show()

    def save(self, filename: str) -> None:
        savefig(self.figure, filename)


[docs]def plot( obj, colorscale: str = "inferno", wireframe: bool = True, scatter: bool = False, size: int = 10, name: str = "f", color: str = "gray", opacity: float = 1.0, show_grid: bool = False, size_frame: typing.Optional[typing.Tuple[int, int]] = None, background: typing.Tuple[int, int, int] = (242, 242, 242), normalize: bool = False, component: typing.Optional[str] = None, showscale: bool = True, show: bool = True, filename: typing.Optional[str] = None, **kwargs, ) -> FEniCSPlotFig: """Plot FEniCSx object Parameters ---------- obj : Mesh, Function. FunctionSpace, MeshFunction, DirichletBC FEniCSx object to be plotted colorscale : str, optional The colorscale, by default "inferno" wireframe : bool, optional Whether you want to show the mesh in wireframe, by default True scatter : bool, optional Plot function as scatter plot, by default False size : int, optional Size of scatter points, by default 10 name : str, optional Name to show up in legend, by default "f" color : str, optional Color to be plotted on the mesh, by default "gray" opacity : float, optional opacity of surface, by default 1.0 show_grid : bool, optional Show x, y (and z) axis grid, by default False size_frame : [type], optional Size of plot, by default None background : tuple, optional Background of plot, by default (242, 242, 242) normalize : bool, optional For vectors, normalize then to have unit length, by default False component : [type], optional Plot a component (["Magnitude", "x", "y", "z"]) for vector, by default None showscale : bool, optional Show colorbar, by default True show : bool, optional Show figure, by default True filename : [type], optional Path to file where you want to save the figure, by default None Raises ------ TypeError If object to be plotted is not recognized. """ if isinstance(obj, dolfinx.mesh.Mesh): handle = _handle_mesh elif isinstance(obj, dolfinx.fem.Function): handle = _handle_function elif isinstance(obj, dolfinx.mesh.MeshTags): handle = _handle_meshtags elif isinstance(obj, dolfinx.fem.FunctionSpace): handle = _handle_function_space elif isinstance(obj, dolfinx.fem.bcs.DirichletBC): handle = _handle_dirichlet_bc else: raise TypeError(f"Cannot plot object of type {type(obj)}") data = handle( obj, scatter=scatter, colorscale=colorscale, normalize=normalize, size=size, size_frame=size_frame, component=component, opacity=opacity, show_grid=show_grid, color=color, wireframe=wireframe, showscale=showscale, name=name, **kwargs, ) layout = go.Layout( scene_xaxis_visible=show_grid, scene_yaxis_visible=show_grid, scene_zaxis_visible=show_grid, paper_bgcolor="rgb" + str(background), margin=dict(l=80, r=80, t=50, b=50), scene=dict(aspectmode="data"), ) if size_frame is not None: layout.update(width=size_frame[0], height=size_frame[1]) fig = go.FigureWidget(data=data, layout=layout) fig.update_layout(hovermode="closest") if filename is not None: savefig(fig, filename) if show and _SHOW_PLOT: fig.show() return FEniCSPlotFig(fig)