from __future__ import annotations
import warnings
from collections.abc import Sequence
from typing import Literal
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from anndata import AnnData
from ..design import TrialDesign
from ._utils import apply_fdr, encode_visit, standardize_series
from .did import (
MIN_CLUSTERS_FOR_ROBUST_SE,
AggregateFunc,
AggregateMode,
_add_feature_columns,
_prepare_did_obs,
)
[docs]
def test_treatment_heterogeneity(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
visits: tuple[str, str],
biomarker_col: str,
*,
threshold: float | None = None,
aggregate: AggregateMode = "participant_visit",
layer: str | None = None,
standardize: bool = True,
agg: AggregateFunc = "mean",
covariates: list[str] | None = None,
fdr: Literal["feature", "all"] = "feature",
) -> pd.DataFrame:
"""Test treatment effect heterogeneity via a 3-way interaction.
Fits: outcome ~ visit + arm + biomarker + visit:arm + visit:biomarker
+ arm:biomarker + visit:arm:biomarker + C(participant)
The coefficient of the 3-way interaction (visit:arm:biomarker) tests whether
the DiD effect differs by biomarker subgroup.
Parameters
----------
adata
AnnData object.
features
Features to test (genes or obs columns).
design
TrialDesign object.
visits
(baseline, followup) visit labels.
biomarker_col
Column in `adata.obs` for stratification.
threshold
If provided, biomarker_high = biomarker > threshold.
If None, median split is used for numeric biomarkers. For binary
biomarkers (0/1 or True/False), values are used directly.
aggregate
Aggregation mode (default: participant_visit).
layer
Expression layer for genes.
standardize
Whether to z-score outcomes before fitting.
agg
Aggregation function.
covariates
Additional covariates for the model.
fdr
FDR correction mode: per feature or across all results.
Returns
-------
pd.DataFrame
Table with heterogeneity effects for each feature.
Examples
--------
>>> res = test_treatment_heterogeneity(
... adata,
... features=["sig_IFN_Response", "sig_Cytotoxicity"],
... design=design,
... visits=("Pre", "Post"),
... biomarker_col="baseline_crp",
... threshold=5.0,
... )
>>> res[["feature", "beta_heterogeneity", "p_heterogeneity"]].head()
"""
if biomarker_col not in adata.obs.columns:
raise KeyError(f"biomarker_col '{biomarker_col}' not found in adata.obs")
ad, obs = _prepare_did_obs(adata, design, visits, celltype=None, exclude_crossovers=True)
obs["arm_bin"] = (obs[design.arm_col] == design.arm_treated).astype(int)
biomarker = adata.obs.loc[obs.index, biomarker_col]
if pd.api.types.is_bool_dtype(biomarker) or set(pd.unique(biomarker.dropna())) <= {0, 1}:
biomarker_high = biomarker.astype(int)
else:
biomarker = pd.to_numeric(biomarker, errors="coerce")
if biomarker.isna().any():
warnings.warn(
f"biomarker_col '{biomarker_col}' contained non-numeric values that were coerced to NaN.",
UserWarning,
stacklevel=2,
)
if threshold is None:
threshold = float(np.nanmedian(biomarker))
biomarker_high = (biomarker > threshold).astype(int)
obs["biomarker_high"] = biomarker_high.values
# Add feature columns (genes or obs)
obs_feat, final_features = _add_feature_columns(obs.copy(), ad, list(features), layer)
if aggregate != "participant_visit":
raise ValueError("test_treatment_heterogeneity supports participant_visit only.")
grp_cols = [design.participant_col, design.visit_col, design.arm_col]
use_cols = list(final_features) + ["biomarker_high", "arm_bin"]
df_use = (
obs_feat.groupby(grp_cols, observed=True)[use_cols].mean(numeric_only=True).reset_index()
)
df_use = encode_visit(df_use, design.visit_col, visits)
unit = design.participant_col
time = "visit_num"
arm_bin = "arm_bin"
obs_feat = df_use
rows = []
for feat in final_features:
tmp = obs_feat.copy()
if tmp.columns.duplicated().any():
tmp = tmp.loc[:, ~tmp.columns.duplicated()].copy()
tmp = tmp.dropna(subset=[feat, "biomarker_high"])
n_clusters = tmp[unit].nunique()
if n_clusters < MIN_CLUSTERS_FOR_ROBUST_SE:
warnings.warn(
f"Feature '{feat}' skipped: only {n_clusters} clusters "
f"(need >= {MIN_CLUSTERS_FOR_ROBUST_SE} for robust SE).",
stacklevel=2,
)
continue
# Handle duplicate column names by selecting the first column
col_data = tmp[feat]
if isinstance(col_data, pd.DataFrame):
col_data = col_data.iloc[:, 0]
tmp[feat] = col_data
if standardize:
y_std, ok = standardize_series(tmp, feat, min_std=1e-8)
if not ok:
continue
tmp["outcome_std"] = y_std
else:
tmp["outcome_std"] = tmp[feat].astype(float)
df = tmp.dropna(subset=["outcome_std"]).reset_index(drop=True)
if df["biomarker_high"].nunique() < 2:
continue
formula = (
f"outcome_std ~ {time} + {arm_bin} + biomarker_high + "
f"{time}:{arm_bin} + {time}:biomarker_high + {arm_bin}:biomarker_high + "
f"{time}:{arm_bin}:biomarker_high + C({unit})"
)
if covariates:
formula += " + " + " + ".join(covariates)
fit = smf.ols(formula, data=df).fit(
cov_type="cluster",
cov_kwds={"groups": df[unit]},
)
term = "visit_num:arm_bin:biomarker_high"
# Report effective participant count after model row drops
model_row_idx = fit.model.data.row_labels
n_units_eff = int(df[unit].loc[model_row_idx].nunique())
rows.append(
{
"feature": feat,
"beta_heterogeneity": float(fit.params.get(term, np.nan)),
"p_heterogeneity": float(fit.pvalues.get(term, np.nan)),
"n_units": n_units_eff,
"threshold": threshold,
}
)
res = pd.DataFrame(rows)
if res.empty:
return res
if fdr in {"all", "feature"}:
res = apply_fdr(res, p_col="p_heterogeneity", fdr_col="FDR_heterogeneity")
return res