Source code for dotime.evaluation

"""Evaluation harness for DoTime benchmark suites.

This module ports the metric functions and aggregation helpers from the
Do-Over-Time-PFN evaluation code (``dotime/eval/metrics.py`` and the
``scripts/tscm_identifiability.py`` reference harness) into a single
dependency-light surface (torch + numpy only — R² is computed directly rather
than via scikit-learn so it stays in the core install).

**Public surface**

- metric functions: :func:`compute_rmse`, :func:`compute_mae`,
  :func:`compute_nmse`, :func:`compute_r2`.
- :func:`direction_accuracy` — sign-consistent accuracy, near-zero targets excluded.
- :func:`bootstrap_ci` — bootstrap mean/std/CI over per-sample values.
- :func:`evaluate` — run a baseline over a suite, aggregating pooled and
  per-structure metrics.
- :class:`Results` — holds the aggregated metrics with ``.summary()`` and
  ``.to_dict()``.
"""

from __future__ import annotations

from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np
import torch

if TYPE_CHECKING:
    from dotime.baselines import Baseline
    from dotime.benchmarks import BenchmarkSuite

__all__ = [
    "DIR_ACC_EPS",
    "Results",
    "bootstrap_ci",
    "compute_mae",
    "compute_nmse",
    "compute_r2",
    "compute_rmse",
    "direction_accuracy",
    "evaluate",
]

# Near-zero targets are ambiguous for sign-based direction accuracy and are
# excluded from that metric (reported separately).
DIR_ACC_EPS = 0.1


# --------------------------------------------------------------------------- #
# Pointwise metrics
# --------------------------------------------------------------------------- #


[docs] def compute_rmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: """Root mean squared error.""" return torch.sqrt(torch.mean((predictions - targets) ** 2)).item()
[docs] def compute_mae(predictions: torch.Tensor, targets: torch.Tensor) -> float: """Mean absolute error.""" return torch.mean(torch.abs(predictions - targets)).item()
[docs] def compute_nmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: """Normalized MSE: ``MSE / Var(targets)``. Equals 1.0 for a predict-the-mean baseline, <1.0 when better, >1.0 worse. Returns NaN when there are fewer than two targets or the variance is ~0. """ if targets.numel() < 2: return float("nan") mse = torch.mean((predictions - targets) ** 2) var = torch.var(targets, unbiased=False) if var < 1e-8: return float("nan") return (mse / var).item()
[docs] def compute_r2(predictions: torch.Tensor, targets: torch.Tensor) -> float: """Coefficient of determination, ``1 - SS_res / SS_tot``. Computed directly (no scikit-learn) so it stays in the core install. Returns NaN when the target variance is ~0. """ targets = targets.float() predictions = predictions.float() ss_res = torch.sum((targets - predictions) ** 2) ss_tot = torch.sum((targets - targets.mean()) ** 2) if ss_tot < 1e-12: return float("nan") return (1.0 - ss_res / ss_tot).item()
[docs] def direction_accuracy( preds: torch.Tensor, targets: torch.Tensor, eps: float = DIR_ACC_EPS ) -> dict[str, float | int]: """Sign-consistent direction accuracy, excluding near-zero targets. Returns a dict with ``accuracy`` (fraction of ``|target| >= eps`` samples whose predicted sign matches), ``n_valid`` and ``n_excluded``. """ if preds.numel() == 0: return {"accuracy": float("nan"), "n_valid": 0, "n_excluded": 0} mask = targets.abs() >= eps n_valid = int(mask.sum().item()) n_excluded = int(preds.numel() - n_valid) if n_valid == 0: return {"accuracy": float("nan"), "n_valid": 0, "n_excluded": n_excluded} acc = (preds[mask].sign() == targets[mask].sign()).float().mean().item() return {"accuracy": acc, "n_valid": n_valid, "n_excluded": n_excluded}
[docs] def bootstrap_ci( values: Iterable[float], n: int = 1000, alpha: float = 0.05, seed: int = 0 ) -> tuple[float, float, float, float]: """Bootstrap ``(mean, std, ci_low, ci_high)`` over per-sample values. Uses the percentile method at confidence ``1 - alpha``. Returns NaNs for an empty input; a degenerate ``(v, 0, v, v)`` for a single value. """ arr = np.asarray( [v for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))], dtype=np.float64, ) if arr.size == 0: return float("nan"), float("nan"), float("nan"), float("nan") if arr.size == 1: v = float(arr[0]) return v, 0.0, v, v rng = np.random.default_rng(seed) idx = rng.integers(0, arr.size, size=(n, arr.size)) boot_means = arr[idx].mean(axis=1) ci_low = float(np.quantile(boot_means, alpha / 2)) ci_high = float(np.quantile(boot_means, 1 - alpha / 2)) return float(arr.mean()), float(arr.std()), ci_low, ci_high
# --------------------------------------------------------------------------- # # Aggregated results container # --------------------------------------------------------------------------- #
[docs] @dataclass class Results: """Aggregated evaluation results for one baseline on one suite.""" suite: str baseline: str n_episodes: int n_queries: int pooled: dict[str, float] per_structure: dict[str, dict[str, float]] = field(default_factory=dict)
[docs] def to_dict(self) -> dict: """JSON-serializable view of the results.""" return { "suite": self.suite, "baseline": self.baseline, "n_episodes": self.n_episodes, "n_queries": self.n_queries, "pooled": self.pooled, "per_structure": self.per_structure, }
[docs] def summary(self) -> str: """Human-readable results table.""" lines = [ f"Suite: {self.suite}", f"Baseline: {self.baseline}", f"Episodes: {self.n_episodes} Queries: {self.n_queries}", "", ] cols = ["rmse", "mae", "nmse", "r2", "dir_acc"] header = f"{'group':<22}" + "".join(f"{c:>10}" for c in cols) lines.append(header) lines.append("-" * len(header)) def _row(name: str, m: dict[str, float]) -> str: cells = [] for c in cols: v = m.get(c, float("nan")) cells.append(f"{v:>10.4f}" if isinstance(v, (int, float)) else f"{v:>10}") return f"{name:<22}" + "".join(cells) lines.append(_row("pooled", self.pooled)) for struct in sorted(self.per_structure): lines.append(_row(struct, self.per_structure[struct])) return "\n".join(lines)
# --------------------------------------------------------------------------- # # Evaluation loop # --------------------------------------------------------------------------- # _DEFAULT_METRICS: dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = { "rmse": compute_rmse, "mae": compute_mae, "nmse": compute_nmse, "r2": compute_r2, } def _aggregate(preds: torch.Tensor, targets: torch.Tensor, metrics) -> dict[str, float]: out = {name: fn(preds, targets) for name, fn in metrics.items()} out["dir_acc"] = direction_accuracy(preds, targets)["accuracy"] return out
[docs] def evaluate( model: Baseline, suite: BenchmarkSuite, metrics: dict[str, Callable[[torch.Tensor, torch.Tensor], float]] | None = None, ) -> Results: """Evaluate a baseline over every episode of a suite. Calls ``model.predict(episode)`` for each episode, pools predictions and ground-truth targets across all queries, and reports pooled and per-structure metrics. """ metrics = metrics or _DEFAULT_METRICS all_preds: list[torch.Tensor] = [] all_targets: list[torch.Tensor] = [] by_struct: dict[str, list[tuple[torch.Tensor, torch.Tensor]]] = {} n_episodes = 0 for ep in suite: pred = torch.as_tensor(model.predict(ep), dtype=torch.float32).reshape(-1) target = torch.as_tensor(ep.y_true, dtype=torch.float32).reshape(-1) if pred.numel() != target.numel(): raise ValueError( f"baseline {getattr(model, 'name', model)!r} returned {pred.numel()} " f"predictions for {target.numel()} queries in episode {ep.scm_id}" ) all_preds.append(pred) all_targets.append(target) if ep.structure is not None: by_struct.setdefault(ep.structure, []).append((pred, target)) n_episodes += 1 if not all_preds: raise ValueError(f"suite {suite.meta.name!r} contains no episodes") preds = torch.cat(all_preds) targets = torch.cat(all_targets) per_structure = { struct: _aggregate( torch.cat([p for p, _ in pairs]), torch.cat([t for _, t in pairs]), metrics, ) for struct, pairs in by_struct.items() } return Results( suite=suite.meta.name, baseline=getattr(model, "name", type(model).__name__), n_episodes=n_episodes, n_queries=int(preds.numel()), pooled=_aggregate(preds, targets, metrics), per_structure=per_structure, )