Source code for dotime.regime_switching

"""Regime-switching temporal SCMs for DoTime.

This module implements temporal SCMs where the causal structure and/or mechanisms
can change over time according to a discrete Markov process.
"""

import warnings

import numpy as np
import torch

from dotime._sampling import DistributionSampler
from dotime.interventions import InterventionSpec, InterventionType
from dotime.temporal_graph import TemporalDAG
from dotime.temporal_mechanism import TemporalMechanism
from dotime.utils import check_divergence, clip_values


[docs] class RegimeSwitchingTemporalSCM: """ Regime-switching temporal SCM. At each time step, a discrete regime variable d_t ∈ {0,...,R-1} determines which causal structure and mechanisms are active: X_t^(i) = f_i^(d_t)(Pa_{G^(d_t)}(X_t^(i))) + ε_t^(i) Regime transitions follow a Markov chain: d_t ~ Categorical(P[d_{t-1}, :]) """
[docs] def __init__( self, dags: list[TemporalDAG], mechanisms: list[dict[str, TemporalMechanism]], noise: dict[str, DistributionSampler], transition_matrix: np.ndarray, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ): """ Parameters ---------- dags : List[TemporalDAG] List of temporal DAGs, one per regime. mechanisms : List[Dict[str, TemporalMechanism]] List of mechanism dictionaries, one per regime. noise : Dict[str, DistributionSampler] Noise distributions (shared across regimes). transition_matrix : np.ndarray Regime transition matrix P of shape (R, R). device : torch.device Device for computation. dtype : torch.dtype Data type. """ self.dags = dags self.mechanisms = mechanisms self.noise = noise self.transition_matrix = transition_matrix self.device = device self.dtype = dtype self.num_regimes = len(dags) assert len(mechanisms) == self.num_regimes assert transition_matrix.shape == (self.num_regimes, self.num_regimes) # Precompute parents for each regime self._regime_parents = [] for dag in dags: instant_parents = {v: list(dag.G_0.predecessors(v)) for v in dag.topo_order} lagged_parents = self._compute_lagged_parents(dag) self._regime_parents.append( { "instant": instant_parents, "lagged": lagged_parents, "topo": dag.topo_order, } )
def _compute_lagged_parents(self, dag: TemporalDAG) -> dict[str, list[list[str]]]: """Compute lagged parents for a given DAG.""" lagged_parents = {} node_to_idx = {v: i for i, v in enumerate(dag.topo_order)} for v in dag.topo_order: v_idx = node_to_idx[v] parents_per_lag = [] for k in range(dag.K): G_k = dag.G_lags[k] parents_k = [ dag.topo_order[j] for j in range(len(dag.topo_order)) if G_k[j, v_idx] > 0 ] parents_per_lag.append(parents_k) lagged_parents[v] = parents_per_lag return lagged_parents
[docs] @torch.no_grad() def sample_observational( self, T: int, burn_in: int = 50, generator: torch.Generator | None = None, return_regimes: bool = False, ) -> torch.Tensor: """ Sample observational data with regime switching. Parameters ---------- return_regimes : bool If True, return (X, regimes) tuple. Otherwise just X. Returns ------- torch.Tensor or Tuple[torch.Tensor, np.ndarray] Time series (T, N), or (time series, regimes) if return_regimes=True. """ total_T = T + burn_in N = len(self.dags[0].topo_order) buffer = torch.zeros(total_T, N, device=self.device, dtype=self.dtype) regimes = np.zeros(total_T, dtype=np.int32) # Initialize regime regime = np.random.randint(0, self.num_regimes) for t in range(total_T): # Sample regime transition if t > 0: regime = np.random.choice(self.num_regimes, p=self.transition_matrix[regime]) regimes[t] = regime # Get regime-specific structures parents = self._regime_parents[regime] self.dags[regime] mechs = self.mechanisms[regime] # Forward simulation for i, v in enumerate(parents["topo"]): instant_parents_v = parents["instant"][v] parent_values_instant = { p: buffer[t, parents["topo"].index(p)] for p in instant_parents_v } parent_values_lagged = [] for k, parents_k in enumerate(parents["lagged"][v]): if t >= k + 1: parent_values_k = { p: buffer[t - k - 1, parents["topo"].index(p)] for p in parents_k } else: parent_values_k = {} parent_values_lagged.append(parent_values_k) eps = self.noise[v].sample(generator=generator) eps_tensor = torch.tensor([eps], device=self.device, dtype=self.dtype) value = mechs[v](parent_values_instant, parent_values_lagged, eps_tensor) value = clip_values(value) buffer[t, i] = value.item() if check_divergence(buffer): warnings.warn( "Regime-switching SCM diverged; returning zeros.", RuntimeWarning, stacklevel=2 ) if return_regimes: return torch.zeros(T, N, device=self.device, dtype=self.dtype), regimes[burn_in:] return torch.zeros(T, N, device=self.device, dtype=self.dtype) if return_regimes: return buffer[burn_in:], regimes[burn_in:] return buffer[burn_in:]
[docs] @torch.no_grad() def sample_interventional( self, T: int, intervention: InterventionSpec, burn_in: int = 50, generator: torch.Generator | None = None, return_regimes: bool = False, ) -> torch.Tensor: """ Sample interventional data with regime switching. Parameters ---------- return_regimes : bool If True, return (X, regimes) tuple. Otherwise just X. Returns ------- torch.Tensor or Tuple[torch.Tensor, np.ndarray] Time series (T, N), or (time series, regimes) if return_regimes=True. """ total_T = T + burn_in N = len(self.dags[0].topo_order) buffer = torch.zeros(total_T, N, device=self.device, dtype=self.dtype) regimes = np.zeros(total_T, dtype=np.int32) regime = np.random.randint(0, self.num_regimes) for t in range(total_T): if t > 0: regime = np.random.choice(self.num_regimes, p=self.transition_matrix[regime]) regimes[t] = regime parents = self._regime_parents[regime] self.dags[regime] mechs = self.mechanisms[regime] for i, v in enumerate(parents["topo"]): is_intervened = (i in intervention.targets) and (t - burn_in in intervention.times) if is_intervened: if intervention.intervention_type == InterventionType.HARD: buffer[t, i] = intervention.values elif intervention.intervention_type == InterventionType.TIME_VARYING: buffer[t, i] = intervention.values(t - burn_in) else: # SOFT instant_parents_v = parents["instant"][v] parent_values_instant = { p: buffer[t, parents["topo"].index(p)] for p in instant_parents_v } parent_values_lagged = [] for k, parents_k in enumerate(parents["lagged"][v]): if t >= k + 1: parent_values_k = { p: buffer[t - k - 1, parents["topo"].index(p)] for p in parents_k } else: parent_values_k = {} parent_values_lagged.append(parent_values_k) eps = self.noise[v].sample(generator=generator) eps_tensor = torch.tensor([eps], device=self.device, dtype=self.dtype) value = mechs[v](parent_values_instant, parent_values_lagged, eps_tensor) value = value + intervention.values value = clip_values(value) buffer[t, i] = value.item() else: instant_parents_v = parents["instant"][v] parent_values_instant = { p: buffer[t, parents["topo"].index(p)] for p in instant_parents_v } parent_values_lagged = [] for k, parents_k in enumerate(parents["lagged"][v]): if t >= k + 1: parent_values_k = { p: buffer[t - k - 1, parents["topo"].index(p)] for p in parents_k } else: parent_values_k = {} parent_values_lagged.append(parent_values_k) eps = self.noise[v].sample(generator=generator) eps_tensor = torch.tensor([eps], device=self.device, dtype=self.dtype) value = mechs[v](parent_values_instant, parent_values_lagged, eps_tensor) value = clip_values(value) buffer[t, i] = value.item() if check_divergence(buffer): warnings.warn( "Regime-switching SCM diverged; returning zeros.", RuntimeWarning, stacklevel=2 ) if return_regimes: return torch.zeros(T, N, device=self.device, dtype=self.dtype), regimes[burn_in:] return torch.zeros(T, N, device=self.device, dtype=self.dtype) if return_regimes: return buffer[burn_in:], regimes[burn_in:] return buffer[burn_in:]