Source code for dotime.temporal_scm

"""Temporal SCM with time-stepped forward simulation."""

import warnings

import torch

from dotime._sampling import DistributionSampler
from dotime.interventions import InterventionSpec, InterventionType
from dotime.temporal_graph import TemporalDAG
from dotime.temporal_mechanism import TemporalMechanism
from dotime.utils import check_divergence, clip_values


[docs] class TemporalSCM: """ Temporal Structural Causal Model with time-stepped forward simulation. Extends Do-PFN's SCM to support temporal dependencies with lags. """
[docs] def __init__( self, dag: TemporalDAG, mechanisms: dict[str, TemporalMechanism], noise: dict[str, DistributionSampler], device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ): """ Parameters ---------- dag : TemporalDAG Temporal DAG with instantaneous and lagged edges. mechanisms : Dict[str, TemporalMechanism] Mechanisms for each variable. noise : Dict[str, DistributionSampler] Noise distributions for each variable. device : torch.device Device for computation. dtype : torch.dtype Data type for computation. """ self.dag = dag self.mechanisms = mechanisms self.noise = noise self.device = device self.dtype = dtype # Store topology self._topo = dag.topo_order self._G_0 = dag.G_0 self._G_lags = dag.G_lags self._K = dag.K # Pre-compute name-to-index mapping (O(1) lookups instead of O(N) list scans) self._topo_idx = {v: i for i, v in enumerate(self._topo)} # Parents for each variable (by name, for mechanism compatibility) self._instant_parents = {v: list(self._G_0.predecessors(v)) for v in self._topo} self._lagged_parents = self._compute_lagged_parents() # Pre-resolved parent indices (avoid dict comprehensions in hot loop) self._instant_parent_idx = { i: [self._topo_idx[p] for p in self._instant_parents[v]] for i, v in enumerate(self._topo) } self._lagged_parent_idx = { i: [[self._topo_idx[p] for p in parents_k] for parents_k in self._lagged_parents[v]] for i, v in enumerate(self._topo) } # Pre-resolved parent names paired with indices (for mechanism weight lookup) self._instant_parent_pairs = { i: [(p, self._topo_idx[p]) for p in self._instant_parents[v]] for i, v in enumerate(self._topo) } self._lagged_parent_pairs = { i: [ [(p, self._topo_idx[p]) for p in parents_k] for parents_k in self._lagged_parents[v] ] for i, v in enumerate(self._topo) }
def _compute_lagged_parents(self) -> dict[str, list[list[str]]]: """Compute lagged parents for each variable.""" lagged_parents = {} for v in self._topo: v_idx = self._topo_idx[v] parents_per_lag = [] for k in range(self._K): G_k = self._G_lags[k] # Find parents at lag k+1 parents_k = [self._topo[j] for j in range(len(self._topo)) if G_k[j, v_idx] > 0] parents_per_lag.append(parents_k) lagged_parents[v] = parents_per_lag return lagged_parents @torch.no_grad() def _simulate( self, total_T: int, burn_in: int = 50, intervention: InterventionSpec | None = None, generator: torch.Generator | None = None, divergence_check_interval: int = 50, ) -> torch.Tensor | None: """Unified forward simulation, optionally with intervention. Returns buffer[burn_in:] on success, or None on early divergence. """ N = len(self._topo) total_T - burn_in buffer = torch.zeros(total_T, N, device=self.device, dtype=self.dtype) # Pre-sample all noise (eliminates per-step tensor creation + RNG state swaps) all_noise = {} for v in self._topo: all_noise[v] = ( self.noise[v] .distribution.sample((total_T,)) .to(device=self.device, dtype=self.dtype) ) # Pre-compute intervention lookup set for O(1) checks int_targets = set() int_times = set() int_type = None int_values = None if intervention is not None: int_targets = set(intervention.targets) int_times = set(intervention.times) int_type = intervention.intervention_type int_values = intervention.values # Forward simulation for t in range(total_T): # Early divergence detection if ( divergence_check_interval > 0 and t > 0 and t % divergence_check_interval == 0 and buffer[t - 1].abs().max() > 500 ): return None for i, v in enumerate(self._topo): # Check intervention if intervention is not None and i in int_targets and (t - burn_in) in int_times: if int_type == InterventionType.HARD: buffer[t, i] = int_values continue elif int_type == InterventionType.TIME_VARYING: buffer[t, i] = int_values(t - burn_in) continue # SOFT: fall through to mechanism, add shift after # Gather instantaneous parent values using pre-resolved indices parent_values_instant = { p: buffer[t, idx] for p, idx in self._instant_parent_pairs[i] } # Gather lagged parent values parent_values_lagged = [] for k, pairs_k in enumerate(self._lagged_parent_pairs[i]): if t >= k + 1: parent_values_lagged.append( {p: buffer[t - k - 1, idx] for p, idx in pairs_k} ) else: parent_values_lagged.append({}) # Noise (pre-sampled, already a tensor) eps = all_noise[v][t].unsqueeze(0) # Apply mechanism value = self.mechanisms[v](parent_values_instant, parent_values_lagged, eps) # Soft intervention shift if ( intervention is not None and int_type == InterventionType.SOFT and i in int_targets and (t - burn_in) in int_times ): value = value + int_values buffer[t, i] = clip_values(value) # Final divergence check if check_divergence(buffer): return None return buffer[burn_in:]
[docs] @torch.no_grad() def sample_observational( self, T: int, burn_in: int = 50, generator: torch.Generator | None = None, ) -> torch.Tensor: """ Sample observational data from the temporal SCM. Parameters ---------- T : int Length of time series to generate (after burn-in). burn_in : int Number of burn-in steps to discard. generator : torch.Generator, optional RNG for reproducibility. Returns ------- torch.Tensor Time series data of shape (T, N) where N is number of variables. """ result = self._simulate(T + burn_in, burn_in=burn_in, generator=generator) if result is None: N = len(self._topo) warnings.warn( "SCM diverged during simulation; returning zeros.", RuntimeWarning, stacklevel=2 ) return torch.zeros(T, N, device=self.device, dtype=self.dtype) return result
[docs] @torch.no_grad() def sample_interventional( self, T: int, intervention: InterventionSpec, burn_in: int = 50, generator: torch.Generator | None = None, ) -> torch.Tensor: """ Sample interventional data from the temporal SCM. Parameters ---------- T : int Length of time series to generate (after burn-in). intervention : InterventionSpec Intervention specification. burn_in : int Number of burn-in steps to discard. generator : torch.Generator, optional RNG for reproducibility. Returns ------- torch.Tensor Time series data of shape (T, N) under intervention. """ result = self._simulate( T + burn_in, burn_in=burn_in, intervention=intervention, generator=generator, ) if result is None: N = len(self._topo) warnings.warn( "SCM diverged during interventional simulation; returning zeros.", RuntimeWarning, stacklevel=2, ) return torch.zeros(T, N, device=self.device, dtype=self.dtype) return result