Source code for sctrial.stats.pseudobulk

from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import cast

import numpy as np
import pandas as pd
import scipy.sparse as sp
from anndata import AnnData
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

from ..design import TrialDesign
from ._utils import apply_fdr, encode_visit
from .did import did_fit

logger = logging.getLogger(__name__)

__all__ = ["pseudobulk_expression", "pseudobulk_within_arm", "pseudobulk_did", "pseudobulk_export"]


def _get_layer(adata: AnnData, layer: str | None) -> np.ndarray | sp.csr_matrix:
    """Get the layer from the AnnData object."""
    X = adata.layers[layer] if layer is not None else adata.X
    if sp.issparse(X):
        X = X.tocsr()
    return X


[docs] def pseudobulk_expression( adata: AnnData, genes: Sequence[str], groupby: Sequence[str], counts_layer: str | None = "counts", scale: float = 1e6, log1p: bool = True, min_cells_per_group: int = 1, include_n_cells: bool = True, ) -> pd.DataFrame: """Compute pseudobulk log1p-CPM for a gene panel per group. Parameters ---------- adata AnnData object. genes List of genes to summarize. groupby Columns in `adata.obs` used for grouping (e.g., participant, visit, cell type). counts_layer Layer containing raw counts. Defaults to "counts". scale CPM scale factor (default 1e6). log1p If True, apply log1p to CPM. min_cells_per_group Minimum cells per group to include. include_n_cells If True, include the number of cells per group. Returns ------- pd.DataFrame DataFrame with pseudobulk expression. """ genes = [g for g in genes if g in adata.var_names] if not genes: return pd.DataFrame() # Group encoding (vectorized aggregation) group_df = adata.obs[list(groupby)].copy() if min_cells_per_group > 1: counts = group_df.value_counts().rename("n_cells") keep_groups = counts[counts >= min_cells_per_group].index # Use isin mask instead of merge to preserve the original cell index group_mi = pd.MultiIndex.from_frame(group_df[list(groupby)]) mask = group_mi.isin(keep_groups) group_df = group_df[mask] adata = adata[group_df.index].copy() if group_df.empty: return pd.DataFrame() # Extract expression matrix AFTER any subsetting so dimensions match X = _get_layer(adata, counts_layer) gene_idx = [int(adata.var_names.get_loc(g)) for g in genes] group_index = pd.MultiIndex.from_frame(group_df) group_index.names = list(groupby) group_codes, groups = pd.factorize(group_index) n_groups = len(groups) # Total counts per cell if sp.issparse(X): X_sparse = cast(sp.csr_matrix, X) X_panel = X_sparse[:, gene_idx].tocsr() total_counts = np.asarray(X.sum(axis=1)).ravel() # Sparse group aggregation: G @ X_panel rows = group_codes cols = np.arange(X_panel.shape[0], dtype=int) data = np.ones(X_panel.shape[0], dtype=float) G = sp.csr_matrix((data, (rows, cols)), shape=(n_groups, X_panel.shape[0])) sums = (G @ X_panel).toarray() totals = np.asarray(G @ total_counts).ravel() n_cells = np.asarray(G @ np.ones(X_panel.shape[0], dtype=float)).ravel() else: X_panel = np.asarray(X[:, gene_idx]) total_counts = np.asarray(X.sum(axis=1)).ravel() sums = np.zeros((n_groups, len(genes)), dtype=float) totals = np.zeros(n_groups, dtype=float) n_cells = np.zeros(n_groups, dtype=float) np.add.at(sums, group_codes, X_panel) np.add.at(totals, group_codes, total_counts) np.add.at(n_cells, group_codes, 1.0) # Build group summary dataframe df_sum = groups.to_frame(index=False) df_sum.columns = list(groupby) for i, g in enumerate(genes): df_sum[g] = sums[:, i] df_sum["total_counts"] = totals if include_n_cells: df_sum["n_cells"] = n_cells.astype(int) totals = df_sum["total_counts"].to_numpy(dtype=float).reshape(-1, 1) zero_mask = totals.ravel() == 0 if zero_mask.any(): n_dropped = int(zero_mask.sum()) logger.warning( "Dropped %d group(s) with zero total counts before CPM normalization.", n_dropped, ) df_sum = df_sum[~zero_mask].copy() totals = totals[~zero_mask] cpm = df_sum[genes].values / totals * scale if log1p: cpm = np.log1p(cpm) df_sum[genes] = cpm return df_sum
[docs] def pseudobulk_export( adata: AnnData, genes: Sequence[str], design: TrialDesign, *, visits: tuple[str, str] | None = None, celltype_col: str | None = None, counts_layer: str | None = "counts", min_cells_per_group: int = 1, log1p: bool = True, ) -> AnnData: """Export pseudobulk expression as a new AnnData object. Parameters ---------- adata Input AnnData. genes List of genes to include. design TrialDesign object. visits Optional (baseline, followup) visits to subset. celltype_col If provided, aggregate per participant-visit-celltype. counts_layer Layer to use for counts (default: "counts"). min_cells_per_group Minimum cells per group to include. log1p Whether to log1p-transform CPM values. Returns ------- AnnData AnnData with pseudobulk expression in .X and group metadata in .obs. """ groupby: list[str] = [design.participant_col, design.visit_col] if design.arm_col is not None and design.arm_col in adata.obs.columns: groupby.append(design.arm_col) if celltype_col is not None: groupby.append(celltype_col) if visits is not None: adata = adata[adata.obs[design.visit_col].isin(visits)].copy() df_sum = pseudobulk_expression( adata, genes=genes, groupby=groupby, counts_layer=counts_layer, min_cells_per_group=min_cells_per_group, log1p=log1p, ) # Use only genes that survived filtering in pseudobulk_expression available_genes = [g for g in genes if g in df_sum.columns] if not available_genes: raise ValueError("No requested genes found in pseudobulk output.") X = df_sum[available_genes].to_numpy(dtype=float) obs = df_sum[groupby].copy() pb = AnnData(X=X, obs=obs) pb.var_names = list(available_genes) return pb
[docs] def pseudobulk_did( adata: AnnData, genes: Sequence[str], design: TrialDesign, visits: tuple[str, str], *, celltype_col: str | None = None, counts_layer: str | None = "counts", log1p: bool = True, min_cells_per_group: int = 5, min_paired: int = 4, use_bootstrap: bool = False, n_boot: int = 999, seed: int = 42, ) -> pd.DataFrame: """Run DiD on pseudobulk expression (participant-level aggregates). This mirrors subject-level pseudobulk DiD workflows where each participant×visit (optionally per cell type) is one observation. Parameters ---------- adata AnnData object. genes List of genes to analyze. design TrialDesign object. visits Tuple of (baseline, followup) visit labels. celltype_col Column name in adata.obs to use for cell types. counts_layer Layer name in adata.layers to use for expression data. log1p Whether to log1p the expression data. min_cells_per_group Minimum number of cells per group to include in the analysis. min_paired Minimum number of paired participants to include in the analysis. use_bootstrap Whether to use bootstrap to calculate p-values. Recommended for small sample sizes. n_boot Number of bootstrap permutations. seed Random seed for reproducibility. Returns ------- pd.DataFrame A DataFrame with the results of the DiD analysis. - feature: Name of the feature. - beta_DiD: Effect size (difference in means between arms). - p_DiD: P-value for the DiD analysis. - FDR_DiD: False Discovery Rate corrected p-value. - n_units: Number of unique units (participants) included in the analysis. """ genes = [g for g in genes if g in adata.var_names] if not genes: return pd.DataFrame() if design.arm_col is None: raise ValueError("pseudobulk_did requires a two-arm design (arm_col must not be None)") groupby: list[str] = [design.participant_col, design.visit_col, design.arm_col] if celltype_col is not None: groupby.append(celltype_col) pb = pseudobulk_expression( adata, genes=genes, groupby=groupby, counts_layer=counts_layer, log1p=log1p, include_n_cells=True, ) if pb.empty: return pd.DataFrame() if "n_cells" in pb.columns: pb = pb[pb["n_cells"] >= min_cells_per_group].copy() pb = pb[pb[design.visit_col].isin(visits)].copy() pb["arm_bin"] = (pb[design.arm_col] == design.arm_treated).astype(int) pb = encode_visit(pb, design.visit_col, visits) rows = [] if celltype_col is None: pools = [None] else: pools = sorted(pb[celltype_col].dropna().unique()) for pool in pools: if pool is not None: df_pool = pb[pb[celltype_col] == pool].copy() else: df_pool = pb.copy() # paired participants: check visit presence using row counts (not a # single gene) so that pairing is consistent across all genes. visit_counts = ( df_pool.groupby([design.participant_col, design.visit_col], observed=True) .size() .unstack(fill_value=0) ) has_v0 = visit_counts.get(visits[0], pd.Series(0, index=visit_counts.index)) > 0 has_v1 = visit_counts.get(visits[1], pd.Series(0, index=visit_counts.index)) > 0 keep = visit_counts.index[has_v0 & has_v1] df_pool = df_pool[df_pool[design.participant_col].isin(keep)].copy() if df_pool[design.participant_col].nunique() < min_paired: continue for g in genes: out = did_fit( df_pool, y=g, unit=design.participant_col, time="visit_num", arm_bin="arm_bin", covariates=None, standardize=True, use_bootstrap=use_bootstrap, n_boot=n_boot, seed=seed, ) row = dict(out) row["feature"] = g # n_units is already set by did_fit (post-row-drop count) if pool is not None: row["celltype"] = pool rows.append(row) res = pd.DataFrame(rows) if res.empty: return res res = apply_fdr(res, p_col="p_DiD", fdr_col="FDR_DiD") # Per-celltype FDR correction if celltype_col is not None: res["FDR_DiD_celltype"] = np.nan for ct, sub in res.groupby("celltype", observed=True): m = sub["p_DiD"].notna() if m.sum() > 0: res.loc[sub.index[m], "FDR_DiD_celltype"] = multipletests( sub.loc[m, "p_DiD"], method="fdr_bh" )[1] return res.reset_index(drop=True)
[docs] def pseudobulk_within_arm( adata: AnnData, genes: Sequence[str], participant_col: str, visit_col: str, visits: Sequence[str], celltype_col: str, counts_layer: str | None = "counts", min_paired: int = 3, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Compute within-arm pseudobulk deltas and Wilcoxon tests. Parameters ---------- adata AnnData object. genes List of genes to analyze. participant_col Column name in adata.obs to use for participant IDs. visit_col Column name in adata.obs to use for visit labels. visits Tuple of (baseline, followup) visit labels. celltype_col Column name in adata.obs to use for cell types. counts_layer Layer name in adata.layers to use for expression data. min_paired Minimum number of paired participants to include in the analysis. Returns ------- tuple[pd.DataFrame, pd.DataFrame] A tuple of ``(summary_df, delta_long_df)``. ``summary_df`` columns: celltype, feature, n_units, mean_delta, median_delta, p_time, FDR_time. ``delta_long_df`` columns: celltype, feature, participant_id, delta. If there are no valid pairs, both DataFrames are empty. """ pb = pseudobulk_expression( adata, genes=genes, groupby=[participant_col, visit_col, celltype_col], counts_layer=counts_layer, log1p=True, ) if pb.empty: return pd.DataFrame(), pd.DataFrame() rows = [] deltas = [] for ct in pb[celltype_col].unique(): sub = pb[pb[celltype_col] == ct].copy() wide = sub.pivot_table( index=participant_col, columns=visit_col, aggfunc="size", fill_value=0, observed=True ) keep = wide[(wide.get(visits[0], 0) > 0) & (wide.get(visits[1], 0) > 0)].index sub = sub[sub[participant_col].isin(keep)].copy() if sub[participant_col].nunique() < min_paired: continue for g in genes: w = sub.pivot_table(index=participant_col, columns=visit_col, values=g, aggfunc="mean") if visits[0] not in w.columns or visits[1] not in w.columns: continue delta = (w[visits[1]] - w[visits[0]]).dropna() for pid, dv in delta.items(): deltas.append( { "celltype": ct, "feature": g, "participant_id": pid, "delta": float(dv), } ) if len(delta) < min_paired: p_val = np.nan else: try: _, p_val = wilcoxon(delta.values) except (ValueError, TypeError): p_val = np.nan rows.append( { "celltype": ct, "feature": g, "n_units": int(len(delta)), "mean_delta": float(delta.mean()) if len(delta) else np.nan, "median_delta": float(delta.median()) if len(delta) else np.nan, "p_time": float(p_val), } ) summary = pd.DataFrame(rows) if not summary.empty: summary = apply_fdr(summary, p_col="p_time", fdr_col="FDR_time") delta_long = pd.DataFrame(deltas) return summary, delta_long