GRN Inference on Pre-trained Model
Here we use the pre-trained blood model as an example for GRN inference, particularly regarding gene program extraction and network visualization. We also present the cell-type specific activations within these gene programs on the Immune Human dataset, as a soft validation for the zero-shot performance.
Note that GRN inference can be performed on pre-trained and finetuned models as showcased in our manuscript.
Users may perform scGPT’s gene-embedding-based GRN inference in the following steps:
1. Load optimized scGPT model (pre-trained or fine-tuned) and data
2. Retrieve scGPT's gene embeddings
3. Extract gene programs from scGPT's gene embedding network
4. Visualize gene program activations on dataset of interest
5. Visualize the interconnectivity of genes within select gene programs
[3]:
import copy
import json
import os
from pathlib import Path
import sys
import warnings
import torch
from anndata import AnnData
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
import tqdm
import gseapy as gp
from torchtext.vocab import Vocab
from torchtext._torchtext import (
Vocab as VocabPybind,
)
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
Global seed set to 0
/h/chloexq/.cache/pypoetry/virtualenvs/scgpt--qSLVbd1-py3.9/lib/python3.7/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
import imp
[4]:
set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins
Step 1: Load pre-trained model and dataset
1.1 Load pre-trained model
The blood pre-trained model can be downloaded via this link.
[5]:
# Specify model path; here we load the pre-trained scGPT blood model
model_dir = Path("../save/scGPT_bc")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
if s not in vocab:
vocab.append_token(s)
# Retrieve model parameters from config files
with open(model_config_file, "r") as f:
model_configs = json.load(f)
print(
f"Resume model from {model_file}, the model args will override the "
f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]
gene2idx = vocab.get_stoi()
Resume model from ../save/scGPT_bc/best_model.pt, the model args will override the config ../save/scGPT_bc/args.json.
[6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntokens = len(vocab) # size of vocabulary
model = TransformerModel(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
vocab=vocab,
pad_value=pad_value,
n_input_bins=n_input_bins,
)
try:
model.load_state_dict(torch.load(model_file))
print(f"Loading all model params from {model_file}")
except:
# only load params that are in the model and match the size
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
for k, v in pretrained_dict.items():
print(f"Loading params {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.to(device)
Using simple batchnorm instead of domain specific batchnorm
Loading params encoder.embedding.weight with shape torch.Size([36574, 512])
Loading params encoder.enc_norm.weight with shape torch.Size([512])
Loading params encoder.enc_norm.bias with shape torch.Size([512])
Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading params value_encoder.linear1.bias with shape torch.Size([512])
Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading params value_encoder.linear2.bias with shape torch.Size([512])
Loading params value_encoder.norm.weight with shape torch.Size([512])
Loading params value_encoder.norm.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.bias with shape torch.Size([512])
Loading params decoder.fc.0.weight with shape torch.Size([512, 512])
Loading params decoder.fc.0.bias with shape torch.Size([512])
Loading params decoder.fc.2.weight with shape torch.Size([512, 512])
Loading params decoder.fc.2.bias with shape torch.Size([512])
Loading params decoder.fc.4.weight with shape torch.Size([1, 512])
Loading params decoder.fc.4.bias with shape torch.Size([1])
[6]:
TransformerModel(
(encoder): GeneEncoder(
(embedding): Embedding(36574, 512, padding_idx=36571)
(enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(value_encoder): ContinuousValueEncoder(
(dropout): Dropout(p=0.5, inplace=False)
(linear1): Linear(in_features=1, out_features=512, bias=True)
(activation): ReLU()
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(bn): BatchNorm1d(512, eps=6.1e-05, momentum=0.1, affine=True, track_running_stats=True)
(transformer_encoder): TransformerEncoder(
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(1): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(2): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(3): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(4): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(5): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(6): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(7): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(8): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(9): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(10): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
(11): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.5, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
)
)
)
(decoder): ExprDecoder(
(fc): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=512, out_features=1, bias=True)
)
)
(cls_decoder): ClsDecoder(
(_decoder): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): ReLU()
(2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(3): Linear(in_features=512, out_features=512, bias=True)
(4): ReLU()
(5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(out_layer): Linear(in_features=512, out_features=1, bias=True)
)
(sim): Similarity(
(cos): CosineSimilarity()
)
(creterion_cce): CrossEntropyLoss()
)
1.2 Load dataset of interest
The Immune Human dataset can be downloaded via this link.
[7]:
# Specify data path; here we load the Immune Human dataset
data_dir = Path("../data")
adata = sc.read(
str(data_dir / "Immune_ALL_human.h5ad"), cache=True
) # 33506 × 12303
ori_batch_col = "batch"
adata.obs["celltype"] = adata.obs["final_annotation"].astype(str)
data_is_raw = False
[8]:
# Preprocess the data following the scGPT data pre-processing pipeline
preprocessor = Preprocessor(
use_key="X", # the key in adata.layers to use as raw data
filter_gene_by_counts=3, # step 1
filter_cell_by_counts=False, # step 2
normalize_total=1e4, # 3. whether to normalize the raw data and to what sum
result_normed_key="X_normed", # the key in adata.layers to store the normalized data
log1p=data_is_raw, # 4. whether to log1p the normalized data
result_log1p_key="X_log1p",
subset_hvg=n_hvg, # 5. whether to subset the raw data to highly variable genes
hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
binning=n_bins, # 6. whether to bin the raw data and to what number of bins
result_binned_key="X_binned", # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="batch")
scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...
Step 2: Retrieve scGPT’s gene embeddings
Note that technically scGPT’s gene embeddings are data independent. Overall, the pre-trained foundation model contains 30+K genes. Here for simplicity, we focus on a subset of HVGs specific to the data at hand.
[9]:
# Retrieve the data-independent gene embeddings from scGPT
gene_ids = np.array([id for id in gene2idx.values()])
gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()
[10]:
# Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if gene in adata.var.index.tolist()}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))
Retrieved gene embeddings for 1173 genes.
[11]:
# Construct gene embedding network
embed = GeneEmbedding(gene_embeddings)
100%|██████████| 1173/1173 [00:00<00:00, 1365127.25it/s]
Step 3: Extract gene programs from gene embedding network
3.1 Perform Louvain clustering on the gene embedding network
[12]:
# Perform Louvain clustering with desired resolution; here we specify resolution=40
gdata = embed.get_adata(resolution=40)
# Retrieve the gene clusters
metagenes = embed.get_metagenes(gdata)
3.2 Filter on clusters with 5 or more genes
[13]:
# Obtain the set of gene programs from clusters with #genes >= 5
mgs = dict()
for mg, genes in metagenes.items():
if len(genes) > 4:
mgs[mg] = genes
[14]:
# Here are the gene programs identified
mgs
[14]:
{'46': ['ZNF683', 'JAKMIP1', 'LAG3', 'FKBP11', 'ZBP1'],
'7': ['ZFP36', 'TNFAIP3', 'DUSP1', 'FOS', 'CD69', 'KLF6', 'PPP1R15A', 'JUN'],
'34': ['YPEL4', 'TMEM86B', 'TMEM63B', 'CYB5R1', 'RUNDC3A', 'DNAJB2'],
'25': ['YPEL3', 'TLE1', 'CXXC5', 'RBM38', 'ODC1', 'TFDP2'],
'43': ['YOD1', 'FLCN', 'CPEB4', 'FBXO30', 'MTMR3'],
'14': ['VSIG4', 'C1QA', 'C1QC', 'C1QB', 'ACP5', 'CD163', 'ABCG2'],
'71': ['VPREB3', 'VPREB1', 'SOCS2', 'IGLL1', 'DNTT'],
'17': ['UROD', 'CMAS', 'CDC27', 'MINPP1', 'CPOX', 'RFESD'],
'58': ['UBE2O', 'FN3K', 'PIM1', 'PNP', 'TRIM58'],
'47': ['TXK', 'SPOCK2', 'SKAP1', 'CD96', 'RNF125'],
'22': ['TTN', 'ICA1L', 'KCNQ1OT1', 'BCL9L', 'ABCA5', 'GPR82'],
'27': ['TSC22D1', 'MEG3', 'PTCRA', 'PGRMC1', 'MLH3', 'GRAP2'],
'21': ['TRAT1', 'LEF1', 'RCAN3', 'BCL11B', 'TCF7', 'OXNAD1'],
'0': ['TPX2',
'STMN1',
'UBE2C',
'MKI67',
'TOP2A',
'CCNB1',
'CDK1',
'TYMS',
'TK1',
'NUSAP1',
'BIRC5',
'PTTG1',
'CENPF',
'RRM2'],
'8': ['TPM2', 'TPM1', 'TGFB1I1', 'CALD1', 'MYL9', 'MYH11', 'MYLK'],
'15': ['TNRC6C', 'DHRS3', 'FBXO32', 'APBA2', 'CEP290', 'PLEKHA1'],
'42': ['TNFSF10', 'DPYSL2', 'AP1S2', 'GLIPR1', 'FGL2'],
'68': ['TNFRSF17', 'MZB1', 'DERL3', 'CD38', 'POU2AF1'],
'4': ['TNFRSF13B',
'SNRPN',
'SLC4A10',
'FCRL2',
'SPINK2',
'CAPN12',
'CELA1',
'CPA5',
'LILRA4'],
'12': ['TMOD1', 'SOX6', 'ART4', 'SPTB', 'AK1', 'OSBP2', 'KANK2'],
'37': ['TMEM154', 'HVCN1', 'NCF1', 'CKAP4', 'POU2F2'],
'29': ['TLR10', 'FGD2', 'HLA-DOB', 'FCRL5', 'CD180', 'MACROD2'],
'26': ['TIMP1', 'S100A11', 'LGALS1', 'ANXA2', 'LGALS3', 'IGFBP7'],
'45': ['TCL1A', 'FCRLA', 'CD22', 'CD72', 'CD19'],
'67': ['TAGAP', 'PHACTR1', 'PMAIP1', 'NFKBID', 'CD83'],
'59': ['SYTL2', 'SMAGP', 'LAIR2', 'DTHD1', 'NCALD'],
'13': ['SYNE2', 'SYNE1', 'PIK3R1', 'OPTN', 'CBLB', 'RORA', 'PARP8'],
'54': ['STYK1', 'SP140', 'LPAR5', 'FCRL3', 'PRKD2'],
'11': ['STXBP2', 'SPI1', 'TKT', 'AIF1', 'PYCARD', 'COTL1', 'LST1'],
'9': ['STRBP', 'TRAF5', 'GNB5', 'BACH2', 'CDCA7L', 'CHD7', 'AFF3'],
'57': ['STRADB', 'ISCA1', 'CCS', 'NPRL3', 'PITHD1'],
'1': ['STAT1',
'ISG15',
'CMPK2',
'IFIT2',
'IFIT1',
'IFI44L',
'RSAD2',
'IFI6',
'SAMD9',
'IFIT3',
'MX1',
'IRF7'],
'55': ['SPON2', 'MYOM2', 'CLIC3', 'CHST2', 'SH2D1B'],
'16': ['SPEF2', 'ECHDC2', 'DUOX1', 'CLUAP1', 'CCDC146', 'CLIC6'],
'52': ['SLC38A5', 'KCNH2', 'POLE2', 'CERS3', 'DHRS13'],
'51': ['SLC38A11', 'IL6', 'CA3', 'FAM177B', 'PTPRS'],
'49': ['SH3BGRL2', 'PDZK1IP1', 'CTDSPL', 'PDGFA', 'CTTN'],
'32': ['SEMA4F', 'CARNS1', 'ARHGAP12', 'XYLT2', 'RECK', 'CBFA2T2'],
'61': ['FCGR3A', 'DOK2', 'CX3CR1', 'ABI3', 'RHOC'],
'19': ['CCR10', 'SLAMF1', 'IL2RA', 'CCR4', 'CCR6', 'CTLA4'],
'72': ['ZFP36L1', 'EGR1', 'NFKBIZ', 'RHOB', 'NR4A1'],
'50': ['VMP1', 'GPCPD1', 'NAMPT', 'CD55', 'SOD2'],
'5': ['PRSS57',
'AZU1',
'CLEC11A',
'CTSG',
'ELANE',
'MS4A3',
'MPO',
'RETN',
'RNASE2'],
'65': ['LILRA5', 'LILRB2', 'GTPBP2', 'HCK', 'LRRC25'],
'3': ['LCK', 'CD2', 'LTB', 'CD3E', 'CD3G', 'CD8B', 'CD8A', 'IL7R', 'CD3D'],
'70': ['LBH', 'MGAT4A', 'RHOH', 'SLC38A1', 'STK17A'],
'36': ['PICALM', 'FOXO3', 'QKI', 'CTNNB1', 'RAPGEF2'],
'41': ['IL6ST', 'ZNF37A', 'ITGA6', 'ACTN1', 'PDK1'],
'35': ['P2RX5', 'PNOC', 'PKIG', 'BLK', 'OSBPL10', 'RRAS2'],
'30': ['ID3', 'GPM6B', 'BEX5', 'CRIP2', 'APOLD1', 'TSHZ2'],
'2': ['HLA-DRB5',
'HLA-DQA1',
'HLA-DPB1',
'HLA-DRA',
'HLA-DMA',
'HLA-DQB1',
'CD74',
'HLA-DQA2',
'HLA-DRB1',
'HLA-DPA1'],
'24': ['CLEC4A', 'NRG1', 'TMEM107', 'CLEC4D', 'PTCH2', 'S1PR3'],
'10': ['GNG11', 'CLEC1B', 'DNASE1L3', 'ESAM', 'CLDN5', 'MYCT1', 'EGFL7'],
'18': ['GLUL', 'SLC40A1', 'APOE', 'FABP5', 'CTSB', 'CTSD'],
'31': ['GIMAP4', 'GIMAP7', 'ETS1', 'GIMAP6', 'SLFN5', 'DENND2D'],
'40': ['FGR', 'TNFRSF1B', 'AOAH', 'AGTRAP', 'MYO1F'],
'56': ['ZSCAN18', 'MAP9', 'ZCWPW1', 'PLXNA3', 'ZNF274'],
'6': ['VMO1', 'LYPD2', 'LGALS3BP', 'AQP3', 'HES4', 'SCGB3A1', 'IRF9', 'CD59'],
'64': ['PASK', 'GATA3', 'CD40LG', 'AP3M2', 'P2RY10'],
'44': ['CYP4F3', 'ARG1', 'LCN2', 'HP', 'S100P'],
'28': ['PTPN7', 'LYAR', 'TCOF1', 'GTF3C1', 'CMC1', 'MATK'],
'33': ['HDC', 'CYTL1', 'CPA3', 'CLC', 'LGALS4', 'GATA2'],
'66': ['DAB2', 'F13A1', 'MAF', 'PRDM1', 'LPAR6'],
'39': ['IGSF6', 'CPVL', 'CFP', 'LY86', 'RAB32'],
'69': ['APOBEC3A', 'IFITM3', 'MT1F', 'IFI27', 'MT2A'],
'48': ['NELL2', 'CAMK4', 'ABLIM1', 'PDE3B', 'SATB1'],
'62': ['PLBD1', 'GCA', 'NUP214', 'PGD', 'LTA4H'],
'38': ['GP1BA', 'ARHGAP6', 'ITGA2B', 'SELP', 'RAB27B'],
'23': ['CCL20', 'CCL8', 'CXCL1', 'CXCL3', 'OSM', 'CCL2'],
'53': ['BTN3A1', 'CALCOCO1', 'RAB2B', 'OPA1', 'ATG4A'],
'20': ['NFIL3', 'EGR3', 'NLRP3', 'OTUD1', 'PTGS2', 'TRIB1'],
'60': ['RAB3IP', 'MXI1', 'PGM2L1', 'FHIT', 'UBE2H'],
'63': ['MAL', 'MYC', 'PIM2', 'CISH', 'S1PR1']}
Step 4: Visualize gene program activation on the Immune Human dataset
[15]:
sns.set(font_scale=0.35)
embed.score_metagenes(adata, metagenes)
embed.plot_metagenes_scores(adata, mgs, "celltype")
<Figure size 360x936 with 0 Axes>
<Figure size 432x288 with 0 Axes>

Step 5: Visualize network connectivity within desired gene program
We can further visualize the connectivity between genes within any gene program of interest from Step 4. Here is an example of gene program 3 consisting of the CD3 cluster, CD8 cluster and other genes. In the visualization, we see strong connections highlighted in blue (by cosine similarity) between CD3D, E, and G, as well as CD8A and B.
[16]:
# Retrieve gene program 3 which contains the CD3 gene set
CD_genes = mgs['3']
print(CD_genes)
# Compute cosine similarities among genes in this gene program
df_CD = pd.DataFrame(columns=['Gene', 'Similarity', 'Gene1'])
for i in tqdm.tqdm(CD_genes):
df = embed.compute_similarities(i, CD_genes)
df['Gene1'] = i
df_CD = df_CD.append(df)
df_CD_sub = df_CD[df_CD['Similarity']<0.99].sort_values(by='Gene') # Filter out edges from each gene to itself
['LCK', 'CD2', 'LTB', 'CD3E', 'CD3G', 'CD8B', 'CD8A', 'IL7R', 'CD3D']
100%|██████████| 9/9 [00:00<00:00, 254.81it/s]
[17]:
# Creates a graph from the cosine similarity network
input_node_weights = [(row['Gene'], row['Gene1'], round(row['Similarity'], 2)) for i, row in df_CD_sub.iterrows()]
G = nx.Graph()
G.add_weighted_edges_from(input_node_weights)
[18]:
# Plot the cosine similarity network; strong edges (> select threshold) are highlighted
thresh = 0.4
plt.figure(figsize=(20, 20))
widths = nx.get_edge_attributes(G, 'weight')
elarge = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] > thresh]
esmall = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] <= thresh]
pos = nx.spring_layout(G, k=0.4, iterations=15, seed=3)
width_large = {}
width_small = {}
for i, v in enumerate(list(widths.values())):
if v > thresh:
width_large[list(widths.keys())[i]] = v*10
else:
width_small[list(widths.keys())[i]] = max(v, 0)*10
nx.draw_networkx_edges(G, pos,
edgelist = width_small.keys(),
width=list(width_small.values()),
edge_color='lightblue',
alpha=0.8)
nx.draw_networkx_edges(G, pos,
edgelist = width_large.keys(),
width = list(width_large.values()),
alpha = 0.5,
edge_color = "blue",
)
# node labels
nx.draw_networkx_labels(G, pos, font_size=25, font_family="sans-serif")
# edge weight labels
d = nx.get_edge_attributes(G, "weight")
edge_labels = {k: d[k] for k in elarge}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=15)
ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
plt.show()

Step 6: Reactome pathway analysis
Again with gene program 3 as an example, users may perform pathway enrichment analysis to identify related pathways. In the paper, we used the Bonferroni correction to adjust the p-value threshold by accounting for the total number of tests performed.
[19]:
# Meta info about the number of terms (tests) in the databases
df_database = pd.DataFrame(
data = [['GO_Biological_Process_2021', 6036],
['GO_Molecular_Function_2021', 1274],
['Reactome_2022', 1818]],
columns = ['dataset', 'term'])
[20]:
# Select desired database for query; here use Reactome as an example
databases = ['Reactome_2022']
m = df_database[df_database['dataset'].isin(databases)]['term'].sum()
# p-value correction for total number of tests done
p_thresh = 0.05/m
[21]:
# Perform pathway enrichment analysis using the gseapy package in the Reactome database
df = pd.DataFrame()
enr_Reactome = gp.enrichr(gene_list=CD_genes,
gene_sets=databases,
organism='Human',
outdir='test/enr_Reactome',
cutoff=0.5)
out = enr_Reactome.results
out = out[out['P-value'] < p_thresh]
df = df.append(out, ignore_index=True)
df
[21]:
Gene_set | Term | Overlap | P-value | Adjusted P-value | Old P-value | Old Adjusted P-value | Odds Ratio | Combined Score | Genes | |
---|---|---|---|---|---|---|---|---|---|---|
0 | Reactome_2022 | Translocation Of ZAP-70 To Immunological Synap... | 3/17 | 4.270574e-08 | 0.000002 | 0 | 0 | 713.464286 | 12106.727222 | LCK;CD3E;CD3D |
1 | Reactome_2022 | Phosphorylation Of CD3 And TCR Zeta Chains R-H... | 3/20 | 7.154717e-08 | 0.000002 | 0 | 0 | 587.470588 | 9665.600052 | LCK;CD3E;CD3D |
2 | Reactome_2022 | PD-1 Signaling R-HSA-389948 | 3/21 | 8.345311e-08 | 0.000002 | 0 | 0 | 554.805556 | 9042.765180 | LCK;CD3E;CD3D |
3 | Reactome_2022 | Immunoregulatory Interactions Between A Lympho... | 4/123 | 1.675722e-07 | 0.000003 | 0 | 0 | 133.593277 | 2084.302453 | CD8B;CD8A;CD3E;CD3D |
4 | Reactome_2022 | Generation Of Second Messenger Molecules R-HSA... | 3/32 | 3.104597e-07 | 0.000004 | 0 | 0 | 344.172414 | 5157.496482 | LCK;CD3E;CD3D |
5 | Reactome_2022 | Immune System R-HSA-168256 | 7/1943 | 2.439382e-06 | 0.000024 | 0 | 0 | 32.640754 | 421.841460 | CD8B;LCK;CD8A;LTB;CD3E;IL7R;CD3D |
6 | Reactome_2022 | Costimulation By CD28 Family R-HSA-388841 | 3/68 | 3.111647e-06 | 0.000027 | 0 | 0 | 153.276923 | 1943.606325 | LCK;CD3E;CD3D |
7 | Reactome_2022 | Adaptive Immune System R-HSA-1280218 | 5/733 | 7.270042e-06 | 0.000055 | 0 | 0 | 33.075206 | 391.337520 | CD8B;LCK;CD8A;CD3E;CD3D |
8 | Reactome_2022 | Downstream TCR Signaling R-HSA-202424 | 3/94 | 8.274085e-06 | 0.000055 | 0 | 0 | 109.340659 | 1279.546191 | LCK;CD3E;CD3D |
9 | Reactome_2022 | TCR Signaling R-HSA-202403 | 3/116 | 1.556787e-05 | 0.000093 | 0 | 0 | 87.955752 | 973.696687 | LCK;CD3E;CD3D |