Source code for sctrial.benchmark.simulator

"""Hierarchical gamma-Poisson simulator for scRNA-seq clinical trial data.

Generates realistic cell-level count data for benchmarking, inspired by
rescueSim's generative framework (Crowell et al., 2020) but implemented
natively in Python for computational efficiency.

The simulator produces data within a single cell type — matching how
edgeR, dreamlet, NEBULA, and sctrial are actually applied in practice.

Design families
---------------
- **Two-arm longitudinal**: Treated/Control × Pre/Post, β₂ = interaction
- **Single-arm paired**: All participants same arm, Pre/Post, β₁ = time

Generative model (per gene g, participant i, visit j, cell k)::

    Library size:         L_ijk ~ LogNormal(log(target_lib), σ_lib)
    Participant RE:       α_ig  ~ N(0, σ²_participant)
    Log-mean:             log(μ_igk) = β₀g + α_ig + β₁·Post_j
                                       + β₂·(Treat_i × Post_j) + log(L_ijk)
    Overdispersion:       Y_igk ~ NegBin(μ_igk, θ_g)
    Cell counts:          n_cells_ij ~ Poisson(λ) or LogNormal (imbalanced)
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Literal

import anndata as ad
import numpy as np
import pandas as pd
from scipy import sparse

logger = logging.getLogger(__name__)


[docs] @dataclass class SimulationConfig: """Configuration for a single simulation run. Parameters ---------- design : {"two_arm", "single_arm"} Trial design family. n_per_arm : int Participants per arm. For two-arm designs with equal allocation, total participants = 2 × n_per_arm. For single-arm, this IS the total participant count. n_genes : int Total genes to simulate. effects : dict[str, float] Gene-name → true effect size. Genes not listed have β₂=0 (null). mean_cells_per_visit : int Average cells per participant-visit. cell_count_mode : {"poisson", "lognormal"} How cell counts vary across participant-visits. cell_count_cv : float Coefficient of variation for lognormal cell counts (ignored if poisson). arm_ratio : tuple[int, int] | None Ratio of treated:control (e.g., (3, 7) for unbalanced). Only for two-arm. If None, uses equal allocation. missing_rate : float Fraction of participants missing their Post visit (0.0 = no attrition). dispersion_mode : {"calibrated", "fixed", "extreme"} How gene-specific θ_g is drawn. dispersion_fixed : float Fixed θ for all genes (only if dispersion_mode="fixed"). participant_sd : float SD of participant random intercept (σ_participant). baseline_mean : float Mean of β₀g across genes (log-scale). baseline_sd : float SD of β₀g across genes. time_effect : float Global time effect β₁ (applied to all genes). target_library_size : int Target library size per cell. library_size_sd : float SD of log-library size. seed : int Random seed. """ design: Literal["two_arm", "single_arm"] = "two_arm" n_per_arm: int = 20 n_genes: int = 50 effects: dict[str, float] = field(default_factory=dict) mean_cells_per_visit: int = 500 cell_count_mode: Literal["poisson", "lognormal"] = "poisson" cell_count_cv: float = 0.5 arm_ratio: tuple[int, int] | None = None missing_rate: float = 0.0 dispersion_mode: Literal["calibrated", "fixed", "extreme"] = "calibrated" dispersion_fixed: float = 10.0 participant_sd: float = 0.3 baseline_mean: float = 2.0 baseline_sd: float = 1.0 time_effect: float = 0.1 target_library_size: int = 10000 library_size_sd: float = 0.3 seed: int = 42
[docs] def simulate_trial(cfg: SimulationConfig) -> dict: """Generate a complete simulated trial dataset. Returns ------- dict with keys: "adata" : AnnData — cell-level counts (X = raw counts, obs has metadata) "pseudobulk_means": DataFrame — participant-visit mean expression (for sctrial/Wilcoxon) "pseudobulk_counts": DataFrame — participant-visit summed counts (for edgeR/limma/dreamlet) "truth" : dict[str, float] — gene → true interaction effect "config" : SimulationConfig """ rng = np.random.default_rng(cfg.seed) # --- Participant structure --- # n_per_arm = participants PER ARM (not total) if cfg.design == "two_arm": if cfg.arm_ratio is not None: # arm_ratio is e.g. (3, 7) meaning 3 treated : 7 control n_treat = cfg.arm_ratio[0] n_ctrl = cfg.arm_ratio[1] else: n_treat = cfg.n_per_arm n_ctrl = cfg.n_per_arm arms = ["Treated"] * n_treat + ["Control"] * n_ctrl else: n_treat = cfg.n_per_arm n_ctrl = 0 arms = ["Treated"] * n_treat n_total = len(arms) participant_ids = [f"P{i:03d}" for i in range(n_total)] visits = ["Pre", "Post"] # --- Gene parameters --- gene_names = [f"gene_{i}" for i in range(cfg.n_genes)] # Baseline expression: log-rate = log(gene_mean / library_size) # In real scRNA-seq, this spans ~5 orders of magnitude (most genes barely # detected, a few highly expressed). baseline_mean and baseline_sd define # the distribution of log-rates BEFORE the library-size offset is added. # Typical values from TNBC: mean=-12.4, sd=2.7 (on log scale). beta0 = rng.normal(cfg.baseline_mean, cfg.baseline_sd, size=cfg.n_genes) # No clipping — extreme low values produce sparse genes (realistic) # Gene-specific dispersion θ_g # Real scRNA-seq has very low θ (high overdispersion): median ~0.1-0.3 theta: np.ndarray if cfg.dispersion_mode == "fixed": theta = np.full(cfg.n_genes, cfg.dispersion_fixed) elif cfg.dispersion_mode == "extreme": # Very low dispersion = extremely high variance theta = np.asarray(rng.uniform(0.01, 0.5, size=cfg.n_genes)) else: # Calibrated: match real scRNA-seq (θ typically 0.01-2.0, median ~0.15) theta = np.asarray(rng.lognormal(np.log(0.15), 1.0, size=cfg.n_genes)) theta = np.asarray(np.clip(theta, 0.01, 50.0)) # True effects true_effects = {} for g, name in enumerate(gene_names): true_effects[name] = cfg.effects.get(name, 0.0) # Participant random intercepts (per gene) alpha = rng.normal(0, cfg.participant_sd, size=(n_total, cfg.n_genes)) # --- Determine which participant-visits exist (attrition) --- missing_post = set() if cfg.missing_rate > 0: n_missing = max(1, int(n_total * cfg.missing_rate)) missing_post = set(rng.choice(n_total, size=n_missing, replace=False)) # --- Generate cells --- all_obs = [] all_counts = [] for i, (pid, arm) in enumerate(zip(participant_ids, arms)): for visit in visits: if visit == "Post" and i in missing_post: continue # Cell count for this participant-visit if cfg.cell_count_mode == "lognormal": mu_cells = cfg.mean_cells_per_visit sigma_cells = np.sqrt(np.log(1 + cfg.cell_count_cv**2)) n_cells = int( rng.lognormal( np.log(mu_cells) - sigma_cells**2 / 2, sigma_cells, ) ) else: n_cells = rng.poisson(cfg.mean_cells_per_visit) n_cells = max(10, min(n_cells, 20000)) # safety bounds # Indicators is_post = 1.0 if visit == "Post" else 0.0 is_treated = 1.0 if arm == "Treated" else 0.0 # Library sizes for each cell log_lib = rng.normal( np.log(cfg.target_library_size), cfg.library_size_sd, size=n_cells, ) lib_sizes = np.exp(log_lib) # Gene expression: log(μ) = β₀ + α_i + β₁·Post + β₂·(Treat×Post) + log(L) # Shape: (n_cells, n_genes) log_mu = ( beta0[np.newaxis, :] # (1, G) + alpha[i, :][np.newaxis, :] # (1, G) + cfg.time_effect * is_post # scalar + log_lib[:, np.newaxis] # (C, 1) ) # Add treatment × post interaction for signal genes for g, name in enumerate(gene_names): effect = true_effects[name] if effect != 0.0: log_mu[:, g] += effect * is_post * is_treated mu = np.exp(log_mu) # Draw counts via Gamma-Poisson (equivalent to NegBin but handles # non-integer theta, which is essential for realistic scRNA-seq # where theta can be as low as 0.01) # # Step 1: rate ~ Gamma(shape=theta, scale=mu/theta) # Step 2: count ~ Poisson(rate) # # This is the standard compound distribution for NB counts. theta_broad = theta[np.newaxis, :] # (1, G) rate = rng.gamma( shape=theta_broad, scale=mu / theta_broad, ) counts = rng.poisson(rate).astype(np.int32) # Store obs metadata for k in range(n_cells): all_obs.append( { "participant": pid, "arm": arm, "visit": visit, "library_size": lib_sizes[k], } ) all_counts.append(counts) # --- Assemble AnnData --- obs_df = pd.DataFrame(all_obs) X = np.vstack(all_counts) adata = ad.AnnData( X=sparse.csr_matrix(X), obs=obs_df, var=pd.DataFrame(index=gene_names), ) adata.obs_names = [f"cell_{i}" for i in range(adata.n_obs)] # --- Pseudobulk aggregation --- # Two outputs: means (for sctrial/Wilcoxon) and summed counts (for edgeR/limma/dreamlet) pb_mean_rows = [] pb_count_rows = [] for (pid, visit), grp in obs_df.groupby(["participant", "visit"]): idx = grp.index.values means = X[idx].mean(axis=0) sums = X[idx].sum(axis=0) meta = {"participant": pid, "visit": visit, "arm": grp["arm"].iloc[0]} mean_row = dict(meta) count_row = dict(meta) for g, name in enumerate(gene_names): mean_row[name] = means[g] count_row[name] = int(sums[g]) pb_mean_rows.append(mean_row) pb_count_rows.append(count_row) pseudobulk_means = pd.DataFrame(pb_mean_rows) pseudobulk_counts = pd.DataFrame(pb_count_rows) return { "adata": adata, "pseudobulk_means": pseudobulk_means, "pseudobulk_counts": pseudobulk_counts, "truth": true_effects, "config": cfg, }
[docs] def calibrate_from_real_data( adata_real, layer: str | None = None, count_layer: str | None = None, participant_col: str = "participant", visit_col: str = "visit", ) -> dict: """Extract distributional parameters from real scRNA-seq data. Uses raw counts for dispersion/library-size calibration (critical for the NB generative model), and normalized expression for ICC estimation. Parameters ---------- adata_real : AnnData Real dataset. layer : str, optional Normalized expression layer (for ICC). If None, uses .X. count_layer : str, optional Raw count layer for dispersion and library size calibration. If None, auto-detects: checks for integer .X, then "counts" layer. If no raw counts available, estimates from normalized data. participant_col : str Column name for participant IDs. visit_col : str Column name for visit/timepoint. Returns ------- dict with keys: "mean_cells_per_visit" : float "cell_count_cv" : float "participant_icc" : float (median across genes) "dispersion_median" : float "baseline_mean" : float (log-scale, on raw counts) "baseline_sd" : float "library_size_mean" : float (log-scale, on raw counts) "library_size_sd" : float "has_raw_counts" : bool """ import warnings obs = adata_real.obs # --- Find raw counts --- has_raw_counts = False X_counts = None if count_layer is not None: X_counts = adata_real.layers[count_layer] has_raw_counts = True else: # Auto-detect: check if .X is integer counts X_test = adata_real.X if sparse.issparse(X_test): X_test_dense = X_test[:100].toarray() else: X_test_dense = X_test[:100] if np.allclose(X_test_dense, X_test_dense.astype(int)): X_counts = adata_real.X has_raw_counts = True elif "counts" in (adata_real.layers or {}): X_counts = adata_real.layers["counts"] has_raw_counts = True if X_counts is not None: if sparse.issparse(X_counts): X_counts = X_counts.toarray() X_counts = np.asarray(X_counts, dtype=float) # --- Normalized expression (for ICC) --- X_norm = adata_real.layers[layer] if layer else adata_real.X if sparse.issparse(X_norm): X_norm = X_norm.toarray() X_norm = np.asarray(X_norm, dtype=float) # --- Cell counts per participant-visit --- cell_counts = obs.groupby([participant_col, visit_col]).size() mean_cells = cell_counts.mean() cv_cells = cell_counts.std() / mean_cells if mean_cells > 0 else 0.5 # --- Count-based statistics --- if has_raw_counts and X_counts is not None: # Library sizes from raw counts lib_sizes = X_counts.sum(axis=1) log_lib = np.log(lib_sizes + 1) avg_lib = lib_sizes.mean() # Gene-level stats on counts gene_means = X_counts.mean(axis=0) gene_vars = X_counts.var(axis=0) # Log-rate: log(gene_mean / avg_library_size) # This is the baseline parameter β₀g in the generative model # BEFORE library-size offset. Spans ~5 orders of magnitude. mask_expressed = gene_means > 0 log_rates = np.log(gene_means / avg_lib + 1e-10) log_rates_expr = log_rates[mask_expressed] log_means = log_rates_expr # used for baseline_mean/sd below # Dispersion: method of moments on counts # Var = mu + mu^2/theta => theta = mu^2 / (var - mu) # Real scRNA-seq has very low theta (0.01-2.0), do NOT clip at 1.0 with warnings.catch_warnings(): warnings.simplefilter("ignore") denom = gene_vars - gene_means denom = np.where(denom > 0, denom, 1e-6) theta_est = gene_means**2 / denom theta_est = np.clip(theta_est, 0.01, 1000) else: # No raw counts — estimate from normalized data # Use typical scRNA-seq defaults calibrated to the sparsity pattern zero_frac = (X_norm == 0).mean() # Typical library size for scRNA-seq: ~2000-10000 # Estimate from expression scale lib_sizes_norm = X_norm.sum(axis=1) log_lib = np.log(np.clip(lib_sizes_norm, 1, None)) gene_means = X_norm.mean(axis=0) gene_vars = X_norm.var(axis=0) log_means = np.log(gene_means + 1) # For normalized data, dispersion estimate is less reliable # Use conservative defaults based on zero fraction # More zeros → lower dispersion (higher overdispersion) theta_est = np.full_like(gene_means, 5.0) # moderate default if zero_frac > 0.9: theta_est[:] = 2.0 # high sparsity → high overdispersion # --- ICC from normalized expression --- iccs = [] n_genes_sample = min(50, adata_real.n_vars) gene_idx = np.random.choice(adata_real.n_vars, n_genes_sample, replace=False) for g in gene_idx: expr = X_norm[:, g] groups = obs[participant_col].values unique_groups = np.unique(groups) if len(unique_groups) < 3: continue grand_mean = expr.mean() ss_between = sum( np.sum(groups == grp) * (expr[groups == grp].mean() - grand_mean) ** 2 for grp in unique_groups ) ss_total = np.sum((expr - grand_mean) ** 2) if ss_total > 0: iccs.append(ss_between / ss_total) return { "mean_cells_per_visit": float(mean_cells), "cell_count_cv": float(cv_cells), "participant_icc": float(np.median(iccs)) if iccs else 0.1, "dispersion_median": float(np.median(theta_est)), "baseline_mean": float(log_means.mean()), "baseline_sd": float(log_means.std()), "library_size_mean": float(log_lib.mean()), "library_size_sd": float(log_lib.std()), "has_raw_counts": has_raw_counts, }
[docs] def validate_simulator( cfg: SimulationConfig, adata_real, layer: str | None = None, participant_col: str = "participant", visit_col: str = "visit", ) -> dict: """Compare simulated data properties against real data. Compares on a common scale: if real data is normalized (no raw counts), the simulated raw counts are log1p-CPM-normalized before comparison. If real data has raw counts, compares counts directly. Returns ------- dict with keys: "cell_level" : dict — mean-variance, zero fraction, library size "pseudobulk_level" : dict — ICC, cell-count distribution, mean distribution """ sim = simulate_trial(cfg) sim_adata = sim["adata"] sim_X_raw = sim_adata.X.toarray() if sparse.issparse(sim_adata.X) else sim_adata.X real_X = adata_real.layers[layer] if layer else adata_real.X if sparse.issparse(real_X): real_X = real_X.toarray() real_X = np.asarray(real_X, dtype=float) # Check if real data is normalized (non-integer) — if so, normalize sim too real_is_counts = np.allclose(real_X[:100], real_X[:100].astype(int)) if real_is_counts: sim_X = sim_X_raw.astype(float) else: # Normalize simulated counts to log1p-CPM to match real normalized data lib = sim_X_raw.sum(axis=1, keepdims=True).astype(float) lib = np.where(lib > 0, lib, 1.0) sim_X = np.log1p(sim_X_raw / lib * 1e6) # --- Cell-level comparisons --- # Mean-variance sim_means = sim_X.mean(axis=0) sim_vars = sim_X.var(axis=0) real_means = real_X.mean(axis=0) real_vars = real_X.var(axis=0) # Zero fraction per gene sim_zeros = (sim_X == 0).mean(axis=0) real_zeros = (real_X == 0).mean(axis=0) # Library size: compare on MATCHED gene count # When simulator uses fewer genes than the real dataset (e.g. 50 vs 20k), # whole-transcriptome library sizes are not comparable. Instead, compare # total counts across the SAME number of genes. n_sim_genes = sim_X.shape[1] n_real_genes = real_X.shape[1] if n_sim_genes < n_real_genes: # Subsample real data to same number of genes (pick random subset # with similar mean expression distribution) real_gene_means_sorted = np.argsort(real_X.mean(axis=0)) # Sample evenly across the expression range idx = np.linspace(0, n_real_genes - 1, n_sim_genes, dtype=int) real_X_subset = real_X[:, real_gene_means_sorted[idx]] real_lib = real_X_subset.sum(axis=1) else: real_lib = real_X.sum(axis=1) sim_lib = sim_X.sum(axis=1) # --- Pseudobulk-level comparisons --- # Cell counts per participant-visit sim_cell_counts = sim_adata.obs.groupby(["participant", "visit"]).size().values real_cell_counts = adata_real.obs.groupby([participant_col, visit_col]).size().values # Pseudobulk means sim_pb = sim["pseudobulk_means"] gene_cols = [c for c in sim_pb.columns if c.startswith("gene_")] sim_pb_means = sim_pb[gene_cols].values.flatten() return { "cell_level": { "sim_gene_means": sim_means, "sim_gene_vars": sim_vars, "real_gene_means": real_means, "real_gene_vars": real_vars, "sim_zero_fraction": sim_zeros, "real_zero_fraction": real_zeros, "sim_library_sizes": sim_lib, "real_library_sizes": real_lib, }, "pseudobulk_level": { "sim_cell_counts": sim_cell_counts, "real_cell_counts": real_cell_counts, "sim_pb_means": sim_pb_means, }, }