Source code for sctrial.stats.gsea

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from anndata import AnnData

from ..design import TrialDesign
from .comparisons import between_arm_comparison, within_arm_comparison
from .did import did_table
from .pseudobulk import pseudobulk_did

if TYPE_CHECKING:
    import gseapy as gp

try:
    import gseapy as gp
except ImportError:
    gp = None

__all__ = [
    "run_gsea_cross_sectional",
    "run_gsea_did",
    "run_gsea_did_by_celltype",
    "run_gsea_did_multi",
    "run_gsea_pseudobulk",
    "run_gsea_within_arm",
]


def _ensure_gseapy() -> None:
    """Ensure gseapy is installed."""
    if gp is None:
        raise ImportError(
            "gseapy is required for GSEA functions. Install with 'pip install gseapy'."
        )


def _rank_did_results(
    res: pd.DataFrame,
    rank_by: str,
    min_units: int,
) -> pd.DataFrame:
    """Rank the DiD results."""
    valid = res[res["n_units"] >= min_units].copy()
    if len(valid) == 0:
        raise ValueError(
            f"No genes have sufficient data (min_units={min_units}). "
            f"Try reducing min_units or checking your data."
        )

    if rank_by == "signed_confidence":
        valid["rank"] = np.sign(valid["beta_DiD"].fillna(0)) * -np.log10(
            valid["p_DiD"].fillna(1) + 1e-12
        )
    elif rank_by == "beta":
        valid["rank"] = valid["beta_DiD"].fillna(0)
    elif rank_by == "tstat":
        valid["rank"] = valid["beta_DiD"].fillna(0) / (valid["se_DiD"].fillna(1) + 1e-12)
    else:
        raise ValueError(f"Unknown rank_by: {rank_by}")

    return valid[["feature", "rank"]].dropna().sort_values("rank", ascending=False)


def _rank_between_arm_results(
    res: pd.DataFrame,
    rank_by: str,
    min_units: int,
) -> pd.DataFrame:
    """Rank the between-arm comparison results."""
    valid = res[res["n_units"] >= min_units].copy()
    if len(valid) == 0:
        raise ValueError(
            f"No genes have sufficient data (min_units={min_units}). "
            f"Try reducing min_units or checking your data."
        )

    if rank_by == "signed_confidence":
        valid["rank"] = np.sign(valid["beta_arm"].fillna(0)) * -np.log10(
            valid["p_arm"].fillna(1) + 1e-12
        )
    elif rank_by == "beta":
        valid["rank"] = valid["beta_arm"].fillna(0)
    elif rank_by == "tstat":
        valid["rank"] = valid["beta_arm"].fillna(0) / (valid["se_arm"].fillna(1) + 1e-12)
    else:
        raise ValueError(f"Unknown rank_by: {rank_by}")

    return valid[["feature", "rank"]].dropna().sort_values("rank", ascending=False)


def _rank_within_arm_results(
    res: pd.DataFrame,
    rank_by: str,
    min_units: int,
) -> pd.DataFrame:
    """Rank the within-arm comparison results."""
    valid = res[res["n_units"] >= min_units].copy()
    if len(valid) == 0:
        raise ValueError(
            f"No genes have sufficient data (min_units={min_units}). "
            f"Try reducing min_units or checking your data."
        )

    if rank_by == "signed_confidence":
        valid["rank"] = np.sign(valid["beta_time"].fillna(0)) * -np.log10(
            valid["p_time"].fillna(1) + 1e-12
        )
    elif rank_by == "beta":
        valid["rank"] = valid["beta_time"].fillna(0)
    elif rank_by == "tstat":
        valid["rank"] = valid["beta_time"].fillna(0) / (valid["se_time"].fillna(1) + 1e-12)
    else:
        raise ValueError(f"Unknown rank_by: {rank_by}")

    return valid[["feature", "rank"]].dropna().sort_values("rank", ascending=False)


[docs] def run_gsea_cross_sectional( adata: AnnData, gene_sets: str | dict[str, list[str]], design: TrialDesign, visit: str, layer: str | None = None, rank_by: str = "tstat", min_units: int = 4, return_obj: bool = False, **kwargs, ) -> pd.DataFrame | gp.Prerank: """Perform GSEA on cross-sectional data using between-arm comparison. Similar to run_gsea_did but for cross-sectional designs (single visit). Uses between_arm_comparison to rank genes by treatment vs control difference at a fixed visit, then runs gseapy.prerank. Parameters ---------- adata AnnData object containing expression data. gene_sets A library name (e.g. 'KEGG_2021_Human') or a dictionary mapping pathway names to gene lists. design A `TrialDesign` object. visit The visit label to analyze (single timepoint). layer Layer to extract gene expression from. rank_by Metric for ranking genes: 'signed_confidence', 'beta', or 'tstat'. min_units Minimum number of participants required per gene. return_obj Whether to return the full gseapy object. **kwargs Additional parameters passed to gseapy.prerank. Returns ------- pd.DataFrame or gseapy.Prerank Enrichment results or the gseapy result object. """ _ensure_gseapy() genes = adata.var_names.tolist() res = between_arm_comparison( adata, visit=visit, features=genes, design=design, layer=layer, aggregate="participant_visit", standardize=True, method="ols", ) ranking = _rank_between_arm_results(res, rank_by=rank_by, min_units=min_units) pre_res = gp.prerank(rnk=ranking, gene_sets=gene_sets, **kwargs) if return_obj: return pre_res if hasattr(pre_res, "res2d"): return pre_res.res2d return pre_res
[docs] def run_gsea_within_arm( adata: AnnData, gene_sets: str | dict[str, list[str]], design: TrialDesign, arm: str, visits: tuple[str, str], layer: str | None = None, rank_by: str = "tstat", min_units: int = 4, return_obj: bool = False, **kwargs, ) -> pd.DataFrame | gp.Prerank: """Perform GSEA on longitudinal data using within-arm comparison. Similar to run_gsea_did but uses within_arm_comparison instead of DiD. Ranks genes by pre→post change within a single arm, then runs gseapy.prerank. Parameters ---------- adata AnnData object containing expression data. gene_sets A library name (e.g. 'KEGG_2021_Human') or a dictionary mapping pathway names to gene lists. design A `TrialDesign` object. arm The arm to analyze (e.g., design.arm_treated). visits Tuple of (baseline, followup) visit labels. layer Layer to extract gene expression from. rank_by Metric for ranking genes: 'signed_confidence', 'beta', or 'tstat'. min_units Minimum number of participants required per gene. return_obj Whether to return the full gseapy object. **kwargs Additional parameters passed to gseapy.prerank. Returns ------- pd.DataFrame or gseapy.Prerank Enrichment results or the gseapy result object. """ _ensure_gseapy() genes = adata.var_names.tolist() res = within_arm_comparison( adata, arm=arm, features=genes, design=design, visits=visits, layer=layer, aggregate="participant_visit", standardize=True, ) ranking = _rank_within_arm_results(res, rank_by=rank_by, min_units=min_units) pre_res = gp.prerank(rnk=ranking, gene_sets=gene_sets, **kwargs) if return_obj: return pre_res if hasattr(pre_res, "res2d"): return pre_res.res2d return pre_res
[docs] def run_gsea_did( adata: AnnData, gene_sets: str | dict[str, list[str]], design: TrialDesign, visits: tuple[str, str], layer: str | None = None, exclude_crossovers: bool = True, celltype: str | None = None, rank_by: str = "signed_confidence", use_bootstrap: bool = False, n_boot: int = 999, min_units: int = 4, return_obj: bool = False, **kwargs, ) -> pd.DataFrame | gp.Prerank: """Perform Gene Set Enrichment Analysis (GSEA) on trial-aware rankings. This function calculates Difference-in-Differences (DiD) effect sizes for all genes, ranks them, and performs GSEA using `gseapy.prerank`. This approach ensures that enriched pathways represent treatment effects rather than baseline differences. Parameters ---------- adata AnnData object containing expression data. gene_sets A library name (e.g. 'KEGG_2021_Human') or a dictionary mapping pathway names to gene lists. design A `TrialDesign` object. visits A tuple of (baseline, followup) visit labels. layer Layer to extract gene expression from. Recommended to use a normalized layer (e.g., 'log1p_cpm'). exclude_crossovers Whether to exclude crossover cells from the DiD ranking. rank_by Metric for ranking genes: - 'signed_confidence': sign(beta_DiD) * -log10(p_DiD). Highlights genes with high effect and high significance. - 'beta': ranks genes solely by the DiD effect size. - 'tstat': ranks genes by the t-statistic (beta_DiD / se_DiD). use_bootstrap Whether to use Wild Cluster Bootstrap for DiD p-values (used if rank_by is 'signed_confidence'). n_boot Number of bootstrap permutations. min_units Minimum number of paired participants required for a gene to be included in the ranking. Genes with fewer participants return NaN and are filtered out before GSEA. Default is 4. return_obj Whether to return the full gseapy object. If False (default), returns the results DataFrame (`res2d`). **kwargs Additional parameters passed to `gseapy.prerank` (e.g., `permutation_num`, `outdir`, `min_size`, `max_size`). Returns ------- pd.DataFrame or gseapy.Prerank A DataFrame of enrichment results (if return_obj=False) or the gseapy result object. Examples -------- >>> res = run_gsea_did(adata, gene_sets="KEGG_2021_Human", design=design, visits=("V1", "V2")) >>> print(res.head()) """ _ensure_gseapy() # 1. Run DiD for all genes genes = adata.var_names.tolist() res = did_table( adata, features=genes, design=design, visits=visits, celltype=celltype, exclude_crossovers=exclude_crossovers, layer=layer, aggregate="participant_visit", use_bootstrap=use_bootstrap, n_boot=n_boot, ) # 2. Filter genes with insufficient data # Genes with n_units < min_units will have NaN beta_DiD ranking = _rank_did_results(res, rank_by=rank_by, min_units=min_units) # 3. Run GSEA Prerank pre_res = gp.prerank(rnk=ranking, gene_sets=gene_sets, **kwargs) if return_obj: return pre_res # Handle both gseapy >= 1.0 (has .res2d) and potentially older versions if hasattr(pre_res, "res2d"): return pre_res.res2d return pre_res
[docs] def run_gsea_did_multi( adata: AnnData, gene_sets: dict[str, str | dict[str, list[str]]], design: TrialDesign, visits: tuple[str, str], **kwargs, ) -> dict[str, pd.DataFrame | gp.Prerank]: """Run GSEA across multiple gene-set collections. Parameters ---------- adata AnnData object. gene_sets Dictionary of gene-set collections. design TrialDesign object. visits Tuple of (baseline, followup) visit labels. **kwargs Additional parameters passed to `gseapy.prerank` (e.g., `permutation_num`, `outdir`, `min_size`, `max_size`). Returns ------- dict[str, pd.DataFrame | gp.Prerank] Dictionary of GSEA results. """ _ensure_gseapy() results: dict[str, pd.DataFrame | gp.Prerank] = {} for label, gs in gene_sets.items(): results[label] = run_gsea_did( adata, gene_sets=gs, design=design, visits=visits, **kwargs, ) return results
[docs] def run_gsea_did_by_celltype( adata: AnnData, gene_sets: str | dict[str, list[str]], design: TrialDesign, visits: tuple[str, str], celltypes: list[str] | None = None, **kwargs, ) -> dict[str, pd.DataFrame | gp.Prerank]: """Run GSEA on DiD rankings separately for each cell type. Parameters ---------- adata AnnData object. gene_sets Dictionary of gene-set collections. design TrialDesign object. visits Tuple of (baseline, followup) visit labels. celltypes List of cell types to analyze. If None, uses all unique cell types in `design.celltype_col`. **kwargs Additional parameters passed to `gseapy.prerank` Returns ------- dict[str, pd.DataFrame | gp.Prerank] Dictionary of GSEA results. """ _ensure_gseapy() if design.celltype_col is None: raise ValueError("design.celltype_col must be set for celltype GSEA.") if celltypes is None: celltypes = sorted(adata.obs[design.celltype_col].dropna().unique()) results: dict[str, pd.DataFrame | gp.Prerank] = {} for ct in celltypes: results[ct] = run_gsea_did( adata, gene_sets=gene_sets, design=design, visits=visits, celltype=ct, **kwargs, ) return results
[docs] def run_gsea_pseudobulk( adata: AnnData, gene_sets: str | dict[str, list[str]], design: TrialDesign, visits: tuple[str, str], *, celltype_col: str | None = None, rank_by: str = "signed_confidence", min_units: int = 4, return_obj: bool = False, **kwargs, ) -> pd.DataFrame | gp.Prerank | dict[str, gp.Prerank]: """Run GSEA using pseudobulk DiD results. Parameters ---------- adata AnnData object. gene_sets Dictionary of gene-set collections. design TrialDesign object. visits Tuple of (baseline, followup) visit labels. celltype_col If provided and results contain per-celltype rows, run GSEA separately for each cell type and concatenate results. rank_by Metric for ranking genes. One of ``'signed_confidence'`` (sign(beta_DiD) * -log10(p_DiD), default), ``'beta'`` (DiD effect size), or ``'tstat'`` (t-statistic beta_DiD / se_DiD). min_units Minimum number of paired participants required for a gene to be included in the ranking. return_obj Whether to return the full gseapy object. If False (default), returns the results DataFrame (`res2d`). **kwargs Additional parameters passed to `gseapy.prerank` (e.g., `permutation_num`, `outdir`, `min_size`, `max_size`). Returns ------- pd.DataFrame | gp.Prerank | dict[str, gp.Prerank] A DataFrame of enrichment results (if return_obj=False) or the gseapy result object. """ _ensure_gseapy() res = pseudobulk_did( adata, genes=adata.var_names.tolist(), design=design, visits=visits, celltype_col=celltype_col, ) # If celltype_col is provided and results contain per-celltype rows, # run GSEA separately for each cell type and concatenate results. if celltype_col is not None and "celltype" in res.columns: all_results: list[pd.DataFrame] = [] obj_results: dict[str, gp.Prerank] = {} for ct, ct_res in res.groupby("celltype"): try: ct_ranking = _rank_did_results(ct_res, rank_by=rank_by, min_units=min_units) except ValueError: continue ct_pre = gp.prerank(rnk=ct_ranking, gene_sets=gene_sets, **kwargs) if return_obj: obj_results[ct] = ct_pre else: ct_df = ct_pre.res2d if hasattr(ct_pre, "res2d") else pd.DataFrame(ct_pre) ct_df["celltype"] = ct all_results.append(ct_df) if return_obj: return obj_results if not all_results: return pd.DataFrame() return pd.concat(all_results, ignore_index=True) ranking = _rank_did_results(res, rank_by=rank_by, min_units=min_units) pre_res = gp.prerank( rnk=ranking, gene_sets=gene_sets, **kwargs, ) if return_obj: return pre_res if hasattr(pre_res, "res2d"): return pre_res.res2d return pre_res