from __future__ import annotations
import sys
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, TypedDict, cast
import numpy as np
import pandas as pd
import scipy.sparse as sp
import statsmodels.formula.api as smf
from anndata import AnnData
if sys.version_info >= (3, 11):
from typing import NotRequired
else:
from typing_extensions import NotRequired
from ..adata_tools import subset_primary
from ..design import TrialDesign
from ..utils import wild_cluster_bootstrap_t
from ._utils import apply_fdr, encode_visit, standardize_series
# Minimum number of clusters for reliable cluster-robust standard errors
# Cameron & Miller (2015) recommend 42+, with 10 as absolute minimum
# Ref: Cameron, A. Colin, and Douglas L. Miller. "A practitioner’s guide to cluster-robust inference." Journal of human resources 50.2 (2015): 317-372.
MIN_CLUSTERS_FOR_ROBUST_SE = 10
[docs]
@dataclass(frozen=True)
class DiDConfig:
"""Configuration for DiD analysis.
Use this dataclass to bundle common DiD settings and pass to did_table.
"""
aggregate: AggregateMode = "participant_visit"
"""Aggregation mode for features (cell-level or participant-level)."""
layer: str | None = None
"""Expression layer to use for genes (None uses adata.X)."""
standardize: bool = True
"""Whether to z-score outcomes before model fitting."""
agg: AggregateFunc = "mean"
"""Aggregation function applied after grouping."""
covariates: list[str] | None = None
"""Optional covariate columns to include as fixed effects."""
use_bootstrap: bool = False
"""If True, use wild cluster bootstrap for p-values."""
n_boot: int = 999
"""Number of bootstrap permutations."""
seed: int = 42
"""Random seed for bootstrap reproducibility."""
exclude_crossovers: bool = True
"""Whether to drop crossover cells if crossover_col is set."""
def _validate_did_fit_inputs(df: pd.DataFrame, cols: list[str]) -> None:
"""Validate the input columns for did_fit."""
missing = [c for c in cols if c not in df.columns]
if missing:
raise KeyError(f"Missing required columns for did_fit: {missing}")
def _standardize_outcome(
tmp: pd.DataFrame,
y: str,
standardize: bool,
outcome_col: str = "outcome_std",
) -> pd.DataFrame | None:
"""Standardize the outcome variable."""
if not standardize:
tmp[outcome_col] = tmp[y].astype(float)
return tmp
y_std, ok = standardize_series(tmp, y, min_std=1e-8)
if not ok:
return None
tmp[outcome_col] = y_std
return tmp
def _build_did_formula(
time: str,
arm_bin: str,
unit: str,
covariates: list[str] | None,
outcome_col: str = "outcome_std",
) -> str:
"""Build the formula for DiD."""
formula = f"{outcome_col} ~ {time} + {time}:{arm_bin} + C({unit})"
if covariates:
formula += " + " + " + ".join(covariates)
return formula
def _prepare_did_obs(
adata: AnnData,
design: TrialDesign,
visits: tuple[str, str],
celltype: str | None,
exclude_crossovers: bool,
) -> tuple[AnnData, pd.DataFrame]:
"""Prepare the data for DiD analysis."""
ad = subset_primary(adata, design, visits=visits, exclude_crossovers=exclude_crossovers)
if celltype is not None and design.celltype_col:
ad = ad[ad.obs[design.celltype_col] == celltype].copy()
obs = encode_visit(ad.obs.copy(), design.visit_col, visits)
obs["arm_bin"] = (obs[design.arm_col] == design.arm_treated).astype(int)
return ad, obs
def _add_feature_columns(
df: pd.DataFrame,
ad: AnnData,
features: Sequence[str],
layer: str | None,
) -> tuple[pd.DataFrame, list[str]]:
"""Add feature columns to the DataFrame."""
# features can be genes (in var_names) or obs columns
# Optimization: Extract all genes from X/layer at once if they are in var_names
genes_to_extract = [f for f in features if f in ad.var_names and f not in ad.obs.columns]
feature_data: dict[str, np.ndarray] = {}
final_features: list[str] = []
missing: list[str] = []
if genes_to_extract:
ad_sub = ad[:, genes_to_extract]
X = ad_sub.layers[layer] if layer is not None else ad_sub.X
if sp.issparse(X):
if isinstance(X, sp.coo_matrix):
X = X.tocsr()
X = X.toarray()
else:
X = np.asarray(X)
for i, g in enumerate(genes_to_extract):
feature_data[g] = X[:, i]
for feat in features:
if feat in ad.obs.columns:
feature_data[feat] = ad.obs[feat].values
final_features.append(feat)
elif feat in genes_to_extract:
final_features.append(feat)
else:
missing.append(feat)
if feature_data:
df = pd.concat([df, pd.DataFrame(feature_data, index=df.index)], axis=1)
if missing:
shown = missing[:5]
suffix = f" ({len(missing)} total)" if len(missing) > 5 else ""
raise KeyError(f"Features not found in obs or var_names: {shown}{suffix}")
if not final_features:
raise ValueError("No numeric features found to analyze.")
return df, final_features
def _aggregate_for_did(
df: pd.DataFrame,
final_features: list[str],
design: TrialDesign,
visits: tuple[str, str],
aggregate: AggregateMode,
agg: AggregateFunc,
covariates: list[str] | None,
) -> tuple[pd.DataFrame, str, str, str]:
"""Aggregate the data for DiD analysis."""
if aggregate == "participant_visit":
grp_cols = [design.participant_col, design.visit_col, design.arm_col]
df["n_cells"] = 1
cov_agg: dict[str, str] = {}
if covariates:
for c in covariates:
if pd.api.types.is_numeric_dtype(df[c]):
cov_agg[c] = str(agg)
else:
nunique = df.groupby(grp_cols, observed=True)[c].nunique()
if nunique.max() > 1:
raise ValueError(
f"Covariate '{c}' varies within participant-visit; "
"use numeric or constant covariates only."
)
cov_agg[c] = "first"
df_use = (
df.groupby(grp_cols, observed=True)
.agg(
{
**{f: agg for f in final_features},
**cov_agg,
"n_cells": "sum",
"arm_bin": "first",
}
)
.reset_index()
)
unit = design.participant_col
time = "visit_num"
arm_bin = "arm_bin"
df_use = _ensure_paired(df_use, unit=unit, time=design.visit_col, visits=visits)
df_use = encode_visit(df_use, design.visit_col, visits)
return df_use, unit, time, arm_bin
if aggregate == "participant_visit_celltype":
if design.celltype_col is None:
raise ValueError("celltype_col is None; cannot use participant_visit_celltype")
grp_cols = [design.participant_col, design.visit_col, design.arm_col, design.celltype_col]
df["n_cells"] = 1
cov_agg_ct: dict[str, str] = {}
if covariates:
for c in covariates:
if pd.api.types.is_numeric_dtype(df[c]):
cov_agg_ct[c] = str(agg)
else:
nunique = df.groupby(grp_cols, observed=True)[c].nunique()
if nunique.max() > 1:
raise ValueError(
f"Covariate '{c}' varies within participant-visit-celltype; "
"use numeric or constant covariates only."
)
cov_agg_ct[c] = "first"
df_use = (
df.groupby(grp_cols, observed=True)
.agg(
{
**{f: agg for f in final_features},
**cov_agg_ct,
"n_cells": "sum",
"arm_bin": "first",
}
)
.reset_index()
)
unit = design.participant_col
time = "visit_num"
arm_bin = "arm_bin"
df_use = _ensure_paired(df_use, unit=unit, time=design.visit_col, visits=visits)
df_use = encode_visit(df_use, design.visit_col, visits)
return df_use, unit, time, arm_bin
# cell-level
df_use = df.copy()
unit = design.participant_col
time = "visit_num"
arm_bin = "arm_bin"
df_use = _ensure_paired(df_use, unit=unit, time=design.visit_col, visits=visits)
df_use = encode_visit(df_use, design.visit_col, visits)
return df_use, unit, time, arm_bin
AggregateMode = Literal["cell", "participant_visit", "participant_visit_celltype"]
"""Supported aggregation modes for DiD analysis."""
AggregateFunc = Literal["mean", "median", "pct_pos"]
"""Supported aggregation functions for grouped feature summaries."""
[docs]
class DidFitResult(TypedDict):
"""Structured return for did_fit."""
beta_DiD: float
se_DiD: float
p_DiD: float
beta_time: float
p_time: float
n_units: int
resid_sd: NotRequired[float]
p_DiD_boot: NotRequired[float]
se_DiD_boot: NotRequired[float]
ci_lo_boot: NotRequired[float]
ci_hi_boot: NotRequired[float]
cov_type_used: NotRequired[str]
def _ensure_paired(df: pd.DataFrame, unit: str, time: str, visits: tuple[str, str]) -> pd.DataFrame:
"""Ensure that the data is paired."""
wide = df.groupby([unit, time], observed=True).size().unstack(fill_value=0)
keep = wide[(wide.get(visits[0], 0) > 0) & (wide.get(visits[1], 0) > 0)].index
return df[df[unit].isin(keep)].copy()
[docs]
def did_fit(
df: pd.DataFrame,
y: str,
unit: str,
time: str,
arm_bin: str,
covariates: list[str] | None = None,
cov_type: str = "cluster",
standardize: bool = True,
use_bootstrap: bool = False,
n_boot: int = 999,
seed: int = 42,
) -> DidFitResult:
"""Fit fixed-effects Difference-in-Differences (DiD) model.
**Mathematical Model**
The DiD model with participant fixed effects:
.. math::
Y_{it} = \\alpha_i + \\beta_1 \\cdot \\text{Post}_t
+ \\beta_2 \\cdot (\\text{Treat}_i \\times \\text{Post}_t)
+ \\epsilon_{it}
where :math:`Y_{it}` is the outcome for participant *i* at time *t*,
:math:`\\alpha_i` is a participant-specific intercept (fixed effect),
:math:`\\text{Post}_t` is an indicator for follow-up visit,
:math:`\\text{Treat}_i` is the treatment arm indicator, and
:math:`\\beta_2` is the **DiD coefficient** (causal estimand of interest).
**Null Hypothesis**
H₀: β₂ = 0 (no differential treatment effect over time).
The DiD estimator:
.. math::
\\hat{\\beta}_2 = (\\bar{Y}_{T,post} - \\bar{Y}_{T,pre})
- (\\bar{Y}_{C,post} - \\bar{Y}_{C,pre})
**Statistical Assumptions**
- **Parallel trends**: In absence of treatment, both groups would follow
same trajectory. Cannot be tested directly but can check pre-trends.
- **No anticipation**: Treatment effect only after treatment starts.
- **SUTVA**: No spillover between participants.
- Requires at least 4 unique units (participants) to estimate fixed effects.
Returns NaN for all estimates if n_units < 4.
- Features with near-zero variance (std < 1e-8) return NaN.
- Cluster-robust standard errors account for within-participant correlation.
- If ``n_cells`` column is present, Weighted Least Squares (WLS) is used.
Weights are proportional to n_cells (inverse-variance weighting for
participant-level means, since Var(mean) = σ²/n).
Parameters
----------
df : DataFrame
Long-format data with columns for unit, time, arm_bin, y, and covariates.
y : str
Name of the outcome column.
unit : str
Name of the participant/unit column.
time : str
Name of the time variable (numeric 0/1).
arm_bin : str
Name of the treatment indicator column (0/1).
covariates : list of str, optional
Additional covariate columns to include as fixed effects.
cov_type : str
Covariance type for standard errors ('cluster' recommended).
standardize : bool
If True, z-score the outcome before fitting (recommended for
interpretable effect sizes).
use_bootstrap : bool
If True, use Wild Cluster Bootstrap for p-values (recommended
when n_participants < 15).
n_boot : int
Number of bootstrap iterations (999 or 1999 for publication).
seed : int
Random seed for reproducibility.
Returns
-------
DidFitResult
Dictionary with keys ``beta_DiD``, ``se_DiD``, ``p_DiD``,
``beta_time``, ``p_time``, ``n_units``, ``cov_type_used``.
When ``use_bootstrap=True``, also includes ``p_DiD_boot``,
``se_DiD_boot``, ``ci_lo_boot``, ``ci_hi_boot``.
"""
if df is None or df.empty:
raise ValueError("df must be a non-empty DataFrame.")
if n_boot < 1:
raise ValueError("n_boot must be >= 1.")
cols = [unit, time, arm_bin, y]
if covariates:
cols.extend(covariates)
# Include n_cells for WLS weighting if available
if "n_cells" in df.columns:
cols.append("n_cells")
_validate_did_fit_inputs(df, cols)
tmp = df[cols].dropna().copy().reset_index(drop=True)
n_units = tmp[unit].nunique()
if n_units < 4:
return {
"beta_DiD": np.nan,
"se_DiD": np.nan,
"p_DiD": np.nan,
"beta_time": np.nan,
"p_time": np.nan,
"n_units": n_units,
}
# time is assumed numeric 0/1 already
tmp_opt = _standardize_outcome(tmp, y, standardize, outcome_col="outcome_std")
if tmp_opt is None:
return {
"beta_DiD": np.nan,
"se_DiD": np.nan,
"p_DiD": np.nan,
"beta_time": np.nan,
"p_time": np.nan,
"n_units": n_units,
}
tmp = tmp_opt
if tmp is None:
raise ValueError("Outcome standardization failed unexpectedly.")
formula = _build_did_formula(time, arm_bin, unit, covariates, outcome_col="outcome_std")
# Inverse-variance weighting for pre-aggregated means:
# Var(mean_i) = sigma^2 / n_i, so weight_i = n_i (proportional to 1/Var)
weights = None
if "n_cells" in tmp.columns:
w = tmp["n_cells"].values.astype(float)
if np.all(np.isfinite(w)) and np.all(w > 0):
weights = w
if weights is not None:
model = smf.wls(formula, data=tmp, weights=weights)
else:
model = smf.ols(formula, data=tmp)
# Warn if using cluster-robust SE with few clusters (unreliable inference)
if cov_type == "cluster" and n_units < MIN_CLUSTERS_FOR_ROBUST_SE:
warnings.warn(
f"Only {n_units} clusters (participants) available. Cluster-robust standard "
f"errors are unreliable with fewer than {MIN_CLUSTERS_FOR_ROBUST_SE} clusters. "
f"Consider using use_bootstrap=True for more reliable p-values.",
UserWarning,
stacklevel=2,
)
fit = model.fit(
cov_type=cov_type, cov_kwds={"groups": tmp[unit]} if cov_type == "cluster" else None
)
term = f"{time}:{arm_bin}"
se_did = float(fit.bse.get(term, np.nan))
p_did = float(fit.pvalues.get(term, np.nan))
# Fallback: if cluster-robust SE is NaN (degenerate, e.g. only 2 obs per
# cluster after participant_visit aggregation), re-fit with nonrobust SE.
# The participant FE already absorbs within-participant correlation, so
# homoskedastic SE is a valid (though conservative) alternative.
#
# Justification for nonrobust bootstrap in fallback mode:
# When cluster-robust SE is degenerate (typically because each cluster has
# only 2 observations after participant_visit aggregation), the wild cluster
# bootstrap with nonrobust covariance is methodologically valid because:
# (a) participant fixed effects already absorb within-cluster correlation,
# (b) with only 2 obs per cluster, heteroskedasticity cannot be estimated
# within clusters anyway, and
# (c) the Rademacher sign-flip at the cluster level still provides valid
# finite-sample inference (Webb 2023, J. Econometrics).
effective_cov_type = cov_type
if cov_type == "cluster" and (not np.isfinite(se_did) or not np.isfinite(p_did)):
warnings.warn(
f"Cluster-robust SE is degenerate (NaN) for feature '{y}' with "
f"{int(tmp[unit].nunique())} clusters. Falling back to nonrobust "
f"(homoskedastic) SE. This typically occurs when each cluster has "
f"very few observations (e.g. participant_visit aggregation with "
f"2 visits). Participant fixed effects still absorb within-cluster "
f"correlation. Set use_bootstrap=True for more reliable p-values.",
UserWarning,
stacklevel=2,
)
fit = model.fit() # nonrobust
se_did = float(fit.bse.get(term, np.nan))
p_did = float(fit.pvalues.get(term, np.nan))
effective_cov_type = "nonrobust"
# Report the effective participant count after any model row drops
# (statsmodels may drop rows during formula parsing, e.g. singular FE levels)
model_row_idx = fit.model.data.row_labels
n_units_eff = int(tmp[unit].loc[model_row_idx].nunique())
res = {
"beta_DiD": float(fit.params.get(term, np.nan)),
"se_DiD": se_did,
"p_DiD": p_did,
"beta_time": float(fit.params.get(time, np.nan)),
"p_time": float(fit.pvalues.get(time, np.nan)),
"n_units": n_units_eff,
"resid_sd": float(np.sqrt(fit.scale)),
"cov_type_used": effective_cov_type,
}
if use_bootstrap and term in fit.params:
# Use model_row_idx to align clusters with the actual rows used in the fit
boot_res = wild_cluster_bootstrap_t(
fit,
X=fit.model.exog,
clusters=np.asarray(tmp[unit].loc[model_row_idx].to_numpy()),
term_name=term,
B=n_boot,
seed=seed,
cov_type=effective_cov_type,
)
res["p_DiD_boot"] = boot_res.p_boot
res["se_DiD_boot"] = boot_res.se_boot
res["ci_lo_boot"] = boot_res.ci_lo
res["ci_hi_boot"] = boot_res.ci_hi
# Use bootstrap p-value as primary if requested
res["p_DiD"] = boot_res.p_boot
return cast(DidFitResult, res)
[docs]
def did_table(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
visits: tuple[str, str],
exclude_crossovers: bool = True,
celltype: str | None = None,
aggregate: AggregateMode = "participant_visit",
layer: str | None = None,
standardize: bool = True,
agg: AggregateFunc = "mean",
covariates: list[str] | None = None,
use_bootstrap: bool = False,
n_boot: int = 999,
seed: int = 42,
config: DiDConfig | None = None,
) -> pd.DataFrame:
"""Run Difference-in-Differences (DiD) for a list of features.
This function implements a fixed-effects DiD model to test for treatment-induced
longitudinal changes. It is optimized for 'panels' of features (tens to hundreds),
such as module scores or selected gene sets.
Statistical Model:
y ~ visit + visit:arm + covariates + C(participant)
The 'visit:arm' interaction term coefficient (beta_DiD) is the primary estimand.
Parameters
----------
adata
AnnData object containing expression data and metadata.
features
List of features to test. Can be gene names in `adata.var_names` or
observation-level scores in `adata.obs.columns`.
design
A `TrialDesign` object specifying the metadata columns.
visits
A tuple of (baseline, followup) visit labels.
exclude_crossovers
If True, excludes observations where `design.crossover_col` is True.
Recommended for primary randomized analysis.
config
Optional DiDConfig object. If provided, its values override the corresponding
keyword arguments (aggregate, layer, standardize, agg, covariates, bootstrap,
seed, and exclude_crossovers).
celltype
If provided, subsets the analysis to a specific cell type.
aggregate
Aggregation mode:
- 'cell': Fit model on individual cells (not recommended for p-values,
as it treats cells as independent).
- 'participant_visit': Average features per participant-visit before fitting.
This is the recommended approach for clinical inference.
- 'participant_visit_celltype': Average per participant-visit-celltype.
layer
Layer to extract gene expression from. If None, uses `adata.X`.
standardize
If True, z-scores the outcome variable before fitting to provide
standardized effect sizes.
agg
Aggregation function: 'mean', 'median', or 'pct_pos'.
covariates
List of additional columns in `adata.obs` to include as fixed effects
in the model (e.g., ['age', 'sex', 'batch']).
Covariates must be **numeric** or **constant within each participant-visit**
group. Non-numeric covariates are aggregated with "first" and will raise
an error if they vary within a participant-visit (or participant-visit-celltype).
use_bootstrap
If True, uses Wild Cluster Bootstrap to calculate p-values. Recommended
for small sample sizes (e.g. < 15 participants per group).
n_boot
Number of bootstrap permutations.
seed
Random seed for bootstrap.
Returns
-------
pd.DataFrame
Table with one row per feature containing beta_DiD, p_DiD, and
FDR-corrected significance.
Examples
--------
>>> res = did_table(adata, features=["ms_OXPHOS"], design=design, visits=("V1", "V2"))
>>> print(res[["feature", "beta_DiD", "p_DiD"]])
"""
# subset and prepare
if config is not None:
aggregate = config.aggregate
layer = config.layer
standardize = config.standardize
agg = config.agg
covariates = config.covariates
use_bootstrap = config.use_bootstrap
n_boot = config.n_boot
seed = config.seed
exclude_crossovers = config.exclude_crossovers
ad, obs = _prepare_did_obs(adata, design, visits, celltype, exclude_crossovers)
# build dataframe with features and all possible grouping columns
cols = [design.participant_col, design.visit_col, design.arm_col, "visit_num", "arm_bin"]
if design.celltype_col and design.celltype_col in obs.columns:
cols.append(design.celltype_col)
if covariates:
for c in covariates:
if c not in obs.columns:
raise KeyError(f"Covariate '{c}' not found in adata.obs")
cols.append(c)
df = obs[cols].copy()
df, final_features = _add_feature_columns(df, ad, features, layer)
df_use, unit, time, arm_bin = _aggregate_for_did(
df,
final_features,
design,
visits,
aggregate,
agg,
covariates,
)
rows = []
for feat in final_features:
out = did_fit(
df_use,
y=feat,
unit=unit,
time=time,
arm_bin=arm_bin,
covariates=covariates,
standardize=standardize,
use_bootstrap=use_bootstrap,
n_boot=n_boot,
seed=seed,
)
row = dict(out)
row["feature"] = feat
rows.append(row)
res = pd.DataFrame(rows).sort_values("p_DiD")
res = apply_fdr(res, p_col="p_DiD", fdr_col="FDR_DiD")
return res.reset_index(drop=True)
[docs]
def get_did_aggregated_df(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
visits: tuple[str, str],
*,
layer: str | None = None,
exclude_crossovers: bool = True,
celltype: str | None = None,
aggregate: AggregateMode = "participant_visit",
agg: AggregateFunc = "mean",
covariates: list[str] | None = None,
) -> tuple[pd.DataFrame, str, str, str]:
"""Build aggregated DataFrame for DiD, for use in permutation tests.
Returns (df_use, unit, time, arm_bin) where df_use has one row per
participant-visit with feature values and arm_bin. Permute arm_bin
at participant level and call did_fit for each permutation.
Parameters
----------
adata
AnnData object.
features
List of genes or module scores.
design
A `TrialDesign` object.
visits
Tuple of (pre, post) visit labels.
layer
Layer to use for gene expression.
exclude_crossovers
Whether to drop crossover cells if crossover_col is set.
celltype
The cell type to analyze.
aggregate
Aggregation mode (see `did_table`).
agg
Aggregation function.
covariates
Optional covariate columns to include as fixed effects. Non-numeric
covariates must be constant within participant-visit.
Returns
-------
tuple[pd.DataFrame, str, str, str]
Tuple containing the aggregated DataFrame, the unit column name, the time column name, and the arm_bin column name.
"""
ad, obs = _prepare_did_obs(adata, design, visits, celltype, exclude_crossovers)
cols = [design.participant_col, design.visit_col, design.arm_col, "visit_num", "arm_bin"]
if covariates:
cols.extend(covariates)
df = obs[cols].copy()
df, final_features = _add_feature_columns(df, ad, features, layer)
df_use, unit, time, arm_bin = _aggregate_for_did(
df, final_features, design, visits, aggregate, agg, covariates
)
return df_use, unit, time, arm_bin
[docs]
def did_table_parallel(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
visits: tuple[str, str],
exclude_crossovers: bool = True,
celltype: str | None = None,
aggregate: AggregateMode = "participant_visit",
layer: str | None = None,
standardize: bool = True,
agg: AggregateFunc = "mean",
covariates: list[str] | None = None,
use_bootstrap: bool = False,
n_boot: int = 999,
seed: int = 42,
n_jobs: int = -1,
backend: str = "loky",
batch_size: int | None = None,
config: DiDConfig | None = None,
) -> pd.DataFrame:
"""Parallelized DiD across features using joblib.
This mirrors `did_table` but parallelizes feature-level model fits. It is
most useful for large feature panels (hundreds to thousands of genes).
Parameters
----------
n_jobs
Number of parallel jobs (joblib convention). Use -1 for all cores.
backend
Joblib backend ("loky", "multiprocessing", or "threading").
batch_size
Optional joblib batch size. If None, joblib chooses.
"""
if n_jobs == 1:
return did_table(
adata=adata,
features=features,
design=design,
visits=visits,
exclude_crossovers=exclude_crossovers,
celltype=celltype,
aggregate=aggregate,
layer=layer,
standardize=standardize,
agg=agg,
covariates=covariates,
use_bootstrap=use_bootstrap,
n_boot=n_boot,
seed=seed,
config=config,
)
try:
from joblib import Parallel, delayed
except ImportError as exc: # pragma: no cover
raise ImportError("joblib is required for did_table_parallel") from exc
if config is not None:
aggregate = config.aggregate
layer = config.layer
standardize = config.standardize
agg = config.agg
covariates = config.covariates
use_bootstrap = config.use_bootstrap
n_boot = config.n_boot
seed = config.seed
exclude_crossovers = config.exclude_crossovers
ad, obs = _prepare_did_obs(adata, design, visits, celltype, exclude_crossovers)
cols = [design.participant_col, design.visit_col, design.arm_col, "visit_num", "arm_bin"]
if design.celltype_col and design.celltype_col in obs.columns:
cols.append(design.celltype_col)
if covariates:
for c in covariates:
if c not in obs.columns:
raise KeyError(f"Covariate '{c}' not found in adata.obs")
cols.append(c)
df = obs[cols].copy()
df, final_features = _add_feature_columns(df, ad, features, layer)
df_use, unit, time, arm_bin = _aggregate_for_did(
df,
final_features,
design,
visits,
aggregate,
agg,
covariates,
)
# Generate independent seeds for each parallel feature using SeedSequence
ss = np.random.SeedSequence(seed)
feature_seeds = [int(s.generate_state(1)[0]) for s in ss.spawn(len(final_features))]
def _fit_feature(idx: int, feat: str) -> dict:
"""Fit a feature using did_fit."""
out = did_fit(
df_use,
y=feat,
unit=unit,
time=time,
arm_bin=arm_bin,
covariates=covariates,
standardize=standardize,
use_bootstrap=use_bootstrap,
n_boot=n_boot,
seed=feature_seeds[idx],
)
row = dict(out)
row["feature"] = feat
return row
job_batch = "auto" if batch_size is None else batch_size
rows = Parallel(n_jobs=n_jobs, backend=backend, batch_size=job_batch)(
delayed(_fit_feature)(i, feat) for i, feat in enumerate(final_features)
)
res = pd.DataFrame(rows).sort_values("p_DiD")
res = apply_fdr(res, p_col="p_DiD", fdr_col="FDR_DiD")
return res.reset_index(drop=True)
[docs]
def did_table_by_celltype(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
visits: tuple[str, str],
celltypes: Sequence[str] | None = None,
exclude_crossovers: bool = True,
aggregate: AggregateMode = "participant_visit",
layer: str | None = None,
standardize: bool = True,
agg: AggregateFunc = "mean",
covariates: list[str] | None = None,
use_bootstrap: bool = False,
n_boot: int = 999,
seed: int = 42,
) -> pd.DataFrame:
"""Run `did_table` stratified by cell type.
Parameters
----------
adata
AnnData object.
features
Genes or module scores to test.
design
A `TrialDesign` object.
visits
Two visit labels for comparison.
celltypes
Subset of cell types to analyze. If None, uses all in `design.celltype_col`.
exclude_crossovers
Exclude cells marked as crossovers.
aggregate
Level of aggregation.
layer
Expression layer.
standardize
Standardize expression to unit variance.
agg
Aggregation function.
covariates
Obs columns to include as covariates.
use_bootstrap
Use Wild Cluster Bootstrap for p-values.
n_boot
Number of bootstrap iterations.
seed
Random seed.
Returns
-------
pd.DataFrame
Table with DiD results for each gene and cell type.
"""
if design.celltype_col is None:
raise ValueError("design.celltype_col must be set for stratified analysis.")
if celltypes is None:
celltypes = sorted(adata.obs[design.celltype_col].dropna().unique())
all_res = []
for ct in celltypes:
try:
res_ct = did_table(
adata,
features=features,
design=design,
visits=visits,
celltype=ct,
exclude_crossovers=exclude_crossovers,
aggregate=aggregate,
layer=layer,
standardize=standardize,
agg=agg,
covariates=covariates,
use_bootstrap=use_bootstrap,
n_boot=n_boot,
seed=seed,
)
res_ct["celltype"] = ct
all_res.append(res_ct)
except (ValueError, np.linalg.LinAlgError, KeyError) as e:
# Common to fail if celltype has too few cells/participants
warnings.warn(f"Failed DiD for celltype '{ct}': {e}")
continue
if not all_res:
return pd.DataFrame()
full_res = pd.concat(all_res, ignore_index=True)
# Recalculate FDR across all tests
full_res = apply_fdr(full_res, p_col="p_DiD", fdr_col="FDR_DiD_stratified")
return full_res