"""Figure I/O management utilities.
Provides functions for saving Matplotlib figures in various formats and
rendering them as SVG or other image formats in Jupyter environments.
"""
from __future__ import annotations
__all__ = ["save_and_show", "save_formats", "show"]
import warnings
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
from xml.dom import minidom
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from ._helpers import create_parent_path
# Image extensions matplotlib's savefig knows how to write — if a caller
# passes a path that already ends with one of these, we strip it so the
# requested ``formats`` are appended cleanly instead of producing
# ``name.png.png`` / ``name.png.svg``.
_KNOWN_IMAGE_SUFFIXES = frozenset(
{
".png",
".pdf",
".svg",
".svgz",
".eps",
".ps",
".jpg",
".jpeg",
".tif",
".tiff",
".webp",
".raw",
".rgba",
}
)
def _normalize_image_stem(image_stem: str) -> str:
"""Strip a trailing image suffix from ``image_stem`` if present.
Callers occasionally pass ``"out/chart.png"`` instead of the
documented ``"out/chart"`` (often by reusing a path variable
that already carries an extension). Without normalization
``save_formats(..., formats=("png", "svg"))`` would emit
``chart.png.png`` and ``chart.png.svg`` — silent file-naming
bugs that pollute output directories. Normalize once at the
boundary and emit a :class:`UserWarning` so call sites can be
cleaned up over time.
"""
suffix = Path(image_stem).suffix.lower()
if suffix in _KNOWN_IMAGE_SUFFIXES:
normalized = image_stem[: -len(suffix)]
warnings.warn(
(
f"save_formats: image_stem {image_stem!r} ends with image "
f"suffix {suffix!r}; stripping to {normalized!r}. "
"Pass the path *without* an extension to silence this."
),
UserWarning,
stacklevel=3,
)
return normalized
return image_stem
[docs]
def show(image_path: str, size: int = 600, unit: str = "pt") -> None:
"""Load an SVG image and display it at the specified size in a browser or Jupyter.
Parameters
----------
image_path : str
Path to the SVG image to display.
size : int, optional
Desired output width. Default is 600.
unit : str, optional
Unit for the width ('pt', 'px', etc.). Default is 'pt'.
Raises
------
ImportError
If IPython is not installed. ``show`` renders inline in Jupyter
via IPython, which is an optional extra — install it with
``pip install "dartwork-mpl[notebook]"``.
"""
try:
from IPython.display import HTML, SVG, display
except ImportError as exc: # pragma: no cover - exercised via mock
raise ImportError(
"dm.show() needs IPython for inline Jupyter display, which is "
"an optional extra. Install it with "
"'pip install \"dartwork-mpl[notebook]\"' "
"(or 'uv add \"dartwork-mpl[notebook]\"')."
) from exc
svg_obj = SVG(data=image_path) # type: ignore[no-untyped-call]
desired_width = size
# Parse SVG dimensions with defensive handling.
# The SVG payload is produced by matplotlib's own SVG backend (via
# IPython.display.SVG) on dartwork-mpl figures, never user-supplied
# XML, so XXE/billion-laughs vectors do not apply here.
dom = minidom.parseString(svg_obj.data) # trusted dartwork SVG only
doc_el = dom.documentElement
width_attr = doc_el.getAttribute("width") if doc_el else ""
height_attr = doc_el.getAttribute("height") if doc_el else ""
try:
width = float(width_attr.replace(unit, ""))
height = float(height_attr.replace(unit, ""))
except ValueError:
display(HTML(svg_obj.data)) # type: ignore[no-untyped-call]
return
if width <= 0:
display(HTML(svg_obj.data)) # type: ignore[no-untyped-call]
return
aspect_ratio = height / width
desired_height = int(desired_width * aspect_ratio)
# Replace width attribute.
for w_str in (str(width), str(int(width))):
old = f'width="{w_str}{unit}"'
if old in svg_obj.data:
svg_obj.data = svg_obj.data.replace(
old, f'width="{desired_width}{unit}"'
)
break
# Replace height attribute.
for h_str in (str(height), str(int(height))):
old = f'height="{h_str}{unit}"'
if old in svg_obj.data:
svg_obj.data = svg_obj.data.replace(
old, f'height="{desired_height}{unit}"'
)
break
display(HTML(svg_obj.data)) # type: ignore[no-untyped-call]
[docs]
def save_and_show(
fig: Figure,
image_path: str | None = None,
size: int = 600,
unit: str = "pt",
*,
adopt_orphan_tick_font: bool = True,
**kwargs: Any,
) -> None:
"""Save a figure to disk, then display it in a Jupyter or web environment.
Parameters
----------
fig : matplotlib.figure.Figure
The Matplotlib figure to save and display.
image_path : str | None, optional
Path to save the image. If None, a system temporary file is used.
size : int, optional
Display width. Default is 600.
unit : str, optional
Unit for the size ('pt', 'px', etc.). Default is 'pt'.
adopt_orphan_tick_font : bool, optional
If ``True`` (default), apply
:func:`~dartwork_mpl.layout.adopt_axis_label_font` before saving
so unlabeled axes' tick labels take the axis-label font, matching
:func:`save_formats`. Mutates the figure (see that function's
Notes). Set to ``False`` to leave tick fonts untouched.
**kwargs
Additional keyword arguments passed to ``savefig``.
"""
if adopt_orphan_tick_font:
from .layout import adopt_axis_label_font
adopt_axis_label_font(fig)
if image_path is None:
with NamedTemporaryFile(suffix=".svg", delete=False) as tmp:
tmp_path = tmp.name
try:
fig.savefig(tmp_path, bbox_inches=None, **kwargs)
plt.close(fig)
show(tmp_path, size=size, unit=unit)
finally:
Path(tmp_path).unlink(missing_ok=True)
else:
create_parent_path(image_path)
fig.savefig(image_path, bbox_inches=None, **kwargs)
plt.close(fig)
show(image_path, size=size, unit=unit)