Source code for sctrial.design

"""Trial design specification: TrialDesign dataclass and design detection."""

from __future__ import annotations

import logging
from collections.abc import Sequence
from dataclasses import dataclass

import pandas as pd
from anndata import AnnData

logger = logging.getLogger(__name__)

__all__ = ["TrialDesign"]


[docs] @dataclass(frozen=True) class TrialDesign: """Describe the trial-design columns and metadata labels in `adata.obs`. The `TrialDesign` object centralizes the mapping of your study design to the AnnData object. It is used by almost all statistical and plotting functions in `sctrial`. """ participant_col: str = "participant_id" """Name of the column containing unique participant identifiers.""" visit_col: str = "visit" """Name of the column containing visit or timepoint labels.""" arm_col: str | None = "arm" """Name of the column containing treatment arm assignments. Set to ``None`` for single-arm studies that lack an arm column. """ arm_treated: str = "Treated" """The label in `arm_col` representing the treatment/experimental group.""" arm_control: str = "Control" """The label in `arm_col` representing the control/placebo group.""" celltype_col: str | None = "celltype" """Optional name of the column containing cell-type annotations.""" crossover_col: str | None = None """Optional name of the column containing boolean-like indicators for crossover cells.""" baseline_visit: str | None = None """Optional default baseline visit label (e.g., 'Baseline', 'V1').""" followup_visit: str | None = None """Optional default follow-up visit label (e.g., 'Follow-up', 'V2')."""
[docs] def primary_visits( self, baseline: str | None = None, followup: str | None = None, ) -> tuple[str, str]: """Return (baseline, followup) visit labels. Parameters ---------- baseline Optional explicit baseline visit label. If None, uses ``self.baseline_visit``. followup Optional explicit follow-up visit label. If None, uses ``self.followup_visit``. Returns ------- tuple[str, str] Tuple of (baseline, followup) visit labels. """ b = baseline if baseline is not None else self.baseline_visit f = followup if followup is not None else self.followup_visit if b is None or f is None: raise ValueError( "Primary visits not specified. Provide baseline/followup or set " "TrialDesign(baseline_visit=..., followup_visit=...)." ) return (b, f)
[docs] def required_cols( self, *, include_celltype: bool = False, include_crossover: bool = False, ) -> Sequence[str]: """Return required obs columns for this design. Parameters ---------- include_celltype If True, include ``celltype_col`` when it is defined. include_crossover If True, include ``crossover_col`` when it is defined. Returns ------- list[str] List of required columns. """ cols = [c for c in [self.participant_col, self.visit_col, self.arm_col] if c is not None] if include_celltype and self.celltype_col is not None: cols.append(self.celltype_col) if include_crossover and self.crossover_col is not None: cols.append(self.crossover_col) return cols
[docs] def validate( self, adata: AnnData, *, include_celltype: bool = False, include_crossover: bool = False, check_arm_labels: bool = True, ) -> None: """Validate that `adata.obs` contains required columns and labels. Parameters ---------- adata AnnData object to validate. include_celltype If True, require ``celltype_col`` in ``adata.obs``. include_crossover If True, require ``crossover_col`` in ``adata.obs``. check_arm_labels If True, verify that treated/control labels are present in ``arm_col``. Returns ------- None Raises ------ KeyError If required columns are missing. ValueError If arm labels are not found in ``adata.obs[self.arm_col]``. """ obs = adata.obs missing = [ c for c in self.required_cols( include_celltype=include_celltype, include_crossover=include_crossover, ) if c not in obs.columns ] if missing: raise KeyError( f"Missing required obs columns: {missing}. Available: {list(obs.columns)}" ) if check_arm_labels and self.arm_col is not None: if self.arm_treated == self.arm_control: raise ValueError( f"arm_treated and arm_control must be distinct for " f"between-arm analyses, got {self.arm_treated!r} for both. " f"For single-arm studies, use check_arm_labels=False." ) arms = set(pd.Series(obs[self.arm_col]).dropna().unique().tolist()) if (self.arm_treated not in arms) or (self.arm_control not in arms): raise ValueError( f"Arm labels not found in obs['{self.arm_col}']. " f"Expected treated='{self.arm_treated}', control='{self.arm_control}'. " f"Observed arms: {sorted(arms)}" ) extra = arms - {self.arm_treated, self.arm_control} if extra: import warnings warnings.warn( f"obs['{self.arm_col}'] contains arms beyond treated/control: " f"{sorted(extra)}. These participants will be treated as " f"control in arm_bin(). Consider subsetting your data to " f"only the two arms of interest.", UserWarning, stacklevel=2, )
[docs] def arm_bin(self, obs: pd.DataFrame) -> pd.Series: """Return 0/1 treated indicator aligned to obs.index. Parameters ---------- obs DataFrame containing the participant-visit data. Returns ------- pd.Series A Series with 0/1 indicator of treated status. Raises ------ ValueError If arm_treated and arm_control are the same. KeyError If arm_col is not in obs.columns. """ if self.arm_col is None: raise ValueError( "arm_bin() requires arm_col to be set. " "Single-arm designs (arm_col=None) do not have arm indicators." ) if self.arm_treated == self.arm_control: raise ValueError( f"arm_bin() requires distinct arm labels, but " f"arm_treated == arm_control == {self.arm_treated!r}." ) if self.arm_col not in obs.columns: raise KeyError(f"arm_col '{self.arm_col}' not in obs.") return (obs[self.arm_col] == self.arm_treated).astype(int)