Source code for sctrial.stats.module_scores

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from anndata import AnnData
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

from ..design import TrialDesign
from ._utils import encode_visit

__all__ = [
    "module_score_pseudobulk",
    "module_score_did_by_pool",
    "module_score_within_arm_by_pool",
]


def _map_pool(celltype: str, pool_map: dict[str, Sequence[str]] | None) -> str | None:
    """Map celltype to pool using the pool_map."""
    if pool_map is None:
        return None
    for pool, labels in pool_map.items():
        if celltype in labels:
            return pool
    return None


[docs] def module_score_pseudobulk( adata: AnnData, module_cols: Sequence[str], design: TrialDesign, visits: tuple[str, str], *, pool_col: str | None = None, pool_map: dict[str, Sequence[str]] | None = None, celltype_col: str | None = None, min_cells_per_group: int = 5, exclude_crossovers: bool = True, ) -> pd.DataFrame: """Build pseudobulk module scores (participant × visit × pool × module). Parameters ---------- adata AnnData object. module_cols Columns in adata.obs with module scores. design TrialDesign object. visits Tuple of (baseline, followup) visit labels. pool_col Column in adata.obs to use as pool labels directly. pool_map Mapping of pool -> list of celltypes. Used if pool_col is None. celltype_col Column for fine cell types (used with pool_map). min_cells_per_group Minimum cells per participant×visit×pool to keep. exclude_crossovers Whether to exclude crossover cells. Returns ------- pd.DataFrame Long-format table with columns: participant, visit, arm, pool, module, ``module_score``, and ``n_cells``. """ df = adata.obs.copy() if exclude_crossovers and design.crossover_col and design.crossover_col in df.columns: df = df[df[design.crossover_col] == 0].copy() df = df[df[design.visit_col].isin(visits)].copy() missing = [c for c in module_cols if c not in df.columns] if missing: raise KeyError(f"Module score columns missing: {missing[:5]}") if pool_col is None: if pool_map is None or celltype_col is None: raise ValueError("Provide pool_col or (pool_map + celltype_col).") df["pool"] = df[celltype_col].apply(lambda x: _map_pool(x, pool_map)) else: if pool_col not in df.columns: raise KeyError(f"pool_col '{pool_col}' not found in adata.obs") df["pool"] = df[pool_col] df = df.dropna(subset=["pool"]) # Build column lists conditionally — arm_col may be None for single-arm id_cols = [design.participant_col, design.visit_col] if design.arm_col is not None: df[design.arm_col] = df[design.arm_col].astype(str).str.strip() id_cols.append(design.arm_col) # Long format: participant × visit × [arm] × pool × module melt_id = id_cols + ["pool"] long_df = df[melt_id + list(module_cols)].melt( id_vars=melt_id, value_vars=list(module_cols), var_name="module", value_name="module_score", ) long_df = long_df.dropna(subset=["module_score"]) # Pseudobulk aggregation pb = ( long_df.groupby( melt_id + ["module"], observed=True, )["module_score"] .mean() .reset_index() ) pb = pb[pb["module_score"].notna()].copy() counts = df.groupby(id_cols + ["pool"], observed=True).size().reset_index(name="n_cells") pb = pb.merge( counts, on=id_cols + ["pool"], how="left", ) pb = pb[pb["n_cells"] >= min_cells_per_group].copy() return pb
def _perm_test_diff( delta: pd.Series, arms: pd.Series, n_perm: int, seed: int, treated_label: str | None = None, ) -> float: """Permutation test for difference between arm-level mean deltas. Parameters ---------- delta Participant-level change scores (post − pre). arms Arm labels aligned to *delta*. n_perm Number of random permutations. seed Random seed for reproducibility. treated_label Label identifying the treated arm. If ``None``, the first element of *arms* is used. Returns ------- float Two-sided permutation p-value in ``[0, 1]``. """ rng = np.random.default_rng(seed) values = delta.to_numpy() labels = arms.to_numpy() if treated_label is None: treated_label = labels[0] obs = values[labels == treated_label].mean() - values[labels != treated_label].mean() count = 0 for _ in range(n_perm): perm = rng.permutation(np.asarray(labels)) perm_diff = values[perm == treated_label].mean() - values[perm != treated_label].mean() if np.abs(perm_diff) >= np.abs(obs): count += 1 return float((count + 1) / (n_perm + 1))
[docs] def module_score_did_by_pool( pb: pd.DataFrame, design: TrialDesign, visits: tuple[str, str], *, min_paired: int = 2, n_perm: int = 1000, seed: int = 42, fdr_within: str | None = "module", fdr_global: bool = True, allow_unpaired: bool = False, ) -> pd.DataFrame: """Compute DiD on module scores by pool with permutation p-values. Parameters ---------- pb Output of module_score_pseudobulk(). design TrialDesign object. visits Tuple of (baseline, followup) visit labels. n_perm Number of permutations for DiD p-values. seed Random seed. fdr_within If "module", FDR is computed within each module across pools. If "pool", FDR is computed within each pool across modules. If None, global FDR. fdr_global Only used when *fdr_within* is not None. If True (default), an additional ``FDR_DiD_global`` column is added that applies BH-FDR correction globally across **all** tests, mirroring the behaviour when ``fdr_within=None``. This is useful because per-group FDR (``FDR_DiD``) controls the false discovery rate only within each group and does **not** control the overall false discovery rate across all tests. allow_unpaired If True, fit an unpaired OLS DiD (module_score ~ visit + arm + visit×arm) using all available participant-visit observations. This is a fallback when no paired participants exist and should be interpreted cautiously. Returns ------- pd.DataFrame One row per (pool, module) with columns: ``pool``, ``module``, ``mean_delta_treated``, ``mean_delta_control``, ``beta_DiD``, ``p_DiD``, ``p_treated``, ``p_control``, ``n_units``, ``FDR_DiD``. When *fdr_within* is set and *fdr_global* is True, an additional ``FDR_DiD_global`` column contains the globally-corrected q-values. """ rows: list[dict[str, Any]] = [] arm_treated = str(design.arm_treated).strip() arm_control = str(design.arm_control).strip() for (pool, module), sub in pb.groupby(["pool", "module"], observed=True): sub = sub[sub[design.visit_col].isin(visits)].copy() sub[design.arm_col] = sub[design.arm_col].astype(str).str.strip() sub = sub[sub[design.arm_col].isin([arm_treated, arm_control])].copy() if sub[design.visit_col].nunique() < 2 or sub[design.arm_col].nunique() < 2: continue if allow_unpaired: df = encode_visit(sub.copy(), design.visit_col, visits).reset_index(drop=True) df["arm_bin"] = (df[design.arm_col] == arm_treated).astype(int) model = smf.ols("module_score ~ visit_num + arm_bin + visit_num:arm_bin", data=df) fit = model.fit(cov_type="HC1") # Compute mean deltas from unpaired group means treated = df[df["arm_bin"] == 1] control = df[df["arm_bin"] == 0] treated_means = treated.groupby(design.visit_col)["module_score"].mean() control_means = control.groupby(design.visit_col)["module_score"].mean() if visits[0] not in treated_means or visits[1] not in treated_means: continue if visits[0] not in control_means or visits[1] not in control_means: continue mean_delta_treated = float(treated_means.loc[visits[1]] - treated_means.loc[visits[0]]) mean_delta_control = float(control_means.loc[visits[1]] - control_means.loc[visits[0]]) # Within-arm visit effect (unpaired OLS) for reporting p-values p_treated = np.nan p_control = np.nan try: fit_t = smf.ols("module_score ~ visit_num", data=treated).fit(cov_type="HC1") p_treated = float(fit_t.pvalues.get("visit_num", np.nan)) except (ValueError, TypeError): p_treated = np.nan try: fit_c = smf.ols("module_score ~ visit_num", data=control).fit(cov_type="HC1") p_control = float(fit_c.pvalues.get("visit_num", np.nan)) except (ValueError, TypeError): p_control = np.nan # Report effective participant count after model row drops model_row_idx = fit.model.data.row_labels n_units_eff = int(df[design.participant_col].loc[model_row_idx].nunique()) rows.append( { "pool": pool, "module": module, "mean_delta_treated": mean_delta_treated, "mean_delta_control": mean_delta_control, "beta_DiD": float(fit.params.get("visit_num:arm_bin", np.nan)), "p_DiD": float(fit.pvalues.get("visit_num:arm_bin", np.nan)), "p_treated": p_treated, "p_control": p_control, "n_units": n_units_eff, } ) continue wide = sub.pivot_table( index=design.participant_col, columns=design.visit_col, values="module_score", aggfunc="mean", ) if visits[0] not in wide.columns or visits[1] not in wide.columns: continue wide = wide.dropna() if len(wide) < min_paired: continue # attach arm labels pid_arm = sub.groupby(design.participant_col)[design.arm_col].first().to_dict() wide["arm"] = wide.index.map(pid_arm) # compute deltas wide["delta"] = wide[visits[1]] - wide[visits[0]] deltas = wide.dropna(subset=["delta"]) if deltas["arm"].nunique() < 2: continue # paired within-arm tests p_arm = {} for arm_label in deltas["arm"].unique(): sub_arm = deltas[deltas["arm"] == arm_label] if len(sub_arm) >= 3: delta_vals = sub_arm[visits[1]].values - sub_arm[visits[0]].values if np.allclose(delta_vals, 0): p = 1.0 else: try: _, p = wilcoxon(delta_vals) except (ValueError, TypeError): p = np.nan else: p = np.nan p_arm[arm_label] = p arm_means = deltas.groupby("arm", observed=True)["delta"].mean() mean_delta_treated = float(arm_means.get(arm_treated, np.nan)) mean_delta_control = float(arm_means.get(arm_control, np.nan)) did = mean_delta_treated - mean_delta_control p_did = _perm_test_diff( deltas["delta"], deltas["arm"], n_perm=n_perm, seed=seed, treated_label=arm_treated, ) rows.append( { "pool": pool, "module": module, "mean_delta_treated": mean_delta_treated, "mean_delta_control": mean_delta_control, "beta_DiD": float(did), "p_DiD": float(p_did), "p_treated": float(p_arm.get(arm_treated, np.nan)), "p_control": float(p_arm.get(arm_control, np.nan)), "n_units": int(deltas.index.nunique()), } ) res = pd.DataFrame(rows) if res.empty: return res if fdr_within is None: mask = res["p_DiD"].notna() res["FDR_DiD"] = np.nan if mask.sum() > 0: res.loc[mask, "FDR_DiD"] = multipletests(res.loc[mask, "p_DiD"], method="fdr_bh")[1] else: warnings.warn( f"FDR correction is applied within each '{fdr_within}' group " f"(column FDR_DiD). Per-group FDR does not control the overall " f"false discovery rate across all tests. Consult " f"FDR_DiD_global (when fdr_global=True) for a globally " f"corrected q-value.", stacklevel=2, ) res["FDR_DiD"] = np.nan for key, sub in res.groupby(fdr_within, observed=True): mask = sub["p_DiD"].notna() if mask.sum() > 0: res.loc[sub.index[mask], "FDR_DiD"] = multipletests( sub.loc[mask, "p_DiD"], method="fdr_bh" )[1] if fdr_global: mask = res["p_DiD"].notna() res["FDR_DiD_global"] = np.nan if mask.sum() > 0: res.loc[mask, "FDR_DiD_global"] = multipletests( res.loc[mask, "p_DiD"], method="fdr_bh" )[1] return res
[docs] def module_score_within_arm_by_pool( pb: pd.DataFrame, design: TrialDesign, visits: tuple[str, str], *, min_paired: int = 3, fdr_within: str | None = "module", ) -> pd.DataFrame: """Within-arm paired tests on module scores by pool. Parameters ---------- pb Output of module_score_pseudobulk(). design TrialDesign object. visits Tuple of (baseline, followup) visit labels. min_paired Minimum paired participants per pool/module. fdr_within If "module", FDR is computed within each module across pools. If "pool", FDR is computed within each pool across modules. If None, global FDR. Returns ------- pd.DataFrame One row per (pool, module) with columns: ``pool``, ``module``, ``mean_delta``, ``p_time``, ``n_units``, ``FDR_time``. """ rows: list[dict[str, Any]] = [] for (pool, module), sub in pb.groupby(["pool", "module"], observed=True): wide = sub.pivot_table( index=design.participant_col, columns=design.visit_col, values="module_score", aggfunc="mean", ) if visits[0] not in wide.columns or visits[1] not in wide.columns: continue wide = wide.dropna() if len(wide) < min_paired: continue pre = wide[visits[0]].to_numpy(dtype=float) post = wide[visits[1]].to_numpy(dtype=float) delta = post - pre try: _, p_val = wilcoxon(delta) except (ValueError, TypeError): p_val = np.nan rows.append( { "pool": pool, "module": module, "mean_delta": float(delta.mean()), "p_time": float(p_val), "n_units": int(len(wide)), } ) res = pd.DataFrame(rows) if res.empty: return res if fdr_within is None: mask = res["p_time"].notna() res["FDR_time"] = np.nan if mask.sum() > 0: res.loc[mask, "FDR_time"] = multipletests(res.loc[mask, "p_time"], method="fdr_bh")[1] else: res["FDR_time"] = np.nan for key, sub in res.groupby(fdr_within, observed=True): mask = sub["p_time"].notna() if mask.sum() > 0: res.loc[sub.index[mask], "FDR_time"] = multipletests( sub.loc[mask, "p_time"], method="fdr_bh" )[1] return res