Source code for sctrial.stats.abundance

from __future__ import annotations

import warnings

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from anndata import AnnData

from ..adata_tools import subset_primary
from ..design import TrialDesign
from ..utils import wild_cluster_bootstrap_t
from ._utils import apply_fdr, encode_visit
from .did import MIN_CLUSTERS_FOR_ROBUST_SE


[docs] def abundance_did( adata: AnnData, design: TrialDesign, visits: tuple[str, str], exclude_crossovers: bool = True, transform: str = "arcsin_sqrt", min_units: int = 5, covariates: list[str] | None = None, use_bootstrap: bool = False, n_boot: int = 999, seed: int = 42, ) -> pd.DataFrame: """Test treatment-induced cell-type abundance changes via DiD on proportions. This function calculates cell-type proportions per participant-visit and fits a DiD model to test for treatment-induced compositional shifts. Statistical Assumptions ----------------------- - Requires `min_units` paired participants (default 5) per cell type. - Requires both treatment arms to be represented among paired participants. - Cell types with no variation in the transformed outcome are skipped. - Uses **cluster-robust standard errors** (clustered by participant) to account for within-participant correlation across visits. - The arcsin-sqrt transform is variance-stabilizing for proportions and is recommended for compositional data. Parameters ---------- adata AnnData object. design A `TrialDesign` object. Must have `celltype_col` defined. visits Tuple of (baseline, followup) visit labels. exclude_crossovers Whether to exclude crossover cells. transform Mathematical transformation for proportions: - 'arcsin_sqrt': arcsin(sqrt(p)), variance-stabilizing for proportions. - 'logit': log(p / (1-p)), useful for extreme proportions. - 'none': use raw proportions (not recommended). min_units Minimum number of paired participants required for a cell type to be tested. Cell types with fewer paired participants are skipped. covariates Additional columns in `adata.obs` to include as fixed effects. Must be constant within participant-visit (e.g., age, sex). use_bootstrap If True, uses Wild Cluster Bootstrap for p-values. Recommended for small sample sizes (< 15 participants per arm). n_boot Number of bootstrap permutations. seed Random seed. Returns ------- pd.DataFrame Table with one row per cell type containing: - celltype: Cell type name - n_participants: Number of paired participants - beta_DiD: Treatment effect (interaction term) - se_DiD: Cluster-robust standard error - p_DiD: P-value for the treatment effect - FDR_DiD: Benjamini-Hochberg FDR-corrected p-value Interpretation notes -------------------- The arcsin-sqrt transform stabilizes variance but is not on the original proportion scale. A positive beta_DiD indicates an increase in proportion in the treated arm relative to control; to interpret effect magnitude in raw proportions, inspect group-level proportions directly. Examples -------- >>> ab_res = abundance_did(adata, design, visits=("V1", "V2")) >>> print(ab_res) """ if design.celltype_col is None: raise ValueError("celltype_col is required for abundance_did") ad = subset_primary(adata, design, visits=visits, exclude_crossovers=exclude_crossovers) obs = ad.obs.copy() # counts per unit×visit×arm×celltype grp_cols = [design.participant_col, design.visit_col, design.arm_col, design.celltype_col] # We need to preserve covariates. Covariates are usually participant-level or participant-visit level. # If they are participant-level, they are constant for all cells of a participant. counts = obs.groupby(grp_cols, observed=True).size().reset_index(name="n_cells") totals = ( counts.groupby([design.participant_col, design.visit_col, design.arm_col], observed=True)[ "n_cells" ] .sum() .reset_index(name="total_cells") ) # Expand to include zero counts for missing celltype/participant/visit celltypes = sorted(counts[design.celltype_col].unique()) base_df = totals[[design.participant_col, design.visit_col, design.arm_col]].drop_duplicates() base_df["_key"] = 1 cell_df = pd.DataFrame({design.celltype_col: celltypes, "_key": 1}) full_df = base_df.merge(cell_df, on="_key").drop(columns=["_key"]) counts = counts.merge( full_df, on=[design.participant_col, design.visit_col, design.arm_col, design.celltype_col], how="right", ) counts["n_cells"] = counts["n_cells"].fillna(0) counts = counts.merge( totals, on=[design.participant_col, design.visit_col, design.arm_col], how="left" ) counts["total_cells"] = counts["total_cells"].fillna(0) counts["prop"] = counts["n_cells"] / counts["total_cells"].clip(lower=1) if covariates: # Merge covariates back into counts. # Assume covariates are constant per (participant, visit). cov_df = obs[[design.participant_col, design.visit_col] + covariates].drop_duplicates() counts = counts.merge(cov_df, on=[design.participant_col, design.visit_col], how="left") if transform == "arcsin_sqrt": y = np.arcsin(np.sqrt(counts["prop"].clip(0, 1))) counts["y"] = y elif transform == "logit": p = counts["prop"].clip(1e-6, 1 - 1e-6) counts["y"] = np.log(p / (1 - p)) else: counts["y"] = counts["prop"] counts = encode_visit(counts, design.visit_col, visits) counts["arm_bin"] = design.arm_bin(counts) rows = [] # Paired participants: must have cells in both visits (overall, not per-celltype). # For abundance analysis, zero cells of a specific celltype is valid data (prop=0) # since the zero-count expansion above fills in missing combos. wide_tot = totals.pivot_table( index=design.participant_col, columns=design.visit_col, values="total_cells", aggfunc="mean", observed=True, ) paired_units = wide_tot[wide_tot[visits[0]].notna() & wide_tot[visits[1]].notna()].index for ct in sorted(counts[design.celltype_col].unique()): tmp = counts[counts[design.celltype_col] == ct].copy() # keep paired units only tmp = tmp[tmp[design.participant_col].isin(paired_units)].copy() n_units = tmp[design.participant_col].nunique() if n_units < min_units: continue # must have both arms among units arm_counts = tmp.groupby("arm_bin")[design.participant_col].nunique() if (arm_counts > 0).sum() < 2: continue # Ensure there is at least some variation in the outcome if tmp["y"].nunique() < 2: continue # If covariates are constant within participant, use differenced model # to avoid collinearity with participant fixed effects. use_diff = False if covariates: per_unit = tmp.groupby(design.participant_col, observed=True)[covariates].nunique( dropna=False ) use_diff = bool((per_unit.max(axis=0) <= 1).all()) if use_diff: wide = tmp.pivot_table( index=design.participant_col, columns=design.visit_col, values="y", aggfunc="mean", observed=True, ) if visits[0] not in wide.columns or visits[1] not in wide.columns: continue delta = (wide[visits[1]] - wide[visits[0]]).dropna() if delta.empty: continue df_delta = delta.rename("delta").to_frame() df_delta["arm_bin"] = ( tmp.groupby(design.participant_col, observed=True)["arm_bin"] .first() .reindex(df_delta.index) ) if covariates: cov_df = ( tmp.groupby(design.participant_col, observed=True)[covariates] .first() .reindex(df_delta.index) ) df_delta = pd.concat([df_delta, cov_df], axis=1) df_delta = df_delta.dropna() if df_delta.shape[0] < min_units: continue # Preserve participant IDs before resetting index for safe .loc diff_pids = df_delta.index.to_numpy() df_delta = df_delta.reset_index(drop=True) formula = "delta ~ arm_bin" if covariates: formula += " + " + " + ".join(covariates) model = smf.ols(formula, data=df_delta) else: tmp = tmp.reset_index(drop=True) # unique int index for .loc formula = f"y ~ visit_num + visit_num:arm_bin + C({design.participant_col})" if covariates: formula += " + " + " + ".join(covariates) model = smf.ols(formula, data=tmp) try: # Warn if using cluster-robust SE with few clusters if n_units < MIN_CLUSTERS_FOR_ROBUST_SE: warnings.warn( f"Only {n_units} clusters (participants) available for celltype " f"'{ct}'. Cluster-robust standard errors are unreliable with fewer " f"than {MIN_CLUSTERS_FOR_ROBUST_SE} clusters. Consider using " f"use_bootstrap=True for more reliable p-values.", UserWarning, stacklevel=2, ) # Use cluster-robust standard errors for consistency with did_fit if use_diff: fit = model.fit() term = "arm_bin" else: fit = model.fit( cov_type="cluster", cov_kwds={"groups": tmp[design.participant_col]} ) term = "visit_num:arm_bin" # Check if interaction term was estimable if term not in fit.params or np.isnan(fit.params[term]): raise ValueError("DiD term not estimable") # Align clusters with actual model rows (statsmodels may drop rows). # DataFrames have been reset_index'd so row_labels are integer positions. model_row_idx = fit.model.data.row_labels if use_diff: clusters_aligned = diff_pids[model_row_idx] else: clusters_aligned = tmp[design.participant_col].loc[model_row_idx].to_numpy() n_units_eff = int(len(np.unique(clusters_aligned))) p_val = float(fit.pvalues[term]) se_boot = np.nan ci_lo_boot = np.nan ci_hi_boot = np.nan if use_bootstrap: boot_res = wild_cluster_bootstrap_t( fit, X=fit.model.exog, clusters=clusters_aligned, term_name=term, B=n_boot, seed=seed, ) p_val = boot_res.p_boot se_boot = boot_res.se_boot ci_lo_boot = boot_res.ci_lo ci_hi_boot = boot_res.ci_hi row_dict: dict = { "celltype": ct, "n_participants": n_units_eff, "beta_DiD": float(fit.params[term]), "se_DiD": float(fit.bse[term]), "p_DiD": p_val, "beta_time": float(fit.params.get("visit_num", np.nan)), "p_time": float(fit.pvalues.get("visit_num", np.nan)), } if use_bootstrap: row_dict["p_DiD_boot"] = p_val row_dict["se_DiD_boot"] = se_boot row_dict["ci_lo_boot"] = ci_lo_boot row_dict["ci_hi_boot"] = ci_hi_boot rows.append(row_dict) except (ValueError, np.linalg.LinAlgError, KeyError): # Fallback: delta model without fixed effects try: wide = tmp.pivot_table( index=design.participant_col, columns=design.visit_col, values="y", aggfunc="mean", observed=True, ) if visits[0] not in wide.columns or visits[1] not in wide.columns: continue delta = (wide[visits[1]] - wide[visits[0]]).dropna() if delta.empty: continue df_delta = delta.rename("delta").to_frame() df_delta["arm_bin"] = ( tmp.groupby(design.participant_col, observed=True)["arm_bin"] .first() .reindex(df_delta.index) ) if covariates: # Use baseline (pre-treatment) covariate values so that # time-varying covariates are not silently collapsed. baseline = tmp[tmp[design.visit_col] == visits[0]] cov_df = ( baseline.groupby(design.participant_col, observed=True)[covariates] .first() .reindex(df_delta.index) ) df_delta = pd.concat([df_delta, cov_df], axis=1) df_delta = df_delta.dropna() if df_delta.shape[0] < min_units: continue fallback_formula = "delta ~ arm_bin" if covariates: fallback_formula += " + " + " + ".join(covariates) model = smf.ols(fallback_formula, data=df_delta) fit = model.fit() term = "arm_bin" if term not in fit.params or np.isnan(fit.params[term]): continue fallback_row: dict = { "celltype": ct, "n_participants": int(df_delta.shape[0]), "beta_DiD": float(fit.params[term]), "se_DiD": float(fit.bse[term]), "p_DiD": float(fit.pvalues[term]), "beta_time": np.nan, "p_time": np.nan, } if use_bootstrap: fallback_row["p_DiD_boot"] = np.nan fallback_row["se_DiD_boot"] = np.nan fallback_row["ci_lo_boot"] = np.nan fallback_row["ci_hi_boot"] = np.nan rows.append(fallback_row) except (ValueError, np.linalg.LinAlgError, KeyError): continue if not rows: return pd.DataFrame( columns=[ "celltype", "n_participants", "beta_DiD", "se_DiD", "p_DiD", "beta_time", "p_time", "FDR_DiD", ] ) res = pd.DataFrame(rows).sort_values("p_DiD") res = apply_fdr(res, p_col="p_DiD", fdr_col="FDR_DiD") return res.reset_index(drop=True)