Source code for sctrial.convenience

"""Convenience functions for quick trial analysis workflows."""

from __future__ import annotations

import logging
from typing import Literal

import pandas as pd
from anndata import AnnData

from .design import TrialDesign
from .preprocessing import add_log1p_cpm_layer
from .scoring import score_gene_sets
from .stats.did import did_table

logger = logging.getLogger(__name__)

__all__ = [
    "quick_did",
    "auto_detect_design",
]


[docs] def quick_did( adata: AnnData, module_scores: dict[str, list[str]], visits: tuple[str, str], participant_col: str = "participant_id", visit_col: str = "visit", arm_col: str = "arm", arm_treated: str = "Treated", arm_control: str = "Control", celltype_col: str | None = None, layer: str = "log1p_cpm", counts_layer: str = "counts", score_method: Literal["zmean", "mean"] = "zmean", min_genes: int = 3, **kwargs, ) -> pd.DataFrame: """One-line wrapper for the most common DiD workflow. This function combines preprocessing, scoring, and DiD analysis in a single call for quick exploration. Parameters ---------- adata AnnData object with trial data. module_scores Dictionary mapping module names to gene lists. visits Tuple of (baseline, followup) visit labels. The order matters: the first element is the baseline visit and the second is the follow-up, which determines the sign of the DiD estimate. participant_col Column name for participant identifiers. visit_col Column name for visit labels. arm_col Column name for treatment arm. arm_treated Label for treated arm. arm_control Label for control arm. celltype_col Optional column name for cell types. layer Layer name for normalized expression. If not present, will be created. counts_layer Layer name for raw counts (used if layer needs to be created). score_method Method for gene set scoring ('mean' or 'zmean'). min_genes Minimum number of overlapping genes for a module score to be computed. Modules with fewer overlapping genes get NaN scores. Default is 3. **kwargs Additional arguments passed to did_table(). Returns ------- pd.DataFrame DiD results table. Examples -------- >>> gene_sets = { ... "OXPHOS": ["COX7A1", "ATP5F1A", "NDUFA1"], ... "Glycolysis": ["PKM", "LDHA", "HK2"] ... } >>> res = quick_did( ... adata, ... module_scores=gene_sets, ... visits=("V1", "V2") ... ) >>> print(res[["feature", "beta_DiD", "p_DiD", "FDR_DiD"]]) """ # Validate that visit_col exists before doing anything else if visit_col not in adata.obs.columns: raise ValueError( f"Column '{visit_col}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns[:20])}" ) # Validate arm labels exist in the data actual_arms = set(adata.obs[arm_col].unique()) if arm_col in adata.obs.columns else set() if arm_col not in adata.obs.columns: raise ValueError(f"Column '{arm_col}' not found in adata.obs.") if arm_treated not in actual_arms: raise ValueError( f"arm_treated='{arm_treated}' not found in column '{arm_col}'. " f"Available values: {sorted(actual_arms)}" ) if arm_control not in actual_arms: raise ValueError( f"arm_control='{arm_control}' not found in column '{arm_col}'. " f"Available values: {sorted(actual_arms)}" ) # 1. Create design design = TrialDesign( participant_col=participant_col, visit_col=visit_col, arm_col=arm_col, arm_treated=arm_treated, arm_control=arm_control, celltype_col=celltype_col, ) # 2. Ensure preprocessing if layer not in adata.layers: logger.info("Creating '%s' layer from '%s'...", layer, counts_layer) adata = add_log1p_cpm_layer(adata, counts_layer=counts_layer, out_layer=layer) # 3. Score gene sets logger.info("Scoring %d gene sets...", len(module_scores)) adata = score_gene_sets( adata, module_scores, layer=layer, method=score_method, prefix="ms_", min_genes=min_genes, ) # 4. Run DiD features = [f"ms_{k}" for k in module_scores.keys()] logger.info("Running DiD for %d features...", len(features)) return did_table(adata, features=features, design=design, visits=visits, **kwargs)
def _detect_column_patterns() -> dict[str, list[str]]: """Return column name patterns for auto-detection. Patterns are tried in order: exact match first, then word-boundary partial match (see :func:`_find_column`). """ return { "participant": [ "participant_id", "participant", "patient_id", "patient", "donor_id", "donor", "subject_id", "subject", "sample_id", ], "visit": [ "visit", "timepoint", "time_point", "time", "day", "week", "collection_day", ], "arm": [ "arm", "treatment_arm", "arm_id", "treatment", "group", "condition", ], "celltype": [ "celltype", "cell_type", "cluster", "annotation", "cell_annotation", "celltype_major", ], } def _auto_detect_arm_labels( arms: list[str], arm_treated: str | None, arm_control: str | None, ) -> tuple[str | None, str | None]: """Auto-detect arm labels based on common keywords.""" treated_keywords = ["treat", "drug", "active", "intervention"] control_keywords = ["control", "placebo", "sham", "vehicle"] for arm in arms: arm_lower = str(arm).lower() if arm_treated is None and any(kw in arm_lower for kw in treated_keywords): arm_treated = str(arm) if arm_control is None and any(kw in arm_lower for kw in control_keywords): arm_control = str(arm) return arm_treated, arm_control def _print_design_summary( participant_col: str, visit_col: str, arm_col: str, arm_treated: str, arm_control: str, celltype_col: str | None, ) -> None: """Print a summary of the detected trial design.""" logger.info("\n" + "=" * 60) logger.info("AUTO-DETECTED TRIAL DESIGN") logger.info("=" * 60) logger.info("Participant column: %s", participant_col) logger.info("Visit column: %s", visit_col) logger.info("Arm column: %s", arm_col) logger.info(" Treated arm: %s", arm_treated) logger.info(" Control arm: %s", arm_control) if celltype_col: logger.info("Cell type column: %s", celltype_col) else: logger.info("Cell type column: (not detected)") logger.info("=" * 60) logger.warning("Please verify this design is correct before using!") logger.info("=" * 60 + "\n")
[docs] def auto_detect_design( adata: AnnData, arm_treated: str | None = None, arm_control: str | None = None, ) -> TrialDesign: """Auto-detect trial design from common column naming patterns. Looks for common patterns in column names: - participant: ``participant_id``, ``patient_id``, ``donor_id``, ``subject_id``, ``sample_id`` - visit: ``visit``, ``timepoint``, ``time``, ``day``, ``week``, ``time_point`` - arm: ``arm``, ``treatment_arm``, ``arm_id``, ``treatment``, ``group``, ``condition`` - celltype: ``celltype``, ``cell_type``, ``cluster``, ``annotation``, ``cell_annotation`` Column matching uses exact (case-insensitive) match first, then word-boundary partial match (pattern must appear at a word boundary in the column name, so ``"arm"`` matches ``"arm_id"`` but not ``"farm_id"``). Parameters ---------- adata AnnData object to analyze. arm_treated Optional: specify the label for treated arm. If not provided, the function tries keyword-based detection (e.g. "Treated", "Drug", "Active"). Raises if detection fails. arm_control Optional: specify the label for control arm. If not provided, the function tries keyword-based detection (e.g. "Control", "Placebo"). Raises if detection fails. Returns ------- TrialDesign Detected design (may need manual adjustment). Examples -------- >>> design = auto_detect_design(adata) >>> print(f"Detected design: {design}") >>> # Verify the detected design. If arm labels need adjustment, >>> # create a new TrialDesign with corrected values: >>> from dataclasses import replace >>> design = replace(design, arm_treated="Drug_A", arm_control="Placebo") Raises ------ ValueError If required columns cannot be detected, or if arm labels cannot be determined (more than 2 arms, only 1 arm, or no keyword match for ambiguous labels). """ obs_cols = adata.obs.columns.tolist() obs_cols_lower = [c.lower() for c in obs_cols] patterns = _detect_column_patterns() participant_col = _find_column(obs_cols, obs_cols_lower, patterns["participant"]) visit_col = _find_column(obs_cols, obs_cols_lower, patterns["visit"]) arm_col = _find_column(obs_cols, obs_cols_lower, patterns["arm"]) celltype_col = _find_column(obs_cols, obs_cols_lower, patterns["celltype"], required=False) # Validate required columns were detected (check early) if participant_col is None: raise ValueError( "Could not detect participant column. Please specify manually or rename column." ) if visit_col is None: raise ValueError("Could not detect visit column. Please specify manually or rename column.") if arm_col is None: raise ValueError("Could not detect arm column. Please specify manually or rename column.") # Auto-detect arm labels if not provided if arm_treated is None or arm_control is None: unique_arms = adata.obs[arm_col].unique() if len(unique_arms) == 2: # Try keyword-based detection arm_treated, arm_control = _auto_detect_arm_labels( [str(a) for a in unique_arms], arm_treated, arm_control ) # If keywords didn't resolve both labels, raise instead # of silently assigning by encounter order. if arm_treated is None or arm_control is None: raise ValueError( f"Found 2 arms {sorted(str(a) for a in unique_arms)} in column '{arm_col}' but could not " "determine which is treated vs control. " "Please specify arm_treated and arm_control." ) logger.info( "Auto-detected arms: treated='%s', control='%s'", arm_treated, arm_control, ) logger.warning("Please verify these are correct!") elif len(unique_arms) > 2: raise ValueError( f"Found {len(unique_arms)} arms in column '{arm_col}': " f"{sorted(str(a) for a in unique_arms)}. " "Please specify arm_treated and arm_control manually." ) else: raise ValueError( f"Found only 1 arm in column '{arm_col}': '{unique_arms[0]}'. " "DiD requires at least 2 arms." ) # Create design design = TrialDesign( participant_col=participant_col, visit_col=visit_col, arm_col=arm_col, arm_treated=arm_treated, arm_control=arm_control, celltype_col=celltype_col, ) # Print summary _print_design_summary( participant_col, visit_col, arm_col, arm_treated, arm_control, celltype_col, ) return design
def _find_column( columns: list[str], columns_lower: list[str], patterns: list[str], required: bool = True, ) -> str | None: """Find column matching one of the patterns. Matching strategy (tried in order for each pattern): 1. **Exact match** (case-insensitive): ``col.lower() == pattern`` 2. **Word-boundary partial match**: pattern must appear at a word boundary (start of string, after ``_``, or after ``-``). This prevents ``"arm"`` from matching ``"farm_id"`` while still matching ``"arm_treatment"`` or ``"treatment_arm"``. """ import re for pattern in patterns: # First try exact match (case-insensitive) for i, col_lower in enumerate(columns_lower): if col_lower == pattern: return columns[i] # Then try word-boundary partial match # Pattern must appear at start of string or after _ or - boundary_re = re.compile(r"(?:^|[_\-])" + re.escape(pattern)) for i, col_lower in enumerate(columns_lower): if boundary_re.search(col_lower): return columns[i] if required: raise ValueError( f"Could not auto-detect column. Tried patterns: {patterns}\n" f"Available columns: {columns[:20]}\n" "Please specify the column manually when creating TrialDesign." ) return None