"""Built-in dataset loaders for clinical trial scRNA-seq cohorts."""
from __future__ import annotations
import gc
import gzip
import logging
import re
import tarfile
import urllib.error
import urllib.request
import warnings
from collections.abc import Sequence
from io import StringIO
from pathlib import Path
import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.io import mmread
from statsmodels.stats.multitest import multipletests
from .utils import get_counts_matrix
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Dataset root: resolved relative to the *repository* (two levels up from
# src/sctrial/), so loaders work regardless of the caller's cwd.
# When installed as a proper package (no repo checkout), falls back to cwd.
# ---------------------------------------------------------------------------
_PACKAGE_DIR = Path(__file__).resolve().parent # src/sctrial/
_REPO_ROOT = _PACKAGE_DIR.parent.parent # sc_trial_inference/
_DATASETS_ROOT = _REPO_ROOT / "datasets"
if not _DATASETS_ROOT.is_dir():
# Installed package without repo structure — fall back to cwd
_DATASETS_ROOT = Path.cwd() / "datasets"
def _default_data_dir(name: str) -> str:
"""Return the absolute default data_dir for a given dataset name."""
return str(_DATASETS_ROOT / name)
# ---------------------------------------------------------------------------
# Marker-based cell-type annotation for immune cells
# Approach adapted from Diab_duod project: Leiden clustering -> Wilcoxon
# marker finding -> weighted scoring against canonical marker gene sets
# ---------------------------------------------------------------------------
# Immune cell markers from Sade-Feldman et al. (Cell 2018, Fig 1, Table S1).
# The paper identifies 11 clusters (G1-G11) among CD45+ sorted cells:
# G1: B cells, G2: Plasma cells, G3: Monocytes/Macrophages,
# G4: Dendritic cells, G5-G11: T/NK/NKT subtypes.
# CD8 T cells are further split into memory-like (CD8_G) and
# exhausted (CD8_B) states.
_IMMUNE_MARKERS: dict[str, set[str]] = {
"CD8 T cell": {
"CD8A",
"CD8B",
"GZMK",
"GZMB",
"GZMA",
"PRF1",
"NKG7",
"CD3D",
"CD3E",
"IFNG",
"EOMES",
"TBX21",
},
"CD4 T cell": {
"CD4",
"IL7R",
"CCR7",
"LEF1",
"TCF7",
"ICOS",
"CD3D",
"CD3E",
"CD40LG",
},
"Treg": {
"FOXP3",
"IL2RA",
"CTLA4",
"IKZF2",
"TNFRSF18",
"CD4",
"CD3D",
"CD3E",
},
"B cell": {
# G1 markers from publication: IGKC, LTB, LY9, SELL, TCF7, CCR7
"MS4A1",
"CD79A",
"CD79B",
"BANK1",
"CD74",
"CD19",
"PAX5",
"BLK",
"IGKC",
"LTB",
"LY9",
},
"Plasma cell": {
# G2 cluster
"MZB1",
"SDC1",
"XBP1",
"JCHAIN",
"PRDM1",
"IGKC",
"IGHG1",
},
"NK cell": {
"KLRD1",
"KLRF1",
"KLRB1",
"GNLY",
"PRF1",
"NKG7",
"GZMB",
"NCAM1",
"FCGR3A",
},
"Monocyte/Macrophage": {
# G3 cluster
"CD14",
"CD68",
"LYZ",
"CST3",
"S100A8",
"S100A9",
"C1QA",
"C1QB",
"C1QC",
"MRC1",
"CSF1R",
"FCGR3A",
},
"Dendritic cell": {
# G4 cluster
"FCER1A",
"CLEC10A",
"CD1C",
"ITGAX",
"HLA-DRA",
"HLA-DQA1",
},
}
# Annotation parameters (following Diab_duod conventions)
_ANNOT_TOP_N = 50 # top markers per cluster from Wilcoxon
_ANNOT_MIN_LFC = 0.25 # minimum log fold-change
_ANNOT_MAX_FDR = 0.1 # maximum adjusted p-value
_ANNOT_SECOND_DELTA = 0.25 # delta to flag ambiguous clusters
_ANNOT_MIN_ACCEPT = 0.3 # minimum weighted score to accept a label
def _weighted_marker_score(
marker_df: pd.DataFrame,
gene_set: set[str],
) -> tuple[float, list[str]]:
"""Compute weighted score of a gene set against ranked cluster markers.
Weight per gene = (1 / rank) * log1p(exp(clipped_logFC)).
Mirrors the ``weighted_score`` function from the Diab_duod project.
"""
hits = marker_df[marker_df["names"].isin(gene_set)]
if hits.empty:
return 0.0, []
lfc = hits["logfoldchanges"].clip(lower=0).values.astype(float)
ranks = hits["rank"].values.astype(float)
weights = (1.0 / ranks) * np.log1p(np.exp(lfc))
top_genes = hits["names"].values[np.argsort(-weights)].tolist()
return float(weights.sum()), top_genes
def _annotate_immune_celltypes(adata: ad.AnnData) -> pd.Series:
"""Assign cell types to immune cells via cluster-level marker scoring.
Pipeline (adapted from Diab_duod project):
1. Use pre-computed Leiden clusters (from caller), or compute them here.
2. Find differentially expressed markers per cluster (Wilcoxon).
3. Score each cluster against canonical immune marker sets using
a rank-weighted scoring function.
4. Assign the best-scoring cell type to each cluster.
Parameters
----------
adata : AnnData
Must contain expression values (TPM) in ``adata.X`` and gene names
in ``adata.var_names``. If ``adata.obs["leiden"]`` already exists
(with PCA/neighbors pre-computed), those clusters are reused so that
annotation and UMAP share the same embedding.
Returns
-------
pd.Series
Cell-type labels indexed like ``adata.obs``.
"""
import scanpy as sc # local import to avoid top-level dependency
# Work on a copy so we don't modify the caller's object
aw = adata.copy()
# Normalise for clustering if raw TPM
if aw.X.max() > 50:
aw.X = np.log1p(aw.X)
if "leiden" in adata.obs.columns:
# Reuse pre-computed Leiden clusters (same embedding used for UMAP)
aw.obs["leiden"] = adata.obs["leiden"].values
logger.info(" Using pre-computed Leiden clusters for annotation...")
else:
# Fallback: compute PCA -> neighbors -> Leiden internally
logger.info(" Computing PCA for cell-type annotation...")
sc.pp.highly_variable_genes(aw, n_top_genes=2000, flavor="seurat")
sc.tl.pca(aw, n_comps=30)
sc.pp.neighbors(aw, n_neighbors=15, n_pcs=20)
sc.tl.leiden(aw, resolution=1.0)
# Wilcoxon marker finding per cluster
logger.info(" Finding cluster markers (Wilcoxon)...")
sc.tl.rank_genes_groups(aw, groupby="leiden", method="wilcoxon", n_genes=_ANNOT_TOP_N)
clusters = sorted(aw.obs["leiden"].unique(), key=int)
cluster_labels: dict[str, str] = {}
for cl in clusters:
# Extract marker table for this cluster
result = aw.uns["rank_genes_groups"]
idx = list(result["names"].dtype.names).index(cl)
names = [result["names"][i][idx] for i in range(len(result["names"]))]
lfcs = [result["logfoldchanges"][i][idx] for i in range(len(result["logfoldchanges"]))]
padjs = [result["pvals_adj"][i][idx] for i in range(len(result["pvals_adj"]))]
df_markers = pd.DataFrame(
{
"names": names,
"logfoldchanges": lfcs,
"pvals_adj": padjs,
"rank": np.arange(1, len(names) + 1),
}
)
# Filter by LFC and FDR
df_markers = df_markers[
(df_markers["logfoldchanges"] >= _ANNOT_MIN_LFC)
& (df_markers["pvals_adj"] <= _ANNOT_MAX_FDR)
]
# If strict filtering yields no markers, use unfiltered markers
if df_markers.empty:
df_markers = pd.DataFrame(
{
"names": names,
"logfoldchanges": lfcs,
"pvals_adj": padjs,
"rank": np.arange(1, len(names) + 1),
}
)
# Score against each cell type
label_scores: dict[str, float] = {}
for ct, gene_set in _IMMUNE_MARKERS.items():
score, _ = _weighted_marker_score(df_markers, gene_set)
if score > 0:
label_scores[ct] = score
if not label_scores:
# Fallback: score with unfiltered markers (no LFC/FDR filter)
df_unfiltered = pd.DataFrame(
{
"names": names,
"logfoldchanges": [max(0, v) for v in lfcs],
"pvals_adj": padjs,
"rank": np.arange(1, len(names) + 1),
}
)
for ct, gene_set in _IMMUNE_MARKERS.items():
score, _ = _weighted_marker_score(df_unfiltered, gene_set)
if score > 0:
label_scores[ct] = score
if not label_scores:
cluster_labels[cl] = "Unassigned"
logger.warning(
f" Cluster {cl}: no marker overlap with any canonical "
f"gene set — labelled 'Unassigned'"
)
continue
sorted_labels = sorted(label_scores.items(), key=lambda x: -x[1])
best_label, best_score = sorted_labels[0]
# Always assign the best-scoring type (no "Unknown immune")
cluster_labels[cl] = best_label
# Log ambiguous clusters
if len(sorted_labels) > 1:
second_label, second_score = sorted_labels[1]
if best_score - second_score < _ANNOT_SECOND_DELTA:
logger.debug(
f" Cluster {cl}: {best_label} ({best_score:.2f}) "
f"vs {second_label} ({second_score:.2f}) [ambiguous]"
)
# Map cluster -> cell type onto every cell
labels = aw.obs["leiden"].map(cluster_labels)
labels.index = adata.obs.index
labels.name = "cell_type"
# Log summary
vc = labels.value_counts()
for ct, n in vc.items():
logger.info(f" {ct}: {n:,} cells")
del aw
import gc
gc.collect()
return labels
__all__ = [
"load_sade_feldman",
"load_stephenson_data",
"load_vaccine_gse171964",
"load_aml",
"load_cart",
"harmonize_response",
"count_paired",
"verify_paired_participants",
"categorize_celltype",
"ensure_fdr",
]
def _resolve_dir_with_files(p: str, required_files: Sequence[str]) -> Path:
"""Resolve a directory path with required files."""
path = Path(p)
if path.is_absolute():
if all((path / f).exists() for f in required_files):
return path
for base in [Path.cwd(), *Path.cwd().parents]:
cand = base / path
if all((cand / f).exists() for f in required_files):
return cand
return path
def _resolve_file(p: str) -> Path:
path = Path(p)
if path.is_absolute() and path.exists():
return path
for base in [Path.cwd(), *Path.cwd().parents]:
cand = base / path
if cand.exists():
return cand
return path
def _params_match(prev: dict, current: dict) -> bool:
"""Robustly compare processing parameters, handling None/NaN/list differences.
Returns True if every key in *current* has a matching value in *prev*.
Extra keys in *prev* (e.g. metadata added later) are tolerated.
"""
if not set(current.keys()).issubset(set(prev.keys())):
return False
for key in current:
v1, v2 = prev.get(key), current.get(key)
# Handle None comparisons (h5ad may store None differently)
if v1 is None or (isinstance(v1, float) and np.isnan(v1)):
v1 = None
if v2 is None or (isinstance(v2, float) and np.isnan(v2)):
v2 = None
# Handle string "None"
if isinstance(v1, str) and v1 in ("None", "null"):
v1 = None
if isinstance(v2, str) and v2 in ("None", "null"):
v2 = None
# Handle list/array comparisons (h5ad may convert lists to numpy arrays)
if isinstance(v1, (list, tuple, np.ndarray)) and isinstance(v2, (list, tuple, np.ndarray)):
if list(v1) != list(v2):
return False
continue
if v1 != v2:
return False
return True
def _looks_log1p(X, sample: int = 10000, seed: int = 0) -> bool:
"""Check if a matrix looks like log-transformed counts."""
if X is None:
return False
if sp.issparse(X):
data = X.data
else:
data = np.asarray(X).ravel()
if data.size == 0:
return False
data = data[np.isfinite(data)]
if data.size == 0:
return False
rng = np.random.default_rng(seed)
if data.size > sample:
data = rng.choice(data, size=sample, replace=False)
return (
(data.min() >= 0)
and (data.max() < 50)
and (not np.allclose(data, np.round(data), atol=1e-3))
)
def _download_file(url: str, dest: Path, label: str = "file") -> None:
"""Download a single file with error handling and partial-file cleanup.
Parameters
----------
url : str
URL to download from.
dest : Path
Local destination path.
label : str
Human-readable label for log messages (e.g. "TPM file").
"""
logger.info(f"Downloading {label} from {url}...")
try:
urllib.request.urlretrieve(url, str(dest))
except (urllib.error.URLError, urllib.error.HTTPError, OSError) as e:
if dest.exists():
dest.unlink()
raise RuntimeError(
f"Failed to download {label} from {url}: {e}. "
f"Please download manually and place it in {dest.parent}"
) from e
logger.info(f"Successfully downloaded {label}: {dest}")
def _get_counts_matrix(adata: ad.AnnData) -> tuple[np.ndarray | None, str | None]:
"""Get the counts matrix from the AnnData object."""
return get_counts_matrix(adata)
[docs]
def load_sade_feldman(
data_dir: str | None = None,
processed_name: str = "sade_feldman_processed_v6.h5ad",
max_cells_per_participant_visit: int | None = None,
seed: int = 42,
allow_download: bool = False,
force_reprocess: bool = False,
) -> ad.AnnData:
"""Load and preprocess Sade-Feldman melanoma immunotherapy dataset (GSE120575).
Parameters
----------
data_dir : str
Directory containing (or to store) the raw data files.
processed_name : str
Filename for the cached processed h5ad file.
max_cells_per_participant_visit : int or None
Maximum number of cells to retain per participant-visit pair.
seed : int
Random seed for reproducibility.
allow_download : bool
If True, download missing files from GEO automatically.
force_reprocess : bool
If True, reprocess even when a cached file exists.
Returns
-------
AnnData
The processed AnnData object.
"""
processing_params = {
"version": "v6",
"max_cells_per_participant_visit": max_cells_per_participant_visit,
"seed": seed,
"assay": "TPM",
}
data_dir = data_dir or _default_data_dir("sade_feldman")
data_dir_path = Path(data_dir)
processed_path = data_dir_path / "processed" / processed_name
if not force_reprocess and processed_path.exists():
adata = ad.read_h5ad(processed_path)
prev = adata.uns.get("processing_params", {})
if prev:
if _params_match(prev, processing_params):
logger.info(
f"Loaded processed Sade-Feldman dataset: {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
logger.info("Processed file parameters differ; reprocessing.")
logger.debug(f" Stored: {prev}")
logger.debug(f" Current: {processing_params}")
else:
warnings.warn(
"Cached file lacks processing_params metadata; cannot verify it matches "
"current settings. Consider reprocessing with force_reprocess=True.",
UserWarning,
stacklevel=2,
)
logger.info(
f"Loaded processed Sade-Feldman dataset: {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
raw_dir = data_dir_path / "raw"
raw_dir_resolved = _resolve_dir_with_files(
str(raw_dir),
[
"GSE120575_Sade_Feldman_melanoma_single_cells_TPM_GEO.txt.gz",
"GSE120575_patient_ID_single_cells.txt.gz",
],
)
# Check scanpy availability BEFORE downloading to avoid wasted bandwidth.
# scanpy is required for cell-type annotation (Leiden + Wilcoxon scoring).
try:
import scanpy as sc # noqa: F401
except ImportError:
raise ImportError(
"scanpy is required for Sade-Feldman cell type annotation. "
"Install with: pip install sctrial[plots] or pip install scanpy"
) from None
tpm_path = raw_dir_resolved / "GSE120575_Sade_Feldman_melanoma_single_cells_TPM_GEO.txt.gz"
meta_path = raw_dir_resolved / "GSE120575_patient_ID_single_cells.txt.gz"
_GEO_BASE = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE120575&format=file&file="
_sade_feldman_files = [
(
tpm_path,
_GEO_BASE
+ "GSE120575%5FSade%5FFeldman%5Fmelanoma%5Fsingle%5Fcells%5FTPM%5FGEO%2Etxt%2Egz",
"TPM file",
),
(
meta_path,
_GEO_BASE + "GSE120575%5Fpatient%5FID%5Fsingle%5Fcells%2Etxt%2Egz",
"metadata file",
),
]
missing = [(p, url, label) for p, url, label in _sade_feldman_files if not p.exists()]
if missing:
if not allow_download:
names = ", ".join(str(p) for p, _, _ in missing)
raise FileNotFoundError(
f"Missing file(s): {names}. Download from GEO: "
"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE120575"
)
raw_dir_resolved.mkdir(parents=True, exist_ok=True)
for dest, url, label in missing:
_download_file(url, dest, label)
logger.info("Processing raw data (this may take a minute)...")
with gzip.open(tpm_path, "rt") as f:
header1 = f.readline().strip().split("\t")
header2 = f.readline().strip().split("\t")
if len(header1) != len(header2):
raise ValueError("TPM file header rows have inconsistent lengths.")
sample_ids = header1
time_labels = header2
data = f.read()
df = pd.read_csv(StringIO(data), sep="\t", header=None)
if df.iloc[:, -1].isna().all():
df = df.iloc[:, :-1]
genes = df.iloc[:, 0].astype(str).values
mat = df.iloc[:, 1:]
if mat.shape[1] != len(sample_ids):
raise ValueError(f"TPM matrix columns ({mat.shape[1]}) != sample IDs ({len(sample_ids)}).")
mat.columns = sample_ids
meta = pd.read_csv(meta_path, sep="\t", skiprows=19, encoding="latin1")
meta = meta.rename(
columns={
"title": "sample_id",
"characteristics: patinet ID (Pre=baseline; Post= on treatment)": "patient_raw",
"characteristics: response": "response",
}
)
meta["sample_id"] = meta["sample_id"].astype(str)
meta = meta.dropna(subset=["sample_id"]).copy()
meta = meta[meta["sample_id"].str.match(r"^[A-Z]\d+_P\d+_M\d+")].copy()
meta = meta[meta["response"].isin(["Responder", "Non-responder"])].copy()
meta["visit"] = meta["patient_raw"].astype(str).str.split("_").str[0]
meta["participant_id"] = meta["patient_raw"].astype(str).str.extract(r"(P\d+)")[0]
time_map = dict(zip(sample_ids, time_labels))
meta["time_label"] = meta["sample_id"].map(time_map)
meta = meta.set_index("sample_id")
meta = meta.loc[[s for s in sample_ids if s in meta.index]].copy()
adata = ad.AnnData(X=mat.T.loc[meta.index].values.astype(np.float32))
adata.obs = meta.copy()
adata.var_names = genes
if max_cells_per_participant_visit is not None:
rng = np.random.default_rng(seed)
keep_indices: list = []
for (pid, visit), group in adata.obs.groupby(["participant_id", "visit"], observed=True):
n_cells = len(group)
if n_cells > max_cells_per_participant_visit:
keep = rng.choice(group.index, size=max_cells_per_participant_visit, replace=False)
else:
keep = group.index.values
keep_indices.extend(keep)
adata = adata[keep_indices].copy()
logger.info(
f"Stratified sampling: {adata.n_obs:,} cells (max {max_cells_per_participant_visit} per participant-visit)"
)
else:
logger.info(f"Using full dataset: {adata.n_obs:,} cells (no subsampling)")
adata.layers["tpm"] = adata.X.copy()
adata.layers["log1p_tpm"] = adata.X.copy() if _looks_log1p(adata.X) else np.log1p(adata.X)
# ── PCA → neighbors → UMAP → Leiden ────────────────────────────────
# Compute BEFORE annotation so that cell-type labels are assigned to
# the SAME Leiden clusters that the UMAP is built from.
logger.info("Computing PCA / neighbors / UMAP / Leiden...")
adata_work = adata.copy()
adata_work.X = adata_work.layers["log1p_tpm"]
sc.pp.highly_variable_genes(adata_work, n_top_genes=3000, flavor="seurat")
adata_hvg = adata_work[:, adata_work.var["highly_variable"]].copy()
sc.pp.scale(adata_hvg, max_value=10)
sc.tl.pca(adata_hvg, n_comps=50)
adata.obsm["X_pca"] = adata_hvg.obsm["X_pca"]
del adata_work, adata_hvg
sc.pp.neighbors(adata, use_rep="X_pca", n_neighbors=15)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=1.0)
# ── Cell type annotation ────────────────────────────────────────────
# Uses the Leiden clusters computed above (same embedding as UMAP).
logger.info("Annotating cell types from marker genes...")
adata.obs["cell_type"] = _annotate_immune_celltypes(adata)
adata.uns["processing_params"] = processing_params
adata.uns["data_source"] = "GSE120575"
adata.uns["paper"] = "Sade-Feldman et al., Cell 2018"
processed_path.parent.mkdir(parents=True, exist_ok=True)
adata.write_h5ad(processed_path)
logger.info(f"Saved processed file: {processed_path}")
logger.info(f"Loaded Sade-Feldman dataset: {adata.n_obs:,} cells, {adata.n_vars:,} genes")
return adata
[docs]
def load_stephenson_data(
data_dir: str | None = None,
processed_name: str = "stephenson_covid19_v3.h5ad",
seed: int = 42,
allow_download: bool = False,
force_reprocess: bool = False,
*,
data_path: str | None = None,
) -> ad.AnnData:
"""Load and preprocess Stephenson COVID-19 dataset (E-MTAB-10026).
Parameters
----------
data_dir : str
Directory containing (or to store) the raw data files.
processed_name : str
Filename for the cached processed h5ad file.
seed : int
Random seed for reproducibility.
allow_download : bool
If True, download the data file automatically when missing.
force_reprocess : bool
If True, reprocess even when a cached file exists.
data_path : str or None
.. deprecated:: 0.2.2
Use *data_dir* instead. When supplied, *data_dir* is ignored
and the parent directory of *data_path* is used.
Returns
-------
AnnData
The processed AnnData object.
"""
data_dir = data_dir or _default_data_dir("stephenson")
# Backward compat: if someone passes an .h5ad file path positionally
# as data_dir (old API had data_path as first param), treat it as data_path.
if data_path is None and str(data_dir).endswith(".h5ad"):
data_path = data_dir
data_dir = _default_data_dir("stephenson") # reset to default
if data_path is not None:
warnings.warn(
"load_stephenson_data(data_path=...) is deprecated. Use data_dir=... instead.",
FutureWarning,
stacklevel=2,
)
raw_file = _resolve_file(data_path)
data_dir_path = (
raw_file.parent.parent if raw_file.exists() else Path(data_path).parent.parent
)
else:
data_dir_path = Path(data_dir)
raw_file = data_dir_path / "raw" / "covid_portal_210320_with_raw.h5ad"
processed_path = data_dir_path / "processed" / processed_name
if not force_reprocess and processed_path.exists():
adata = ad.read_h5ad(processed_path)
prev = adata.uns.get("processing_params", {})
if not prev:
warnings.warn(
"Cached file lacks processing_params metadata; cannot verify it matches "
"current settings. Consider reprocessing with force_reprocess=True.",
UserWarning,
stacklevel=2,
)
logger.info(f"Loaded cached file: {processed_path}")
logger.info(f" {adata.n_obs:,} cells, {adata.n_vars:,} genes")
return adata
if not raw_file.exists():
# Also check old location (data_dir directly, not raw subdir)
legacy_raw = data_dir_path / "covid_portal_210320_with_raw.h5ad"
if legacy_raw.exists():
raw_file = legacy_raw
elif not allow_download:
raise FileNotFoundError(
f"Data not found at {raw_file}. Download from: "
"https://www.ebi.ac.uk/biostudies/files/E-MTAB-10026/"
)
else:
raw_file.parent.mkdir(parents=True, exist_ok=True)
url = (
"https://www.ebi.ac.uk/biostudies/files/E-MTAB-10026/"
"covid_portal_210320_with_raw.h5ad"
)
_download_file(url, raw_file, "Stephenson COVID-19 data")
logger.info("Processing raw data...")
adata = ad.read_h5ad(raw_file)
X_counts, source = _get_counts_matrix(adata)
if X_counts is None:
raise ValueError("No raw counts found in dataset.")
adata.layers["counts"] = X_counts
logger.info(f" Counts source: {source}")
obs = adata.obs.copy()
obs["severity"] = obs["Status_on_day_collection_summary"].astype(str)
obs = obs[obs["severity"].isin(["Mild", "Severe"])].copy()
logger.info(f" After severity filter: {len(obs):,} cells")
obs["dfo"] = pd.to_numeric(obs["Days_from_onset"], errors="coerce")
obs["dfo_bin"] = pd.cut(
obs["dfo"],
bins=[-np.inf, 7, 14, np.inf],
labels=["DFO_0-7", "DFO_8-14", "DFO_15+"],
).astype(str)
valid_dfo = obs["dfo_bin"].isin(["DFO_0-7", "DFO_8-14", "DFO_15+"])
obs = obs[valid_dfo].copy()
logger.info(f" After DFO filter: {len(obs):,} cells")
if "Collection_Day" in obs.columns:
obs["collection_day"] = obs["Collection_Day"].astype(str)
obs["participant_id"] = obs["patient_id"].astype(str)
obs["celltype"] = obs["full_clustering"].astype(str)
adata = adata[obs.index].copy()
adata.obs = obs
processed_path.parent.mkdir(parents=True, exist_ok=True)
adata.write_h5ad(processed_path)
logger.info(f" Saved: {processed_path}")
logger.info(f" Final: {adata.n_obs:,} cells, {adata.n_vars:,} genes")
return adata
[docs]
def load_vaccine_gse171964(
data_dir: str | None = None,
processed_name: str = "vaccine_gse171964.h5ad",
max_participants: int | None = None,
max_cells_per_group: int | None = None,
seed: int = 42,
allow_download: bool = False,
force_reprocess: bool = False,
) -> ad.AnnData:
"""Load and preprocess GSE171964 PBMC vaccine time course data (Day 0 vs Day 7).
Parameters
----------
data_dir : str
Directory containing (or to store) the raw data files.
processed_name : str
Filename for the cached processed h5ad file.
max_participants : int or None
Maximum number of participants to retain.
max_cells_per_group : int or None
Maximum number of cells per participant-day-celltype group.
seed : int
Random seed for reproducibility.
allow_download : bool
If True, download missing files from GEO automatically.
force_reprocess : bool
If True, reprocess even when a cached file exists.
Returns
-------
AnnData
The processed AnnData object.
"""
processing_params = {
"version": "v2",
"max_participants": max_participants,
"max_cells_per_group": max_cells_per_group,
"seed": seed,
"days": [0, 7],
}
data_dir = data_dir or _default_data_dir("vaccine_gse171964")
data_dir_path = Path(data_dir)
processed_path = data_dir_path / "processed" / processed_name
if not force_reprocess and processed_path.exists():
adata = ad.read_h5ad(processed_path)
prev = adata.uns.get("processing_params", {})
if prev:
if _params_match(prev, processing_params):
logger.info(
f"Loaded processed vaccine dataset (GSE171964): {adata.n_obs} cells, {adata.n_vars} genes"
)
return adata
logger.info("Processed file parameters differ; reprocessing.")
logger.debug(f" Stored: {prev}")
logger.debug(f" Current: {processing_params}")
else:
warnings.warn(
"Cached file lacks processing_params metadata; cannot verify it matches "
"current settings. Consider reprocessing with force_reprocess=True.",
UserWarning,
stacklevel=2,
)
logger.info(
f"Loaded processed vaccine dataset (GSE171964): {adata.n_obs} cells, {adata.n_vars} genes"
)
return adata
raw_dir = data_dir_path / "raw"
raw_dir_resolved = _resolve_dir_with_files(
str(raw_dir),
[
"GSE171964_barcodes_v2.tsv.gz",
"GSE171964_feats_v2.tsv.gz",
"GSE171964_geo_pheno_v2.csv.gz",
"GSE171964_countsmatrix_v2.mtx.gz",
],
)
barcodes_path = raw_dir_resolved / "GSE171964_barcodes_v2.tsv.gz"
feats_path = raw_dir_resolved / "GSE171964_feats_v2.tsv.gz"
pheno_path = raw_dir_resolved / "GSE171964_geo_pheno_v2.csv.gz"
mtx_path = raw_dir_resolved / "GSE171964_countsmatrix_v2.mtx.gz"
_GEO_BASE_V = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE171964&format=file&file="
_vaccine_files = [
(barcodes_path, _GEO_BASE_V + "GSE171964%5Fbarcodes%5Fv2%2Etsv%2Egz", "barcodes file"),
(feats_path, _GEO_BASE_V + "GSE171964%5Ffeats%5Fv2%2Etsv%2Egz", "features file"),
(pheno_path, _GEO_BASE_V + "GSE171964%5Fgeo%5Fpheno%5Fv2%2Ecsv%2Egz", "pheno file"),
(mtx_path, _GEO_BASE_V + "GSE171964%5Fcountsmatrix%5Fv2%2Emtx%2Egz", "counts matrix"),
]
missing = [(p, url, label) for p, url, label in _vaccine_files if not p.exists()]
if missing:
if not allow_download:
names = ", ".join(str(p) for p, _, _ in missing)
raise FileNotFoundError(
f"Missing file(s): {names}. Download from GEO: "
"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE171964 "
"or set allow_download=True to fetch automatically."
)
raw_dir_resolved.mkdir(parents=True, exist_ok=True)
for dest, url, label in missing:
_download_file(url, dest, label)
barcodes = (
pd.read_csv(barcodes_path, sep="\\s+", header=None, engine="python", skiprows=1)[1]
.astype(str)
.str.strip('"')
.tolist()
)
features = (
pd.read_csv(feats_path, sep="\\s+", header=None, engine="python", skiprows=1)[1]
.astype(str)
.str.strip('"')
.tolist()
)
with gzip.open(mtx_path, "rb") as f:
X = mmread(f).tocsr()
if X.shape[0] == len(features) and X.shape[1] == len(barcodes):
X = X.T
elif X.shape[0] == len(barcodes) and X.shape[1] == len(features):
pass
else:
raise ValueError("Matrix dimensions do not match barcodes/features.")
adata = ad.AnnData(X=X)
adata.obs_names = barcodes
adata.var_names = features
pheno = pd.read_csv(pheno_path)
pheno["barcode"] = pheno["barcode"].astype(str)
pheno = pheno.set_index("barcode")
pheno = pheno.loc[adata.obs_names]
adata.obs = pheno
adata = adata[adata.obs["day"].isin([0, 7])].copy()
paired = adata.obs.groupby("pt_id")["day"].nunique()
keep_ids = paired[paired >= 2].index
adata = adata[adata.obs["pt_id"].isin(keep_ids)].copy()
rng = np.random.default_rng(seed)
uniq_ids = adata.obs["pt_id"].unique()
# If max_participants is None, use all participants (no subsampling)
if max_participants is not None:
n = min(len(uniq_ids), max_participants)
sel = rng.choice(uniq_ids, size=n, replace=False)
adata = adata[adata.obs["pt_id"].isin(sel)].copy()
# else: use all participants
if max_cells_per_group is not None:
grp = ["pt_id", "day", "clustnm"]
sampled = adata.obs.groupby(grp, observed=True, group_keys=False).apply(
lambda x: x.sample(min(len(x), max_cells_per_group), random_state=seed)
)
adata = adata[sampled.index].copy()
adata.layers["counts"] = adata.X.copy()
adata.uns["processing_params"] = processing_params
processed_path.parent.mkdir(parents=True, exist_ok=True)
adata.write_h5ad(processed_path)
logger.info(f"Saved processed file: {processed_path}")
logger.info(f"Loaded vaccine dataset (GSE171964): {adata.n_obs} cells, {adata.n_vars} genes")
logger.info(f"Days: {adata.obs['day'].unique()}")
logger.info(f"Participants: {adata.obs['pt_id'].nunique()}")
logger.info(f"Cell types: {adata.obs['clustnm'].nunique()}")
return adata
[docs]
def count_paired(
obs: pd.DataFrame,
visit_col: str,
visits: Sequence[str],
participant_col: str = "participant_id",
) -> int:
"""Count participants with data at both visits.
Parameters
----------
obs
DataFrame containing the participant-visit data.
visit_col
Column name in `obs` to use for visit labels.
visits
Sequence of visit labels to check (e.g. ["baseline", "followup"]).
participant_col
Column name in `obs` to use for participant IDs.
Returns
-------
int
Number of participants with data at both visits.
Raises
------
ValueError
If visits does not contain at least 2 labels (baseline and followup).
"""
if len(visits) < 2:
raise ValueError(
f"visits must contain at least 2 labels, got {len(visits)}: {list(visits)}"
)
wide = obs.groupby([participant_col, visit_col], observed=True).size().unstack(fill_value=0)
if visits[0] not in wide.columns or visits[1] not in wide.columns:
return 0
has_both = (wide[visits[0]] > 0) & (wide[visits[1]] > 0)
return int(has_both.sum())
[docs]
def verify_paired_participants(
obs: pd.DataFrame,
visit_col: str,
visits: Sequence[str],
features: Sequence[str] | None = None,
participant_col: str = "participant_id",
) -> dict:
"""Validate paired participants by visit presence and optional feature completeness.
Parameters
----------
obs
DataFrame containing the participant-visit data.
visit_col
Column name in `obs` to use for visit labels.
visits
Sequence of visit labels to check (e.g. ["baseline", "followup"]).
features
Sequence of feature names to check.
participant_col
Column name in `obs` to use for participant IDs.
Returns
-------
dict
A dictionary containing the following keys:
- paired_ids: set of participant IDs with both visits (and non-NaN features if provided)
- dropped_ids: list of participant IDs dropped by validation
- n_paired: count of paired_ids
- n_total: total unique participants
"""
if len(visits) < 2:
raise ValueError(
f"visits must contain at least 2 labels, got {len(visits)}: {list(visits)}"
)
wide = obs.groupby([participant_col, visit_col], observed=True).size().unstack(fill_value=0)
if visits[0] not in wide.columns or visits[1] not in wide.columns:
paired_ids = set()
else:
paired_ids = set(wide[(wide[visits[0]] > 0) & (wide[visits[1]] > 0)].index)
if features:
grouped = obs.groupby([participant_col, visit_col], observed=True)[list(features)]
# Use .first() instead of .mean() so categorical/string features
# don't raise TypeError. For numeric columns the NaN-presence check
# below is still correct because .first() returns NaN for empty groups.
df_pv = grouped.first().reset_index()
valid_ids: set | None = None
for feat in features:
wide_feat = df_pv.pivot(index=participant_col, columns=visit_col, values=feat)
if visits[0] not in wide_feat.columns or visits[1] not in wide_feat.columns:
feat_valid = set()
else:
mask = wide_feat[visits[0]].notna() & wide_feat[visits[1]].notna()
feat_valid = set(wide_feat[mask].index)
valid_ids = feat_valid if valid_ids is None else (valid_ids & feat_valid)
if valid_ids is not None:
paired_ids = paired_ids & valid_ids
all_ids = set(obs[participant_col].unique())
return {
"paired_ids": paired_ids,
"dropped_ids": sorted(all_ids - paired_ids),
"n_paired": len(paired_ids),
"n_total": len(all_ids),
}
[docs]
def categorize_celltype(ct: str) -> str:
"""Map fine-grained cell types to coarse lineages (COVID-19 example).
Parameters
----------
ct
Cell type string.
Returns
-------
str
Coarse lineage string.
"""
ct_lower = str(ct).lower()
if "cd4" in ct_lower or "th1" in ct_lower or "th2" in ct_lower or "treg" in ct_lower:
return "CD4_T"
if "cd8" in ct_lower or "cytotoxic" in ct_lower:
return "CD8_T"
if "nk" in ct_lower or "natural killer" in ct_lower:
return "NK"
# DC check must precede B cell check so "plasmacytoid dendritic cell"
# is not captured by the "plasma" substring in the B cell rule.
if "dc" in ct_lower or "dendritic" in ct_lower:
return "DCs"
if "b cell" in ct_lower or "plasma" in ct_lower or "b_cell" in ct_lower:
return "B_cells"
if "mono" in ct_lower or "cd14" in ct_lower or "cd16" in ct_lower:
return "Monocytes"
return "Other"
def _extract_aml_sample_name(filename: str) -> str | None:
"""Extract sample name from AML filename (ignoring GSM number)."""
m = re.search(r"GSM\d+_(.+)\.(dem|anno)\.txt\.gz", filename)
return m.group(1) if m else None
def _parse_aml_sample_info(sample_name: str) -> tuple[str, str, int]:
"""Parse patient ID and day from AML sample name."""
if "-D" in sample_name:
parts = sample_name.rsplit("-D", 1)
patient = parts[0]
try:
day = int(parts[1])
except ValueError:
day = 0
else:
patient = sample_name
day = 0
return sample_name, patient, day
def _process_aml_raw(
raw_dir: Path,
max_cells_per_sample: int | None = None,
seed: int = 42,
) -> ad.AnnData:
"""Process raw GSE116256 AML files into an AnnData object.
Reads per-sample expression (.dem.txt.gz) and annotation (.anno.txt.gz)
files, combines them, applies QC, normalisation, and computes embeddings.
Cell-type labels come from the original van Galen et al. annotations.
"""
import scanpy as sc
# Build file mapping: match dem ↔ anno by sample name
dem_files: dict[str, Path] = {}
anno_files: dict[str, Path] = {}
for fp in raw_dir.glob("GSM*_*.dem.txt.gz"):
name = _extract_aml_sample_name(fp.name)
if name:
dem_files[name] = fp
for fp in raw_dir.glob("GSM*_*.anno.txt.gz"):
name = _extract_aml_sample_name(fp.name)
if name:
anno_files[name] = fp
matched = sorted(set(dem_files) & set(anno_files))
if not matched:
raise ValueError(
f"No matched sample pairs found in {raw_dir}. "
f"Found {len(dem_files)} expression and {len(anno_files)} annotation files."
)
logger.info(f"Found {len(matched)} matched AML sample pairs")
# Pass 1: collect all gene names
all_genes: set[str] = set()
for sname in matched:
with gzip.open(dem_files[sname], "rt") as f:
df = pd.read_csv(f, sep="\t", usecols=[0])
all_genes.update(df.iloc[:, 0].tolist())
all_genes_sorted = sorted(all_genes)
gene_to_idx = {g: i for i, g in enumerate(all_genes_sorted)}
logger.info(f"Total unique genes across samples: {len(all_genes_sorted)}")
# Pass 2: load expression + annotations
rng = np.random.default_rng(seed)
all_X: list[sp.csr_matrix] = []
all_obs: list[pd.DataFrame] = []
n_genes_total = len(all_genes_sorted)
for sname in matched:
with gzip.open(dem_files[sname], "rt") as f:
expr_df = pd.read_csv(f, sep="\t", index_col=0)
with gzip.open(anno_files[sname], "rt") as f:
anno_df = pd.read_csv(f, sep="\t", index_col=0)
common_cells = sorted(set(expr_df.columns) & set(anno_df.index))
if not common_cells:
logger.warning(f"No common cells for {sname}, skipping")
continue
if max_cells_per_sample and len(common_cells) > max_cells_per_sample:
common_cells = list(rng.choice(common_cells, max_cells_per_sample, replace=False))
expr_df = expr_df[common_cells]
anno_df = anno_df.loc[common_cells]
X_raw = sp.csr_matrix(expr_df.T.values.astype(np.float32))
genes = list(expr_df.index)
# Re-index genes to common set
X_re = sp.lil_matrix((X_raw.shape[0], n_genes_total), dtype=np.float32)
for j, gene in enumerate(genes):
if gene in gene_to_idx:
X_re[:, gene_to_idx[gene]] = X_raw[:, j].toarray()
_, patient, day = _parse_aml_sample_info(sname)
unique_cells = [f"{sname}_{c}" for c in common_cells]
obs_df = anno_df.copy()
obs_df.index = unique_cells
obs_df["sample_id"] = sname
obs_df["patient_id"] = patient
obs_df["day"] = day
all_X.append(sp.csr_matrix(X_re))
all_obs.append(obs_df)
del X_raw, X_re
gc.collect()
if not all_X:
raise ValueError("No samples could be loaded from raw AML data.")
X_combined = sp.vstack(all_X)
obs_combined = pd.concat(all_obs, axis=0)
adata = ad.AnnData(X=X_combined, obs=obs_combined, var=pd.DataFrame(index=all_genes_sorted))
del all_X
gc.collect()
# ── QC and normalisation ──────────────────────────────────────────
adata.layers["counts"] = adata.X.copy()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
n_before = adata.n_obs
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_cells(adata, max_genes=6000)
adata = adata[adata.obs["pct_counts_mt"] < 20].copy()
sc.pp.filter_genes(adata, min_cells=10)
logger.info(f"QC: {n_before:,} → {adata.n_obs:,} cells, {adata.n_vars:,} genes")
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.layers["log1p_norm"] = adata.X.copy()
# ── Standardised sctrial obs columns ──────────────────────────────
obs = adata.obs
obs["participant_id"] = obs["patient_id"].astype(str)
obs["visit"] = obs["day"].apply(lambda d: "Pre" if d == 0 else "Post")
obs["sample_type"] = obs["patient_id"].apply(
lambda pid: "AML" if str(pid).startswith("AML") else "Healthy"
)
if "CellType" in obs.columns:
obs["cell_type"] = obs["CellType"]
elif "PredictionRefined" in obs.columns:
obs["cell_type"] = obs["PredictionRefined"]
else:
obs["cell_type"] = "Unknown"
if "PredictionRefined" in obs.columns:
obs["is_malignant"] = obs["PredictionRefined"].apply(
lambda x: "malignant" in str(x).lower() if pd.notna(x) else False
)
else:
obs["is_malignant"] = False
obs["response"] = (
obs["sample_type"].map({"AML": "Treatment", "Healthy": "Control"}).fillna("Unknown")
)
adata.obs = obs
# ── Embeddings (HVG → PCA → neighbours → UMAP) ───────────────────
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat", subset=False)
sc.tl.pca(adata, n_comps=50, use_highly_variable=True)
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=30)
sc.tl.umap(adata)
# Store paired-patient list
aml_obs = adata.obs[adata.obs["sample_type"] == "AML"]
paired = []
for pid in aml_obs["participant_id"].unique():
days = set(aml_obs.loc[aml_obs["participant_id"] == pid, "day"].unique())
if 0 in days and any(d > 0 for d in days):
paired.append(pid)
adata.uns["paired_aml_patients"] = paired
adata.uns["dataset"] = "GSE116256"
adata.uns["paper"] = "van Galen et al., Cell 2019"
adata.uns["description"] = "AML chemotherapy longitudinal scRNA-seq"
return adata
[docs]
def load_aml(
data_dir: str | None = None,
processed_name: str = "gse116256_aml_processed.h5ad",
max_cells_per_sample: int | None = None,
seed: int = 42,
allow_download: bool = False,
force_reprocess: bool = False,
) -> ad.AnnData:
"""Load the van Galen AML chemotherapy dataset (GSE116256).
This dataset contains pre/post-chemotherapy bone marrow samples from AML
patients with cell-type annotations and treatment-response metadata.
Parameters
----------
data_dir : str
Directory containing (or to store) the data files.
Raw files go in ``<data_dir>/raw/`` and the processed cache in
``<data_dir>/processed/``.
processed_name : str
Filename for the cached processed h5ad file.
max_cells_per_sample : int
Maximum cells to keep per sample after subsampling.
seed : int
Random seed for reproducibility.
allow_download : bool
If True, download raw data from GEO when not found locally.
force_reprocess : bool
If True, reprocess even when a cached file exists.
Returns
-------
AnnData
The processed AnnData object.
Notes
-----
The raw data is automatically downloaded from GEO when
``allow_download=True``:
https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE116256
Reference: van Galen et al., Cell 2019.
Examples
--------
>>> adata = sctrial.load_aml(allow_download=True)
"""
data_dir = data_dir or _default_data_dir("aml")
data_dir_path = Path(data_dir)
processing_params = {
"version": "v1",
"max_cells_per_sample": max_cells_per_sample,
"seed": seed,
}
# ── Try to load cached processed file ─────────────────────────────
processed_path = data_dir_path / "processed" / processed_name
if not force_reprocess and processed_path.exists():
adata = ad.read_h5ad(processed_path)
prev = adata.uns.get("processing_params", {})
if prev:
if _params_match(prev, processing_params):
logger.info(
f"Loaded AML dataset (GSE116256): {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
logger.info("Processed file parameters differ; reprocessing.")
else:
warnings.warn(
"Cached file lacks processing_params metadata; cannot verify it matches "
"current settings. Consider reprocessing with force_reprocess=True.",
UserWarning,
stacklevel=2,
)
logger.info(
f"Loaded AML dataset (GSE116256): {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
# ── Locate or download raw files ──────────────────────────────────
try:
import scanpy # noqa: F401
except ImportError:
raise ImportError(
"scanpy is required for AML dataset processing. "
"Install with: pip install sctrial[plots] or pip install scanpy"
) from None
raw_dir = data_dir_path / "raw"
found_raw = raw_dir if raw_dir.is_dir() and list(raw_dir.glob("GSM*_*.dem.txt.gz")) else None
if found_raw is None:
if not allow_download:
raise FileNotFoundError(
f"AML dataset not found. Searched for raw files and "
f"'{processed_name}' in several locations including {data_dir_path}. "
"Download from GEO: "
"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE116256 "
"or set allow_download=True to fetch automatically."
)
# Download tar from GEO and extract
raw_dir.mkdir(parents=True, exist_ok=True)
tar_url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE116256&format=file"
tar_dest = raw_dir / "GSE116256_RAW.tar"
_download_file(tar_url, tar_dest, "GSE116256 supplementary tar")
logger.info("Extracting raw files...")
with tarfile.open(tar_dest, "r") as tf:
tf.extractall(path=raw_dir)
# Clean up tar to save disk space
tar_dest.unlink(missing_ok=True)
found_raw = raw_dir
# ── Process raw data ──────────────────────────────────────────────
logger.info("Processing raw AML data (this may take several minutes)...")
adata = _process_aml_raw(found_raw, max_cells_per_sample=max_cells_per_sample, seed=seed)
adata.uns["processing_params"] = processing_params
# Save cache
processed_dir = data_dir_path / "processed"
processed_dir.mkdir(parents=True, exist_ok=True)
out_path = processed_dir / processed_name
adata.write_h5ad(out_path)
logger.info(f"Saved processed AML file: {out_path}")
logger.info(f"Loaded AML dataset (GSE116256): {adata.n_obs:,} cells, {adata.n_vars:,} genes")
return adata
def _parse_cart_sample_info(filename: str) -> tuple[str | None, str | None, int]:
"""Parse patient and timepoint from CAR-T filename."""
m = re.search(r"GSM\d+_(P\d+)_(.+)_rna\.csv\.gz", filename)
if not m:
return None, None, -1
patient = m.group(1)
tp_raw = m.group(2)
if "Leukapheresis" in tp_raw:
return patient, "Leukapheresis", 0
if "4wk" in tp_raw:
return patient, "4wk_post", 28
if "6mo" in tp_raw:
return patient, "6mo_post", 180
if "12mo" in tp_raw:
return patient, "12mo_post", 365
return patient, tp_raw, -1
def _process_cart_raw(
raw_dir: Path,
max_cells_per_sample: int | None = None,
seed: int = 42,
) -> ad.AnnData:
"""Process raw GSE290722 CAR-T files into an AnnData object.
Reads per-sample expression CSV files, combines them, applies QC,
normalisation, computes embeddings, and annotates cell types via
Leiden clustering + Wilcoxon marker scoring.
"""
import scanpy as sc
rna_files = sorted(raw_dir.glob("GSM*_*_rna.csv.gz"))
if not rna_files:
raise ValueError(f"No RNA expression files found in {raw_dir}")
samples: list[dict] = []
for f in rna_files:
patient, timepoint, days = _parse_cart_sample_info(f.name)
if patient is None:
continue
samples.append(
{
"file": f,
"patient": patient,
"timepoint": timepoint,
"days": days,
"sample_id": f"{patient}_{timepoint}",
}
)
if not samples:
raise ValueError("No valid CAR-T samples found.")
logger.info(f"Found {len(samples)} CAR-T RNA expression files")
# Load first sample to get gene names
first_df = pd.read_csv(samples[0]["file"])
gene_col = first_df.columns[-1]
gene_names = first_df[gene_col].tolist()
logger.info(f"Detected {len(gene_names)} genes")
# Load all samples
rng = np.random.default_rng(seed)
all_X: list[np.ndarray] = []
all_obs: list[pd.DataFrame] = []
for sample in samples:
df = pd.read_csv(sample["file"])
gc_col = df.columns[-1]
genes = df[gc_col].tolist()
expr_df = df.drop(columns=[gc_col])
cells = list(expr_df.columns)
X = expr_df.values.T.astype(np.float32) # cells × genes
# Re-order genes if needed
if genes != gene_names:
if set(genes) == set(gene_names):
g2i = {g: i for i, g in enumerate(genes)}
new_order = [g2i[g] for g in gene_names]
X = X[:, new_order]
else:
logger.warning(f"Gene mismatch for {sample['sample_id']}, skipping")
continue
# Subsample
if max_cells_per_sample and X.shape[0] > max_cells_per_sample:
idx = rng.choice(X.shape[0], max_cells_per_sample, replace=False)
X = X[idx]
cells = [cells[i] for i in idx]
obs_df = pd.DataFrame(
{
"cell_id": cells,
"sample_id": sample["sample_id"],
"patient_id": sample["patient"],
"timepoint": sample["timepoint"],
"days_post_treatment": sample["days"],
},
index=[f"{sample['sample_id']}_{c}" for c in cells],
)
all_X.append(X)
all_obs.append(obs_df)
del df, expr_df
gc.collect()
if not all_X:
raise ValueError("No CAR-T samples could be loaded.")
X_combined = sp.csr_matrix(np.vstack(all_X))
obs_combined = pd.concat(all_obs, axis=0)
adata = ad.AnnData(X=X_combined, obs=obs_combined, var=pd.DataFrame(index=gene_names))
del all_X
gc.collect()
# ── QC and normalisation ──────────────────────────────────────────
adata.layers["counts"] = adata.X.copy()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
n_before = adata.n_obs
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_cells(adata, max_genes=6000)
if "pct_counts_mt" in adata.obs.columns:
adata = adata[adata.obs["pct_counts_mt"] < 20].copy()
sc.pp.filter_genes(adata, min_cells=10)
logger.info(f"QC: {n_before:,} → {adata.n_obs:,} cells, {adata.n_vars:,} genes")
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.layers["log1p_norm"] = adata.X.copy()
# ── Standardised sctrial obs columns ──────────────────────────────
obs = adata.obs
obs["participant_id"] = obs["patient_id"].astype(str)
obs["visit"] = obs["timepoint"].apply(lambda t: "Pre" if t == "Leukapheresis" else "Post")
obs["is_paired"] = False # filled below
obs["response"] = "CAR-T"
adata.obs = obs
# ── Embeddings + clustering ───────────────────────────────────────
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat", subset=False)
sc.tl.pca(adata, n_comps=50, use_highly_variable=True)
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=30)
sc.tl.umap(adata)
# Clustering — try Leiden, fallback to KMeans
try:
sc.tl.leiden(adata, resolution=0.8)
except Exception:
from sklearn.cluster import MiniBatchKMeans
X_pca = adata.obsm["X_pca"][:, :20]
kmeans = MiniBatchKMeans(n_clusters=15, random_state=seed, batch_size=1000)
clusters = kmeans.fit_predict(X_pca)
adata.obs["leiden"] = pd.Categorical([str(c) for c in clusters])
# ── Cell type annotation (marker scoring) ─────────────────────────
logger.info("Annotating CAR-T cell types...")
adata.obs["cell_type"] = _annotate_immune_celltypes(adata)
# Mark paired patients
patients = adata.obs["participant_id"].unique()
paired_patients = []
for p in patients:
tps = set(adata.obs.loc[adata.obs["participant_id"] == p, "timepoint"].unique())
if "Leukapheresis" in tps and len(tps) > 1:
paired_patients.append(p)
adata.obs["is_paired"] = adata.obs["participant_id"].isin(paired_patients)
adata.uns["paired_patients"] = paired_patients
adata.uns["dataset"] = "GSE290722"
adata.uns["trial"] = "ZUMA-1"
adata.uns["description"] = "CAR-T therapy longitudinal scRNA-seq"
return adata
[docs]
def load_cart(
data_dir: str | None = None,
processed_name: str = "gse290722_cart_processed.h5ad",
max_cells_per_sample: int | None = None,
seed: int = 42,
allow_download: bool = False,
force_reprocess: bool = False,
) -> ad.AnnData:
"""Load the CAR-T cell therapy dataset (GSE290722).
This dataset contains pre/post-CAR-T infusion samples with cell-type
annotations and treatment-response metadata from the ZUMA-1 trial.
Parameters
----------
data_dir : str
Directory containing (or to store) the data files.
Raw files go in ``<data_dir>/raw/`` and the processed cache in
``<data_dir>/processed/``.
processed_name : str
Filename for the cached processed h5ad file.
max_cells_per_sample : int
Maximum cells to keep per sample after subsampling.
seed : int
Random seed for reproducibility.
allow_download : bool
If True, download raw data from GEO when not found locally.
force_reprocess : bool
If True, reprocess even when a cached file exists.
Returns
-------
AnnData
The processed AnnData object.
Notes
-----
The raw data is automatically downloaded from GEO when
``allow_download=True``:
https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE290722
Reference: GSE290722 CAR-T therapy dataset (ZUMA-1 trial).
Examples
--------
>>> adata = sctrial.load_cart(allow_download=True)
"""
data_dir = data_dir or _default_data_dir("cart")
data_dir_path = Path(data_dir)
processing_params = {
"version": "v1",
"max_cells_per_sample": max_cells_per_sample,
"seed": seed,
}
# ── Try to load cached processed file ─────────────────────────────
processed_path = data_dir_path / "processed" / processed_name
if not force_reprocess and processed_path.exists():
adata = ad.read_h5ad(processed_path)
prev = adata.uns.get("processing_params", {})
if prev:
if _params_match(prev, processing_params):
logger.info(
f"Loaded CAR-T dataset (GSE290722): {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
logger.info("Processed file parameters differ; reprocessing.")
else:
warnings.warn(
"Cached file lacks processing_params metadata; cannot verify it matches "
"current settings. Consider reprocessing with force_reprocess=True.",
UserWarning,
stacklevel=2,
)
logger.info(
f"Loaded CAR-T dataset (GSE290722): {adata.n_obs:,} cells, {adata.n_vars:,} genes"
)
return adata
# ── Locate or download raw files ──────────────────────────────────
try:
import scanpy # noqa: F401
except ImportError:
raise ImportError(
"scanpy is required for CAR-T dataset processing. "
"Install with: pip install sctrial[plots] or pip install scanpy"
) from None
raw_dir = data_dir_path / "raw"
found_raw = raw_dir if raw_dir.is_dir() and list(raw_dir.glob("GSM*_*_rna.csv.gz")) else None
if found_raw is None:
if not allow_download:
raise FileNotFoundError(
f"CAR-T dataset not found. Searched for raw files and "
f"'{processed_name}' in several locations including {data_dir_path}. "
"Download from GEO: "
"https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE290722 "
"or set allow_download=True to fetch automatically."
)
raw_dir.mkdir(parents=True, exist_ok=True)
tar_url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE290722&format=file"
tar_dest = raw_dir / "GSE290722_RAW.tar"
_download_file(tar_url, tar_dest, "GSE290722 supplementary tar")
logger.info("Extracting raw files...")
with tarfile.open(tar_dest, "r") as tf:
tf.extractall(path=raw_dir)
tar_dest.unlink(missing_ok=True)
found_raw = raw_dir
# ── Process raw data ──────────────────────────────────────────────
logger.info("Processing raw CAR-T data (this may take several minutes)...")
adata = _process_cart_raw(found_raw, max_cells_per_sample=max_cells_per_sample, seed=seed)
adata.uns["processing_params"] = processing_params
processed_dir = data_dir_path / "processed"
processed_dir.mkdir(parents=True, exist_ok=True)
out_path = processed_dir / processed_name
adata.write_h5ad(out_path)
logger.info(f"Saved processed CAR-T file: {out_path}")
logger.info(f"Loaded CAR-T dataset (GSE290722): {adata.n_obs:,} cells, {adata.n_vars:,} genes")
return adata
[docs]
def harmonize_response(adata: ad.AnnData, *, force: bool = False) -> ad.AnnData:
"""Create a ``response_harmonized`` column with consistent labels.
Maps various responder/non-responder column names and label formats
(e.g. "R"/"NR", "Responder"/"Non-responder") to a standard vocabulary:
``"Responder"`` and ``"Non-responder"``.
Parameters
----------
adata : AnnData
Must contain one of: ``response``, ``Response``, or
``clinical_response`` in ``.obs``.
force : bool
If True, recompute even when the column already exists.
Returns
-------
AnnData
The input AnnData with ``response_harmonized`` added to ``.obs``.
"""
if force and "response_harmonized" in adata.obs.columns:
del adata.obs["response_harmonized"]
if "response_harmonized" in adata.obs.columns:
if "participant_id" in adata.obs.columns:
n_per = adata.obs.groupby("participant_id")["response_harmonized"].nunique()
if (n_per > 1).any():
pid_resp = adata.obs.groupby("participant_id")["response_harmonized"].agg(
lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0]
)
adata.obs["response_harmonized"] = adata.obs["participant_id"].map(pid_resp)
return adata
mapping = {
"responder": "Responder",
"Responder": "Responder",
"R": "Responder",
"non-responder": "Non-responder",
"Non-responder": "Non-responder",
"NR": "Non-responder",
"nonresponder": "Non-responder",
}
for col in ("response", "Response", "clinical_response"):
if col in adata.obs.columns:
adata.obs["response_harmonized"] = (
adata.obs[col].astype(str).map(lambda x: mapping.get(x, x))
)
break
if "response_harmonized" in adata.obs.columns and "participant_id" in adata.obs.columns:
pid_resp = adata.obs.groupby("participant_id")["response_harmonized"].agg(
lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0]
)
adata.obs["response_harmonized"] = adata.obs["participant_id"].map(pid_resp)
return adata
[docs]
def ensure_fdr(df: pd.DataFrame, p_col: str = "p_time", fdr_col: str = "FDR_time") -> pd.DataFrame:
"""Add Benjamini-Hochberg FDR column for a p-value column.
Parameters
----------
df
DataFrame containing the p-value column.
p_col
Column name in `df` to use for p-value column.
fdr_col
Column name in `df` to use for FDR-corrected p-value column.
Returns
-------
pd.DataFrame
A copy of the DataFrame with the FDR-corrected p-value column added.
"""
if df.empty:
return df
if fdr_col in df.columns:
return df
if p_col in df.columns:
mask = df[p_col].notna()
df[fdr_col] = np.nan
if mask.sum() > 0:
df.loc[mask, fdr_col] = multipletests(df.loc[mask, p_col], method="fdr_bh")[1]
return df