Source code for dotime.interventions

"""Intervention specifications and sampling for DoTime."""

from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any

import numpy as np
import torch


@dataclass
class StepIntervention:
    """Step function intervention profile (picklable replacement for lambda)."""

    step_time: int

    def __call__(self, t):
        return 2.0 if t >= self.step_time else -2.0


@dataclass
class RampIntervention:
    """Ramp intervention profile (picklable replacement for lambda)."""

    start_time: int
    intervention_length: int

    def __call__(self, t):
        return -2.0 + 4.0 * (t - self.start_time) / self.intervention_length


@dataclass
class SineIntervention:
    """Sinusoidal intervention profile (picklable replacement for lambda)."""

    start_time: int
    freq: float

    def __call__(self, t):
        return 2.0 * np.sin(self.freq * (t - self.start_time))


@dataclass
class TrajectoryIntervention:
    """Trajectory-based intervention profile (picklable replacement for lambda)."""

    trajectory_dict: dict

    def __call__(self, t):
        return self.trajectory_dict.get(t, 0.0)


[docs] class InterventionType(Enum): """Types of interventions.""" HARD = "hard" # do(X_i := c) SOFT = "soft" # X_i = f_i(...) + delta TIME_VARYING = "time_varying" # do(X_i := c(t))
# Picklable time-varying intervention profiles, keyed by name for (de)serialization. _PROFILE_REGISTRY = { "StepIntervention": StepIntervention, "RampIntervention": RampIntervention, "SineIntervention": SineIntervention, "TrajectoryIntervention": TrajectoryIntervention, }
[docs] @dataclass class InterventionSpec: """Specification of an intervention on a temporal SCM. Attributes: targets: List of variable indices to intervene on times: List of time indices when intervention is active intervention_type: Type of intervention (hard, soft, time-varying) values: Intervention values (constant, shift, or time-varying function) """ targets: list[int] times: list[int] intervention_type: InterventionType values: float | torch.Tensor | Callable
[docs] def to_dict(self) -> dict: """JSON-serializable view of the spec (round-trips via :meth:`from_dict`). The ``values`` field is encoded by kind: a scalar, a dense ``tensor`` list, or a named time-varying ``profile`` with its parameters. """ from dataclasses import asdict, is_dataclass v = self.values values: dict[str, Any] if isinstance(v, torch.Tensor): values = {"kind": "tensor", "data": v.detach().cpu().tolist()} elif not isinstance(v, type) and is_dataclass(v) and type(v).__name__ in _PROFILE_REGISTRY: # v is a dataclass instance here; mypy's is_dataclass TypeGuard can't narrow it. values = {"kind": "profile", "name": type(v).__name__, "params": asdict(v)} # type: ignore[call-overload] elif callable(v): raise TypeError( f"cannot serialize intervention value of type {type(v).__name__!r}; " "use a registered profile dataclass (StepIntervention, ...)" ) else: values = {"kind": "scalar", "data": float(v)} return { "targets": list(self.targets), "times": list(self.times), "intervention_type": self.intervention_type.value, "values": values, }
[docs] @classmethod def from_dict(cls, d: dict) -> "InterventionSpec": """Reconstruct an :class:`InterventionSpec` from :meth:`to_dict` output.""" v = d["values"] kind = v["kind"] if kind == "scalar": values: float | torch.Tensor | Callable = float(v["data"]) elif kind == "tensor": values = torch.tensor(v["data"], dtype=torch.float32) elif kind == "profile": values = _PROFILE_REGISTRY[v["name"]](**v["params"]) else: raise ValueError(f"unknown intervention value kind {kind!r}") return cls( targets=list(d["targets"]), times=list(d["times"]), intervention_type=InterventionType(d["intervention_type"]), values=values, )
class InterventionSampler: """Samples random intervention specifications for temporal SCMs.""" def __init__( self, N: int, T: int, p_hard: float = 0.5, p_soft: float = 0.3, p_time_varying: float = 0.2, max_targets: int = 2, min_intervention_length: int = 10, generator: torch.Generator | None = None, ): """ Parameters ---------- N : int Number of variables in the SCM. T : int Length of time series. p_hard : float Probability of hard intervention. p_soft : float Probability of soft intervention. p_time_varying : float Probability of time-varying intervention. max_targets : int Maximum number of variables to intervene on. min_intervention_length : int Minimum length of intervention period. generator : torch.Generator, optional RNG for reproducibility. """ self.N = N self.T = T self.p_hard = p_hard self.p_soft = p_soft self.p_time_varying = p_time_varying self.max_targets = min(max_targets, N) self.min_intervention_length = min_intervention_length self.generator = generator # Normalize probabilities total = p_hard + p_soft + p_time_varying self.p_hard /= total self.p_soft /= total self.p_time_varying /= total def sample(self) -> InterventionSpec: """Sample a random intervention specification. Returns ------- InterventionSpec Sampled intervention specification. """ # Sample intervention type r = torch.rand(1, generator=self.generator).item() if r < self.p_hard: intervention_type = InterventionType.HARD elif r < self.p_hard + self.p_soft: intervention_type = InterventionType.SOFT else: intervention_type = InterventionType.TIME_VARYING # Sample targets (1 to max_targets variables) num_targets = int( torch.randint(1, self.max_targets + 1, (1,), generator=self.generator).item() ) targets = torch.randperm(self.N, generator=self.generator)[:num_targets].tolist() # Sample intervention times (contiguous period) intervention_length = int( torch.randint( self.min_intervention_length, self.T - self.min_intervention_length + 1, (1,), generator=self.generator, ).item() ) start_time = int( torch.randint( self.min_intervention_length, self.T - intervention_length + 1, (1,), generator=self.generator, ).item() ) times = list(range(start_time, start_time + intervention_length)) # Sample intervention values based on type values: float | torch.Tensor | Callable if intervention_type == InterventionType.HARD: # Hard intervention: constant value value = torch.randn(1, generator=self.generator).item() * 2.0 values = value elif intervention_type == InterventionType.SOFT: # Soft intervention: additive shift value = torch.randn(1, generator=self.generator).item() * 1.0 values = value else: # TIME_VARYING # Time-varying intervention: choose profile type profile_type = int(torch.randint(0, 4, (1,), generator=self.generator).item()) if profile_type == 0: # Step function step_time = start_time + intervention_length // 2 values = StepIntervention(step_time) elif profile_type == 1: # Ramp values = RampIntervention(start_time, intervention_length) elif profile_type == 2: # Sinusoidal freq = 2 * np.pi / intervention_length values = SineIntervention(start_time, freq) else: # Sampled trajectory trajectory = torch.randn(intervention_length, generator=self.generator) * 2.0 trajectory_dict = { start_time + i: trajectory[i].item() for i in range(intervention_length) } values = TrajectoryIntervention(trajectory_dict) return InterventionSpec( targets=targets, times=times, intervention_type=intervention_type, values=values, )