from __future__ import annotations
import logging
from collections.abc import Sequence
import numpy as np
import pandas as pd
from anndata import AnnData
from ..design import TrialDesign
from ._extract import extract_gene_matrix
from ._utils import aggregate_features, apply_fdr
logger = logging.getLogger(__name__)
__all__ = ["hazard_regression_with_features"]
[docs]
def hazard_regression_with_features(
adata: AnnData,
features: Sequence[str],
design: TrialDesign,
time_col: str,
event_col: str,
*,
visit: str | None = None,
layer: str | None = None,
agg: str = "mean",
covariates: Sequence[str] | None = None,
standardize: bool = True,
fdr: bool = True,
) -> pd.DataFrame:
"""Cox proportional hazards regression using single-cell features.
This function aggregates features to the participant level and fits a Cox model
for each feature, optionally adjusting for covariates.
Parameters
----------
adata
AnnData object.
features
Features to test (obs columns or gene names in var_names).
design
TrialDesign with participant and visit columns.
time_col
Column in adata.obs containing survival time.
event_col
Column in adata.obs indicating event (1) or censoring (0).
visit
Optional visit label to subset prior to aggregation (e.g., baseline).
layer
Expression layer to use for gene features.
agg
Aggregation function for features ("mean", "median", or "pct_pos").
covariates
Optional covariate columns to include in the Cox model.
standardize
If True, z-score each feature before fitting.
fdr
If True, add FDR correction across all features.
Returns
-------
pd.DataFrame
One row per feature with hazard ratio (HR), confidence interval, and p-value.
"""
try:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError as _ConvergenceError
except ImportError as exc: # pragma: no cover
raise ImportError("lifelines is required for survival analysis") from exc
obs = adata.obs.copy()
if visit is not None:
obs = obs[obs[design.visit_col] == visit].copy()
base_cols = [design.participant_col, time_col, event_col]
if covariates:
base_cols += list(covariates)
missing = [c for c in base_cols if c not in obs.columns]
if missing:
raise KeyError(f"Missing required columns in adata.obs: {missing}")
# Validate survival times are positive before building df
valid_mask = obs[time_col] > 0
if not valid_mask.all():
n_bad = int((~valid_mask).sum())
obs = obs[valid_mask].copy()
if obs.empty:
raise ValueError(
f"All {n_bad} observations have non-positive survival times in '{time_col}'."
)
df = obs[base_cols].copy()
obs_feats = [f for f in features if f in obs.columns]
gene_feats = [f for f in features if f in adata.var_names and f not in obs.columns]
for feat in obs_feats:
df[feat] = obs[feat].values
if gene_feats:
mat = extract_gene_matrix(adata[obs.index], gene_feats, layer=layer)
df_genes = pd.DataFrame(mat, columns=gene_feats, index=df.index)
df = pd.concat([df, df_genes], axis=1)
grp = design.participant_col
# Enforce constant time/event/covariates within participant
for col in base_cols[1:]:
if df.groupby(grp, observed=True)[col].nunique().max() > 1:
raise ValueError(f"Column '{col}' varies within participant; cannot aggregate.")
# Aggregate features to participant level
feat_cols = list(obs_feats) + list(gene_feats)
df_feats = aggregate_features(df, [grp], feat_cols, agg)
df_meta = df.groupby(grp, observed=True)[base_cols[1:]].first().reset_index()
df_part = pd.merge(df_meta, df_feats, on=grp, how="inner")
# Ensure survival columns are retained after merge
for col in base_cols[1:]:
if col not in df_part.columns:
df_part[col] = df_meta.set_index(grp).loc[df_part[grp], col].values
results = []
for feat in feat_cols:
df_model = df_part[
[time_col, event_col, feat] + (list(covariates) if covariates else [])
].dropna()
if df_model.shape[0] < 5:
results.append({"feature": feat, "HR": np.nan, "p": np.nan, "n": df_model.shape[0]})
continue
if standardize:
std = df_model[feat].std(ddof=1)
if not np.isfinite(std) or std < 1e-12:
results.append({"feature": feat, "HR": np.nan, "p": np.nan, "n": df_model.shape[0]})
continue
df_model[feat] = (df_model[feat] - df_model[feat].mean()) / std
try:
cph = CoxPHFitter()
cph.fit(df_model, duration_col=time_col, event_col=event_col)
except (
_ConvergenceError,
np.linalg.LinAlgError,
):
# Expected: convergence failure or singular matrix from separation
results.append({"feature": feat, "HR": np.nan, "p": np.nan, "n": df_model.shape[0]})
continue
except Exception:
# Unexpected error — log and propagate so data/config bugs
# are not silently swallowed as NaN results.
logger.exception("Unexpected error fitting Cox model for feature '%s'", feat)
raise
coef = cph.params_[feat]
hr = float(np.exp(coef))
ci = cph.confidence_intervals_.loc[feat]
hr_low = float(np.exp(ci.iloc[0]))
hr_high = float(np.exp(ci.iloc[1]))
p_val = float(cph.summary.loc[feat, "p"])
# Detect unstable estimates from separation (infinite CI bounds)
if not (np.isfinite(hr_low) and np.isfinite(hr_high) and np.isfinite(hr)):
results.append({"feature": feat, "HR": np.nan, "p": np.nan, "n": df_model.shape[0]})
continue
results.append(
{
"feature": feat,
"HR": hr,
"HR_low": hr_low,
"HR_high": hr_high,
"p": p_val,
"n": df_model.shape[0],
}
)
out = pd.DataFrame(results).sort_values("p")
if fdr:
out = apply_fdr(out, p_col="p", fdr_col="FDR")
return out.reset_index(drop=True)