Source code for dotime.prior

"""DoTime: Main orchestrator for sampling temporal SCMs with interventions."""

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn

from dotime._activations import Tanh, TanhReLU, TanhX2
from dotime._sampling import ShiftedExponentialSampler
from dotime.chain_scm import ChainSCMBuilder
from dotime.interventions import InterventionSampler, InterventionSpec
from dotime.regime_switching import RegimeSwitchingTemporalSCM
from dotime.regime_switching_builder import RegimeSwitchingSCMBuilder
from dotime.temporal_scm import TemporalSCM
from dotime.temporal_scm_builder import TemporalSCMBuilder
from dotime.utils import DEFAULT_CONFIG


class Sin(nn.Module):
    def forward(self, x):
        return torch.sin(x)


class Cos(nn.Module):
    def forward(self, x):
        return torch.cos(x)


class Abs(nn.Module):
    def forward(self, x):
        return torch.abs(x)


class Square(nn.Module):
    def forward(self, x):
        return torch.pow(x, 2)


[docs] class DoTime: """ Prior distribution over temporal SCMs with interventions. Main interface for generating synthetic causal time series data. """
[docs] def __init__( self, config: dict[str, Any] | None = None, seed: int = 42, chain_prob: float = 0.15, regime_switching_prob: float = 0.15, ): """ Parameters ---------- config : Dict[str, Any], optional Configuration dictionary. If None, uses DEFAULT_CONFIG. seed : int Random seed for reproducibility. chain_prob : float Probability of generating a chain SCM (default 0.15). regime_switching_prob : float Probability of generating a regime-switching SCM (default 0.15). """ # Merge config with defaults self.config = {**DEFAULT_CONFIG} if config is not None: self.config.update(config) self.seed = seed self.chain_prob = chain_prob self.regime_switching_prob = regime_switching_prob self.generator = torch.Generator() self.generator.manual_seed(seed) # Activation functions (from paper + Do-PFN) self.activations = [ nn.Identity(), # Linear Tanh(), # tanh TanhX2(), # tanh(x^2) TanhReLU(), # tanh(relu(x)) nn.ReLU(), # relu # Additional nonlinear functions from the paper Sin(), # sin Cos(), # cos Abs(), # abs Square(), # x^2 ] # Chain SCM builder self.chain_builder = ChainSCMBuilder( activations=self.activations, device=self.config["device"], )
# Regime-switching SCM builder (will be instantiated per sample) # since it depends on sampled N
[docs] def sample_scm(self) -> TemporalSCM: """Sample a temporal SCM from the prior. Distribution: - chain_prob: chain SCMs - regime_switching_prob: regime-switching SCMs - remaining: diverse nonlinear SCMs Returns ------- TemporalSCM Sampled temporal SCM (or compatible regime-switching SCM). """ # Decide SCM type rand_val = torch.rand(1, generator=self.generator).item() if rand_val < self.chain_prob: # Sample chain SCM scm = self.chain_builder.sample(self.generator) elif rand_val < self.chain_prob + self.regime_switching_prob: # Sample regime-switching SCM N = int( torch.randint(3, self.config["N_max"] + 1, (1,), generator=self.generator).item() ) K = int( torch.randint(1, self.config["K_max"] + 1, (1,), generator=self.generator).item() ) rs_builder = RegimeSwitchingSCMBuilder( num_nodes=N, max_lag=K, activations=self.activations, gamma=self.config["gamma"], sigma_w=self.config["sigma_w"], sigma_b=self.config["sigma_b"], device=self.config["device"], ) scm = rs_builder.sample(self.generator) else: # Sample diverse nonlinear SCM # Sample hyperparameters N = int( torch.randint(3, self.config["N_max"] + 1, (1,), generator=self.generator).item() ) K = int( torch.randint(1, self.config["K_max"] + 1, (1,), generator=self.generator).item() ) # Sample edge probability from Beta distribution alpha, beta = self.config["alpha"], self.config["beta"] edge_prob = float(torch.distributions.Beta(alpha, beta).sample().item()) # Sample dropout probability dropout_prob = float(torch.rand(1, generator=self.generator).item() * 0.3) # Up to 30% # Create noise distributions root_std_dist = ShiftedExponentialSampler(rate=1.0, shift=0.1) non_root_std_dist = ShiftedExponentialSampler(rate=10.0, shift=0.01) # Create SCM builder scm_builder = TemporalSCMBuilder( num_nodes=N, max_lag=K, edge_prob=edge_prob, dropout_prob=dropout_prob, gamma=self.config["gamma"], activations=self.activations, root_std_dist=root_std_dist, non_root_std_dist=non_root_std_dist, root_mean=self.config["root_mean"], non_root_mean=self.config["non_root_mean"], sigma_w=self.config["sigma_w"], sigma_b=self.config["sigma_b"], device=self.config["device"], ) # Sample SCM scm = scm_builder.sample(self.generator) return scm
[docs] def generate_pair( self, T: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, InterventionSpec, TemporalSCM]: """Generate a pair of observational and interventional time series. Parameters ---------- T : int, optional Length of time series. If None, uses config default. Returns ------- Tuple[torch.Tensor, torch.Tensor, InterventionSpec, TemporalSCM] (X_obs, X_int, intervention_spec, scm) """ if T is None: T = self.config["T"] # Sample SCM scm = self.sample_scm() N = len(scm._topo) # Sample intervention intervention_sampler = InterventionSampler( N=N, T=T, generator=self.generator, ) intervention = intervention_sampler.sample() # Generate observational data X_obs = scm.sample_observational( T=T, burn_in=self.config["burn_in"], generator=self.generator, ) # Generate interventional data X_int = scm.sample_interventional( T=T, intervention=intervention, burn_in=self.config["burn_in"], generator=self.generator, ) return X_obs, X_int, intervention, scm
[docs] def generate_regime_pair( self, T: int | None = None, num_regimes: int = 2, ) -> tuple[torch.Tensor, torch.Tensor, InterventionSpec, RegimeSwitchingTemporalSCM]: """Generate a paired (obs, int) trajectory from a regime-switching SCM. Like :meth:`generate_pair` but forces a regime-switching SCM with a fixed number of regimes (for the regime-density benchmark tiers). """ if T is None: T = self.config["T"] N = int(torch.randint(3, self.config["N_max"] + 1, (1,), generator=self.generator).item()) K = int(torch.randint(1, self.config["K_max"] + 1, (1,), generator=self.generator).item()) rs_builder = RegimeSwitchingSCMBuilder( num_nodes=N, max_lag=K, activations=self.activations, gamma=self.config["gamma"], sigma_w=self.config["sigma_w"], sigma_b=self.config["sigma_b"], device=self.config["device"], ) scm = rs_builder.sample(self.generator, num_regimes=num_regimes) intervention = InterventionSampler(N=N, T=T, generator=self.generator).sample() X_obs = scm.sample_observational( T=T, burn_in=self.config["burn_in"], generator=self.generator ) X_int = scm.sample_interventional( T=T, intervention=intervention, burn_in=self.config["burn_in"], generator=self.generator ) return X_obs, X_int, intervention, scm
[docs] def generate_dataset( self, n_scms: int, T: int | None = None, ) -> list[tuple[torch.Tensor, torch.Tensor, InterventionSpec]]: """Generate a dataset of paired observational/interventional time series. Parameters ---------- n_scms : int Number of SCMs to sample. T : int, optional Length of time series. If None, uses config default. Returns ------- List[Tuple[torch.Tensor, torch.Tensor, InterventionSpec]] List of (X_obs, X_int, intervention_spec) tuples. """ dataset = [] for i in range(n_scms): X_obs, X_int, intervention, _scm = self.generate_pair(T=T) dataset.append((X_obs, X_int, intervention)) if (i + 1) % 100 == 0: print(f"Generated {i + 1}/{n_scms} SCM pairs...") return dataset
[docs] def generate_training_tuples( self, n_scms: int, T: int | None = None, ) -> list[tuple[torch.Tensor, list[int], list[int], Any, torch.Tensor]]: """Generate training tuples for PFN training. Format: (X_obs, targets, times, values, Y_int_tau) Parameters ---------- n_scms : int Number of SCMs to sample. T : int, optional Length of time series. If None, uses config default. Returns ------- List[Tuple[torch.Tensor, List[int], List[int], Any, torch.Tensor]] Training tuples suitable for PFN training. """ if T is None: T = self.config["T"] training_data = [] for i in range(n_scms): X_obs, X_int, intervention, _scm = self.generate_pair(T=T) # Extract target variable outcomes at intervention times target_idx = intervention.targets[0] if len(intervention.targets) > 0 else 0 Y_int_tau = X_int[:, target_idx] training_data.append( ( X_obs, intervention.targets, intervention.times, intervention.values, Y_int_tau, ) ) if (i + 1) % 100 == 0: print(f"Generated {i + 1}/{n_scms} training tuples...") return training_data