from pathlib import Path
from shutil import copy2
from tempfile import NamedTemporaryFile
from xml.dom import minidom
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, SVG, display
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.transforms import ScaledTranslation
from scipy.optimize import OptimizeResult, minimize
def _create_parent_path_if_not_exists(path: str | Path) -> None:
"""
Create parent directory if it doesn't exist.
Parameters
----------
path : str or Path
Path to check and create parent directory for.
"""
path = Path(path)
if not path.parent.exists():
path.parent.mkdir(parents=True)
[docs]
def set_decimal(ax: Axes, xn: int | None = None, yn: int | None = None) -> None:
"""
Set decimal places for tick labels.
Parameters
----------
ax : matplotlib.axes.Axes
Axes object to modify.
xn : int, optional
Number of decimal places for x-axis tick labels.
yn : int, optional
Number of decimal places for y-axis tick labels.
"""
if xn is not None:
xticks = ax.get_xticks()
ax.set_xticks(xticks)
ax.set_xticklabels([f"{x:.{xn}f}" for x in xticks])
if yn is not None:
yticks = ax.get_yticks()
ax.set_yticks(yticks)
ax.set_yticklabels([f"{y:.{yn}f}" for y in yticks])
[docs]
def get_bounding_box(boxes: list) -> tuple[float, float, float, float]:
"""
Get the bounding box that contains all given boxes.
Parameters
----------
boxes : list
List of box objects with p0, width, and height attributes.
Returns
-------
tuple
(min_x, min_y, bbox_width, bbox_height) of the bounding box.
"""
# Initialize extremes
min_x = float("inf")
min_y = float("inf")
max_x = float("-inf")
max_y = float("-inf")
# Iterate through each box
for box in boxes:
# Update minimum x and y
min_x = min(min_x, box.p0[0])
min_y = min(min_y, box.p0[1])
# Update maximum x and y
max_x = max(max_x, box.p0[0] + box.width)
max_y = max(max_y, box.p0[1] + box.height)
# Calculate bounding box width and height
bbox_width = max_x - min_x
bbox_height = max_y - min_y
return (min_x, min_y, bbox_width, bbox_height)
[docs]
def simple_layout(
fig: Figure,
gs: GridSpec | None = None,
margins: tuple[float, float, float, float] = (0.05, 0.05, 0.05, 0.05),
bbox: tuple[float, float, float, float] = (0, 1, 0, 1),
verbose: bool = False,
gtol: float = 1e-2,
bound_margin: float = 0.2,
use_all_axes: bool = True,
importance_weights: tuple[float, float, float, float] = (1, 1, 1, 1),
) -> OptimizeResult:
"""Apply simple layout to figure for given grid spec.
Parameters
----------
fig : matplotlib.figure.Figure
Figure object.
gs : matplotlib.gridspec.GridSpec, optional(default=None)
Grid spec object. If None, use the first grid spec.
margins : tuple(float, float, float, float), optional(default=(0.05, 0.05, 0.05, 0.05))
Margins in inches, (left, right, bottom, top).
bbox : tuple(float, float, float, float), optional(default=(0, 1, 0, 1))
Bounding box in figure coordinates, (left, right, bottom, top).
verbose : bool, optional(default=True)
Print verbose.
gtol : float, optional(default=1e-2)
Gradient tolerance. If the maximum change in the objective function is
less than gtol, the optimization will stop.
bound_margin : float, optional(default=0.1)
Margin for bounds generation.
use_all_axes : bool, optional(default=True)
Use all axes in the figure. If False, use only axes in the given grid spec.
IF True, use all axes in the figure.
importance_weights : tuple(float, float, float, float), optional(default=(1, 1, 1, 1))
Importance weights for each target. (left, right, bottom, top).
Returns
-------
result : scipy.optimize.OptimizeResult
Optimization result.
TODO
----
- Upgrade bounds generation algorithm.
- Readable code.
"""
if gs is None:
gs = fig.axes[0].get_gridspec()
importance_weights = np.array(importance_weights)
margins = np.array(margins) * fig.get_dpi()
def fun(x: np.ndarray) -> float:
gs.update(left=x[0], right=x[1], bottom=x[2], top=x[3])
if use_all_axes:
ax_bboxes = [ax.get_tightbbox() for ax in fig.axes]
else:
ax_bboxes = [
ax.get_tightbbox()
for ax in fig.axes
if id(ax.get_gridspec()) == id(gs)
]
all_bbox = get_bounding_box(ax_bboxes)
values = np.array(all_bbox)
# Targets.
fbox = fig.bbox
targets = np.array(
[
fbox.width * bbox[0] + margins[0],
fbox.height * bbox[2] + margins[2],
fbox.width * (bbox[1] - bbox[0]) - 2 * margins[1],
fbox.height * (bbox[3] - bbox[2]) - 2 * margins[3],
]
)
scales = np.array([fbox.width, fbox.height, fbox.width, fbox.height])
loss = np.square((values - targets) / scales * importance_weights).sum()
return loss
# Order: left, right, bottom, top.
bounds = [
(bbox[0], bbox[0] + bound_margin),
(bbox[1] - bound_margin, bbox[1]),
(bbox[2], bbox[2] + bound_margin),
(bbox[3] - bound_margin, bbox[3]),
]
result = minimize(
fun,
x0=np.array(bounds).mean(axis=1),
bounds=bounds,
# # Gradient-free optimization.
# method='Nelder-Mead',
# # Relax convergence criteria.
# options=dict(xatol=1e-3),
method="L-BFGS-B",
options={"gtol": gtol},
)
return result
[docs]
def fs(n: int | float) -> float:
"""
Return base font size + n.
Parameters
----------
n : int or float
Value to add to base font size.
Returns
-------
float
Base font size + n.
"""
return plt.rcParams["font.size"] + n
[docs]
def fw(n: int) -> int:
"""
Return base font weight + 100 * n.
Only works for integer weights and n.
Parameters
----------
n : int
Value to multiply by 100 and add to base font weight.
Returns
-------
int
Base font weight + 100 * n.
"""
return plt.rcParams["font.weight"] + 100 * n
def lw(n: int | float) -> float:
"""
Return base line width + n.
Parameters
----------
n : int or float
Value to add to base line width.
Returns
-------
float
Base line width + n.
"""
return plt.rcParams["lines.linewidth"] + n
[docs]
def mix_colors(
color1: str | tuple[float, float, float],
color2: str | tuple[float, float, float],
alpha: float = 0.5,
) -> tuple[float, float, float]:
"""
Mix two colors.
Parameters
----------
color1 : color
First color (any format accepted by matplotlib).
color2 : color
Second color (any format accepted by matplotlib).
alpha : float, optional
Weight of the first color, between 0 and 1.
Returns
-------
tuple
RGB tuple of the mixed color.
"""
color1 = mcolors.to_rgb(color1)
color2 = mcolors.to_rgb(color2)
return tuple(
alpha * c1 + (1 - alpha) * c2
for c1, c2 in zip(color1, color2, strict=False)
)
[docs]
def pseudo_alpha(
color: str | tuple[float, float, float],
alpha: float = 1.0,
background: str | tuple[float, float, float] = "white",
) -> tuple[float, float, float]:
"""
Return a color with pseudo alpha.
Parameters
----------
color : color
Color to apply pseudo-transparency to.
alpha : float, optional
Alpha value between 0 and 1.
background : color, optional
Background color to mix with.
Returns
-------
tuple
RGB tuple of the resulting color.
"""
return mix_colors(color, background, alpha=alpha)
def cm2in(cm: float) -> float:
"""
Convert centimeters to inches.
Parameters
----------
cm : float
Value in centimeters.
Returns
-------
float
Value in inches.
"""
return cm / 2.54
[docs]
def make_offset(x: float, y: float, fig: Figure) -> ScaledTranslation:
"""
Create a translation offset for figure elements.
Parameters
----------
x : float
X offset in points.
y : float
Y offset in points.
fig : matplotlib.figure.Figure
Figure to create offset for.
Returns
-------
matplotlib.transforms.ScaledTranslation
Offset transform.
"""
dx, dy = x / 72, y / 72
offset = ScaledTranslation(dx, dy, fig.dpi_scale_trans)
return offset
[docs]
def show(image_path: str, size: int = 600, unit: str = "pt") -> None:
"""
Display an SVG image with specified size.
Parameters
----------
image_path : str
Path to the SVG image.
size : int, optional
Desired width in specified units.
unit : str, optional
Unit for size ('pt', 'px', etc.).
"""
# SVG 객체 생성
svg_obj = SVG(data=image_path)
# 원하는 가로 폭 또는 세로 높이 설정
desired_width = size
# SVG 코드에서 현재 가로 폭과 세로 높이 가져오기
dom = minidom.parseString(svg_obj.data)
width = float(dom.documentElement.getAttribute("width")[: -len(unit)])
height = float(dom.documentElement.getAttribute("height")[: -len(unit)])
# 비율 계산
aspect_ratio = height / width
desired_height = int(desired_width * aspect_ratio)
# 가로 폭과 세로 높이 설정
if f'width="{width}{unit}"' in svg_obj.data:
svg_obj.data = svg_obj.data.replace(
f'width="{width}{unit}"', f'width="{desired_width}{unit}"'
)
else:
width = int(width)
svg_obj.data = svg_obj.data.replace(
f'width="{width}{unit}"', f'width="{desired_width}{unit}"'
)
if f'height="{height}{unit}"' in svg_obj.data:
svg_obj.data = svg_obj.data.replace(
f'height="{height}{unit}"', f'height="{desired_height}{unit}"'
)
else:
height = int(height)
svg_obj.data = svg_obj.data.replace(
f'height="{height}{unit}"', f'height="{desired_height}{unit}"'
)
# HTML을 사용하여 SVG 이미지 표시
svg_code = svg_obj.data
display(HTML(svg_code))
[docs]
def save_and_show(
fig: Figure,
image_path: str | None = None,
size: int = 600,
unit: str = "pt",
**kwargs,
) -> None:
"""
Save a figure and display it.
Parameters
----------
fig : matplotlib.figure.Figure
Figure to save and display.
image_path : str, optional
Path to save the image. If None, uses a temporary file.
size : int, optional
Display size.
unit : str, optional
Unit for size.
**kwargs
Additional arguments passed to savefig.
"""
if image_path is None:
with NamedTemporaryFile(suffix=".svg") as f:
f.close()
image_path = f.name
fig.savefig(image_path, bbox_inches=None, **kwargs)
plt.close(fig)
show(image_path, size=size, unit=unit)
else:
_create_parent_path_if_not_exists(image_path)
fig.savefig(image_path, bbox_inches=None, **kwargs)
plt.close(fig)
show(image_path, size=size, unit=unit)
def prompt_path(name: str) -> Path:
"""
Get the path to a prompt guide file.
Parameters
----
name : str
Name of the prompt guide ('layout-guide' or 'general-guide').
Returns
----
Path
Path to the prompt guide file.
Raises
----
ValueError
If the prompt guide is not found.
"""
path: Path = Path(__file__).parent / f"asset/prompt/{name}.md"
if not path.exists():
raise ValueError(f"Prompt guide not found: {name}")
return path
def get_prompt(name: str) -> str:
"""
Read and return the content of a prompt guide file.
Parameters
----
name : str
Name of the prompt guide ('layout-guide' or 'general-guide').
Returns
----
str
Content of the prompt guide file.
Raises
----
ValueError
If the prompt guide is not found.
"""
path = prompt_path(name)
return path.read_text(encoding="utf-8")
def list_prompts() -> list[str]:
"""
List all available prompt guide files.
Returns
----
list[str]
List of available prompt guide names.
"""
path: Path = Path(__file__).parent / "asset/prompt"
if not path.exists():
return []
return sorted([p.stem for p in path.glob("*.md")])
def copy_prompt(name: str, destination: str | Path) -> Path:
"""
Copy a prompt guide file to the specified destination.
Parameters
----
name : str
Name of the prompt guide ('layout-guide' or 'general-guide').
destination : str or Path
Destination path where the prompt file should be copied.
If a directory path is provided, the file will be copied with
its original name. If a file path is provided, the file will
be copied to that exact location.
Returns
----
Path
Path to the copied file.
Raises
----
ValueError
If the prompt guide is not found.
FileNotFoundError
If the destination directory does not exist and cannot be created.
Examples
-----
>>> import dartwork_mpl as dm
>>>
>>> # Copy to a directory (keeps original filename)
>>> copied_path = dm.copy_prompt('layout-guide', '.cursor/rules/')
>>> print(copied_path)
PosixPath('.cursor/rules/layout-guide.md')
>>>
>>> # Copy to a specific file path
>>> copied_path = dm.copy_prompt('general-guide', '.cursor/rules/my-guide.md')
>>> print(copied_path)
PosixPath('.cursor/rules/my-guide.md')
"""
source_path = prompt_path(name)
dest_path = Path(destination)
# If destination is a directory, append the source filename
if dest_path.is_dir() or (not dest_path.exists() and not dest_path.suffix):
dest_path = dest_path / f"{name}.md"
# Ensure parent directory exists
_create_parent_path_if_not_exists(dest_path)
# Copy the file
copy2(source_path, dest_path)
return dest_path