Source code for sctrial.stats.simulation

"""Monte Carlo simulation engine for DiD method comparison.

Generates synthetic **cell-level** scRNA-seq-like data with known ground-truth
DiD effects, enabling controlled comparison of statistical methods under
realistic conditions (variable cell counts, participant heterogeneity,
cell-level noise).
"""

from __future__ import annotations

import logging

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


[docs] def simulate_did_data( n_participants: int = 20, n_genes: int = 50, n_cells_per_participant: int = 100, effect_sizes: dict[str, float] | None = None, noise_sd: float = 1.0, baseline_mean: float = 5.0, participant_sd: float = 0.5, time_effect: float = 0.1, seed: int = 42, ) -> dict: """Generate synthetic cell-level DiD data with known ground-truth effects. Creates cell-level expression data for a two-arm (Treated vs Control), two-timepoint (Pre vs Post) design with participant random intercepts, variable cell counts per participant-visit, and cell-level noise. The data-generating process mirrors real scRNA-seq clinical trial data: .. math:: Y_{ijk} = \\mu + \\alpha_i + \\beta_1 \\text{Post}_j + \\beta_2 (\\text{Treat}_i \\times \\text{Post}_j) + \\epsilon_{ijk} where *i* indexes participants, *j* indexes visits, *k* indexes cells, :math:`\\alpha_i \\sim N(0, \\sigma_{\\text{participant}}^2)`, and :math:`\\epsilon_{ijk} \\sim N(0, \\sigma_{\\text{noise}}^2)`. Parameters ---------- n_participants : int Total participants (split equally between arms). n_genes : int Number of genes to simulate. n_cells_per_participant : int Mean cells per participant-visit. Actual counts are drawn from Poisson(n_cells_per_participant) with a floor of 20. effect_sizes : dict Mapping of gene_name -> true DiD effect (beta_DiD). Genes not in this dict have effect = 0 (null). noise_sd : float Cell-level residual standard deviation. baseline_mean : float Grand mean expression level. participant_sd : float Between-participant standard deviation (random intercept). time_effect : float Main effect of time (Pre->Post shift, same in both arms under null). seed : int Random seed for reproducibility. Returns ------- dict with keys: "adata" : AnnData — cell-level expression matrix with obs columns ``participant``, ``visit``, ``arm`` "pseudobulk" : DataFrame — participant-visit means with ``n_cells`` "truth" : dict mapping gene_name -> true_beta_DiD "params" : dict of simulation parameters """ import anndata as ad if n_participants % 2 != 0: raise ValueError( f"n_participants must be even (got {n_participants}); " "participants are split equally between arms." ) rng = np.random.default_rng(seed) effect_sizes = effect_sizes or {} n_per_arm = n_participants // 2 gene_names = [f"gene_{i}" for i in range(n_genes)] # Build truth vector truth = {g: effect_sizes.get(g, 0.0) for g in gene_names} obs_rows: list[dict] = [] X_rows: list[np.ndarray] = [] pb_rows: list[dict] = [] # pseudobulk aggregation for arm in ["Control", "Treated"]: for p in range(n_per_arm): pid = f"{arm[0]}{p}" # Participant random intercept (per-gene) participant_effect = rng.normal(0, participant_sd, size=n_genes) for visit_idx, visit in enumerate(["Pre", "Post"]): # Variable cell count (Poisson with floor) n_cells = max(20, rng.poisson(n_cells_per_participant)) # Per-cell expression cell_X = np.empty((n_cells, n_genes), dtype=np.float64) for c in range(n_cells): y = np.full(n_genes, baseline_mean, dtype=np.float64) y += participant_effect y += time_effect * visit_idx # DiD interaction: only Treated x Post if arm == "Treated" and visit == "Post": for gi, g in enumerate(gene_names): y[gi] += truth[g] # Cell-level noise y += rng.normal(0, noise_sd, size=n_genes) cell_X[c] = y # Store cell-level data for c in range(n_cells): obs_rows.append( { "participant": pid, "visit": visit, "arm": arm, } ) X_rows.append(cell_X) # Pre-compute pseudobulk (mean per participant-visit) pb_mean = cell_X.mean(axis=0) pb_row = { "participant": pid, "visit": visit, "arm": arm, "n_cells": n_cells, } for gi, g in enumerate(gene_names): pb_row[g] = pb_mean[gi] pb_rows.append(pb_row) # Build AnnData obs = pd.DataFrame(obs_rows) X = np.vstack(X_rows).astype(np.float32) adata = ad.AnnData( X=X, obs=obs.reset_index(drop=True), var=pd.DataFrame(index=gene_names), ) pb = pd.DataFrame(pb_rows) return { "adata": adata, "pseudobulk": pb, "truth": truth, "params": { "n_participants": n_participants, "n_genes": n_genes, "n_cells_per_participant": n_cells_per_participant, "noise_sd": noise_sd, "baseline_mean": baseline_mean, "participant_sd": participant_sd, "time_effect": time_effect, "seed": seed, }, }
def _run_single_iteration(args: tuple) -> list[dict]: """Run all methods on a single simulated dataset (for parallel dispatch).""" ( it, it_seed, n_participants, n_genes, n_cells_per_participant, effect_sizes, noise_sd, methods, sim_kwargs, ) = args import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") return _run_single_iteration_inner( it, it_seed, n_participants, n_genes, n_cells_per_participant, effect_sizes, noise_sd, methods, sim_kwargs, ) def _run_single_iteration_inner( it, it_seed, n_participants, n_genes, n_cells_per_participant, effect_sizes, noise_sd, methods, sim_kwargs, ) -> list[dict]: """Inner implementation — called inside a ``catch_warnings`` context.""" sim = simulate_did_data( n_participants=n_participants, n_genes=n_genes, n_cells_per_participant=n_cells_per_participant, effect_sizes=effect_sizes, noise_sd=noise_sd, seed=it_seed, **sim_kwargs, ) adata = sim["adata"] pb = sim["pseudobulk"] truth = sim["truth"] gene_cols = [c for c in pb.columns if c.startswith("gene_")] rows = [] for method in methods: if method == "sctrial_did": results = _run_sctrial_did(adata, gene_cols) elif method == "mixed_did": results = _run_mixed_did(adata, gene_cols) elif method == "wilcoxon": results = _run_wilcoxon(pb, gene_cols) elif method == "pseudobulk_ols": results = _run_pseudobulk_ols(pb, gene_cols) else: raise ValueError(f"Unknown method: {method}") for g in gene_cols: r = results.get(g, {}) rows.append( { "iteration": it, "method": method, "gene": g, "true_beta": truth[g], "estimated_beta": r.get("beta", np.nan), "pvalue": r.get("pvalue", np.nan), "ci_lo": r.get("ci_lo", np.nan), "ci_hi": r.get("ci_hi", np.nan), } ) return rows
[docs] def run_method_comparison( n_participants: int = 20, n_genes: int = 50, effect_sizes: dict[str, float] | None = None, noise_sd: float = 1.0, n_iterations: int = 200, methods: list[str] | None = None, seed: int = 42, n_cells_per_participant: int = 100, n_jobs: int | None = None, **sim_kwargs, ) -> pd.DataFrame: """Run Monte Carlo comparison of DiD methods on simulated data. Generates cell-level scRNA-seq data with known ground truth and runs each method on the same datasets. Methods that operate on cell-level data (sctrial DiD, mixed DiD) receive the full AnnData so their internal aggregation + cluster-robust SE pipeline is exercised faithfully. Methods that expect pseudobulk (OLS, Wilcoxon) receive the pre-aggregated participant-visit means. Iterations are parallelised across CPU cores for speed. Parameters ---------- methods : list of str Methods to compare. Options: ``"sctrial_did"``, ``"mixed_did"``, ``"pseudobulk_ols"``, ``"wilcoxon"``. Default: all four. n_iterations : int Number of simulation repetitions. n_jobs : int, optional Number of parallel workers. Defaults to ``min(cpu_count, 8)``. Set to 1 to disable parallelisation. Returns ------- DataFrame with columns: iteration, method, gene, true_beta, estimated_beta, pvalue, ci_lo, ci_hi """ import multiprocessing as mp methods = methods or ["sctrial_did", "mixed_did", "pseudobulk_ols", "wilcoxon"] rng = np.random.default_rng(seed) if n_jobs is None: n_jobs = min(mp.cpu_count(), 8) # Pre-generate seeds for reproducibility it_seeds = [int(rng.integers(0, 2**31)) for _ in range(n_iterations)] task_args = [ ( it, it_seeds[it], n_participants, n_genes, n_cells_per_participant, effect_sizes, noise_sd, methods, sim_kwargs, ) for it in range(n_iterations) ] if n_jobs == 1: # Sequential fallback all_rows = [] for args in task_args: all_rows.extend(_run_single_iteration(args)) else: all_rows = [] with mp.Pool(n_jobs) as pool: for batch in pool.imap(_run_single_iteration, task_args): all_rows.extend(batch) # Progress callback (if caller wraps with tqdm etc.) done = len(all_rows) // (len(methods) * n_genes) if done % 10 == 0: logger.info(" %d/%d iterations complete", done, n_iterations) return pd.DataFrame(all_rows)
def _run_sctrial_did(adata, gene_cols: list[str]) -> dict: """Run sctrial did_table on cell-level AnnData. Mirrors the real user workflow: cell-level data → did_table with ``aggregate="cell"`` so that cluster-robust standard errors have many observations per cluster (participant), producing reliable inference. """ import warnings from ..design import TrialDesign from .did import did_table design = TrialDesign( participant_col="participant", visit_col="visit", arm_col="arm", arm_treated="Treated", arm_control="Control", ) with warnings.catch_warnings(): warnings.simplefilter("ignore") res = did_table( adata, gene_cols, design, visits=("Pre", "Post"), aggregate="cell", standardize=False, ) out = {} for _, row in res.iterrows(): out[row["feature"]] = { "beta": row["beta_DiD"], "pvalue": row["p_DiD"], "ci_lo": row.get("ci_lo_DiD", np.nan), "ci_hi": row.get("ci_hi_DiD", np.nan), } return out def _run_mixed_did(adata, gene_cols: list[str]) -> dict: """Run mixed-effects DiD with participant as random intercept. Receives cell-level AnnData. Uses ``aggregate="participant_visit"`` because the mixed model's random intercept absorbs participant heterogeneity and only needs pseudobulk-level data. """ import warnings from ..design import TrialDesign from .mixed_effects import did_table_mixed design = TrialDesign( participant_col="participant", visit_col="visit", arm_col="arm", arm_treated="Treated", arm_control="Control", celltype_col=None, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") res = did_table_mixed( adata, gene_cols, design, visits=("Pre", "Post"), aggregate="participant_visit", standardize=False, ) out = {} for _, row in res.iterrows(): # Skip non-converged fits — their coefficients are unreliable if not row.get("converged", True): logger.debug("Mixed DiD did not converge for %s", row["feature"]) continue out[row["feature"]] = { "beta": row["beta_DiD"], "pvalue": row["p_DiD"], "ci_lo": row.get("ci_lower", np.nan), "ci_hi": row.get("ci_upper", np.nan), } return out def _run_wilcoxon(pb: pd.DataFrame, gene_cols: list[str]) -> dict: """Naive cross-sectional Wilcoxon on post-treatment treated vs control. This deliberately ignores pre-treatment data to demonstrate the cost of not accounting for baseline differences (participant random intercepts). It serves as a comparator that discards longitudinal information. """ from scipy.stats import mannwhitneyu post = pb[pb["visit"] == "Post"] treated = post[post["arm"] == "Treated"] control = post[post["arm"] == "Control"] out = {} n_failures = 0 for g in gene_cols: t_vals = treated[g].values c_vals = control[g].values try: _, pval = mannwhitneyu(t_vals, c_vals, alternative="two-sided") out[g] = { "beta": t_vals.mean() - c_vals.mean(), "pvalue": pval, } except (ValueError, TypeError) as exc: logger.debug("Wilcoxon failed for %s: %s", g, exc) n_failures += 1 out[g] = {} if n_failures: logger.warning("Wilcoxon: %d/%d genes failed", n_failures, len(gene_cols)) return out def _run_pseudobulk_ols(pb: pd.DataFrame, gene_cols: list[str]) -> dict: """OLS DiD without participant fixed effects. Intentionally omits participant FE. Without accounting for within- participant correlation the resulting standard errors may be mis-sized (conservative or anti-conservative depending on the variance structure). """ import statsmodels.formula.api as smf pb = pb.copy() pb["arm_bin"] = (pb["arm"] == "Treated").astype(float) pb["visit_bin"] = (pb["visit"] == "Post").astype(float) pb["interaction"] = pb["arm_bin"] * pb["visit_bin"] out = {} n_failures = 0 for g in gene_cols: try: # Backtick-quote gene names for formula safety fit = smf.ols(f"Q('{g}') ~ arm_bin + visit_bin + interaction", data=pb).fit() beta = fit.params["interaction"] pval = fit.pvalues["interaction"] ci = fit.conf_int().loc["interaction"] out[g] = { "beta": beta, "pvalue": pval, "ci_lo": ci.iloc[0], "ci_hi": ci.iloc[1], } except (ValueError, np.linalg.LinAlgError, KeyError) as exc: logger.debug("Pseudobulk OLS failed for %s: %s", g, exc) n_failures += 1 out[g] = {} if n_failures: logger.warning("Pseudobulk OLS: %d/%d genes failed", n_failures, len(gene_cols)) return out