# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Causal module."""
from __future__ import annotations
import itertools as it
import re
import warnings
from collections.abc import Sequence
from typing import Annotated, Literal
try:
import networkx as nx
except ImportError: # Optional dependency
nx = None # type: ignore[assignment]
import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pydantic import Field, InstanceOf, validate_call
from pymc_extras.prior import Prior
try:
from dowhy import CausalModel
except ImportError:
class LazyCausalModel:
"""Lazy import of dowhy's CausalModel."""
def __init__(self, *args, **kwargs):
msg = (
"To use Causal Graph functionality, please install the optional dependencies with: "
"pip install pymc-marketing[dag]"
)
raise ImportError(msg)
CausalModel = LazyCausalModel
[docs]
class BuildModelFromDAG:
"""Build a PyMC probabilistic model directly from a Causal DAG and a tabular dataset.
The class interprets a Directed Acyclic Graph (DAG) where each node is a column
in the provided `df`. For every edge ``A -> B`` it creates a slope prior for
the contribution of ``A`` into the mean of ``B``. Each node receives a
likelihood prior. Dims and coords are used to align and index observed data
via ``pm.Data`` and xarray.
Parameters
----------
dag : str
DAG in DOT format (e.g. ``digraph { A -> B; B -> C; }``) or as a simple
comma/newline separated list of edges (e.g. ``"A->B, B->C"``).
df : pandas.DataFrame
DataFrame that contains a column for every node present in the DAG and
all columns named by the provided ``dims``.
target : str
Name of the target node present in both the DAG and ``df``. This is not
used to restrict modeling but is validated to exist in the DAG.
dims : tuple[str, ...]
Dims for the observed variables and likelihoods (e.g. ``("date", "channel")``).
coords : dict
Mapping from dim names to coordinate values. All coord keys must exist as
columns in ``df`` and will be used to pivot the data to match dims.
model_config : dict, optional
Optional configuration with priors for keys ``"intercept"``, ``"slope"`` and
``"likelihood"``. Values should be ``pymc_extras.prior.Prior`` instances.
Missing keys fall back to :pyattr:`default_model_config`.
Examples
--------
Minimal example using DOT format:
.. code-block:: python
import numpy as np
import pandas as pd
from pymc_marketing.mmm.causal import BuildModelFromDAG
dates = pd.date_range("2024-01-01", periods=5, freq="D")
df = pd.DataFrame(
{
"date": dates,
"X": np.random.normal(size=5),
"Y": np.random.normal(size=5),
}
)
dag = "digraph { X -> Y; }"
dims = ("date",)
coords = {"date": dates}
builder = BuildModelFromDAG(
dag=dag, df=df, target="Y", dims=dims, coords=coords
)
model = builder.build()
Edge-list format and custom likelihood prior:
.. code-block:: python
from pymc_extras.prior import Prior
dag = "X->Y" # equivalent to the DOT example above
model_config = {
"likelihood": Prior(
"StudentT", nu=5, sigma=Prior("HalfNormal", sigma=1), dims=("date",)
),
}
builder = BuildModelFromDAG(
dag=dag,
df=df,
target="Y",
dims=("date",),
coords={"date": dates},
model_config=model_config,
)
model = builder.build()
"""
[docs]
@validate_call
def __init__(
self,
*,
dag: str = Field(..., description="DAG in DOT string format or A->B list"),
df: InstanceOf[pd.DataFrame] = Field(
..., description="DataFrame containing all DAG node columns"
),
target: str = Field(..., description="Target node name present in DAG and df"),
dims: tuple[str, ...] = Field(
..., description="Dims for observed/likelihood variables"
),
coords: dict = Field(
...,
description=(
"Required coords mapping for dims and priors. All coord keys must exist as columns in df."
),
),
model_config: dict | None = Field(
None,
description=(
"Optional model config with Priors for 'intercept', 'slope' and "
"'likelihood'. Keys not supplied fall back to defaults."
),
),
) -> None:
self.dag = dag
self.df = df
self.target = target
self.dims = dims
self.coords = coords
# Parse graph and validate target
self.graph = self._parse_dag(self.dag)
self.nodes = list(nx.topological_sort(self.graph))
if self.target not in self.nodes:
raise ValueError(f"Target '{self.target}' not in DAG nodes: {self.nodes}")
# Merge provided model_config with defaults
provided = model_config
self.model_config = self.default_model_config
if provided is not None:
self.model_config.update(provided)
# Validate required priors are present and of correct type
self._validate_model_config_priors()
# Validate coords are present and consistent with dims, priors, and df
self._validate_coords_required_are_consistent()
# Validate prior dims consistency early (does not require building the model)
self._warning_if_slope_dims_dont_match_likelihood_dims()
self._validate_intercept_dims_match_slope_dims()
@property
def default_model_config(self) -> dict[str, Prior]:
"""Default priors for intercepts, slopes and likelihood using ``pymc_extras.Prior``.
Returns
-------
dict
Dictionary with keys ``"intercept"``, ``"slope"`` and ``"likelihood"``
mapping to ``Prior`` instances with dims derived from
:pyattr:`dims`.
"""
slope_dims = tuple(dim for dim in (self.dims or ()) if dim != "date")
return {
"intercept": Prior("Normal", mu=0, sigma=1, dims=slope_dims),
"slope": Prior("Normal", mu=0, sigma=1, dims=slope_dims),
"likelihood": Prior(
"Normal",
sigma=Prior("HalfNormal", sigma=1),
dims=self.dims,
),
}
@staticmethod
def _parse_dag(dag_str: str) -> nx.DiGraph:
"""Parse DOT digraph or edge-list string into a directed acyclic graph."""
if nx is None:
raise ImportError(
"To use Causal Graph functionality, please install the optional dependencies with: "
"pip install pymc-marketing[dag]"
)
# Primary format: DOT digraph
s = dag_str.strip()
g = nx.DiGraph()
if s.lower().startswith("digraph"):
# Extract content within the first top-level {...}
brace_start = s.find("{")
brace_end = s.rfind("}")
if brace_start == -1 or brace_end == -1 or brace_end <= brace_start:
raise ValueError("Malformed DOT digraph: missing braces")
body = s[brace_start + 1 : brace_end]
# Remove comments (// ... or # ... at line end)
lines = []
for raw_line in body.splitlines():
line = re.split(r"//|#", raw_line, maxsplit=1)[0].strip()
if line:
lines.append(line)
body = "\n".join(lines)
# Find edges "A -> B" possibly ending with ';'
for m in re.finditer(
r"\b([A-Za-z0-9_]+)\s*->\s*([A-Za-z0-9_]+)\s*;?", body
):
a, b = m.group(1), m.group(2)
g.add_edge(a, b)
# Find standalone node declarations (lines with single identifier, optional ';')
for raw_line in body.splitlines():
line = raw_line.strip().rstrip(";")
if not line or "->" in line or "[" in line or "]" in line:
continue
mnode = re.match(r"^([A-Za-z0-9_]+)$", line)
if mnode:
g.add_node(mnode.group(1))
else:
# Fallback: simple comma/newline-separated "A->B" tokens
edges: list[tuple[str, str]] = []
for token in re.split(r"[,\n]+", s):
token = token.strip().rstrip(";")
if not token:
continue
medge = re.match(r"^([A-Za-z0-9_]+)\s*->\s*([A-Za-z0-9_]+)$", token)
if not medge:
raise ValueError(f"Invalid edge token: '{token}'")
a, b = medge.group(1), medge.group(2)
edges.append((a, b))
g.add_edges_from(edges)
if not nx.is_directed_acyclic_graph(g):
raise ValueError("Provided graph is not a DAG.")
return g
def _warning_if_slope_dims_dont_match_likelihood_dims(self) -> None:
"""Warn if slope prior dims differ from likelihood dims without the 'date' dim."""
slope_prior = self.model_config["slope"]
likelihood_prior = self.model_config["likelihood"]
like_dims = getattr(likelihood_prior, "dims", None)
if isinstance(like_dims, str):
like_dims = (like_dims,)
elif isinstance(like_dims, list):
like_dims = tuple(like_dims)
# Guard against None dims (treat as empty)
if like_dims is None:
expected_slope_dims = ()
else:
expected_slope_dims = tuple(dim for dim in like_dims if dim != "date")
slope_dims = getattr(slope_prior, "dims", None)
if slope_dims is None or not isinstance(slope_dims, tuple):
slope_dims = ()
elif isinstance(slope_dims, str):
slope_dims = (slope_dims,)
elif isinstance(slope_dims, list):
slope_dims = tuple(slope_dims)
if slope_dims != expected_slope_dims:
warnings.warn(
(
"Slope prior dims "
f"{slope_dims if slope_dims else '()'} do not match expected dims "
f"{expected_slope_dims} (likelihood dims without 'date')."
),
stacklevel=2,
)
def _validate_intercept_dims_match_slope_dims(self) -> None:
"""Ensure intercept prior dims match slope prior dims exactly."""
def _to_tuple(maybe_dims):
if maybe_dims is None:
return tuple()
if isinstance(maybe_dims, str):
return (maybe_dims,)
if isinstance(maybe_dims, list | tuple):
return tuple(maybe_dims)
return tuple()
slope_dims = _to_tuple(getattr(self.model_config["slope"], "dims", None))
intercept_dims = _to_tuple(
getattr(self.model_config["intercept"], "dims", None)
)
if slope_dims != intercept_dims:
raise ValueError(
"model_config['intercept'].dims must match model_config['slope'].dims. "
f"Got intercept dims {intercept_dims or '()'} and slope dims {slope_dims or '()'}."
)
def _validate_model_config_priors(self) -> None:
"""Ensure required model_config entries are Prior instances.
Enforces that keys 'slope' and 'likelihood' exist and are Prior objects,
so downstream code can safely index and call Prior helper methods.
"""
required_keys = ("intercept", "slope", "likelihood")
for key in required_keys:
if key not in self.model_config:
raise ValueError(f"model_config must include '{key}' as a Prior.")
for key in required_keys:
if not isinstance(self.model_config[key], Prior):
raise TypeError(
f"model_config['{key}'] must be a Prior, got "
f"{type(self.model_config[key]).__name__}."
)
def _validate_coords_required_are_consistent(self) -> None:
"""Validate mutual consistency among dims, coords, priors, and data columns."""
if self.coords is None:
raise ValueError("'coords' is required and cannot be None.")
# 1) All coords keys must correspond to columns in the dataset
for key in self.coords.keys():
if key not in self.df.columns:
raise KeyError(
f"Coordinate key '{key}' not found in DataFrame columns. Present columns: {list(self.df.columns)}"
)
# 2) Ensure dims are present in coords
for d in self.dims:
if d not in self.coords:
raise ValueError(f"Missing coordinate values for dim '{d}' in coords.")
# 3) Ensure Prior.dims exist in coords (for all top-level priors we manage)
def _to_tuple(maybe_dims):
if isinstance(maybe_dims, str):
return (maybe_dims,)
if isinstance(maybe_dims, list | tuple):
return tuple(maybe_dims)
else:
return tuple()
for prior_name, prior in self.model_config.items():
if not isinstance(prior, Prior):
continue
for d in _to_tuple(getattr(prior, "dims", None)):
if d not in self.coords:
raise ValueError(
f"Dim '{d}' declared in Prior '{prior_name}' must be present in coords."
)
# 4) Enforce that likelihood dims match class dims exactly
likelihood_prior = self.model_config["likelihood"]
likelihood_dims = _to_tuple(getattr(likelihood_prior, "dims", None))
if likelihood_dims and tuple(self.dims) != likelihood_dims:
raise ValueError(
"Likelihood Prior dims "
f"{likelihood_dims} must match class dims {tuple(self.dims)}. "
"When supplying a custom model_config, ensure likelihood.dims equals the 'dims' argument."
)
def _parents(self, node: str) -> list[str]:
"""Return the list of parent node names for the given DAG node."""
return list(self.graph.predecessors(node))
[docs]
def build(self) -> pm.Model:
"""Construct and return the PyMC model implied by the DAG and data.
The method creates a ``pm.Data`` container for every node to align the
observed data with the declared ``dims``. For each edge ``A -> B``, a
slope prior is instantiated from ``model_config['slope']`` and used in the
mean of node ``B``'s likelihood, which is instantiated from
``model_config['likelihood']``.
Returns
-------
pymc.Model
A fully specified model with slopes and likelihoods for all nodes.
Examples
--------
Build a model and sample from it:
.. code-block:: python
builder = BuildModelFromDAG(
dag="A->B", df=df, target="B", dims=("date",), coords={"date": dates}
)
model = builder.build()
with model:
idata = pm.sample(100, tune=100, chains=2, cores=2)
Multi-dimensional dims (e.g. date and country):
.. code-block:: python
dims = ("date", "country")
coords = {"date": dates, "country": ["Venezuela", "Colombia"]}
builder = BuildModelFromDAG(
dag="A->B, B->Y", df=df, target="Y", dims=dims, coords=coords
)
model = builder.build()
"""
dims = self.dims
coords = self.coords
with pm.Model(coords=coords) as model:
data_containers: dict[str, pm.Data] = {}
for node in self.nodes:
if node not in self.df.columns:
raise KeyError(f"Column '{node}' not found in df.")
# Ensure observed data has shape consistent with declared dims by pivoting via xarray
indexed = self.df.set_index(list(dims))
xarr = indexed.to_xarray()[node]
values = xarr.values
data_containers[node] = pm.Data(f"_{node}", values, dims=dims)
# For each node add slope priors per parent and likelihood with sigma prior
slope_rvs: dict[tuple[str, str], pt.TensorVariable] = {}
# Create priors in a stable deterministic order
for node in self.nodes:
parents = self._parents(node)
# Slopes for each parent -> node
mu_expr = 0
for parent in parents:
slope_name = f"{parent.lower()}{node.lower()}"
slope_rv = self.model_config["slope"].create_variable(slope_name)
slope_rvs[(parent, node)] = slope_rv
mu_expr += slope_rv * data_containers[parent]
intercept_rv = self.model_config["intercept"].create_variable(
f"{node.lower()}_intercept"
)
self.model_config["likelihood"].create_likelihood_variable(
name=node,
mu=mu_expr + intercept_rv,
observed=data_containers[node],
)
self.model = model
return self.model
[docs]
def model_graph(self):
"""Return a Graphviz visualization of the built PyMC model.
Returns
-------
graphviz.Source
Graphviz object representing the model graph.
Examples
--------
.. code-block:: python
model = builder.build()
g = builder.model_graph()
g
"""
if not hasattr(self, "model"):
raise RuntimeError("Call build() first.")
return pm.model_to_graphviz(self.model)
[docs]
def dag_graph(self):
"""Return a copy of the parsed DAG as a NetworkX directed graph.
Returns
-------
networkx.DiGraph
A directed acyclic graph with the same nodes and edges as the input DAG.
Examples
--------
.. code-block:: python
g = builder.dag_graph()
list(g.edges())
"""
if nx is None:
raise ImportError(
"To use Causal Graph functionality, please install the optional dependencies with: "
"pip install pymc-marketing[dag]"
)
g = nx.DiGraph()
g.add_nodes_from(self.graph.nodes)
g.add_edges_from(self.graph.edges)
return g
[docs]
class TBFPC:
r"""
Target-first Bayes Factor PC (TBF-PC) causal discovery algorithm.
This algorithm is a target-oriented variant of the Peter–Clark (PC) algorithm,
using Bayes factors (via ΔBIC approximation) as the conditional independence test.
For each conditional independence test of the form
.. math::
H_0 : Y \perp X \mid S
\quad \text{vs.} \quad
H_1 : Y \not\!\perp X \mid S
we compare two linear models:
.. math::
M_0 : Y \sim S
\\
M_1 : Y \sim S + X
where :math:`S` is a conditioning set of variables.
The Bayesian Information Criterion (BIC) is defined as
.. math::
\mathrm{BIC}(M) = n \log\!\left(\frac{\mathrm{RSS}}{n}\right)
+ k \log(n),
with residual sum of squares :math:`\mathrm{RSS}`, sample size :math:`n`,
and number of parameters :math:`k`.
The Bayes factor is approximated by
.. math::
\log \mathrm{BF}_{10} \approx -\tfrac{1}{2}
\left[ \mathrm{BIC}(M_1) - \mathrm{BIC}(M_0) \right].
Independence is declared if :math:`\mathrm{BF}_{10} < \tau`,
where :math:`\tau` is set via the ``bf_thresh`` parameter.
Target Edge Rules
-----------------
Different rules govern how driver → target edges are retained:
- ``"any"``:
keep :math:`X \to Y` unless **any** conditioning set renders
:math:`X \perp Y \mid S`.
- ``"conservative"``:
keep :math:`X \to Y` if **at least one** conditioning set shows
dependence.
- ``"fullS"``:
test only with the **full set** of other drivers as :math:`S`.
Examples
--------
**1. Basic usage with full conditioning set**
.. code-block:: python
import numpy as np, pandas as pd
rng = np.random.default_rng(7)
n = 2000
C = rng.gamma(2,1,n)
A = 0.7*C + rng.gamma(2,1,n)
D = 0.5*C + rng.gamma(2,1,n)
B = 0.8*A + rng.gamma(2,1,n)
Y = 0.9*B + 0.6*D + 0.7*C + rng.gamma(2,1,n)
df = pd.DataFrame({"A":A,"B":B,"C":C,"D":D,"Y":Y})
df = (df - df.mean())/df.std() # recommended scaling
model = TBFPC(target="Y", target_edge_rule="fullS")
model.fit(df, drivers=["A","B","C","D"])
print(model.get_directed_edges())
print(model.get_undirected_edges())
print(model.to_digraph())
**2. Using forbidden edges**
You can specify edges that must *not* be tested or included
(prior knowledge about the domain).
.. code-block:: python
model = TBFPC(
target="Y",
target_edge_rule="any",
forbidden_edges=[("A","C")] # forbid A--C
)
model.fit(df, drivers=["A","B","C","D"])
print(model.to_digraph())
**3. Conservative rule**
Keeps driver → target edges if **any conditioning set**
shows dependence.
.. code-block:: python
model = TBFPC(target="Y", target_edge_rule="conservative")
model.fit(df, drivers=["A","B","C","D"])
print(model.to_digraph())
References
----------
- Spirtes, Glymour, Scheines (2000). *Causation, Prediction, and Search*. MIT Press. [PC algorithm]
- Spirtes & Glymour (1991). "An Algorithm for Fast Recovery of Sparse Causal Graphs."
- Kass, R. & Raftery, A. (1995). "Bayes Factors."
"""
[docs]
@validate_call(config=dict(arbitrary_types_allowed=True))
def __init__(
self,
target: Annotated[
str,
Field(
min_length=1,
description="Name of the outcome variable to orient the search.",
),
],
*,
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
forbidden_edges: Sequence[tuple[str, str]] | None = None,
):
"""Create a new TBFPC causal discovery model.
Parameters
----------
target
Variable name for the model outcome; must be present in the data
used during fitting.
target_edge_rule
Rule that controls which driver → target edges are retained.
Options are ``"any"``, ``"conservative"``, and ``"fullS"``.
bf_thresh
Positive Bayes factor threshold applied during conditional
independence tests.
forbidden_edges
Optional sequence of node pairs that must not be connected in the
learned graph.
"""
warnings.warn(
"TBFPC is experimental and its API may change; use with caution.",
UserWarning,
stacklevel=2,
)
self.target = target
self.target_edge_rule = target_edge_rule
self.bf_thresh = float(bf_thresh)
self.forbidden_edges: set[tuple[str, str]] = set(forbidden_edges or [])
# Internal state
self.sep_sets: dict[tuple[str, str], set[str]] = {}
self._adj_directed: set[tuple[str, str]] = set()
self._adj_undirected: set[tuple[str, str]] = set()
self.nodes_: list[str] = []
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
# Shared response vector for symbolic BIC computation
# Initialized with placeholder; will be updated with actual data during fitting
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
self._bic_fn = self._build_symbolic_bic_fn()
def _key(self, u: str, v: str) -> tuple[str, str]:
"""Return a sorted 2-tuple key for an undirected edge between ``u`` and ``v``."""
return (u, v) if u <= v else (v, u)
def _set_sep(self, u: str, v: str, S: Sequence[str]) -> None:
"""Record the separation set ``S`` for the node pair ``(u, v)``."""
self.sep_sets[self._key(u, v)] = set(S)
def _has_forbidden(self, u: str, v: str) -> bool:
"""Return True if edge ``u—v`` is forbidden in either direction."""
return (u, v) in self.forbidden_edges or (v, u) in self.forbidden_edges
def _add_directed(self, u: str, v: str) -> None:
"""Add a directed edge ``u -> v`` if not forbidden; drop undirected if present."""
if not self._has_forbidden(u, v):
self._adj_undirected.discard(self._key(u, v))
self._adj_directed.add((u, v))
def _add_undirected(self, u: str, v: str) -> None:
"""Add an undirected edge ``u -- v`` if allowed and not already directed."""
if (
not self._has_forbidden(u, v)
and (u, v) not in self._adj_directed
and (v, u) not in self._adj_directed
):
self._adj_undirected.add(self._key(u, v))
def _remove_all(self, u: str, v: str) -> None:
"""Remove any edge (directed or undirected) between ``u`` and ``v``."""
self._adj_undirected.discard(self._key(u, v))
self._adj_directed.discard((u, v))
self._adj_directed.discard((v, u))
def _build_symbolic_bic_fn(self):
"""Build a BIC callable using a fast solver with a pseudoinverse fallback."""
X = pt.matrix("X")
n = pt.iscalar("n")
xtx = pt.dot(X.T, X)
xty = pt.dot(X.T, self.y_sh)
beta_solve = pt.linalg.solve(xtx, xty)
resid_solve = self.y_sh - pt.dot(X, beta_solve)
rss_solve = pt.sum(resid_solve**2)
beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh
resid_pinv = self.y_sh - pt.dot(X, beta_pinv)
rss_pinv = pt.sum(resid_pinv**2)
k = X.shape[1]
nf = pt.cast(n, "float64")
rss_solve_safe = pt.maximum(rss_solve, np.finfo("float64").tiny)
rss_pinv_safe = pt.maximum(rss_pinv, np.finfo("float64").tiny)
bic_solve = nf * pt.log(rss_solve_safe / nf) + k * pt.log(nf)
bic_pinv = nf * pt.log(rss_pinv_safe / nf) + k * pt.log(nf)
bic_solve_fn = pytensor.function(
[X, n], [bic_solve, rss_solve], on_unused_input="ignore", mode="FAST_RUN"
)
bic_pinv_fn = pytensor.function(
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
)
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
try:
bic_value, rss_value = bic_solve_fn(X_val, n_val)
if np.isfinite(rss_value) and rss_value > np.finfo("float64").tiny:
return float(bic_value)
except (np.linalg.LinAlgError, RuntimeError, ValueError):
pass
return float(bic_pinv_fn(X_val, n_val))
return bic_fn
def _ci_independent(
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]
) -> bool:
"""Return True if ΔBIC indicates independence of ``x`` and ``y`` given ``cond``."""
if self._has_forbidden(x, y):
return True
n = len(df)
self.y_sh.set_value(df[y].to_numpy().astype("float64"))
if len(cond) == 0:
X0 = np.ones((n, 1))
else:
X0 = np.column_stack([np.ones(n), df[list(cond)].to_numpy()])
X1 = np.column_stack([X0, df[x].to_numpy()])
bic0 = float(self._bic_fn(X0, n))
bic1 = float(self._bic_fn(X1, n))
delta_bic = bic1 - bic0
logBF10 = -0.5 * delta_bic
BF10 = np.exp(logBF10)
result = {
"bic0": bic0,
"bic1": bic1,
"delta_bic": delta_bic,
"logBF10": logBF10,
"BF10": BF10,
"independent": BF10 < self.bf_thresh,
"conditioning_set": list(cond),
}
self.test_results[(x, y, frozenset(cond))] = result
return result["independent"]
def _test_target_edges(self, df: pd.DataFrame, drivers: Sequence[str]) -> None:
"""Phase 1: test driver→target edges according to ``target_edge_rule``."""
for xi in drivers:
nbrs = [d for d in drivers if d != xi]
max_k = min(3, len(nbrs))
all_sets = [S for k in range(max_k + 1) for S in it.combinations(nbrs, k)]
if self.target_edge_rule == "any":
keep = True
for S in all_sets:
if self._ci_independent(df, xi, self.target, S):
self._set_sep(xi, self.target, S)
keep = False
break
if keep:
self._add_directed(xi, self.target)
else:
self._remove_all(xi, self.target)
elif self.target_edge_rule == "conservative":
indep_all = True
for S in all_sets:
if not self._ci_independent(df, xi, self.target, S):
indep_all = False
else:
self._set_sep(xi, self.target, S)
if indep_all:
self._remove_all(xi, self.target)
else:
self._add_directed(xi, self.target)
elif self.target_edge_rule == "fullS":
S = tuple(nbrs)
if self._ci_independent(df, xi, self.target, S):
self._set_sep(xi, self.target, S)
self._remove_all(xi, self.target)
else:
self._add_directed(xi, self.target)
def _test_driver_skeleton(self, df: pd.DataFrame, drivers: Sequence[str]) -> None:
"""Phase 2: build the undirected driver skeleton via pairwise CI tests."""
for xi, xj in it.combinations(drivers, 2):
others = [d for d in drivers if d not in (xi, xj)]
max_k = min(3, len(others))
dependent = True
sep_rec = False
for k in range(max_k + 1):
for S in it.combinations(others, k):
if self._ci_independent(df, xi, xj, S):
self._set_sep(xi, xj, S)
dependent = False
sep_rec = True
break
if sep_rec:
break
if dependent:
self._add_undirected(xi, xj)
else:
self._remove_all(xi, xj)
[docs]
def fit(self, df: pd.DataFrame, drivers: Sequence[str]):
"""Fit the TBFPC procedure to the supplied dataframe.
Parameters
----------
df : pandas.DataFrame
Dataset containing the target column and every candidate driver.
drivers : Sequence[str]
Iterable of column names to treat as potential drivers of the
target.
Returns
-------
TBFPC
The fitted instance (``self``) with internal adjacency structures
populated.
Examples
--------
.. code-block:: python
model = TBFPC(target="Y", target_edge_rule="fullS")
model.fit(df, drivers=["A", "B", "C"])
"""
self.sep_sets.clear()
self._adj_directed.clear()
self._adj_undirected.clear()
self.test_results.clear()
self._test_target_edges(df, drivers)
self._test_driver_skeleton(df, drivers)
self.nodes_ = [*list(drivers), self.target]
return self
[docs]
def get_directed_edges(self) -> list[tuple[str, str]]:
"""Return directed edges learned by the algorithm.
Returns
-------
list[tuple[str, str]]
Sorted list of ``(u, v)`` pairs representing oriented edges.
Examples
--------
.. code-block:: python
directed = model.get_directed_edges()
"""
return sorted(self._adj_directed)
[docs]
def get_undirected_edges(self) -> list[tuple[str, str]]:
"""Return undirected edges remaining after orientation.
Returns
-------
list[tuple[str, str]]
Sorted list of ``(u, v)`` pairs for unresolved adjacencies.
Examples
--------
.. code-block:: python
skeleton = model.get_undirected_edges()
"""
return sorted(self._adj_undirected)
[docs]
def get_test_results(self, x: str, y: str) -> list[dict[str, float]]:
"""Return ΔBIC diagnostics for the unordered pair ``(x, y)``.
Parameters
----------
x : str
Name of the first variable in the pair.
y : str
Name of the second variable in the pair.
Returns
-------
list[dict[str, float]]
Each dictionary contains ``bic0``, ``bic1``, ``delta_bic``,
``logBF10``, ``BF10``, and the conditioning set used during the
test.
Examples
--------
.. code-block:: python
stats = model.get_test_results("A", "Y")
"""
return [v for (xi, yi, _), v in self.test_results.items() if {xi, yi} == {x, y}]
[docs]
def summary(self) -> str:
"""Render a text summary of the learned graph and test count.
Returns
-------
str
Multiline string describing directed edges, undirected edges, and
the number of conditional independence tests executed.
Examples
--------
.. code-block:: python
print(model.summary())
"""
lines = ["=== Directed edges ==="]
for u, v in self.get_directed_edges():
lines.append(f"{u} -> {v}")
lines.append("=== Undirected edges ===")
for u, v in self.get_undirected_edges():
lines.append(f"{u} -- {v}")
lines.append("=== Number of CI tests run ===")
lines.append(str(len(self.test_results)))
return "\n".join(lines)
[docs]
def to_digraph(self) -> str:
"""Return the learned graph encoded in DOT format.
Returns
-------
str
DOT string compatible with Graphviz rendering utilities.
Examples
--------
.. code-block:: python
dot_str = model.to_digraph()
"""
lines = ["digraph G {", " node [shape=ellipse];"]
for n in self.nodes_:
if n == self.target:
lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];')
else:
lines.append(f' "{n}";')
for u, v in self.get_directed_edges():
lines.append(f' "{u}" -> "{v}";')
for u, v in self.get_undirected_edges():
lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none];')
lines.append("}")
return "\n".join(lines)
[docs]
class TBF_FCI:
r"""
Target-first Bayes Factor Temporal PC.
This is a time-series–adapted version of TBF-PC. It combines ideas from
temporal FCI/PCMCI with a Bayes-factor ΔBIC conditional independence test.
For each test :math:`X \perp Y \mid S`, compare:
.. math::
M_0 : Y \sim S
\\
M_1 : Y \sim S + X
with BIC scores
.. math::
\mathrm{BIC}(M) = n \log\!\left(\tfrac{\mathrm{RSS}}{n}\right)
+ k \log(n),
and Bayes factor approximation
.. math::
\log \mathrm{BF}_{10} \approx -\tfrac{1}{2}
\left[ \mathrm{BIC}(M_1) - \mathrm{BIC}(M_0) \right].
Declare independence if :math:`\mathrm{BF}_{10} < \tau`.
Parameters
----------
target : str
Name of the target variable (at time t).
target_edge_rule : {"any", "conservative", "fullS"}
Rule for keeping lagged → target edges.
bf_thresh : float, default=1.0
Declare independence if BF10 < bf_thresh.
forbidden_edges : list of tuple[str, str], optional
Prior knowledge: edges to exclude.
max_lag : int, default=2
Maximum lag to include (t-1, t-2, …).
allow_contemporaneous : bool, default=True
Whether to allow contemporaneous edges at time t.
"""
[docs]
@validate_call(config=dict(arbitrary_types_allowed=True))
def __init__(
self,
target: Annotated[
str,
Field(
min_length=1,
description="Name of the outcome variable at time t.",
),
],
*,
target_edge_rule: Literal["any", "conservative", "fullS"] = "any",
bf_thresh: Annotated[float, Field(gt=0.0)] = 1.0,
forbidden_edges: Sequence[tuple[str, str]] | None = None,
max_lag: Annotated[int, Field(ge=0)] = 2,
allow_contemporaneous: bool = True,
):
"""Create a new temporal TBF-PC causal discovery model.
Parameters
----------
target
Target variable name at time ``t`` that the algorithm orients
toward.
target_edge_rule
Rule used to retain lagged → target edges. Choose from
``"any"``, ``"conservative"``, or ``"fullS"``.
bf_thresh
Positive Bayes factor threshold applied during conditional
independence testing.
forbidden_edges
Optional sequence of node pairs that must be excluded from the
final graph.
max_lag
Maximum lag (inclusive) to consider when constructing temporal
drivers.
allow_contemporaneous
Whether contemporaneous edges at time ``t`` are permitted.
"""
warnings.warn(
"TBF_FCI is experimental and its API may change; use with caution.",
UserWarning,
stacklevel=2,
)
self.target = target
self.target_edge_rule = target_edge_rule
self.bf_thresh = float(bf_thresh)
self.max_lag = int(max_lag)
self.allow_contemporaneous = allow_contemporaneous
self.forbidden_edges: set[tuple[str, str]] = self._expand_edges(forbidden_edges)
self.sep_sets: dict[tuple[str, str], set[str]] = {}
self._adj_directed: set[tuple[str, str]] = set()
self._adj_undirected: set[tuple[str, str]] = set()
self.nodes_: list[str] = []
self.test_results: dict[tuple[str, str, frozenset], dict[str, float]] = {}
# Shared response vector for symbolic BIC computation
# Initialized with placeholder; will be updated with actual data during fitting
self.y_sh = pytensor.shared(np.zeros(1, dtype="float64"), name="y_sh")
self._bic_fn = self._build_symbolic_bic_fn()
def _lag_name(self, var: str, lag: int) -> str:
"""Return canonical lagged variable name like ``X[t-2]`` or ``X[t]``."""
return f"{var}[t-{lag}]" if lag > 0 else f"{var}[t]"
def _parse_lag(self, name: str) -> tuple[str, int]:
"""Parse a lagged variable name into its base and lag components."""
if "[t-" in name:
base, lagpart = name.split("[t-")
return base, int(lagpart[:-1])
if "[t]" in name:
return name.replace("[t]", ""), 0
return name, 0
def _expand_edges(
self, forbidden_edges: Sequence[tuple[str, str]] | None
) -> set[tuple[str, str]]:
"""Expand collapsed forbidden edge pairs into all lagged variants."""
expanded = set()
if forbidden_edges:
for u, v in forbidden_edges:
if "[t" in u or "[t" in v:
expanded.add((u, v))
else:
for lag_u in range(0, self.max_lag + 1):
for lag_v in range(0, self.max_lag + 1):
u_name = f"{u}[t-{lag_u}]" if lag_u > 0 else f"{u}[t]"
v_name = f"{v}[t-{lag_v}]" if lag_v > 0 else f"{v}[t]"
expanded.add((u_name, v_name))
return expanded
def _build_lagged_df(
self, df: pd.DataFrame, variables: Sequence[str]
) -> pd.DataFrame:
"""Construct a time-unrolled dataframe up to ``max_lag`` for variables."""
frames = {}
for lag in range(0, self.max_lag + 1):
shifted = df[variables].shift(lag)
shifted.columns = [self._lag_name(c, lag) for c in shifted.columns]
frames[lag] = shifted
out = pd.concat(frames.values(), axis=1).iloc[self.max_lag :]
return out.astype("float64")
def _admissible_cond_set(
self, all_vars: Sequence[str], x: str, y: str
) -> list[str]:
"""Return conditioning variables admissible for testing ``x`` and ``y``."""
_, lag_x = self._parse_lag(x)
_, lag_y = self._parse_lag(y)
max_time = min(lag_x, lag_y)
keep = []
for z in all_vars:
if z in (x, y):
continue
_, lag_z = self._parse_lag(z)
if lag_z >= max_time:
keep.append(z)
return keep
def _key(self, u: str, v: str) -> tuple[str, str]:
"""Return sorted tuple key for undirected edges between ``u`` and ``v``."""
return (u, v) if u <= v else (v, u)
def _set_sep(self, u: str, v: str, S: Sequence[str]) -> None:
"""Store separation set ``S`` associated with nodes ``u`` and ``v``."""
self.sep_sets[self._key(u, v)] = set(S)
def _has_forbidden(self, u: str, v: str) -> bool:
"""Return True if the edge between ``u`` and ``v`` is forbidden."""
return (u, v) in self.forbidden_edges or (v, u) in self.forbidden_edges
def _add_directed(self, u: str, v: str) -> None:
"""Insert directed edge ``u -> v`` unless forbidden."""
if not self._has_forbidden(u, v):
self._adj_undirected.discard(self._key(u, v))
self._adj_directed.add((u, v))
def _add_undirected(self, u: str, v: str) -> None:
"""Insert undirected edge ``u -- v`` when no orientation is forced."""
if (
not self._has_forbidden(u, v)
and (u, v) not in self._adj_directed
and (v, u) not in self._adj_directed
):
self._adj_undirected.add(self._key(u, v))
def _remove_all(self, u: str, v: str) -> None:
"""Remove any edge (directed or undirected) between ``u`` and ``v``."""
self._adj_undirected.discard(self._key(u, v))
self._adj_directed.discard((u, v))
self._adj_directed.discard((v, u))
def _build_symbolic_bic_fn(self):
"""Build a BIC callable using a fast solver with fallback pseudoinverse."""
X = pt.matrix("X")
n = pt.iscalar("n")
xtx = pt.dot(X.T, X)
xty = pt.dot(X.T, self.y_sh)
beta_solve = pt.linalg.solve(xtx, xty)
resid_solve = self.y_sh - pt.dot(X, beta_solve)
rss_solve = pt.sum(resid_solve**2)
beta_pinv = pt.nlinalg.pinv(X) @ self.y_sh
resid_pinv = self.y_sh - pt.dot(X, beta_pinv)
rss_pinv = pt.sum(resid_pinv**2)
k = X.shape[1]
bic_solve = n * pt.log(rss_solve / n) + k * pt.log(n)
bic_pinv = n * pt.log(rss_pinv / n) + k * pt.log(n)
bic_solve_fn = pytensor.function(
[X, n], bic_solve, on_unused_input="ignore", mode="FAST_RUN"
)
bic_pinv_fn = pytensor.function(
[X, n], bic_pinv, on_unused_input="ignore", mode="FAST_RUN"
)
def bic_fn(X_val: np.ndarray, n_val: int) -> float:
try:
value = float(bic_solve_fn(X_val, n_val))
if np.isfinite(value):
return value
except (np.linalg.LinAlgError, RuntimeError, ValueError):
pass
return float(bic_pinv_fn(X_val, n_val))
return bic_fn
def _ci_independent(
self, df: pd.DataFrame, x: str, y: str, cond: Sequence[str]
) -> bool:
"""Return True if Bayes factor suggests independence of ``x`` and ``y``."""
if self._has_forbidden(x, y):
return True
n = len(df)
self.y_sh.set_value(df[y].to_numpy().astype("float64"))
if len(cond) == 0:
X0 = np.ones((n, 1))
else:
X0 = np.column_stack([np.ones(n), df[list(cond)].to_numpy()])
X1 = np.column_stack([X0, df[x].to_numpy()])
bic0 = float(self._bic_fn(X0, n))
bic1 = float(self._bic_fn(X1, n))
delta_bic = bic1 - bic0
logBF10 = -0.5 * delta_bic
BF10 = np.exp(logBF10)
result = {
"bic0": bic0,
"bic1": bic1,
"delta_bic": delta_bic,
"logBF10": logBF10,
"BF10": BF10,
"independent": BF10 < self.bf_thresh,
"conditioning_set": list(cond),
}
self.test_results[(x, y, frozenset(cond))] = result
return result["independent"]
def _stageA_target_lagged(self, L: pd.DataFrame, drivers: Sequence[str]) -> None:
"""Evaluate lagged driver → target edges according to edge rule."""
y = self._lag_name(self.target, 0)
all_cols = list(L.columns)
for v in drivers:
for lag in range(1, self.max_lag + 1):
x = self._lag_name(v, lag)
cand = self._admissible_cond_set(all_cols, x, y)
max_k = min(3, len(cand))
all_sets = [
S for k in range(max_k + 1) for S in it.combinations(cand, k)
]
if self.target_edge_rule == "fullS":
all_sets = [tuple(cand)]
if self.target_edge_rule == "any":
keep = True
for S in all_sets:
if self._ci_independent(L, x, y, S):
self._set_sep(x, y, S)
keep = False
break
if keep:
self._add_directed(x, y)
else:
self._remove_all(x, y)
elif self.target_edge_rule == "conservative":
indep_all = True
for S in all_sets:
if not self._ci_independent(L, x, y, S):
indep_all = False
else:
self._set_sep(x, y, S)
if indep_all:
self._remove_all(x, y)
else:
self._add_directed(x, y)
elif self.target_edge_rule == "fullS":
S = all_sets[0]
if self._ci_independent(L, x, y, S):
self._set_sep(x, y, S)
self._remove_all(x, y)
else:
self._add_directed(x, y)
def _stageA_driver_lagged(self, L: pd.DataFrame, drivers: Sequence[str]) -> None:
"""Build lagged driver skeleton via conditional independence tests."""
cols = [c for c in L.columns if not c.startswith(self.target)]
for xi, xj in it.combinations(cols, 2):
_, li = self._parse_lag(xi)
_, lj = self._parse_lag(xj)
if li == 0 and lj == 0:
continue
cand = self._admissible_cond_set(
[*cols, self._lag_name(self.target, 0)], xi, xj
)
max_k = min(3, len(cand))
dependent, found_sep = True, False
for k in range(max_k + 1):
for S in it.combinations(cand, k):
if self._ci_independent(L, xi, xj, S):
self._set_sep(xi, xj, S)
dependent = False
found_sep = True
break
if found_sep:
break
if dependent:
self._add_undirected(xi, xj)
else:
self._remove_all(xi, xj)
def _parents_of(self, node: str) -> list[str]:
"""Return list of parents for ``node`` using directed adjacencies."""
return [u for (u, v) in self._adj_directed if v == node]
def _stageB_contemporaneous(self, L: pd.DataFrame, drivers: Sequence[str]) -> None:
"""Test contemporaneous (time ``t``) relations among variables."""
y_nodes = [self._lag_name(v, 0) for v in [*drivers, self.target]]
for xi, xj in it.combinations(y_nodes, 2):
base_S = list(set(self._parents_of(xi) + self._parents_of(xj)))
cand_extra = [z for z in y_nodes if z not in (xi, xj)]
max_k = 2
dependent, found_sep = True, False
for k in range(max_k + 1):
for extra in it.combinations(cand_extra, k):
S = tuple(sorted(set(base_S).union(extra)))
if self._ci_independent(L, xi, xj, S):
self._set_sep(xi, xj, S)
dependent = False
found_sep = True
break
if found_sep:
break
if dependent:
self._add_undirected(xi, xj)
else:
self._remove_all(xi, xj)
[docs]
def fit(self, df: pd.DataFrame, drivers: Sequence[str]):
"""Fit the temporal causal discovery algorithm to ``df``.
Parameters
----------
df : pandas.DataFrame
Input dataframe containing the target column and every driver
column.
drivers : Sequence[str]
Iterable of column names to be treated as drivers of the target.
Returns
-------
TBF_FCI
The fitted instance with internal adjacency structures populated.
Examples
--------
.. code-block:: python
model = TBF_FCI(target="Y", max_lag=2)
model.fit(df, drivers=["A", "B"])
"""
self.sep_sets.clear()
self._adj_directed.clear()
self._adj_undirected.clear()
self.test_results.clear()
all_vars = [*list(drivers), self.target]
L = self._build_lagged_df(df, all_vars)
self.nodes_ = list(L.columns)
self._stageA_target_lagged(L, drivers)
self._stageA_driver_lagged(L, drivers)
if self.allow_contemporaneous:
self._stageB_contemporaneous(L, drivers)
return self
[docs]
def collapsed_summary(
self,
) -> tuple[list[tuple[str, str, int]], list[tuple[str, str]]]:
"""Summarize lagged edges into a driver-level view.
Returns
-------
tuple[list[tuple[str, str, int]], list[tuple[str, str]]]
A tuple with directed edges represented as ``(u, v, lag)`` and
contemporaneous undirected edges represented as ``(u, v)`` pairs.
Examples
--------
.. code-block:: python
directed, undirected = model.collapsed_summary()
"""
collapsed_directed: list[tuple[str, str, int]] = []
for u, v in self._adj_directed:
base_u, lag_u = self._parse_lag(u)
base_v, lag_v = self._parse_lag(v)
if lag_v == 0:
collapsed_directed.append((base_u, base_v, lag_u))
collapsed_undirected: list[tuple[str, str]] = []
for u, v in self._adj_undirected:
base_u, lag_u = self._parse_lag(u)
base_v, lag_v = self._parse_lag(v)
if lag_u == lag_v == 0:
collapsed_undirected.append((base_u, base_v))
return collapsed_directed, collapsed_undirected
[docs]
def get_directed_edges(self) -> list[tuple[str, str]]:
"""Return directed edges in the time-unrolled graph.
Returns
-------
list[tuple[str, str]]
Sorted list of directed edges in the expanded (lagged) graph.
Examples
--------
.. code-block:: python
directed = model.get_directed_edges()
"""
return sorted(self._adj_directed)
[docs]
def get_undirected_edges(self) -> list[tuple[str, str]]:
"""Return undirected edges in the time-unrolled graph.
Returns
-------
list[tuple[str, str]]
Sorted list of undirected edges among lagged variables.
Examples
--------
.. code-block:: python
undirected = model.get_undirected_edges()
"""
return sorted(self._adj_undirected)
[docs]
def summary(self) -> str:
"""Return a human-readable summary of edges and test count.
Returns
-------
str
Multiline description of directed edges, undirected edges, and the
number of conditional independence tests executed.
Examples
--------
.. code-block:: python
print(model.summary())
"""
lines = ["=== Directed edges ==="]
for u, v in self.get_directed_edges():
lines.append(f"{u} -> {v}")
lines.append("=== Undirected edges ===")
for u, v in self.get_undirected_edges():
lines.append(f"{u} -- {v}")
lines.append("=== Number of CI tests run ===")
lines.append(str(len(self.test_results)))
return "\n".join(lines)
[docs]
def to_digraph(self, collapsed: bool = True) -> str:
"""Export the learned graph as DOT text.
Parameters
----------
collapsed : bool, default True
``True`` collapses the time-unrolled graph into driver-level nodes
with lag annotations; ``False`` returns the full lag-expanded
structure.
Returns
-------
str
DOT format string suitable for Graphviz rendering.
Examples
--------
.. code-block:: python
dot_text = model.to_digraph(collapsed=True)
"""
lines = ["digraph G {", " node [shape=ellipse];"]
if not collapsed:
# --- original time-unrolled graph ---
for n in self.nodes_:
if n == self._lag_name(self.target, 0):
lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];')
else:
lines.append(f' "{n}";')
for u, v in self.get_directed_edges():
lines.append(f' "{u}" -> "{v}";')
for u, v in self.get_undirected_edges():
lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none];')
else:
directed, undirected = self.collapsed_summary()
base_nodes = {self._parse_lag(n)[0] for n in self.nodes_}
for n in base_nodes:
if n == self.target:
lines.append(f' "{n}" [style=filled, fillcolor="#eef5ff"];')
else:
lines.append(f' "{n}";')
for u, v, lag in directed:
lines.append(f' "{u}" -> "{v}" [label="lag {lag}"];')
for u, v in undirected:
lines.append(f' "{u}" -> "{v}" [style=dashed, dir=none, label="t"];')
lines.append("}")
return "\n".join(lines)
[docs]
class CausalGraphModel:
"""Represent a causal model based on a Directed Acyclic Graph (DAG).
Provides methods to analyze causal relationships and determine the minimal adjustment set
for backdoor adjustment between treatment and outcome variables.
Parameters
----------
causal_model : CausalModel
An instance of dowhy's CausalModel, representing the causal graph and its relationships.
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.
References
----------
.. [1] https://github.com/microsoft/dowhy
"""
[docs]
def __init__(
self, causal_model: CausalModel, treatment: list[str] | tuple[str], outcome: str
) -> None:
self.causal_model = causal_model
self.treatment = treatment
self.outcome = outcome
[docs]
@classmethod
def build_graphical_model(
cls, graph: str, treatment: list[str] | tuple[str], outcome: str
) -> CausalGraphModel:
"""Create a CausalGraphModel from a string representation of a graph.
Parameters
----------
graph : str
A string representation of the graph (e.g., String in DOT format).
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.
Returns
-------
CausalGraphModel
An instance of CausalGraphModel constructed from the given graph string.
"""
causal_model = CausalModel(
data=pd.DataFrame(), graph=graph, treatment=treatment, outcome=outcome
)
return cls(causal_model, treatment, outcome)
[docs]
def get_backdoor_paths(self) -> list[list[str]]:
"""Find all backdoor paths between the combined treatment and outcome variables.
Returns
-------
list[list[str]]
A list of backdoor paths, where each path is represented as a list of variable names.
References
----------
.. [1] Causal Inference in Statistics: A Primer
By Judea Pearl, Madelyn Glymour, Nicholas P. Jewell · 2016
"""
# Use DoWhy's internal method to get backdoor paths for all treatments combined
return self.causal_model._graph.get_backdoor_paths(
nodes1=self.treatment, nodes2=[self.outcome]
)
[docs]
def get_unique_adjustment_nodes(self) -> list[str]:
"""Compute the minimal adjustment set required for backdoor adjustment across all treatments.
Returns
-------
list[str]
A list of unique adjustment variables needed to block all backdoor paths.
"""
paths = self.get_backdoor_paths()
# Flatten paths and exclude treatments and outcome from adjustment set
adjustment_nodes = set(
node
for path in paths
for node in path
if node not in self.treatment and node != self.outcome
)
return list(adjustment_nodes)
[docs]
def compute_adjustment_sets(
self,
channel_columns: list[str] | tuple[str],
control_columns: list[str] | None = None,
) -> list[str] | None:
"""Compute minimal adjustment sets and handle warnings."""
channel_columns = list(channel_columns)
if control_columns is None:
return control_columns
self.adjustment_set = self.get_unique_adjustment_nodes()
common_controls = set(control_columns).intersection(self.adjustment_set)
unique_controls = set(control_columns) - set(self.adjustment_set)
if unique_controls:
warnings.warn(
f"Columns {unique_controls} are not in the adjustment set. Controls are being modified.",
stacklevel=2,
)
control_columns = list(common_controls - set(channel_columns))
self.minimal_adjustment_set = control_columns + list(channel_columns)
for column in self.adjustment_set:
if column not in control_columns and column not in channel_columns:
warnings.warn(
f"""Column {column} in adjustment set not found in data.
Not controlling for this may induce bias in treatment effect estimates.""",
stacklevel=2,
)
return control_columns