"""Reference baselines for DoTime benchmark suites.
A small registry maps baseline *names* to constructors, so the CLI and the
evaluation harness can request a baseline by string (mirroring the
``BASELINE_STRING_TO_CLASS`` table in the original ``tscm_identifiability.py``).
**Public surface**
- :class:`Baseline` — the predict interface every baseline implements.
- :func:`available` — list registered baseline names.
- :func:`get` — instantiate a baseline by name.
- :func:`register` — decorator to add a baseline to the registry.
Implemented: the trivial baselines (``Zero``, ``Mean``/TrajMean, ``AR1``,
``VAR-OLS``), the classical structural baselines (``BackDoorOLS``, ``IV2SLS``),
``Oracle`` (stored ground truth), and ``DoOverTimePFN`` (checkpoint-backed, the
``[models]`` extra). ``PCMCI+`` / ``BayesianITS`` / ``Chronos`` require the
``[baselines]`` extra and raise an actionable error until that dependency and
their wiring are present.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable
import numpy as np
import torch
if TYPE_CHECKING:
from dotime.benchmarks import Episode
__all__ = ["Baseline", "available", "get", "register"]
# --------------------------------------------------------------------------- #
# Interface
# --------------------------------------------------------------------------- #
[docs]
@runtime_checkable
class Baseline(Protocol):
"""Predict interventional outcomes for an episode's queries.
Implementations return a 1-D tensor aligned with ``episode.query_target`` /
``episode.query_time`` — one predicted value per query.
"""
name: str
[docs]
def predict(self, episode: Episode) -> torch.Tensor: ...
# --------------------------------------------------------------------------- #
# Registry
# --------------------------------------------------------------------------- #
_REGISTRY: dict[str, Callable[..., Baseline]] = {}
[docs]
def register(name: str) -> Callable[[Callable[..., Baseline]], Callable[..., Baseline]]:
"""Class/factory decorator: register a baseline constructor under ``name``."""
def _decorator(ctor: Callable[..., Baseline]) -> Callable[..., Baseline]:
if name in _REGISTRY:
raise ValueError(f"baseline {name!r} is already registered")
_REGISTRY[name] = ctor
return ctor
return _decorator
[docs]
def available() -> list[str]:
"""Return the names of all registered baselines."""
return sorted(_REGISTRY)
[docs]
def get(name: str, **kwargs: object) -> Baseline:
"""Instantiate a registered baseline by name.
Extra keyword arguments are forwarded to the baseline constructor.
"""
if name not in _REGISTRY:
raise KeyError(f"unknown baseline {name!r}; available: {available()}")
return _REGISTRY[name](**kwargs)
# --------------------------------------------------------------------------- #
# Trivial baselines (fully implemented)
# --------------------------------------------------------------------------- #
@register("Zero")
class ZeroBaseline:
"""Predicts zero for every query. Sanity-check lower bound."""
name = "Zero"
def predict(self, episode: Episode) -> torch.Tensor:
return torch.zeros(episode.query_target.numel())
def _pre_onset_index(episode: Episode) -> int:
"""First post-intervention step (onset); falls back to the full length."""
times = episode.intervention.times
return min(times) if times else episode.x_obs.shape[0]
@register("Mean")
class MeanBaseline:
"""Predicts the pre-intervention mean of the queried variable (a.k.a. TrajMean)."""
name = "Mean"
def predict(self, episode: Episode) -> torch.Tensor:
onset = _pre_onset_index(episode)
preds = []
for q in range(episode.query_target.numel()):
var = int(episode.query_target[q])
pre = episode.x_obs[:onset, var]
preds.append(pre.mean() if pre.numel() else episode.x_obs[:, var].mean())
return torch.stack(preds)
@register("AR1")
class AR1Baseline:
"""Predicts the last pre-intervention value of the queried variable."""
name = "AR1"
def predict(self, episode: Episode) -> torch.Tensor:
onset = _pre_onset_index(episode)
preds = []
for q in range(episode.query_target.numel()):
var = int(episode.query_target[q])
last = max(0, min(onset, episode.x_obs.shape[0]) - 1)
preds.append(episode.x_obs[last, var])
return torch.stack(preds)
@register("VAR-OLS")
class VAROLSBaseline:
"""Linear vector-autoregression fit by OLS on the observational trajectory.
A genuinely causal-naive baseline: it forecasts the queried variable from
its own and others' lagged values, ignoring the intervention semantics.
"""
name = "VAR-OLS"
def __init__(self, lag: int = 3):
self.lag = lag
def predict(self, episode: Episode) -> torch.Tensor:
x = episode.x_obs.detach().cpu().numpy() # (T, N)
coef, mean = self._fit(x)
preds = []
for q in range(episode.query_target.numel()):
var = int(episode.query_target[q])
# One-step-ahead prediction from the tail of the trajectory.
hist = x[-self.lag :].reshape(-1)
yhat = mean[var] + coef[var] @ (hist - np.tile(mean, self.lag))
preds.append(float(yhat))
return torch.tensor(preds, dtype=torch.float32)
def _fit(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
t, _n = x.shape
mean = x.mean(axis=0)
xc = x - mean
rows, targets = [], []
for s in range(self.lag, t):
rows.append(xc[s - self.lag : s].reshape(-1))
targets.append(xc[s])
a = np.asarray(rows) # (T-lag, lag*N)
b = np.asarray(targets) # (T-lag, N)
# Ridge-stabilised least squares: coef has shape (N, lag*N).
gram = a.T @ a + 1e-3 * np.eye(a.shape[1])
coef = np.linalg.solve(gram, a.T @ b).T
return coef, mean
# --------------------------------------------------------------------------- #
# Classical structural baselines (correct adjustment per graph)
# --------------------------------------------------------------------------- #
def _ols_fit(design: np.ndarray, target: np.ndarray) -> np.ndarray:
"""Ridge-stabilised OLS coefficients for ``target ~ [1, design]``."""
x = np.column_stack([np.ones(len(design)), design])
gram = x.T @ x + 1e-6 * np.eye(x.shape[1])
return np.linalg.solve(gram, x.T @ target)
@register("BackDoorOLS")
class BackDoorOLSBaseline:
"""Linear back-door adjustment: E[Y_t | do(A=v)] = E_X[ E[Y_t | A=v, X, Y_{t-1}] ].
Fits an OLS outcome model ``Y_t ~ A_t + X_t + Y_{t-1}`` on the pre-intervention
observational data (X = the adjustment set, i.e. all variables other than the
treatment A and outcome Y), then plugs the intervention value for A and
averages over the observed confounder distribution. Applicable to the
back-door family; on other structures it falls back to the pre-intervention
outcome mean.
"""
name = "BackDoorOLS"
_BACK_DOOR: ClassVar[set[str]] = {"back_door", "observed_confounder", "confounder_mediator"}
def predict(self, episode: Episode) -> torch.Tensor:
x = episode.x_obs.detach().cpu().numpy()
t_len, n = x.shape
a = episode.intervention.targets[0] if episode.intervention.targets else 0
onset = min(episode.intervention.times) if episode.intervention.times else t_len
preds = []
for q in range(episode.query_target.numel()):
y = int(episode.query_target[q])
adj = [v for v in range(n) if v not in (a, y)]
fit_end = max(2, min(onset, t_len))
if episode.structure not in self._BACK_DOOR or fit_end < 4:
preds.append(float(x[:fit_end, y].mean()))
continue
# Design over t in [1, fit_end): [A_t, X_t..., Y_{t-1}] -> Y_t
a_t = x[1:fit_end, a]
x_t = x[1:fit_end, adj] if adj else np.empty((fit_end - 1, 0))
y_prev = x[0 : fit_end - 1, y]
design = np.column_stack([a_t, x_t, y_prev])
coef = _ols_fit(design, x[1:fit_end, y])
# Predict at do(A = intervention value), averaging over observed X rows.
a_val = (
float(episode.intervention.values)
if isinstance(episode.intervention.values, (int, float))
else float(a_t.mean())
)
x_rows = x[1:fit_end, adj] if adj else np.empty((fit_end - 1, 0))
yhat = (
coef[0]
+ coef[1] * a_val
+ (x_rows @ coef[2 : 2 + len(adj)] if adj else 0.0)
+ coef[2 + len(adj)] * y_prev
)
preds.append(float(np.mean(yhat)))
return torch.tensor(preds, dtype=torch.float32)
@register("IV2SLS")
class IV2SLSBaseline:
"""Two-stage least squares for the instrumental-variable structure.
Stage 1 regresses the treatment on the instrument(s) ``Z``; stage 2 regresses
the outcome on the fitted treatment. The intervention effect is the stage-2
treatment coefficient: ``E[Y | do(A=v)] = beta_0 + beta_A * v``. On non-IV
structures it falls back to the pre-intervention outcome mean.
"""
name = "IV2SLS"
def predict(self, episode: Episode) -> torch.Tensor:
x = episode.x_obs.detach().cpu().numpy()
t_len, n = x.shape
a = episode.intervention.targets[0] if episode.intervention.targets else 0
onset = min(episode.intervention.times) if episode.intervention.times else t_len
fit_end = max(2, min(onset, t_len))
preds = []
for q in range(episode.query_target.numel()):
y = int(episode.query_target[q])
instruments = [v for v in range(n) if v not in (a, y)]
if episode.structure != "instrumental_variable" or fit_end < 4 or not instruments:
preds.append(float(x[:fit_end, y].mean()))
continue
z = x[:fit_end, instruments]
a_obs = x[:fit_end, a]
y_obs = x[:fit_end, y]
# Stage 1: A ~ Z ; Stage 2: Y ~ Ahat -> slope is the IV effect estimate.
s1 = _ols_fit(z, a_obs)
a_hat = s1[0] + z @ s1[1:]
# Weak-instrument guard: 2SLS is unreliable when Z explains little of A.
denom = float(np.var(a_obs))
stage1_r2 = float(np.var(a_hat)) / denom if denom > 1e-8 else 0.0
if stage1_r2 < 0.1:
preds.append(float(y_obs.mean()))
continue
s2 = _ols_fit(a_hat.reshape(-1, 1), y_obs)
beta_a = s2[1]
a_val = (
float(episode.intervention.values)
if isinstance(episode.intervention.values, (int, float))
else float(a_obs.mean())
)
# Centered prediction: baseline outcome + effect of moving A from its
# observed mean to the intervention value (robust to extrapolation).
preds.append(float(y_obs.mean() + beta_a * (a_val - a_obs.mean())))
return torch.tensor(preds, dtype=torch.float32)
# --------------------------------------------------------------------------- #
# Model-backed baselines (templates — wire to real implementations)
# --------------------------------------------------------------------------- #
@register("Oracle")
class OracleBaseline:
"""Ground-truth SCM rollout. Upper bound on synthetic suites only.
TODO(consolidate): the generating SCM is available at suite-build time;
persist the true counterfactual target into Episode.metadata (or compute it
from a stored SCM handle) and return it here. On suites without a stored
oracle this should raise a clear error rather than guess.
"""
name = "Oracle"
def predict(self, episode: Episode) -> torch.Tensor:
if "y_oracle" in episode.metadata:
return torch.as_tensor(episode.metadata["y_oracle"], dtype=torch.float32)
# If y_true carries the exact counterfactual for synthetic suites, use it.
if episode.y_true is not None and episode.y_true.numel():
return episode.y_true.float()
raise RuntimeError("Oracle baseline requires a stored ground-truth target")
@register("PCMCI+")
class PCMCIBaseline:
"""PCMCI+ causal discovery (tigramite) + linear effect estimate.
Requires the ``baselines`` extra: ``pip install 'dotime[baselines]'``.
TODO(consolidate): run PCMCI+ to recover the lagged graph, then estimate the
interventional effect by linear adjustment on the discovered parents.
"""
name = "PCMCI+"
def __init__(self, lag: int = 3, alpha: float = 0.05):
try:
import tigramite # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover
raise ImportError(
"PCMCI+ baseline needs the 'baselines' extra: pip install 'dotime[baselines]'"
) from exc
self.lag = lag
self.alpha = alpha
def predict(self, episode: Episode) -> torch.Tensor:
raise NotImplementedError("wire PCMCIBaseline.predict to tigramite + adjustment")
@register("BayesianITS")
class BayesianPiecewiseITSBaseline:
"""Bayesian piecewise interrupted-time-series reference (CausalPy).
Intended for dot-RegimeSwitch-v1, where a 2-regime episode reduces to a
classic ABA ITS design.
Requires the ``baselines`` extra.
TODO(consolidate): fit a CausalPy InterruptedTimeSeries on the pre/post
split implied by the intervention window; return the posterior-mean
counterfactual at the query time.
"""
name = "BayesianITS"
def __init__(self) -> None:
try:
import causalpy # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover
raise ImportError(
"Bayesian ITS baseline needs the 'baselines' extra: pip install 'dotime[baselines]'"
) from exc
def predict(self, episode: Episode) -> torch.Tensor:
raise NotImplementedError("wire BayesianPiecewiseITSBaseline.predict to CausalPy")
@register("Chronos")
class ChronosObservationalBaseline:
"""Chronos forecaster used observationally (intervention-unaware).
TODO(consolidate): reuse the existing `Chronos2Observational` wrapper from
the original baselines module rather than re-implementing it.
"""
name = "Chronos"
def predict(self, episode: Episode) -> torch.Tensor:
raise NotImplementedError("adapt Chronos2Observational into this interface")
_INT_TYPE_CODE = {"hard": 0, "soft": 1, "time_varying": 2}
def _episode_to_batch(episode: Episode, n_max: int, device: str) -> dict:
"""Convert a released Episode into the model's normalized, padded batch.
Mirrors ``ExtendedDoTime.generate_sample``: causal masking (zero
``x_obs`` from the intervention onset), per-variable normalization over the
pre-intervention window, and the intervention/query field encoding. Returns a
batch of size 1 with the normalization stats so predictions can be mapped back
to the raw scale.
"""
from dotime.normalization import normalize_batch
x_obs = episode.x_obs
t_len, n = x_obs.shape
onset = min(episode.intervention.times) if episode.intervention.times else t_len
int_target = episode.intervention.targets[0] if episode.intervention.targets else 0
# Causal masking (idempotent if the episode is already masked).
masked = x_obs.clone()
masked[onset:] = 0.0
x_padded = torch.zeros(t_len, n_max)
x_padded[:, :n] = masked
var_mask = torch.zeros(n_max)
var_mask[:n] = 1.0
raw_value = episode.intervention.values
raw_value = float(raw_value) if isinstance(raw_value, (int, float)) else 0.0
pre = x_obs[:onset, int_target] if onset > 0 else x_obs[:, int_target]
int_value_norm = raw_value / max(float(pre.std().item()) if pre.numel() > 1 else 1.0, 1e-4)
def _norm_time(v: float) -> float:
return v if v <= 1.0 else v / t_len
q_time = float(episode.query_time[0]) if episode.query_time.numel() else float(t_len - 1)
batch = {
"X_obs": x_padded.unsqueeze(0).to(device),
"variable_mask": var_mask.unsqueeze(0).to(device),
"int_onset_idx": torch.tensor([onset], device=device),
"intervention_target": torch.tensor([int_target], device=device),
"intervention_type": torch.tensor(
[_INT_TYPE_CODE.get(episode.intervention.intervention_type.value, 0)], device=device
),
"intervention_value": torch.tensor([int_value_norm], dtype=torch.float32, device=device),
"intervention_time_start": torch.tensor(
[
_norm_time(
float(min(episode.intervention.times) if episode.intervention.times else 0)
)
],
device=device,
),
"intervention_time_end": torch.tensor(
[
_norm_time(
float(max(episode.intervention.times) if episode.intervention.times else 0)
)
],
device=device,
),
"query_target": torch.tensor([int(episode.query_target[0])], device=device),
"query_time": torch.tensor([_norm_time(q_time)], dtype=torch.float32, device=device),
"Y_true": episode.y_true[:1].to(device),
}
normalize_batch(batch)
return batch
@register("DoOverTimePFN")
class DoOverTimePFNBaseline:
"""The Do-Over-Time-PFN causal foundation model (the headline method).
Loads a trained checkpoint (``[models]`` extra) and predicts the raw
interventional outcome at the query: it builds the model's normalized batch
from the Episode, runs the model in normalized space, then maps the predicted
mean back to the raw scale with the query variable's normalization stats.
Pass a ``checkpoint`` path. NOTE: reproducing the paper's reference numbers
requires the checkpoint trained for the corresponding suite/structure and a
matched evaluation protocol (Phase 8 verification); this wiring is the
inference path, validated to run and produce finite predictions.
"""
name = "DoOverTimePFN"
def __init__(self, checkpoint: str | None = None, device: str = "cpu"):
if checkpoint is None:
raise ValueError(
"DoOverTimePFN baseline needs a trained checkpoint: "
"baselines.get('DoOverTimePFN', checkpoint='/path/to/best.pt')"
)
from dotime.models.loader import load_dotpfn
self.device = device
self.model = load_dotpfn(checkpoint, device=device)
self.n_max = int(getattr(self.model, "n_max", 41))
@torch.no_grad()
def predict(self, episode: Episode) -> torch.Tensor:
batch = _episode_to_batch(episode, self.n_max, self.device)
out = self.model(batch)
head = getattr(self.model, "quantile_head", None) or getattr(self.model, "bar_head", None)
if head is None:
raise RuntimeError("model has neither a quantile_head nor a bar_head")
pred_norm = head.predict_mean(out).reshape(-1)
# Map back to the raw scale with the query variable's stats.
q = int(episode.query_target[0])
mean = batch["_norm_means"][0, q]
std = batch["_norm_stds"][0, q]
return (pred_norm * std + mean).cpu()