Source code for sctrial.workflow
"""End-to-end trial analysis workflow: TrialWorkflow and convenience pipeline."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Literal
from anndata import AnnData
from .design import TrialDesign
from .preprocessing import add_log1p_cpm_layer
from .scoring import score_gene_sets
from .stats.did import DiDConfig, did_table
[docs]
@dataclass
class TrialWorkflow:
"""Fluent API for common sctrial workflows.
This class provides a minimal chainable interface for common tasks:
preprocessing, gene scoring, and DiD analysis.
"""
adata: AnnData
last_result: Any | None = None
[docs]
def add_log1p_cpm_layer(self, counts_layer: str = "counts") -> TrialWorkflow:
"""Add a log1p-CPM layer to the workflow AnnData.
Parameters
----------
counts_layer
Layer name in `adata.layers` to use for counts data.
Returns
-------
TrialWorkflow
The workflow object with the new log1p-CPM layer added.
"""
self.adata = add_log1p_cpm_layer(self.adata, counts_layer=counts_layer)
return self
[docs]
def score_gene_sets(
self,
gene_sets: dict[str, list[str]],
*,
layer: str | None = None,
method: Literal["zmean", "mean"] = "zmean",
prefix: str = "ms_",
min_genes: int = 5,
overwrite: bool = False,
) -> TrialWorkflow:
"""Score gene sets and store module scores in `adata.obs`.
Parameters
----------
gene_sets
Dictionary mapping set names to lists of gene names.
Each value must be a ``list`` (not a bare string). Duplicate gene
names within a set are automatically removed.
layer
Layer name in `adata.layers` to use for expression data.
If None, uses `adata.X`.
For log1p-CPM workflows, use layer="log1p_cpm".
method
Scoring method. ``"zmean"`` (default) z-scores each gene across
cells then averages; ``"mean"`` uses raw mean expression.
prefix
Prefix to add to column names (e.g., ``ms_`` for module scores).
min_genes
Minimum number of genes from the set that must be present in the data.
Default is 5.
overwrite
If False (default), skip gene sets that already have a column in
``adata.obs``.
Returns
-------
TrialWorkflow
The workflow object with the gene sets scored and module scores stored in ``adata.obs``.
"""
self.adata = score_gene_sets(
self.adata,
gene_sets,
layer=layer,
method=method,
prefix=prefix,
min_genes=min_genes,
overwrite=overwrite,
)
return self
[docs]
def did_table(
self,
features: list[str],
*,
design: TrialDesign,
visits: tuple[str, str],
config: DiDConfig | None = None,
) -> TrialWorkflow:
"""Run DiD and store the result on the workflow.
Parameters
----------
features
List of feature names to analyze.
design
Trial design object.
visits
Tuple of (baseline, followup) visit labels.
config
Configuration for the DiD analysis. See `sctrial.stats.did.DiDConfig` for more details.
If None, uses the default configuration.
Returns
-------
TrialWorkflow
The workflow object with the DiD result stored in ``self.last_result``.
"""
self.last_result = did_table(
self.adata,
features=features,
design=design,
visits=visits,
config=config,
)
return self
[docs]
def result(self) -> Any:
"""Return the last result computed in the workflow.
Returns
-------
Any
The last result computed in the workflow (e.g., a DataFrame).
"""
return self.last_result
[docs]
def workflow(adata: AnnData) -> TrialWorkflow:
"""Create a TrialWorkflow for fluent chaining.
Parameters
----------
adata
AnnData object with trial data.
Returns
-------
TrialWorkflow
A TrialWorkflow object with the AnnData object set.
"""
return TrialWorkflow(adata=adata)