"""Colormap visualization functions.
Functions for classifying and plotting colormaps with category
badges and row-major layout.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib as mpl
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
if TYPE_CHECKING:
from matplotlib.colors import Colormap
# ---------------------------------------------------------------------------
# Category badge colors (background, text)
# ---------------------------------------------------------------------------
_CATEGORY_STYLE: dict[str, tuple[str, str]] = {
"Single-Hue": ("#e3f2fd", "#1565c0"),
"Multi-Hue": ("#e8f5e9", "#2e7d32"),
"Diverging": ("#fff3e0", "#e65100"),
"Cyclical": ("#f3e5f5", "#7b1fa2"),
"Categorical": ("#fce4ec", "#c62828"),
}
# Override classification for standard dartwork customized maps
# 30+ maps mapped into 5 core types
_CLASSIFICATION_OVERRIDES: dict[str, str] = {
# Single-Hue
"dc.obsidian": "Single-Hue",
"dc.sapphire": "Single-Hue",
"dc.emerald": "Single-Hue",
"dc.ruby": "Single-Hue",
"dc.amethyst": "Single-Hue",
"dc.topaz": "Single-Hue",
"dc.graphite": "Single-Hue",
"dc.coral": "Single-Hue",
# Multi-Hue
"dc.aurora": "Multi-Hue",
"dc.sunset_glow": "Multi-Hue",
"dc.plasma_arc": "Multi-Hue",
"dc.spring_bloom": "Multi-Hue",
"dc.deep_sea": "Multi-Hue",
"dc.autumn_leaf": "Multi-Hue",
"dc.nebula_dust": "Multi-Hue",
"dc.tropical_fruit": "Multi-Hue",
# Diverging
"dc.ice_fire": "Diverging",
"dc.earth_sky": "Diverging",
"dc.teal_rose": "Diverging",
"dc.purple_lime": "Diverging",
"dc.navy_gold": "Diverging",
"dc.forest_brick": "Diverging",
"dc.magenta_cyan": "Diverging",
"dc.slate_orange": "Diverging",
"dc.cool_warm": "Diverging",
"dc.arctic_heat": "Diverging",
"dc.frost_flame": "Diverging",
"dc.water_fire": "Diverging",
"dc.spring_autumn": "Diverging",
"dc.summer_winter": "Diverging",
"dc.electric_surge": "Diverging",
"dc.neon_pulse": "Diverging",
# Cyclical
"dc.twilight_oklch": "Cyclical",
"dc.phase_wheel": "Cyclical",
"dc.color_wheel": "Cyclical",
"dc.seasons": "Cyclical",
"dc.day_night": "Cyclical",
"dc.rainbow_cycle": "Cyclical",
"dc.neon_wheel": "Cyclical",
"dc.electric_cycle": "Cyclical",
# Discrete
"dc.vivid": "Categorical",
"dc.lucid": "Categorical",
"dc.chalk": "Categorical",
"dc.vibrant": "Categorical",
"dc.pastel": "Categorical",
"dc.candy": "Categorical",
"dc.pop": "Categorical",
"dc.macaron": "Categorical",
}
[docs]
def classify_colormap(cmap: Colormap) -> str:
"""Classify a colormap into one of the following categories.
Categories
----------
- Categorical
- Single-Hue
- Multi-Hue
- Diverging
- Cyclical
Parameters
----------
cmap : matplotlib.colors.Colormap
Colormap to classify.
Returns
-------
str
Category of the colormap.
"""
if hasattr(cmap, "name") and cmap.name in _CLASSIFICATION_OVERRIDES:
return _CLASSIFICATION_OVERRIDES[cmap.name]
n_samples = 256
samples = cmap(np.linspace(0, 1, n_samples))[:, :3]
hsv_samples = np.array([mcolors.rgb_to_hsv(rgb) for rgb in samples])
hues = hsv_samples[:, 0]
saturations = hsv_samples[:, 1]
values = hsv_samples[:, 2]
hue_diffs = np.abs(np.diff(hues))
hue_diffs = np.minimum(hue_diffs, 1 - hue_diffs)
# Known categorical colormaps
categorical_cmaps = [
"Accent",
"Dark2",
"Paired",
"Pastel1",
"Pastel2",
"Set1",
"Set2",
"Set3",
"tab10",
"tab20",
"tab20b",
"tab20c",
"Spectral",
"prism",
"hsv",
"gist_rainbow",
"rainbow",
"nipy_spectral",
]
if hasattr(cmap, "name") and cmap.name in categorical_cmaps:
return "Categorical"
# Cyclical check
start_end_diff = np.sqrt(np.sum((samples[0] - samples[-1]) ** 2))
if start_end_diff < 0.01:
mid_idx = n_samples // 2
mid_diff = np.sqrt(np.sum((samples[0] - samples[mid_idx]) ** 2))
if mid_diff > 0.3:
return "Cyclical"
# Categorical by plateau detection
color_diffs = np.sqrt(np.sum(np.diff(samples, axis=0) ** 2, axis=1))
plateau_mask = color_diffs < 0.001
plateau_indices = np.where(plateau_mask)[0]
if len(plateau_indices) > 0:
plateau_runs = np.split(
plateau_indices, np.where(np.diff(plateau_indices) != 1)[0] + 1
)
significant_plateaus = [run for run in plateau_runs if len(run) >= 3]
if len(significant_plateaus) >= 3:
plateau_positions = [np.mean(run) for run in significant_plateaus]
position_range = max(plateau_positions) - min(plateau_positions)
if position_range > n_samples * 0.3:
return "Categorical"
# Categorical by large jumps
large_color_jumps = np.where(color_diffs > 0.1)[0]
if len(large_color_jumps) > 3 and len(large_color_jumps) < n_samples // 8:
jump_diffs = np.diff(large_color_jumps)
if np.std(jump_diffs) < np.mean(jump_diffs) * 0.8:
return "Categorical"
# Diverging check
mid_idx = n_samples // 2
mid_value = values[mid_idx]
start_value = values[0]
end_value = values[-1]
if (mid_value > start_value + 0.2 and mid_value > end_value + 0.2) or (
mid_value < start_value - 0.2 and mid_value < end_value - 0.2
):
start_hue = hues[0]
end_hue = hues[-1]
hue_diff = min(abs(end_hue - start_hue), 1 - abs(end_hue - start_hue))
if hue_diff > 0.1:
return "Diverging"
# Sequential single vs multi-hue
high_sat_indices = np.where(saturations > 0.3)[0]
if len(high_sat_indices) > n_samples // 4:
high_sat_hues = hues[high_sat_indices]
if len(high_sat_hues) > 1:
hue_min = np.min(high_sat_hues)
hue_max = np.max(high_sat_hues)
hue_range = hue_max - hue_min
if hue_range > 0.5:
hue_range = 1 - hue_range
if hue_range < 0.01:
return "Single-Hue"
else:
return "Multi-Hue"
hue_min = np.min(hues)
hue_max = np.max(hues)
hue_range = hue_max - hue_min
if hue_range > 0.5:
hue_range = 1 - hue_range
is_monotonic = np.all(
np.diff(values[: n_samples // 2]) * np.diff(values[n_samples // 2 :])
>= 0
)
if hue_range < 0.01 and is_monotonic:
return "Single-Hue"
elif hue_range > 0.01:
return "Multi-Hue"
else:
if np.std(hue_diffs) < 0.02:
return "Single-Hue"
else:
return "Multi-Hue"
[docs]
def plot_colormaps(
cmap_list: list[str] | list[Colormap] | None = None,
ncols: int = 3,
group_by_type: bool = True,
) -> list[Figure]:
"""Plot colormaps grouped by type.
Returns a list of figures, one per category. Does **not** call
``plt.show()`` — the caller decides when to display.
Parameters
----------
cmap_list : list, optional
List of colormap names or objects. Defaults to all registered
colormaps (excluding ``_r`` reversed variants).
ncols : int, optional
Number of columns, default 3.
group_by_type : bool, optional
If True, group colormaps by their classified type and return
one figure per category. Otherwise return a single figure.
Returns
-------
list of matplotlib.figure.Figure
One figure per category (or a single-element list when
*group_by_type* is False).
"""
from ..cmap import ensure_loaded as ensure_cmaps_loaded
ensure_cmaps_loaded()
if cmap_list is None:
cmap_list = list(mpl.colormaps.keys())
cmap_list = [c for c in cmap_list if not c.endswith("_r")]
cmap_list = [
mpl.colormaps.get_cmap(c) if isinstance(c, str) else c
for c in cmap_list
]
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
if not group_by_type:
return [_plot_flat(cmap_list, gradient, ncols)]
# ----- Group by category -----
category_order = [
"Single-Hue",
"Multi-Hue",
"Diverging",
"Cyclical",
"Categorical",
]
categories: dict[str, list] = {cat: [] for cat in category_order}
for cmap in cmap_list:
category = classify_colormap(cmap)
categories[category].append(cmap)
categories = {k: v for k, v in categories.items() if v}
figures: list[Figure] = []
for category in category_order:
if category not in categories:
continue
cmaps = categories[category]
cmaps.sort(key=lambda c: c.name.lower())
fig = _plot_category(cmaps, category, gradient, ncols)
figures.append(fig)
return figures
# ---------------------------------------------------------------------------
# Internal drawing helpers
# ---------------------------------------------------------------------------
def _plot_category(
cmaps: list, category: str, gradient: np.ndarray, ncols: int
) -> Figure:
"""Draw a single category figure with badge header."""
nrows = (len(cmaps) + ncols - 1) // ncols
figw = 6.4 * ncols / 1.5
figh = 0.35 + 0.15 + (nrows + 1 + (nrows) * 0.1) * 0.44
fig = plt.figure(figsize=(figw, figh))
gs = plt.GridSpec(
nrows + 1, ncols, figure=fig, height_ratios=[0.35] + [1] * nrows
)
# --- Category title with badge ---
title_ax = fig.add_subplot(gs[0, :])
bg_color, text_color = _CATEGORY_STYLE.get(category, ("#f5f5f5", "#333333"))
title_ax.set_facecolor(bg_color)
count_str = f" ({len(cmaps)})"
title_ax.text(
0.5,
0.5,
category,
fontsize=14,
fontweight="bold",
color=text_color,
ha="center",
va="center",
transform=title_ax.transAxes,
)
title_ax.text(
0.5 + len(category) * 0.012,
0.5,
count_str,
fontsize=10,
color=text_color,
alpha=0.7,
ha="left",
va="center",
transform=title_ax.transAxes,
)
title_ax.set_axis_off()
# --- Colormap strips (row-major order) ---
for i, cmap in enumerate(cmaps):
row = i // ncols
col = i % ncols
ax = fig.add_subplot(gs[row + 1, col])
ax.imshow(gradient, aspect="auto", cmap=cmap)
ax.text(
-0.01,
0.5,
cmap.name,
va="center",
ha="right",
fontsize=10,
color="#333333",
transform=ax.transAxes,
)
ax.set_axis_off()
# Hide empty cells
total_subplots = (nrows + 1) * ncols
used = 1 + len(cmaps) # title + cmap axes
for i in range(used, total_subplots):
r = i // ncols
c = i % ncols
if r <= nrows and c < ncols:
ax = fig.add_subplot(gs[r, c])
ax.set_visible(False)
fig.subplots_adjust(
left=0.15 / ncols,
right=0.99,
top=1 - 0.2 / figh,
bottom=0.1 / figh,
hspace=0.15,
)
return fig
def _plot_flat(cmap_list: list, gradient: np.ndarray, ncols: int) -> Figure:
"""Draw all colormaps in a single figure without grouping."""
cmap_list.sort(key=lambda c: c.name.lower())
nrows = (len(cmap_list) + ncols - 1) // ncols
figw = 6.4 * ncols / 1.5
figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.44
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(figw, figh))
fig.subplots_adjust(
top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2 / ncols, right=0.99
)
if nrows == 1 and ncols == 1:
axs = np.array([axs])
axs_flat = axs.flatten()
# Row-major order
for i, cmap in enumerate(cmap_list):
row = i // ncols
col = i % ncols
ax_idx = row * ncols + col
if ax_idx < len(axs_flat):
ax = axs_flat[ax_idx]
ax.imshow(gradient, aspect="auto", cmap=cmap)
ax.text(
-0.01,
0.5,
cmap.name,
va="center",
ha="right",
fontsize=10,
color="#333333",
transform=ax.transAxes,
)
for ax in axs_flat:
ax.set_axis_off()
for i in range(len(cmap_list), len(axs_flat)):
axs_flat[i].set_visible(False)
fig.subplots_adjust(
left=0.15 / ncols,
right=0.99,
top=1 - 0.2 / figh,
bottom=0.1 / figh,
hspace=0.15,
)
return fig