Source code for dbs_annotator.utils.report_chart_utils

"""
Shared chart utilities for report generation.

Provides reusable functions for building matplotlib-based scale timeline
charts used by both the session and longitudinal exporters.  Centralising
the chart logic avoids code duplication and guarantees consistent styling
across report types.
"""

from __future__ import annotations

import logging
from io import BytesIO
from typing import Any

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from docx.document import Document as DocumentType
from docx.shared import Inches
from matplotlib import cm
from matplotlib.ticker import MaxNLocator

matplotlib.use("Agg")

logger = logging.getLogger(__name__)

# ── Colour constants (normalised 0-1 RGBA for matplotlib) ─────────
BEST_GREEN = (150 / 255, 210 / 255, 160 / 255, 160 / 255)
SECOND_GREEN = (200 / 255, 235 / 255, 205 / 255, 130 / 255)


# ── Scale-target helpers ────────────────────────────────────────────


[docs] def parse_scale_targets( prefs: list[tuple[str, str, str, str, str]] | None, ) -> dict[str, dict[str, Any]]: """Convert user preference tuples into a look-up dict. Args: prefs: [(name, min, max, mode, custom_value), ...] Returns: {scale_name: {"type": mode, "value": numeric_target}} """ targets: dict[str, dict[str, Any]] = {} for pref in prefs or []: if len(pref) < 5: continue name, smin, smax, mode, custom_val = pref try: lower = float(smin) except (TypeError, ValueError): lower = 0.0 try: upper = float(smax) except (TypeError, ValueError): upper = 0.0 if lower > upper: lower, upper = upper, lower mode_norm = str(mode).strip().lower() if mode_norm in ("low", "min"): targets[name] = { "type": "min", "value": lower, "lower": lower, "upper": upper, } elif mode_norm in ("high", "max"): targets[name] = { "type": "max", "value": upper, "lower": lower, "upper": upper, } elif mode_norm == "custom": try: custom_num = float(custom_val) except (ValueError, TypeError): custom_num = 0.0 targets[name] = { "type": "custom", "value": custom_num, "lower": lower, "upper": upper, } return targets
[docs] def compute_aggregate_index( scale_data: dict[str, dict[int, float]], all_points: list[int], scale_targets: dict[str, dict[str, Any]], ) -> dict[int, float]: """Compute a weighted aggregate-index value per x-point. The index represents how close the scales are to their respective targets (1.0 = perfect, 0.0 = worst observed). Args: scale_data: {scale_name: {x_point: value}} all_points: sorted list of x-axis keys scale_targets: output of :func:`parse_scale_targets` Returns: {x_point: index_value} """ def _clip01(value: float) -> float: return max(0.0, min(1.0, value)) index_vals: dict[int, float] = {} for pt in all_points: weighted_scores: list[float] = [] weights: list[float] = [] for scale_name, pts in scale_data.items(): if pt not in pts: continue original_value = pts[pt] if scale_name in scale_targets: info = scale_targets[scale_name] ttype = info["type"] tvalue = info["value"] lower = float(info.get("lower", 0.0)) upper = float(info.get("upper", 0.0)) denom = upper - lower if denom <= 0: score = 0.5 else: z = _clip01((original_value - lower) / denom) if ttype == "min": score = 1.0 - z elif ttype == "max": score = z elif ttype == "custom": max_distance = max(abs(tvalue - lower), abs(upper - tvalue)) if max_distance > 0: score = 1.0 - _clip01( abs(original_value - tvalue) / max_distance ) else: score = 0.5 else: score = 0.5 weighted_scores.append(score) weights.append(1.0) else: weighted_scores.append(0.5) weights.append(0.5) if weighted_scores and weights: total_w = sum(weights) if total_w > 0: index_vals[pt] = ( sum(w * s for w, s in zip(weights, weighted_scores, strict=False)) / total_w ) else: index_vals[pt] = 0.5 return index_vals
[docs] def get_declared_scale_range( scale_targets: dict[str, dict[str, Any]], ) -> tuple[float, float] | None: """Return overall (min, max) from declared per-scale ranges.""" lowers: list[float] = [] uppers: list[float] = [] for info in scale_targets.values(): try: lower = float(info.get("lower", 0.0)) upper = float(info.get("upper", 0.0)) except (TypeError, ValueError): continue if upper < lower: lower, upper = upper, lower lowers.append(lower) uppers.append(upper) if not lowers or not uppers: return None overall_min = min(lowers) overall_max = max(uppers) if overall_max <= overall_min: return None return overall_min, overall_max
[docs] def find_best_and_second( index_vals: dict[int, float], ) -> tuple[int | None, int | None]: """Return the x-points with the best and second-best aggregate-index scores. Returns: (best_x, second_best_x) – either may be *None*. """ if not index_vals: return None, None ranked = sorted(index_vals, key=lambda k: index_vals[k], reverse=True) best = ranked[0] second = ranked[1] if len(ranked) > 1 else None return best, second
# ── Chart rendering ──────────────────────────────────────────────── # Matplotlib line styles matching the old set _LINE_STYLES = ["-", "--", ":", "-.", (0, (3, 1, 1, 1, 1, 1))]
[docs] def build_scales_chart( scale_data: dict[str, dict[int, float]], scale_prefs: list[tuple[str, str, str, str, str]] | None, *, title: str = "", x_label: str = "X", y_label: str = "Scale Value", x_ticks: list[tuple[int, str]] | None = None, width: int = 1100, height: int = 520, show_general_index: bool = True, rotate_x_ticks: bool = False, ) -> bytes | None: """Build a scale-trend chart and return it as PNG bytes. The chart includes: * Rainbow-coloured individual scale lines * A thick black Aggregate Index line (when *show_general_index* and ≥ 2 scales) * Green vertical bands for best (dark) and second-best (light) * A compact legend above the plot Args: scale_data: {scale_name: {x_point: value}} scale_prefs: user optimisation preferences (may be *None*) title: chart title text (empty string for no title) x_label: bottom-axis label y_label: left-axis label x_ticks: optional custom bottom-axis ticks width: image width in px height: image height in px show_general_index: whether to draw the Aggregate Index line rotate_x_ticks: whether to rotate x-axis tick labels by 90 degrees Returns: PNG image bytes, or *None* on failure. """ try: n_scales = len(scale_data) if n_scales == 0: return None dpi = 100 fig_w = width / dpi # Increase height when rotating ticks actual_height = height + 150 if rotate_x_ticks else height fig_h = actual_height / dpi cmap_obj = cm.get_cmap("Dark2") colors = [cmap_obj(i % cmap_obj.N)[:3] for i in range(n_scales)] has_index = show_general_index and n_scales >= 2 fig, ax1 = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi) fig.patch.set_facecolor("white") ax1.set_facecolor("white") # ── Collect all x-points for NaN gap handling ───────────── all_x = sorted({x for pts in scale_data.values() for x in pts}) scale_targets = parse_scale_targets(scale_prefs) # ── Plot individual scales ────────────────────────────────── for idx, (sname, pts) in enumerate(scale_data.items()): c = colors[idx] ls = _LINE_STYLES[idx % len(_LINE_STYLES)] ys = [pts.get(x, float("nan")) for x in all_x] # Build segment arrays so NaN creates gaps in lines xs_arr = np.array(all_x, dtype=float) ys_arr = np.array(ys, dtype=float) ax1.plot( xs_arr, ys_arr, color=c, linewidth=2, linestyle=ls, marker="o", markersize=6, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1, label=sname, ) # ── Aggregate Index on right y-axis ──────────────────────── best_x: int | None = None second_x: int | None = None ax2 = None if has_index: all_points = sorted({x for pts in scale_data.values() for x in pts}) index_vals = compute_aggregate_index(scale_data, all_points, scale_targets) if index_vals: ax2 = ax1.twinx() ix = sorted(index_vals.keys()) iy = [index_vals[x] for x in ix] ax2.plot( ix, iy, color="black", linewidth=3, marker="D", markersize=7, markerfacecolor="black", markeredgecolor="black", label="Aggregate Index", zorder=5, ) ax2.set_ylim(0, 1) ax2.yaxis.set_label_position("right") ax2.yaxis.tick_right() ax2.set_ylabel( "Aggregate Index Score (best = 1.0)", fontsize=12, fontfamily="Arial", color="black", fontweight="bold", ) ax2.tick_params(axis="y", labelsize=10) for tick_label in ax2.get_yticklabels(): tick_label.set_fontweight("bold") tick_label.set_color("black") ax2.spines["right"].set_linewidth(2.0) ax2.spines["right"].set_color("black") best_x, second_x = find_best_and_second(index_vals) # ── Best / second-best green vertical bands ───────────────── if best_x is not None: ax1.axvspan(best_x - 0.35, best_x + 0.35, color=BEST_GREEN, zorder=0) if second_x is not None and second_x != best_x: ax1.axvspan(second_x - 0.35, second_x + 0.35, color=SECOND_GREEN, zorder=0) # ── Axes labels and styling ───────────────────────────────── if title: ax1.set_title(title, fontsize=16, fontfamily="Arial", color="black") ax1.set_ylabel(y_label, fontsize=12, fontfamily="Arial", color="black") ax1.set_xlabel(x_label, fontsize=12, fontfamily="Arial", color="black") declared_range = get_declared_scale_range(scale_targets) if declared_range is not None: ax1.set_ylim(*declared_range) ax1.tick_params(axis="both", labelsize=10) ax1.grid(True, alpha=0.3) # ── X-ticks ───────────────────────────────────────────────── if x_ticks is not None: tick_positions = [t[0] for t in x_ticks] tick_labels = [t[1] for t in x_ticks] ax1.set_xticks(tick_positions) if rotate_x_ticks: ax1.set_xticklabels( tick_labels, rotation=90, ha="center", fontsize=10, ) else: ax1.set_xticklabels(tick_labels, fontsize=10) # Ensure x-range includes all tick positions with padding if x_ticks: x_min = min(t[0] for t in x_ticks) - 0.5 x_max = max(t[0] for t in x_ticks) + 0.5 ax1.set_xlim(x_min, x_max) else: # Block IDs are discrete integers; avoid fractional tick labels. ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # ── Legend ────────────────────────────────────────────────── handles1, labels1 = ax1.get_legend_handles_labels() handles2, labels2 = ([], []) if ax2 is not None: handles2, labels2 = ax2.get_legend_handles_labels() all_handles = handles1 + handles2 all_labels = labels1 + labels2 if all_handles: n_cols = max(1, len(all_handles)) fig.legend( all_handles, all_labels, loc="upper center", ncol=n_cols, fontsize=9, frameon=True, facecolor="white", edgecolor=(0.6, 0.6, 0.6, 0.4), framealpha=0.9, bbox_to_anchor=(0.5, 1.0), ) fig.tight_layout(rect=(0, 0, 1, 0.93)) # ── Export to PNG bytes ───────────────────────────────────── buf = BytesIO() fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight", facecolor="white") plt.close(fig) buf.seek(0) png_bytes = buf.read() buf.close() return png_bytes except Exception as exc: logger.exception("build_scales_chart failed: %s", exc) return None
[docs] def add_chart_to_doc( doc: DocumentType, png_bytes: bytes | None, *, heading: str | None = None, heading_level: int = 2, width_inches: float | None = None, fallback_message: str = "Chart generation error.", ) -> None: """Insert a PNG chart into a Word document. Args: doc: python-docx Document instance png_bytes: raw PNG bytes (or *None* on failure) heading: optional heading text above the chart heading_level: heading level (default 2) width_inches: image width in the document (None for full page width) fallback_message: text shown if *png_bytes* is None """ if heading: doc.add_heading(heading, level=heading_level) # Calculate page width if not specified if width_inches is None: section = doc.sections[0] page_w = ( int(section.page_width or 0) - int(section.left_margin or 0) - int(section.right_margin or 0) ) / 914400 # Convert twips to inches width_inches = max(4.0, page_w) # Minimum 4 inches, otherwise full page if png_bytes is None: doc.add_paragraph(fallback_message) return img_buf = BytesIO(png_bytes) doc.add_picture(img_buf, width=Inches(width_inches)) doc.add_paragraph() img_buf.close()