Source code for sctrial.validation

"""Data validation utilities for trial analysis."""

from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Any

import numpy as np
import pandas as pd
from anndata import AnnData

from .datasets import count_paired
from .design import TrialDesign
from .utils import get_counts_matrix

logger = logging.getLogger(__name__)

__all__ = [
    "TrialDataValidator",
    "validate_adata",
    "validate_features",
    "diagnose_trial_data",
    "check_covariate_balance",
]


[docs] class TrialDataValidator: """Comprehensive validation for trial analysis data."""
[docs] @staticmethod def validate_adata( adata: AnnData, design: TrialDesign, strict: bool = False, ) -> list[str]: """Validate AnnData object for trial analysis. Parameters ---------- adata AnnData object to validate. design TrialDesign specifying column names. strict If True, raises exceptions. If False, returns warnings. Returns ------- List of warning/error messages. Examples -------- >>> validator = TrialDataValidator() >>> issues = validator.validate_adata(adata, design, strict=False) >>> if issues: ... print(f"Found {len(issues)} issues:") ... for issue in issues: ... print(f" - {issue}") """ issues = [] # Check required columns required_cols = [ c for c in [design.participant_col, design.visit_col, design.arm_col] if c is not None ] for col in required_cols: if col not in adata.obs.columns: msg = f"Required column '{col}' not found in adata.obs" if strict: raise KeyError(msg) issues.append(msg) # Check for missing data for col in required_cols: if col in adata.obs.columns: n_missing = adata.obs[col].isna().sum() if n_missing > 0: pct = 100 * n_missing / len(adata) msg = f"{n_missing} ({pct:.1f}%) missing values in '{col}'" issues.append(msg) # Check sample size if design.participant_col in adata.obs.columns: n_participants = adata.obs[design.participant_col].nunique() if n_participants < 4: msg = f"Only {n_participants} participants (minimum 4 required for DiD)" if strict: raise ValueError(msg) issues.append(msg) elif n_participants < 10: issues.append( f"Only {n_participants} participants; consider using " "bootstrap (use_bootstrap=True) for robust inference" ) # Check paired data if design.visit_col in adata.obs.columns: visits = adata.obs[design.visit_col].unique() if len(visits) < 2: msg = f"Only {len(visits)} visit(s) found (need >= 2 for longitudinal analysis)" if strict: raise ValueError(msg) issues.append(msg) # Check for counts layer counts, _source = get_counts_matrix(adata) if counts is None: issues.append( "No raw counts found in adata.layers['counts'], adata.raw.X, adata.layers['raw'], or adata.X. " "For best results, provide raw counts in adata.layers['counts']" ) # Check cell type column if specified if design.celltype_col: if design.celltype_col not in adata.obs.columns: msg = f"Cell type column '{design.celltype_col}' not found in adata.obs" if strict: raise KeyError(msg) issues.append(msg) return issues
[docs] @staticmethod def validate_features( adata: AnnData, features: Sequence[str], allow_missing: bool = False, ) -> tuple[list[str], list[str]]: """Validate feature names. Parameters ---------- adata AnnData object. features List of feature names to validate. allow_missing If False, raises error for missing features. Returns ------- Tuple of (valid_features, missing_features). Examples -------- >>> valid, missing = TrialDataValidator.validate_features( ... adata, ["Gene1", "Gene2", "NonExistent"] ... ) >>> print(f"Valid: {valid}, Missing: {missing}") """ valid = [] missing = [] for feat in features: if feat in adata.obs.columns or feat in adata.var_names: valid.append(feat) else: missing.append(feat) if missing and not allow_missing: raise KeyError( f"Features not found: {missing[:10]}\n" f"Available obs columns: {list(adata.obs.columns)[:10]}\n" f"Available var names: {list(adata.var_names)[:10]}\n" f"Hint: Check spelling and case sensitivity" ) return valid, missing
[docs] def validate_adata( adata: AnnData, design: TrialDesign, strict: bool = False, ) -> list[str]: """Validate AnnData object for trial analysis. Convenience wrapper around TrialDataValidator.validate_adata(). Parameters ---------- adata AnnData object to validate. design TrialDesign specifying column names. strict If True, raises exceptions. If False, returns warnings. Returns ------- List of warning/error messages. """ return TrialDataValidator.validate_adata(adata, design, strict=strict)
[docs] def validate_features( adata: AnnData, features: Sequence[str], allow_missing: bool = False, ) -> tuple[list[str], list[str]]: """Validate feature names. Convenience wrapper around TrialDataValidator.validate_features(). Parameters ---------- adata AnnData object. features List of feature names to validate. allow_missing If False, raises error for missing features. Returns ------- Tuple of (valid_features, missing_features). """ return TrialDataValidator.validate_features(adata, features, allow_missing)
[docs] def diagnose_trial_data( adata: AnnData, design: TrialDesign, verbose: bool = True, ) -> dict[str, Any]: """Comprehensive diagnostic report for trial data. Parameters ---------- adata AnnData object to diagnose. design TrialDesign object. verbose If True, prints diagnostic report. Returns ------- dict Diagnostic summary with keys including: - n_cells, n_genes, n_participants, n_visits, n_arms - paired_participants (dict of visit pairs -> counts) - cells_per_participant (pd.Series) - warnings (list of strings) - recommendations (list of strings) Examples -------- >>> diagnostics = diagnose_trial_data(adata, design, verbose=True) >>> if diagnostics['warnings']: ... print("Warnings found:") ... for w in diagnostics['warnings']: ... print(f" - {w}") """ report: dict[str, Any] = {} warnings_list: list[str] = [] recommendations: list[str] = [] # Basic counts report["n_cells"] = adata.n_obs report["n_genes"] = adata.n_vars if design.participant_col in adata.obs.columns: report["n_participants"] = adata.obs[design.participant_col].nunique() if report["n_participants"] < 4: warnings_list.append( f"Only {report['n_participants']} participants (minimum 4 required)" ) recommendations.append("Consider pooling data from multiple cohorts") elif report["n_participants"] < 10: recommendations.append( "Sample size is small; use bootstrap inference (use_bootstrap=True)" ) if design.visit_col in adata.obs.columns: visits = sorted(adata.obs[design.visit_col].unique()) report["n_visits"] = len(visits) report["visits"] = visits # Check paired participants for all visit pairs if len(visits) >= 2: paired_counts = {} for i, v1 in enumerate(visits): for v2 in visits[i + 1 :]: n_paired = count_paired( adata.obs, design.visit_col, [v1, v2], design.participant_col ) paired_counts[(v1, v2)] = n_paired if n_paired < 4: warnings_list.append( f"Only {n_paired} paired participants for {v1} vs {v2}" ) report["paired_participants"] = paired_counts if design.arm_col in adata.obs.columns: arms = sorted(adata.obs[design.arm_col].unique()) report["n_arms"] = len(arms) report["arms"] = arms # Check cell counts per participant if design.participant_col in adata.obs.columns: if design.visit_col in adata.obs.columns: cells_per_pv = adata.obs.groupby( [design.participant_col, design.visit_col], observed=True ).size() report["cells_per_participant_visit_mean"] = cells_per_pv.mean() report["cells_per_participant_visit_median"] = cells_per_pv.median() report["cells_per_participant_visit_min"] = cells_per_pv.min() if cells_per_pv.min() < 10: warnings_list.append( f"Some participant-visits have < 10 cells (min: {cells_per_pv.min()})" ) recommendations.append("Consider QC filtering to remove low-quality samples") else: cells_per_p = adata.obs.groupby(design.participant_col, observed=True).size() report["cells_per_participant_mean"] = cells_per_p.mean() # Check cell type distribution if available if design.celltype_col and design.celltype_col in adata.obs.columns: celltype_counts = adata.obs[design.celltype_col].value_counts() report["n_celltypes"] = len(celltype_counts) report["celltype_distribution"] = celltype_counts.to_dict() report["warnings"] = warnings_list report["recommendations"] = recommendations if verbose: _print_diagnostic_report(report) return report
[docs] def check_covariate_balance( adata: AnnData, design: TrialDesign, covariates: Sequence[str], *, visit: str | None = None, dropna: bool = True, smd_threshold: float = 0.1, ) -> pd.DataFrame: """Compute standardized mean differences (SMD) for baseline covariates. This compares treated vs control arms at a single visit (usually baseline) and reports SMD values that quantify imbalance. Parameters ---------- adata AnnData object with trial metadata in ``adata.obs``. design TrialDesign describing participant, visit, and arm columns. covariates List of covariate column names in ``adata.obs``. visit Visit label to use for balance checks. If None, uses ``design.baseline_visit``; otherwise raises if not available. dropna If True, drop rows with missing covariate values for each covariate. smd_threshold Absolute SMD threshold for "balanced" flag (default 0.1). Returns ------- pd.DataFrame Table with SMD values. Numeric covariates produce one row per covariate. Categorical covariates produce one row per level with proportions. """ if design.visit_col not in adata.obs.columns: raise KeyError(f"visit_col '{design.visit_col}' not in adata.obs") if design.arm_col is None: raise ValueError( "check_covariate_balance() requires a two-arm design (arm_col must not be None)." ) if design.arm_col not in adata.obs.columns: raise KeyError(f"arm_col '{design.arm_col}' not in adata.obs") if visit is None: if design.baseline_visit is None: raise ValueError("visit must be provided or design.baseline_visit must be set.") visit = design.baseline_visit obs = adata.obs.copy() obs = obs[obs[design.visit_col] == visit].copy() if obs.empty: raise ValueError(f"No observations found for visit '{visit}'.") # collapse to participant-level to avoid pseudoreplication group_cols = [design.participant_col, design.arm_col] base = obs[group_cols + list(covariates)].copy() rows: list[dict[str, object]] = [] for cov in covariates: if cov not in base.columns: raise KeyError(f"Covariate '{cov}' not found in adata.obs") sub = base[group_cols + [cov]].copy() if dropna: sub = sub.dropna(subset=[cov]) if sub.empty: continue # reduce to participant-level sub = sub.groupby(group_cols, observed=True)[cov].first().reset_index() treated = sub[sub[design.arm_col] == design.arm_treated][cov] control = sub[sub[design.arm_col] == design.arm_control][cov] if treated.empty or control.empty: continue if pd.api.types.is_numeric_dtype(sub[cov]): mean_t = float(treated.mean()) mean_c = float(control.mean()) sd_t = float(treated.std(ddof=1)) sd_c = float(control.std(ddof=1)) pooled = np.sqrt((sd_t**2 + sd_c**2) / 2) if (sd_t > 0 and sd_c > 0) else np.nan smd = (mean_t - mean_c) / pooled if np.isfinite(pooled) and pooled > 0 else np.nan rows.append( { "covariate": cov, "level": None, "mean_treated": mean_t, "mean_control": mean_c, "smd": float(smd) if np.isfinite(smd) else np.nan, "n_treated": int(treated.shape[0]), "n_control": int(control.shape[0]), "balanced": bool(np.isfinite(smd) and abs(smd) < smd_threshold), } ) else: # categorical: compute SMD for each level via proportions levels = pd.Series(sub[cov].astype(str)).unique() for lvl in levels: p_t = float((treated.astype(str) == lvl).mean()) p_c = float((control.astype(str) == lvl).mean()) p_pool = (p_t + p_c) / 2 denom = np.sqrt(p_pool * (1 - p_pool)) if 0 < p_pool < 1 else np.nan smd = (p_t - p_c) / denom if denom and np.isfinite(denom) else np.nan rows.append( { "covariate": cov, "level": str(lvl), "mean_treated": p_t, "mean_control": p_c, "smd": float(smd) if np.isfinite(smd) else np.nan, "n_treated": int(treated.shape[0]), "n_control": int(control.shape[0]), "balanced": bool(np.isfinite(smd) and abs(smd) < smd_threshold), } ) return pd.DataFrame(rows)
def _print_diagnostic_report(report: dict[str, Any]) -> None: """Print formatted diagnostic report.""" logger.info("=" * 60) logger.info("TRIAL DATA DIAGNOSTIC REPORT") logger.info("=" * 60) logger.info("DATA SUMMARY") logger.info(f" Cells: {report.get('n_cells', 'N/A'):,}") logger.info(f" Genes: {report.get('n_genes', 'N/A'):,}") logger.info(f" Participants: {report.get('n_participants', 'N/A')}") logger.info(f" Visits: {report.get('n_visits', 'N/A')}") logger.info(f" Arms: {report.get('n_arms', 'N/A')}") if "visits" in report: logger.info(f" Visit labels: {', '.join(map(str, report['visits']))}") if "arms" in report: logger.info(f" Arm labels: {', '.join(map(str, report['arms']))}") if "paired_participants" in report: logger.info("PAIRED PARTICIPANTS") for (v1, v2), count in report["paired_participants"].items(): status = "OK" if count >= 4 else "LOW" logger.info(f" [{status}] {v1} <-> {v2}: {count} paired") if "cells_per_participant_visit_mean" in report: logger.info("CELLS PER PARTICIPANT-VISIT") logger.info(f" Mean: {report['cells_per_participant_visit_mean']:.1f}") logger.info(f" Median: {report['cells_per_participant_visit_median']:.1f}") logger.info(f" Min: {report['cells_per_participant_visit_min']}") if "celltype_distribution" in report: logger.debug("CELL TYPE DISTRIBUTION") for ct, count in list(report["celltype_distribution"].items())[:10]: logger.debug(f" {ct}: {count:,}") if report.get("warnings"): logger.warning(f"WARNINGS ({len(report['warnings'])})") for w in report["warnings"]: logger.warning(f" - {w}") if report.get("recommendations"): logger.info("RECOMMENDATIONS") for r in report["recommendations"]: logger.info(f" - {r}") logger.info("=" * 60)