Source code for sctrial.benchmark.orchestrator

"""Benchmark orchestrator — runs the full simulation grid with parallelization.

Usage::

    from sctrial.benchmark.orchestrator import run_benchmark
    results = run_benchmark(n_jobs=25, output_dir="benchmark_results")

Or from command line::

    python -m sctrial.benchmark.orchestrator --n-jobs 25 --output-dir benchmark_results
"""

from __future__ import annotations

import logging
import multiprocessing as mp
import time
from pathlib import Path

import numpy as np
import pandas as pd

from .simulator import SimulationConfig, simulate_trial

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Scenario definitions (from the locked NatMeth benchmark plan)
# ---------------------------------------------------------------------------

_N_GENES = 50
_N_SIGNAL = 10
_N_ITERATIONS = 200
_MEAN_CELLS = 500

# Core methods
CORE_METHODS = [
    "sctrial_did",
    "dreamlet",
    "nebula",
    "wilcoxon_paired",
]

# Excluded from benchmark:
# - edger_qlf: no native repeated-measures support; ~arm*visit without
#   participant blocking is severely conservative (~0% FPR); participant
#   FE designs are rank-deficient with nested participants.
# - limma_voom: duplicateCorrelation crashes at n>=40; participant FE
#   designs are rank-deficient; unblocked ~arm*visit is conservative.
# Both lack proper paired-design handling. dreamlet (mixed model)
# represents the count-based pseudobulk class correctly.

# Internal sensitivity (not headline)
INTERNAL_METHODS = ["sctrial_mixed"]


def _make_effects(n_signal: int, beta: float, mixed_sign: bool = False) -> dict:
    """Create effect dict for n_signal genes."""
    effects = {}
    for i in range(n_signal):
        if mixed_sign and i % 2 == 1:
            effects[f"gene_{i}"] = -beta
        else:
            effects[f"gene_{i}"] = beta
    return effects


def build_scenario_grid(design: str = "two_arm") -> list[dict]:
    """Build the full scenario grid for one design family.

    Returns list of dicts, each with keys: name, config_kwargs, description.
    """
    scenarios = []

    # 1. Complete null
    for n in [8, 12, 20, 40, 60]:
        scenarios.append(
            {
                "name": f"null_n{n}",
                "description": f"Complete null, n={n} per arm",
                "config_kwargs": {
                    "design": design,
                    "n_per_arm": n,
                    "n_genes": _N_GENES,
                    "effects": {},
                    "mean_cells_per_visit": _MEAN_CELLS,
                },
            }
        )

    # 2. Null + nuisance heterogeneity
    for n in [20, 40]:
        scenarios.append(
            {
                "name": f"null_hetero_n{n}",
                "description": f"Null + high participant SD + imbalanced cells, n={n}",
                "config_kwargs": {
                    "design": design,
                    "n_per_arm": n,
                    "n_genes": _N_GENES,
                    "effects": {},
                    "mean_cells_per_visit": _MEAN_CELLS,
                    "participant_sd": 0.8,
                    "cell_count_mode": "lognormal",
                    "cell_count_cv": 0.8,
                },
            }
        )

    # 3. Sparse DE (positive)
    for n in [20, 40, 60]:
        for beta in [0.2, 0.5, 1.0]:
            scenarios.append(
                {
                    "name": f"de_pos_n{n}_b{beta}",
                    "description": f"Sparse DE (positive), n={n}, beta={beta}",
                    "config_kwargs": {
                        "design": design,
                        "n_per_arm": n,
                        "n_genes": _N_GENES,
                        "effects": _make_effects(_N_SIGNAL, beta, mixed_sign=False),
                        "mean_cells_per_visit": _MEAN_CELLS,
                    },
                }
            )

    # 4. Sparse DE (mixed sign)
    for n in [20, 40]:
        for beta in [0.5, 1.0]:
            scenarios.append(
                {
                    "name": f"de_mixed_n{n}_b{beta}",
                    "description": f"Sparse DE (mixed sign), n={n}, beta={beta}",
                    "config_kwargs": {
                        "design": design,
                        "n_per_arm": n,
                        "n_genes": _N_GENES,
                        "effects": _make_effects(_N_SIGNAL, beta, mixed_sign=True),
                        "mean_cells_per_visit": _MEAN_CELLS,
                    },
                }
            )

    # 5. Varying cells
    for n_cells in [200, 1000, 5000]:
        scenarios.append(
            {
                "name": f"cells_{n_cells}_n40",
                "description": f"Varying cells ({n_cells}/visit), n=40, beta=0.5",
                "config_kwargs": {
                    "design": design,
                    "n_per_arm": 40,
                    "n_genes": _N_GENES,
                    "effects": _make_effects(_N_SIGNAL, 0.5),
                    "mean_cells_per_visit": n_cells,
                },
            }
        )

    # 6. Unequal arms (two-arm only)
    if design == "two_arm":
        for ratio in [(3, 7), (5, 10), (10, 20)]:
            scenarios.append(
                {
                    "name": f"imbal_{ratio[0]}v{ratio[1]}",
                    "description": f"Unequal arms {ratio[0]}:{ratio[1]}, beta=0.5",
                    "config_kwargs": {
                        "design": design,
                        "n_per_arm": sum(ratio),
                        "n_genes": _N_GENES,
                        "effects": _make_effects(_N_SIGNAL, 0.5),
                        "mean_cells_per_visit": _MEAN_CELLS,
                        "arm_ratio": ratio,
                    },
                }
            )

    # 7. Missing visits
    for rate in [0.1, 0.2]:
        scenarios.append(
            {
                "name": f"missing_{int(rate * 100)}pct_n40",
                "description": f"Missing {int(rate * 100)}% post visits, n=40, beta=0.5",
                "config_kwargs": {
                    "design": design,
                    "n_per_arm": 40,
                    "n_genes": _N_GENES,
                    "effects": _make_effects(_N_SIGNAL, 0.5),
                    "mean_cells_per_visit": _MEAN_CELLS,
                    "missing_rate": rate,
                },
            }
        )

    return scenarios


# ---------------------------------------------------------------------------
# Single-iteration worker (for multiprocessing)
# ---------------------------------------------------------------------------


def _run_single_iteration(args: tuple) -> list[dict]:
    """Run all methods on one simulated dataset. Called in worker processes."""
    import warnings

    warnings.filterwarnings("ignore")

    scenario_name, iteration, seed, config_kwargs, methods = args

    # For single-arm designs, force time_effect=0 so null scenarios are
    # truly null (the default 0.1 cancels in two-arm DiD but is detected
    # by single-arm methods testing Δ vs 0).
    kw = dict(config_kwargs)
    if kw.get("design") == "single_arm" and "time_effect" not in kw:
        kw["time_effect"] = 0.0

    cfg = SimulationConfig(seed=seed, **kw)
    design_type = cfg.design
    sim = simulate_trial(cfg)
    gene_cols = [f"gene_{i}" for i in range(cfg.n_genes)]
    signal_genes = set(cfg.effects.keys())

    rows = []
    for method in methods:
        t0 = time.time()
        try:
            results = _dispatch_method(method, sim, gene_cols, design_type=design_type)
        except Exception as exc:
            logger.warning(
                "Method %s failed on %s iter %d: %s", method, scenario_name, iteration, exc
            )
            results = {
                g: {
                    "beta": np.nan,
                    "pvalue": np.nan,
                    "ci_lo": np.nan,
                    "ci_hi": np.nan,
                    "converged": False,
                    "failure_mode": "numerical",
                }
                for g in gene_cols
            }
        elapsed = time.time() - t0

        for gene in gene_cols:
            r = results.get(gene, {})
            rows.append(
                {
                    "scenario": scenario_name,
                    "iteration": iteration,
                    "method": method,
                    "gene": gene,
                    "true_beta": cfg.effects.get(gene, 0.0),
                    "is_signal": gene in signal_genes,
                    "estimated_beta": r.get("beta", np.nan),
                    "pvalue": r.get("pvalue", np.nan),
                    "ci_lo": r.get("ci_lo", np.nan),
                    "ci_hi": r.get("ci_hi", np.nan),
                    "converged": r.get("converged", False),
                    "failure_mode": r.get("failure_mode", "numerical"),
                    "runtime_seconds": elapsed / len(gene_cols),
                    "n_per_arm": cfg.n_per_arm,
                    "mean_cells": cfg.mean_cells_per_visit,
                }
            )

    return rows


def _dispatch_method(
    method: str,
    sim: dict,
    gene_cols: list[str],
    design_type: str = "two_arm",
) -> dict:
    """Route to the correct runner with the appropriate data representation.

    - sctrial, NEBULA: cell-level AnnData
    - edgeR, limma-voom, dreamlet: summed-count pseudobulk (true counts)
    - Wilcoxon (paired delta): mean-expression pseudobulk

    Parameters
    ----------
    design_type : str
        "two_arm" or "single_arm". Determines the statistical model used
        by R-based runners (interaction vs paired visit).
    """
    # Backwards compat: old sim dicts may have "pseudobulk" instead of split keys
    pb_counts = sim.get("pseudobulk_counts", sim.get("pseudobulk"))
    pb_means = sim.get("pseudobulk_means", sim.get("pseudobulk"))

    # Log-transformed pseudobulk means for methods that need expression-scale
    # input (sctrial_did, wilcoxon_paired). This ensures all methods report
    # betas on a comparable log-fold-change scale:
    #   - edgeR/limma/dreamlet internally log-transform raw counts
    #   - NEBULA fits a log-link NB model on raw counts
    #   - sctrial_did and wilcoxon_paired need log-transformed input explicitly
    if pb_means is not None:
        pb_log = pb_means.copy()
        gene_mask = [c for c in gene_cols if c in pb_log.columns]
        pb_log[gene_mask] = np.log1p(pb_log[gene_mask])
    else:
        pb_log = None

    if method == "sctrial_did":
        from .runners import sctrial_did

        # Pass log-pseudobulk DataFrame instead of raw cell-level AnnData
        # so that DiD betas are on log-expression scale, comparable to
        # edgeR/limma/dreamlet log-fold-changes
        return sctrial_did.run(pb_log, gene_cols, from_pseudobulk=True, design_type=design_type)
    elif method == "edger_qlf":
        from .runners import edger_qlf

        return edger_qlf.run(pb_counts, gene_cols, design_type=design_type)
    elif method == "limma_voom":
        from .runners import limma_voom

        return limma_voom.run(pb_counts, gene_cols, design_type=design_type)
    elif method == "dreamlet":
        from .runners import dreamlet_runner

        return dreamlet_runner.run(pb_counts, gene_cols, design_type=design_type)
    elif method == "nebula":
        from .runners import nebula_runner

        return nebula_runner.run(sim["adata"], gene_cols, design_type=design_type)
    elif method == "wilcoxon_paired":
        from .runners import wilcoxon_paired

        return wilcoxon_paired.run(pb_log, gene_cols, design_type=design_type)
    else:
        raise ValueError(f"Unknown method: {method}")


# ---------------------------------------------------------------------------
# Main orchestrator
# ---------------------------------------------------------------------------


[docs] def run_benchmark( designs: list[str] | None = None, methods: list[str] | None = None, n_iterations: int = _N_ITERATIONS, n_jobs: int = 1, output_dir: str | Path = "benchmark_results", resume: bool = True, ) -> pd.DataFrame: """Run the full NatMeth benchmark grid. Parameters ---------- designs : list of str Design families to run. Default: ["two_arm", "single_arm"]. methods : list of str Methods to benchmark. Default: CORE_METHODS. n_iterations : int Monte Carlo iterations per scenario. n_jobs : int Parallel workers. Use -1 for all cores. output_dir : str or Path Directory for output CSVs and figures. resume : bool If True, skip scenarios that already have results in output_dir. Returns ------- DataFrame with all results concatenated. """ if designs is None: designs = ["two_arm", "single_arm"] if methods is None: methods = CORE_METHODS if n_jobs == -1: n_jobs = mp.cpu_count() output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) rng = np.random.default_rng(2024) all_results = [] for design in designs: scenarios = build_scenario_grid(design) print(f"\n{'=' * 60}") print(f"Design: {design}{len(scenarios)} scenarios × {n_iterations} iterations") print(f"Methods: {methods}") print(f"Parallel workers: {n_jobs}") print(f"{'=' * 60}") for si, scenario in enumerate(scenarios): name = f"{design}__{scenario['name']}" csv_path = output_dir / f"{name}.csv" # Resume support if resume and csv_path.exists(): print(f" [{si + 1}/{len(scenarios)}] {name} — CACHED, skipping") existing = pd.read_csv(csv_path) all_results.append(existing) continue print(f" [{si + 1}/{len(scenarios)}] {name}: {scenario['description']}") # Pre-generate seeds seeds = [int(rng.integers(0, 2**31)) for _ in range(n_iterations)] # Build task args task_args = [ (name, it, seeds[it], scenario["config_kwargs"], methods) for it in range(n_iterations) ] t0 = time.time() all_rows: list = [] flush_interval = 20 # write to disk every 20 iterations def _process_iteration(i: int, batch: list) -> None: all_rows.extend(batch) if (i + 1) % flush_interval == 0: elapsed = time.time() - t0 eta = elapsed / (i + 1) * (n_iterations - i - 1) print( f" {i + 1}/{n_iterations} iterations " f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)" ) # Incremental save — protects against crashes pd.DataFrame(all_rows).to_csv(csv_path, index=False) if n_jobs == 1: for i, args in enumerate(task_args): _process_iteration(i, _run_single_iteration(args)) else: # Use 'spawn' context to avoid fork-inheriting corrupted R/rpy2 # state. With 'fork', R subprocess calls inside workers # produce incorrect results (e.g., edgeR FPR=0.002 instead # of 0.05) even when using Rscript subprocess. ctx = mp.get_context("spawn") with ctx.Pool(n_jobs) as pool: for i, batch in enumerate(pool.imap(_run_single_iteration, task_args)): _process_iteration(i, batch) elapsed = time.time() - t0 df = pd.DataFrame(all_rows) df.to_csv(csv_path, index=False) all_results.append(df) print(f" Done in {elapsed:.0f}s → {csv_path.name}") # Combine all combined = pd.concat(all_results, ignore_index=True) combined_path = output_dir / "benchmark_combined.csv" combined.to_csv(combined_path, index=False) print(f"\nAll results saved → {combined_path}") print(f"Total rows: {len(combined):,}") return combined
# --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="NatMeth benchmark: simulation grid") parser.add_argument("--n-jobs", type=int, default=1, help="Parallel workers (-1 = all cores)") parser.add_argument( "--n-iterations", type=int, default=_N_ITERATIONS, help="Monte Carlo iterations per scenario", ) parser.add_argument( "--output-dir", type=str, default="benchmark_results", help="Output directory" ) parser.add_argument( "--designs", nargs="+", default=["two_arm", "single_arm"], help="Design families" ) parser.add_argument( "--methods", nargs="+", default=None, help="Methods to run (default: all core)" ) parser.add_argument("--no-resume", action="store_true", help="Don't skip existing results") args = parser.parse_args() logging.basicConfig(level=logging.INFO) run_benchmark( designs=args.designs, methods=args.methods, n_iterations=args.n_iterations, n_jobs=args.n_jobs, output_dir=args.output_dir, resume=not args.no_resume, )