Source code for sctrial.stats.cv

"""Cross-validation utilities for effect stability assessment.

This module provides leave-one-out (LOO) and k-fold cross-validation
for assessing the robustness of DiD estimates.

Why Cross-Validation for DiD?
-----------------------------
Cross-validation in the DiD context serves different purposes than in
predictive modeling:

1. **Influence diagnostics**: Identify participants with outsized influence
   on effect estimates (potential outliers or data quality issues).

2. **Effect stability**: Assess how robust the DiD estimate is to the
   exclusion of individual participants.

3. **Generalizability**: Estimate how well the effect might replicate
   in new samples (though true generalization requires new data).

Leave-One-Out (LOO)
-------------------
For each participant i:
    1. Fit DiD model excluding participant i
    2. Record beta_DiD^(-i)
    3. Compare to full-sample beta_DiD

Metrics:
- Influence: ``|beta_DiD - beta_DiD^(-i)| / SE(beta_DiD)``
- Cook's D analog for DiD

K-Fold Cross-Validation
-----------------------
1. Randomly partition participants into K folds
2. For each fold k:
   - Fit DiD on participants NOT in fold k
   - Record estimate
3. Report mean, SD, and CI of estimates

Interpretation Guidelines
-------------------------
- High LOO variance: Effect driven by few participants
- Consistently signed CV estimates: Robust effect
- Estimate changes sign across folds: Unreliable effect
"""

from __future__ import annotations

import warnings
from collections.abc import Sequence

import numpy as np
import pandas as pd
from anndata import AnnData

from ..adata_tools import subset_primary
from ..design import TrialDesign
from .did import AggregateMode, did_table

__all__ = [
    "loo_cv_did",
    "kfold_cv_did",
    "influence_diagnostics",
    "cv_summary",
]


[docs] def loo_cv_did( adata: AnnData, features: Sequence[str], design: TrialDesign, visits: tuple[str, str], layer: str | None = None, exclude_crossovers: bool = True, aggregate: AggregateMode = "participant_visit", standardize: bool = True, ) -> pd.DataFrame: """Leave-one-out cross-validation for DiD analysis. For each participant, fits the DiD model without that participant and records the effect estimate. This reveals which participants have the largest influence on the results. Parameters ---------- adata AnnData object. features Features to analyze. design TrialDesign object. visits (baseline, followup) visits. layer Expression layer. exclude_crossovers Exclude crossover participants. aggregate Aggregation mode passed to did_table (default "participant_visit"). standardize Whether to z-score the outcome variable before fitting (default True). Returns ------- pd.DataFrame Long-format results with columns: - feature: Feature name - excluded: ID of excluded participant - beta_DiD: DiD estimate without this participant - se_DiD: Standard error - influence: ``|beta_full - beta_loo| / SE`` Examples -------- >>> loo = loo_cv_did(adata, features=["sig_IFN"], design=design, visits=visits) >>> # Find influential participants >>> influential = loo[loo["influence"] > 1.0] >>> print(f"Influential participants: {influential['excluded'].unique()}") """ # Get paired participants ad = subset_primary(adata, design, visits, exclude_crossovers=exclude_crossovers) participants = ad.obs[design.participant_col].unique().tolist() # Full-sample estimates for comparison full_res = did_table( adata, features, design, visits, exclude_crossovers=exclude_crossovers, layer=layer, aggregate=aggregate, standardize=standardize, ) full_betas = full_res.set_index("feature")["beta_DiD"].to_dict() full_ses = full_res.set_index("feature")["se_DiD"].to_dict() # LOO iterations rows = [] for pid in participants: # Exclude this participant mask = adata.obs[design.participant_col] != pid ad_loo = adata[mask].copy() try: res_loo = did_table( ad_loo, features, design, visits, exclude_crossovers=exclude_crossovers, layer=layer, aggregate=aggregate, standardize=standardize, ) for _, row in res_loo.iterrows(): feat = row["feature"] beta_loo = row["beta_DiD"] se_loo = row["se_DiD"] # Influence measure beta_full = full_betas.get(feat, np.nan) se_full = full_ses.get(feat, np.nan) if not np.isnan(beta_full) and not np.isnan(beta_loo) and se_full > 0: influence = abs(beta_full - beta_loo) / se_full else: influence = np.nan rows.append( { "feature": feat, "excluded": pid, "beta_DiD": beta_loo, "se_DiD": se_loo, "beta_full": beta_full, "influence": influence, } ) except (ValueError, np.linalg.LinAlgError, KeyError) as exc: warnings.warn( f"LOO CV: model failed when excluding participant {pid}: {exc}", stacklevel=2, ) for feat in features: rows.append( { "feature": feat, "excluded": pid, "beta_DiD": np.nan, "se_DiD": np.nan, "beta_full": full_betas.get(feat, np.nan), "influence": np.nan, } ) return pd.DataFrame(rows)
[docs] def kfold_cv_did( adata: AnnData, features: Sequence[str], design: TrialDesign, visits: tuple[str, str], k: int = 5, n_repeats: int = 10, seed: int = 42, layer: str | None = None, exclude_crossovers: bool = True, ) -> pd.DataFrame: """K-fold cross-validation for DiD effect stability. Randomly partitions participants into K folds and estimates DiD using K-1 folds. Repeating this process gives a distribution of estimates that reflects sampling variability. Parameters ---------- adata AnnData object. features Features to analyze. design TrialDesign object. visits (baseline, followup) visits. k Number of folds (default 5). n_repeats Number of times to repeat CV (default 10). seed Random seed. layer Expression layer. exclude_crossovers Exclude crossover participants. Returns ------- pd.DataFrame Results with one row per feature: - feature: Feature name - beta_full: Full-sample estimate - beta_cv_mean: Mean CV estimate - beta_cv_sd: SD of CV estimates - beta_cv_lower, beta_cv_upper: 2.5-97.5 percentiles - cv_stability: beta_cv_mean / beta_cv_sd (higher = more stable) - sign_consistency: Proportion of CV estimates with same sign as full Examples -------- >>> cv = kfold_cv_did(adata, features=genes, design=design, visits=visits, k=5) >>> # Check stability >>> stable = cv[cv["sign_consistency"] > 0.9] >>> print(f"Stable effects: {len(stable)} / {len(cv)}") """ rng = np.random.default_rng(seed) # Get paired participants ad = subset_primary(adata, design, visits, exclude_crossovers=exclude_crossovers) participants = ad.obs[design.participant_col].unique().tolist() n_participants = len(participants) if n_participants < k: k = max(2, n_participants // 2) # Full-sample estimates full_res = did_table( adata, features, design, visits, exclude_crossovers=exclude_crossovers, layer=layer ) full_betas = full_res.set_index("feature")["beta_DiD"].to_dict() # Collect CV estimates cv_estimates: dict[str, list[float]] = {feat: [] for feat in features} n_failed_folds = 0 # Build a mapping from participant to arm for stratified assignment _arm_map = ( ad.obs[[design.participant_col, design.arm_col]] .drop_duplicates(subset=[design.participant_col]) .set_index(design.participant_col)[design.arm_col] ) arms = _arm_map.unique() # Group participants by arm arm_groups = {arm: [p for p in participants if _arm_map[p] == arm] for arm in arms} for _ in range(n_repeats): # Stratified fold assignment: shuffle within each arm, then assign # folds so that every fold has participants from each arm. fold_assignments: dict[int, list] = {fi: [] for fi in range(k)} for arm in arms: members = arm_groups[arm] shuffled_arm = rng.permutation(members).tolist() for idx, pid in enumerate(shuffled_arm): fold_assignments[idx % k].append(pid) # K-fold splits for fold_idx in range(k): excluded = set(fold_assignments[fold_idx]) included = [p for p in participants if p not in excluded] if len(included) < 4: # Need minimum participants continue # Verify both arms are represented in the training fold included_arms = {_arm_map[p] for p in included} if len(included_arms) < len(arms): continue # Subset data mask = adata.obs[design.participant_col].isin(included) ad_fold = adata[mask].copy() try: res_fold = did_table( ad_fold, features, design, visits, exclude_crossovers=exclude_crossovers, layer=layer, ) for _, row in res_fold.iterrows(): feat = row["feature"] if feat in cv_estimates: cv_estimates[feat].append(row["beta_DiD"]) except (ValueError, np.linalg.LinAlgError, KeyError): n_failed_folds += 1 if n_failed_folds > 0: total_folds = n_repeats * k warnings.warn( f"K-fold CV: {n_failed_folds}/{total_folds} folds failed to fit.", stacklevel=2, ) # Summarize CV estimates rows = [] for feat in features: estimates = np.array(cv_estimates[feat]) valid = estimates[~np.isnan(estimates)] beta_full = full_betas.get(feat, np.nan) if len(valid) >= 3: cv_mean = np.mean(valid) cv_sd = np.std(valid, ddof=1) cv_lower = np.percentile(valid, 2.5) cv_upper = np.percentile(valid, 97.5) stability = cv_mean / cv_sd if cv_sd > 0 else np.nan # Sign consistency if not np.isnan(beta_full): same_sign = np.sum(np.sign(valid) == np.sign(beta_full)) sign_consistency = same_sign / len(valid) else: sign_consistency = np.nan else: cv_mean = cv_sd = cv_lower = cv_upper = stability = sign_consistency = np.nan rows.append( { "feature": feat, "beta_full": beta_full, "beta_cv_mean": cv_mean, "beta_cv_sd": cv_sd, "beta_cv_lower": cv_lower, "beta_cv_upper": cv_upper, "cv_stability": stability, "sign_consistency": sign_consistency, "n_cv_samples": len(valid), } ) return pd.DataFrame(rows)
[docs] def influence_diagnostics( loo_results: pd.DataFrame, threshold: float = 1.0, ) -> pd.DataFrame: """Summarize influence diagnostics from LOO results. Parameters ---------- loo_results Output from loo_cv_did. threshold Influence threshold for flagging (default 1.0 = 1 SE shift). Returns ------- pd.DataFrame Per-participant influence with columns: - feature: Feature name - excluded: ID of excluded participant - influence: Influence score for this participant - beta_DiD: DiD estimate without this participant - is_influential: Whether influence > threshold """ if loo_results.empty: return pd.DataFrame() # Return per-participant results with influence flag result = loo_results.copy() result["is_influential"] = result["influence"].abs() > threshold return result[["feature", "excluded", "influence", "beta_DiD", "beta_full", "is_influential"]]
[docs] def cv_summary( cv_results: pd.DataFrame, alpha: float = 0.05, ) -> pd.DataFrame: """Generate a summary of CV results (works with both LOO and k-fold). Parameters ---------- cv_results Output from loo_cv_did or kfold_cv_did. alpha Significance level for classification. Returns ------- pd.DataFrame Summary per feature with: - feature: Feature name - mean_estimate: Full-sample estimate (or mean LOO estimate) - mean_loo: Mean of LOO estimates - std_loo: Standard deviation of LOO estimates - cv: Coefficient of variation (std/mean) """ df = cv_results.copy() # Detect LOO vs kfold format is_loo = "excluded" in df.columns if is_loo: # LOO format: one row per feature-participant rows = [] for feat, group in df.groupby("feature"): betas = group["beta_DiD"].dropna() beta_full = group["beta_full"].iloc[0] if "beta_full" in group.columns else np.nan if len(betas) >= 2: mean_loo = betas.mean() std_loo = betas.std(ddof=1) cv = abs(std_loo / mean_loo) if abs(mean_loo) > 1e-12 else np.nan else: mean_loo = std_loo = cv = np.nan rows.append( { "feature": feat, "mean_estimate": beta_full, "mean_loo": mean_loo, "std_loo": std_loo, "cv": cv, "n_loo": len(betas), } ) return pd.DataFrame(rows) else: # kfold format: one row per feature result = df[["feature"]].copy() result["mean_estimate"] = df.get("beta_full", np.nan) result["mean_loo"] = df.get("beta_cv_mean", np.nan) result["std_loo"] = df.get("beta_cv_sd", np.nan) # Calculate CV result["cv"] = np.where( result["mean_loo"].abs() > 1e-12, result["std_loo"].abs() / result["mean_loo"].abs(), np.nan, ) return result