Source code for sctrial.stats.mixed_effects

"""Mixed effects models for trial-aware single-cell inference.

This module provides linear mixed effects model alternatives to the fixed effects
approach used in the standard DiD functions.

Mathematical Background
-----------------------
**Fixed Effects vs Mixed Effects**:

Fixed Effects (current did_table approach):
    Y_it = α_i + β₁×Post + β₂×Treat×Post + ε_it

    - α_i: participant-specific intercept (absorbed as dummy variables)
    - Pros: No distributional assumptions on random effects
    - Cons: Cannot estimate participant-level variance; loses df

Mixed Effects:
    Y_it = (α + u_i) + β₁×Post + β₂×Treat×Post + ε_it
    u_i ~ N(0, σ²_u)  (random intercept)
    ε_it ~ N(0, σ²_ε)

    - Pros: Estimates variance components; more efficient when assumptions hold
    - Cons: Requires distributional assumptions; can be biased if misspecified

**When to Use Each**:

Use Fixed Effects when:
- Participant effects may correlate with treatment assignment
- Small number of time points
- Interest is purely in within-participant changes

Use Mixed Effects when:
- Random sample of participants from a population
- Interest in generalizing to new participants
- Want to estimate variance components
- Large number of clusters with few observations each

**Recommendation**: For randomized trials, both approaches yield similar
treatment effect estimates. Fixed effects is more robust to misspecification.
Report both as sensitivity analysis.

Model Specification
-------------------
The full DiD mixed model:

    Y_ijt = β₀ + β₁×Treat_j + β₂×Post_t + β₃×(Treat×Post)_jt + u_j + ε_ijt

    Random effects: u_j ~ N(0, σ²_u)  (participant)
    Residuals: ε_ijt ~ N(0, σ²_ε)

The DiD effect is β₃, same as in fixed effects.

Optional random slopes:
    Y_ijt = β₀ + β₁×Treat_j + β₂×Post_t + β₃×(Treat×Post)_jt + u_j + v_j×Post_t + ε_ijt

    This allows the time effect to vary by participant.
"""

from __future__ import annotations

import warnings
from collections.abc import Sequence

import numpy as np
import pandas as pd
from anndata import AnnData
from statsmodels.tools.sm_exceptions import ConvergenceWarning

from ..design import TrialDesign
from ._utils import aggregate_features, apply_fdr, standardize_series
from .did import AggregateFunc, AggregateMode, _ensure_paired

__all__ = [
    "did_mixed",
    "did_table_mixed",
    "compare_fixed_vs_mixed",
]


[docs] def did_mixed( df: pd.DataFrame, outcome: str, participant_col: str, time_col: str, arm_col: str, arm_treated: str, random_slope: bool = False, covariates: Sequence[str] | None = None, visits: tuple[str, str] | None = None, ) -> dict: """Fit a single DiD mixed effects model. Model specification: Y ~ Treat + Post + Treat:Post + (1 | Participant) Or with random slopes: Y ~ Treat + Post + Treat:Post + (1 + Post | Participant) Parameters ---------- df DataFrame with outcome and design columns. outcome Name of the outcome variable column. participant_col Column identifying participants (random effect grouping). time_col Column with time/visit indicators (binary: 0=pre, 1=post). arm_col Column with treatment arm labels. arm_treated Label for the treated arm. random_slope If True, include random slope for time. covariates Additional covariates to include as fixed effects. Returns ------- dict Results dictionary with keys: - beta_DiD: DiD coefficient - se_DiD: Standard error - p_DiD: P-value - ci_lower, ci_upper: 95% CI - var_participant: Random intercept variance - var_residual: Residual variance - icc: Intraclass correlation - converged: Model convergence status """ try: import statsmodels.formula.api as smf except ImportError: raise ImportError("statsmodels is required for mixed effects models") df = df.copy() # Encode binary variables df["arm_bin"] = (df[arm_col] == arm_treated).astype(int) # Ensure time is numeric time_vals = df[time_col].unique() if len(time_vals) != 2: return { "beta_DiD": np.nan, "se_DiD": np.nan, "p_DiD": np.nan, "ci_lower": np.nan, "ci_upper": np.nan, "var_participant": np.nan, "var_residual": np.nan, "icc": np.nan, "converged": False, "error": "Requires exactly 2 time points", } # Map time to 0/1 (pre=0, post=1) if visits is not None: time_sorted = list(visits) else: time_sorted = sorted(time_vals) df["post_num"] = df[time_col].map({time_sorted[0]: 0, time_sorted[1]: 1}).astype(float) # Build formula fixed_part = f"{outcome} ~ arm_bin + post_num + arm_bin:post_num" if covariates: for cov in covariates: if cov in df.columns: if pd.api.types.is_numeric_dtype(df[cov]): fixed_part += f" + {cov}" else: fixed_part += f" + C({cov})" # Random effects specification if random_slope: re_formula = "1 + post_num" else: re_formula = "1" # Fit model try: model = smf.mixedlm( fixed_part, df, groups=df[participant_col], re_formula=re_formula, ) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always", ConvergenceWarning) fit = model.fit(method="powell", maxiter=500, full_output=False) # Determine convergence: trust fit.converged as primary indicator, # fall back to caught warnings only if the attribute is unavailable. if hasattr(fit, "converged") and fit.converged is not None: converged = fit.converged else: converged = not bool(caught) # Fallback optimizer: retry with lbfgs if Powell did not converge if not converged: with warnings.catch_warnings(record=True) as caught2: warnings.simplefilter("always", ConvergenceWarning) fit = model.fit(method="lbfgs", maxiter=500, full_output=False) if hasattr(fit, "converged") and fit.converged is not None: converged = fit.converged else: converged = not bool(caught2) except (ValueError, np.linalg.LinAlgError) as e: return { "beta_DiD": np.nan, "se_DiD": np.nan, "p_DiD": np.nan, "ci_lower": np.nan, "ci_upper": np.nan, "var_participant": np.nan, "var_residual": np.nan, "icc": np.nan, "converged": False, "error": str(e), } # Extract DiD coefficient did_term = "arm_bin:post_num" if did_term not in fit.params.index: did_term = "post_num:arm_bin" # Alternative ordering if did_term in fit.params.index: beta = float(fit.params[did_term]) se = float(fit.bse[did_term]) pval = float(fit.pvalues[did_term]) ci = fit.conf_int().loc[did_term] ci_lower, ci_upper = float(ci[0]), float(ci[1]) else: beta = se = pval = ci_lower = ci_upper = np.nan # Variance components var_resid = float(fit.scale) var_intercept = np.nan var_slope = np.nan if hasattr(fit, "cov_re") and fit.cov_re is not None: cov_re = fit.cov_re var_intercept = float(cov_re.iloc[0, 0]) # Extract slope variance if random slopes were fitted (2x2 matrix) if cov_re.shape[0] > 1 and cov_re.shape[1] > 1: var_slope = float(cov_re.iloc[1, 1]) # ICC (based on intercept variance only - standard definition) if not np.isnan(var_intercept) and (var_intercept + var_resid) > 0: icc = var_intercept / (var_intercept + var_resid) else: icc = np.nan result = { "beta_DiD": beta, "se_DiD": se, "p_DiD": pval, "ci_lower": ci_lower, "ci_upper": ci_upper, "var_participant": var_intercept, "var_residual": var_resid, "icc": icc, "converged": converged, } # Add slope variance if available (when random_slope=True) if not np.isnan(var_slope): result["var_slope"] = var_slope return result
[docs] def did_table_mixed( adata: AnnData, features: Sequence[str], design: TrialDesign, visits: tuple[str, str], exclude_crossovers: bool = True, aggregate: AggregateMode = "participant_visit", layer: str | None = None, agg: AggregateFunc = "mean", standardize: bool = True, random_slope: bool = False, covariates: Sequence[str] | None = None, ) -> pd.DataFrame: """Run DiD analysis using mixed effects models for multiple features. This is the mixed effects analog of did_table. The key difference is that participant effects are modeled as random draws from a normal distribution rather than fixed parameters. Model: Y_it ~ Treat + Post + Treat×Post + (1 | Participant) H₀: β_{Treat×Post} = 0 (no differential treatment effect) Parameters ---------- adata AnnData object containing expression data. features List of features (genes or module scores) to test. design TrialDesign object specifying column mappings. visits Tuple of (baseline, followup) visit labels. exclude_crossovers Whether to exclude crossover participants. aggregate Aggregation mode: "participant_visit" or "cell". layer Expression layer to use (None for adata.X). agg Aggregation function ("mean", "median"). standardize Whether to z-score outcomes before fitting. random_slope Include random slope for time (allows time effect to vary by participant). covariates Additional covariates to include. Returns ------- pd.DataFrame Results with columns: - feature: Feature name - beta_DiD, se_DiD, p_DiD, FDR_DiD: DiD statistics - ci_lower, ci_upper: 95% CI - var_participant, var_residual, icc: Variance components - n_units: Number of participants - converged: Model convergence Examples -------- >>> res_mixed = did_table_mixed( ... adata, features=genes, design=design, visits=("Pre", "Post") ... ) >>> print(res_mixed[["feature", "beta_DiD", "icc", "converged"]]) See Also -------- did_table : Fixed effects version (recommended for most applications). compare_fixed_vs_mixed : Compare both approaches. """ from ..adata_tools import subset_primary from ._extract import extract_gene_matrix # Subset to analysis population ad = subset_primary(adata, design, visits, exclude_crossovers=exclude_crossovers) # Build dataframe if design.arm_col is None: raise ValueError("mixed-effects DiD requires a two-arm design (arm_col must not be None)") obs = ad.obs.copy() cols = [design.participant_col, design.visit_col, design.arm_col] if design.celltype_col is not None: cols.append(design.celltype_col) df = obs[cols].copy() # Add feature values obs_feats = [f for f in features if f in ad.obs.columns] gene_feats = [f for f in features if f in ad.var_names and f not in ad.obs.columns] missing = [f for f in features if f not in ad.obs.columns and f not in ad.var_names] if missing: raise KeyError(f"Features not found in obs or var_names: {missing[:5]}") for feat in obs_feats: df[feat] = ad.obs[feat].values if gene_feats: mat = extract_gene_matrix(ad, gene_feats, layer=layer) df_genes = pd.DataFrame(mat, columns=gene_feats, index=df.index) df = pd.concat([df, df_genes], axis=1) # Aggregate if requested if aggregate == "participant_visit": grp_cols = [design.participant_col, design.visit_col, design.arm_col] df = aggregate_features(df, grp_cols=grp_cols, features=list(features), agg=agg) # Ensure paired data df = _ensure_paired(df, unit=design.participant_col, time=design.visit_col, visits=visits) n_units = df[design.participant_col].nunique() # Run mixed model for each feature rows = [] for feat in features: df_feat = df.copy() if standardize: y_std, ok = standardize_series(df_feat, feat, min_std=1e-12) if not ok: rows.append( { "feature": feat, "beta_DiD": np.nan, "se_DiD": np.nan, "p_DiD": np.nan, "ci_lower": np.nan, "ci_upper": np.nan, "var_participant": np.nan, "var_residual": np.nan, "icc": np.nan, "n_units": n_units, "converged": False, } ) continue df_feat["outcome_std"] = y_std else: df_feat["outcome_std"] = df_feat[feat].astype(float) result = did_mixed( df_feat, outcome="outcome_std", participant_col=design.participant_col, time_col=design.visit_col, arm_col=design.arm_col, arm_treated=design.arm_treated, random_slope=random_slope, covariates=covariates, visits=visits, ) rows.append( { "feature": feat, "beta_DiD": result["beta_DiD"], "se_DiD": result["se_DiD"], "p_DiD": result["p_DiD"], "ci_lower": result["ci_lower"], "ci_upper": result["ci_upper"], "var_participant": result["var_participant"], "var_residual": result["var_residual"], "icc": result["icc"], "n_units": n_units, "converged": result["converged"], } ) res = pd.DataFrame(rows) res = apply_fdr(res, p_col="p_DiD", fdr_col="FDR_DiD") return res
[docs] def compare_fixed_vs_mixed( adata: AnnData, features: Sequence[str], design: TrialDesign, visits: tuple[str, str], **kwargs, ) -> pd.DataFrame: """Compare fixed effects and mixed effects DiD results. This function runs both approaches and returns a combined DataFrame for comparison. This is useful for sensitivity analysis. Parameters ---------- adata AnnData object. features Features to test. design TrialDesign object. visits (baseline, followup) visit labels. **kwargs Additional arguments passed to both did_table and did_table_mixed. Returns ------- pd.DataFrame Combined results with columns: - feature - beta_fixed, se_fixed, p_fixed: Fixed effects results - beta_mixed, se_mixed, p_mixed, icc: Mixed effects results - beta_diff: Difference in estimates - agreement: Whether both methods agree on significance direction Examples -------- >>> comparison = compare_fixed_vs_mixed( ... adata, features=genes, design=design, visits=("Pre", "Post") ... ) >>> # Check agreement >>> print(f"Methods agree: {comparison['agreement'].mean():.0%}") """ from .did import did_table # Split kwargs: 'random_slope' is mixed-only and would cause TypeError in did_table _MIXED_ONLY_KEYS = {"random_slope"} fixed_kwargs = {k: v for k, v in kwargs.items() if k not in _MIXED_ONLY_KEYS} # Run fixed effects res_fixed = did_table(adata, features, design, visits, **fixed_kwargs) # Run mixed effects res_mixed = did_table_mixed(adata, features, design, visits, **kwargs) # Merge results comparison = res_fixed[["feature", "beta_DiD", "se_DiD", "p_DiD"]].copy() comparison = comparison.rename( columns={ "beta_DiD": "beta_fixed", "se_DiD": "se_fixed", "p_DiD": "p_fixed", } ) mixed_cols = res_mixed[["feature", "beta_DiD", "se_DiD", "p_DiD", "icc"]].copy() mixed_cols = mixed_cols.rename( columns={ "beta_DiD": "beta_mixed", "se_DiD": "se_mixed", "p_DiD": "p_mixed", } ) comparison = comparison.merge(mixed_cols, on="feature", how="outer") # Compute agreement metrics comparison["beta_diff"] = comparison["beta_fixed"] - comparison["beta_mixed"] # Agreement on direction and significance alpha = 0.05 fixed_sig = comparison["p_fixed"] < alpha mixed_sig = comparison["p_mixed"] < alpha same_direction = np.sign(comparison["beta_fixed"]) == np.sign(comparison["beta_mixed"]) comparison["agreement"] = (fixed_sig == mixed_sig) & same_direction return comparison