Source code for STAVAG.STAVAG

# -*- coding: utf-8 -*-
"""
Created on April 2 18:59:29 2025

@author: Qunlun Shen
"""
from __future__ import annotations
from typing import Dict, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
import lightgbm as lgb
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.multioutput import MultiOutputRegressor
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster


[docs]def generate_coord_dict(n_dim: int) -> Dict[str, Optional[pd.DataFrame]]: """ Build a placeholder dict for coordinate axes. Args: n_dim: Number of coordinate dimensions. Supported up to four. Returns: A dict mapping axis names to None. For example {'x': None, 'y': None} for two dimensions. Raises: ValueError: If n_dim is greater than four. """ if n_dim > 4: raise ValueError("Only up to n_dim=4 is supported. Please provide a coords DataFrame with less than four columns.") keys = ['x', 'y', 'z', 'a'][:n_dim] return {key: None for key in keys}
[docs]def calculate_sps( coord_dict_raw: Dict[str, pd.DataFrame], coord_dict_rand: Dict[str, np.ndarray], n_dim: int, keys: Optional[Sequence[str]] = None, ) -> Dict[str, pd.DataFrame]: """ Compute STAVAG priority scores (sps) for each axis. The score is the right tail proportion of random importances that are greater than or equal to the observed importance. Args: coord_dict_raw: Dict of DataFrames per axis. Each DataFrame must contain columns 'Feature' and 'Importance'. coord_dict_rand: Dict of random importance arrays per axis. n_dim: Number of coordinate dimensions. keys: Optional explicit axis names to use. Returns: The same dict as coord_dict_raw with a new column 'sps' added to each axis DataFrame. """ if keys: pass else: keys = ['x', 'y', 'z', 'a'][:n_dim] for k in range(n_dim): coord_dict_raw[keys[k]]['sps'] = [(np.sum(coord_dict_rand[keys[k]] >= val) +1) / (len(coord_dict_rand[keys[k]]) +1) for val in coord_dict_raw[keys[k]]['Importance']] return coord_dict_raw
[docs]def keep_variant_genes( coord_dict_raw: Dict[str, pd.DataFrame], coord_dict_rand: Dict[str, np.ndarray], n_dim: int, threshold: float = 0.05, keys: Optional[Sequence[str]] = None, ) -> Dict[str, pd.DataFrame]: """ Filter genes whose observed importance exceeds a random baseline. For each axis this keeps rows where Importance is greater than a high percentile of the random importance distribution. Args: coord_dict_raw: Dict of DataFrames per axis with importance values. coord_dict_rand: Dict of random importance arrays per axis. n_dim: Number of coordinate dimensions. threshold: Significance level. For example 0.05 targets the top tail of the random distribution. keys: Optional explicit axis names. Returns: Filtered dict with the same structure as coord_dict_raw. """ if keys: pass else: keys = ['x', 'y', 'z', 'a'][:n_dim] for k in range(n_dim): top_percentile = np.percentile(coord_dict_rand[keys[k]], 100*(1-threshold)) coord_dict_raw[keys[k]] = coord_dict_raw[keys[k]][coord_dict_raw[keys[k]]['Importance']>top_percentile] return coord_dict_raw
[docs]def DVG_detection( adata: sc.AnnData, coords: np.ndarray, sps: bool = False, threshold: float = 0.05, num_perm: int = 1, ) -> Dict[str, pd.DataFrame]: """ Detect Directionally Variable Genes (DVGs) using regression on spatial coordinates. Args: adata (AnnData): AnnData with expression matrix ``adata.X`` and gene names ``adata.var.index``. coords (ndarray): Spatial coordinates of cells with shape ``(n_cells, n_dim)``. For example two columns for x and y or three columns for x y z. sps (bool, optional): If True, compute STAVAG priority scores by comparing observed importances with random baselines. Defaults to False. threshold (float, optional): Importance threshold used when selecting DVGs. Larger values keep more genes. Defaults to 0.05. num_perm (int, optional): Number of permutations used to build an empirical null distribution of feature importances. - If 1: keep the original single-permutation behavior. - If > 1: must be >= 100; empirical p-values are computed for each gene. Returns: Dict[str, DataFrame]: Dictionary containing top important genes per coordinate axis (e.g., 'x', 'y', 'z'), filtered with the threshold. - For num_perm == 1: same structure as before (with SPS scores if sps=True). - For num_perm > 1: each DataFrame contains columns ['Feature', 'Importance', 'null_mean', 'pval'] and is filtered by p-value (<= threshold). """ np.random.seed(0) n_dim = coords.shape[1] if n_dim == 1: raise ValueError("n_dim must be at least 2.") # Axis labels based on dimensionality keys = ['x', 'y', 'z', 'a'][:n_dim] X = adata.X # Gene expression matrix Y = coords.copy() # Spatial coordinates # LightGBM regression parameters params = dict( n_estimators=1000, learning_rate=0.05, num_leaves=31, objective="regression", metric="mse", boosting_type="gbdt", colsample_bytree=0.2, subsample=0.9, subsample_freq=5, importance_type='gain', verbosity=-1, ) lgb_model = lgb.LGBMRegressor(**params) # ===== Case 1: num_perm == 1 (original behavior with a single random baseline) ===== if num_perm == 1: # Initialize coordinate-wise containers coord_dict_raw = generate_coord_dict(n_dim) coord_dict_rand = generate_coord_dict(n_dim) # Fit MultiOutput LightGBM on real data model = MultiOutputRegressor(lgb_model) model.fit(X, Y) # Observed feature importances per coordinate axis feature_importances = [] for i, estimator in enumerate(model.estimators_): imp_df = pd.DataFrame({ 'Feature': adata.var.index, 'Importance': estimator.feature_importances_ }).sort_values(by='Importance', ascending=False) feature_importances.append(imp_df) for k in range(n_dim): coord_dict_raw[keys[k]] = feature_importances[k] # Build a single random baseline by permuting rows of X num_rows = X.shape[0] shuffled_indices = np.random.permutation(num_rows) random_matrix = X[shuffled_indices, :] lgb_model2 = lgb.LGBMRegressor(**params) model2 = MultiOutputRegressor(lgb_model2) model2.fit(random_matrix, Y) # Random-baseline feature importances for k in range(n_dim): coord_dict_rand[keys[k]] = model2.estimators_[k].feature_importances_ # Compare observed vs random importances if sps: # Compute STAVAG priority scores coord_dict_raw = calculate_sps(coord_dict_raw, coord_dict_rand, n_dim) else: # Keep variant genes based on importance threshold coord_dict_raw = keep_variant_genes( coord_dict_raw, coord_dict_rand, n_dim, threshold=threshold ) return coord_dict_raw # ===== Case 2: num_perm > 1 (empirical null + permutation p-values) ===== if num_perm < 100: raise ValueError( f"num_perm must be >= 100 when using multiple permutations; got {num_perm}." ) # 1) Fit MultiOutput LightGBM on real data to obtain observed importances model = MultiOutputRegressor(lgb_model) model.fit(X, Y) n_genes = X.shape[1] # obs_importances[axis_key] has shape (n_genes,) obs_importances: Dict[str, np.ndarray] = {} for k, axis_key in enumerate(keys): obs_importances[axis_key] = model.estimators_[k].feature_importances_.astype(float) # 2) Multiple permutations: permute rows of X, refit model, store null importances # null_importances[axis_key] has shape (num_perm, n_genes) null_importances: Dict[str, np.ndarray] = { axis_key: np.zeros((num_perm, n_genes), dtype=float) for axis_key in keys } num_rows = X.shape[0] for b in tqdm_notebook(range(num_perm)): shuffled_indices = np.random.permutation(num_rows) random_matrix = X[shuffled_indices, :] perm_model = MultiOutputRegressor(lgb.LGBMRegressor(**params)) perm_model.fit(random_matrix, Y) for k, axis_key in enumerate(keys): null_importances[axis_key][b, :] = perm_model.estimators_[k].feature_importances_ # 3) Compute empirical permutation p-values and filter genes coord_results: Dict[str, pd.DataFrame] = {} for axis_key in keys: obs = obs_importances[axis_key] # shape: (n_genes,) null_mat = null_importances[axis_key] # shape: (num_perm, n_genes) # Empirical p-value: P(null >= observed) # p_j = (1 + count(null_b >= obs_j)) / (num_perm + 1) counts = (null_mat >= obs[None, :]).sum(axis=0) pvals = (counts + 1.0) / (num_perm + 1.0) null_mean = null_mat.mean(axis=0) df = pd.DataFrame({ "Feature": adata.var.index, "Importance": obs, "null_mean": null_mean, "pval": pvals, }) # Sort by p-value df = df.sort_values(by="pval", ascending=True) # Apply p-value threshold if threshold is not None: df = df[df["pval"] <= threshold] df = df.reset_index(drop=True) coord_results[axis_key] = df return coord_results
[docs]def TVG_detection( adata: sc.AnnData, coords: np.ndarray, sps: bool = False, threshold: float = 0.05, num_perm: int = 1, ) -> Dict[str, pd.DataFrame]: """ Detect Temporally Variable Genes (TVGs) using regression on a 1D time coordinate. Args: adata (AnnData): An AnnData object containing gene expression matrix ``adata.X`` and gene names ``adata.var.index``. coords (ndarray): 1D temporal coordinate of cells with shape ``(n_cells, 1)``. sps (bool, optional): If True and num_perm == 1, compute STAVAG priority scores by comparing observed importances with a single random baseline (original behavior). Defaults to False. threshold (float, optional): - If num_perm == 1: cutoff used by ``keep_variant_genes`` to select TVGs based on importance. - If num_perm > 1: p-value cutoff; genes with pval <= threshold are kept. Defaults to 0.05. num_perm (int, optional): Number of permutations used to build an empirical null distribution of feature importances. - If 1: keep the original single-permutation behavior. - If > 1: must be >= 100; empirical permutation p-values are computed for each gene. Returns: Dict[str, DataFrame]: Dictionary containing important genes over the time axis ``'T'``. - For num_perm == 1: same structure as before (with SPS scores if sps=True), already filtered by the given threshold. - For num_perm > 1: the DataFrame under key 'T' contains columns: ['Feature', 'Importance', 'null_mean', 'pval'] and is filtered by p-value (<= threshold). """ np.random.seed(0) n_dim = coords.shape[1] if n_dim != 1: raise ValueError("n_dim must be 1 for TVG_detection.") keys = ['T'] X = adata.X Y = coords.copy() # LightGBM regression parameters params = dict( n_estimators=1000, learning_rate=0.05, num_leaves=31, objective="regression", metric="mse", boosting_type="gbdt", colsample_bytree=0.2, subsample=0.9, subsample_freq=5, importance_type='gain', verbosity=-1, ) # ===== Case 1: num_perm == 1 (original behavior) ===== if num_perm == 1: coord_dict_raw: Dict[str, pd.DataFrame] = {} coord_dict_rand: Dict[str, np.ndarray] = {} # Fit LightGBM on real data model = lgb.LGBMRegressor(**params) model.fit(X, Y) _ = model.predict(X) # Observed feature importances (gain) imp_df = pd.DataFrame({ 'Feature': adata.var.index, 'Importance': model.booster_.feature_importance(importance_type='gain') }).sort_values(by='Importance', ascending=False) coord_dict_raw['T'] = imp_df # Build a single random baseline by permuting rows of X num_rows = X.shape[0] shuffled_indices = np.random.permutation(num_rows) random_matrix = X[shuffled_indices, :] model2 = lgb.LGBMRegressor(**params) model2.fit(random_matrix, Y) _ = model2.predict(random_matrix) coord_dict_rand['T'] = model2.booster_.feature_importance(importance_type='gain') # Compare observed vs random importances if sps: coord_dict_raw = calculate_sps( coord_dict_raw, coord_dict_rand, n_dim, keys=['T'] ) else: coord_dict_raw = keep_variant_genes( coord_dict_raw, coord_dict_rand, n_dim, threshold=threshold, keys=['T'] ) return coord_dict_raw # ===== Case 2: num_perm > 1 (empirical null + permutation p-values) ===== if num_perm < 100: raise ValueError( f"num_perm must be >= 100 when using multiple permutations; got {num_perm}." ) # 1) Fit LightGBM on real data to obtain observed importances base_model = lgb.LGBMRegressor(**params) base_model.fit(X, Y) _ = base_model.predict(X) n_genes = X.shape[1] # Observed importances (shape: (n_genes,)) obs_importances = base_model.booster_.feature_importance( importance_type='gain' ).astype(float) # 2) Multiple permutations: permute rows of X and refit model to get null importances # null_importances has shape (num_perm, n_genes) null_importances = np.zeros((num_perm, n_genes), dtype=float) num_rows = X.shape[0] for b in tqdm_notebook(range(num_perm)): shuffled_indices = np.random.permutation(num_rows) random_matrix = X[shuffled_indices, :] perm_model = lgb.LGBMRegressor(**params) perm_model.fit(random_matrix, Y) _ = perm_model.predict(random_matrix) null_importances[b, :] = perm_model.booster_.feature_importance( importance_type='gain' ).astype(float) # 3) Compute empirical permutation p-values for each gene # p_j = (1 + count(null_b >= obs_j)) / (num_perm + 1) counts = (null_importances >= obs_importances[None, :]).sum(axis=0) pvals = (counts + 1.0) / (num_perm + 1.0) null_mean = null_importances.mean(axis=0) df = pd.DataFrame({ "Feature": adata.var.index, "Importance": obs_importances, "null_mean": null_mean, "pval": pvals, }) # Sort by p-value and apply cutoff df = df.sort_values(by="pval", ascending=True) if threshold is not None: df = df[df["pval"] <= threshold] df = df.reset_index(drop=True) # Return in the same Dict[str, DataFrame] format coord_results: Dict[str, pd.DataFrame] = {"T": df} return coord_results
[docs]def gene_modules(adata: sc.AnnData, gene_list: Sequence[str]) -> Tuple[np.ndarray, pd.DataFrame, pd.DataFrame]: """ Cluster genes into modules using correlation among selected genes. Args: adata (AnnData): AnnData that contains the expression matrix ``adata.X`` and gene names in ``adata.var.index``. gene_list (Sequence[str]): Genes to include when building modules. Each gene should exist in ``adata.var.index``. Returns: Tuple[np.ndarray, pd.DataFrame, pd.DataFrame]: Z: Linkage matrix from hierarchical clustering. corr: Gene to gene correlation matrix as a pandas DataFrame.Index and columns are gene names in ``gene_list``. df: Expression matrix of the selected genes as a pandas DataFrame. Rows are cells and columns are genes. """ df = adata[:, gene_list].to_df() corr = df.corr() Z = linkage(corr, 'complete', metric='correlation') return Z, corr, df