GRN Inference on Pre-trained Model

Here we use the pre-trained blood model as an example for GRN inference, particularly regarding gene program extraction and network visualization. We also present the cell-type specific activations within these gene programs on the Immune Human dataset, as a soft validation for the zero-shot performance.

Note that GRN inference can be performed on pre-trained and finetuned models as showcased in our manuscript.

Users may perform scGPT’s gene-embedding-based GRN inference in the following steps:

1. Load optimized scGPT model (pre-trained or fine-tuned) and data

2. Retrieve scGPT's gene embeddings

3. Extract gene programs from scGPT's gene embedding network

4. Visualize gene program activations on dataset of interest

5. Visualize the interconnectivity of genes within select gene programs
[3]:
import copy
import json
import os
from pathlib import Path
import sys
import warnings

import torch
from anndata import AnnData
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
import tqdm
import gseapy as gp

from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed

os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
Global seed set to 0
/h/chloexq/.cache/pypoetry/virtualenvs/scgpt--qSLVbd1-py3.9/lib/python3.7/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
[4]:
set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins

Step 1: Load pre-trained model and dataset

1.1 Load pre-trained model

The blood pre-trained model can be downloaded via this link.

[5]:
# Specify model path; here we load the pre-trained scGPT blood model
model_dir = Path("../save/scGPT_bc")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

# Retrieve model parameters from config files
with open(model_config_file, "r") as f:
    model_configs = json.load(f)
print(
    f"Resume model from {model_file}, the model args will override the "
    f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

gene2idx = vocab.get_stoi()
Resume model from ../save/scGPT_bc/best_model.pt, the model args will override the config ../save/scGPT_bc/args.json.
[6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    pad_value=pad_value,
    n_input_bins=n_input_bins,
)

try:
    model.load_state_dict(torch.load(model_file))
    print(f"Loading all model params from {model_file}")
except:
    # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    for k, v in pretrained_dict.items():
        print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

model.to(device)
Using simple batchnorm instead of domain specific batchnorm
Loading params encoder.embedding.weight with shape torch.Size([36574, 512])
Loading params encoder.enc_norm.weight with shape torch.Size([512])
Loading params encoder.enc_norm.bias with shape torch.Size([512])
Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading params value_encoder.linear1.bias with shape torch.Size([512])
Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading params value_encoder.linear2.bias with shape torch.Size([512])
Loading params value_encoder.norm.weight with shape torch.Size([512])
Loading params value_encoder.norm.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.bias with shape torch.Size([512])
Loading params decoder.fc.0.weight with shape torch.Size([512, 512])
Loading params decoder.fc.0.bias with shape torch.Size([512])
Loading params decoder.fc.2.weight with shape torch.Size([512, 512])
Loading params decoder.fc.2.bias with shape torch.Size([512])
Loading params decoder.fc.4.weight with shape torch.Size([1, 512])
Loading params decoder.fc.4.bias with shape torch.Size([1])
[6]:
TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(36574, 512, padding_idx=36571)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (bn): BatchNorm1d(512, eps=6.1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (2): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (3): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (4): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (5): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (6): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (7): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (8): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (9): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (10): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
      (11): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
    )
  )
  (decoder): ExprDecoder(
    (fc): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=512, out_features=1, bias=True)
    )
  )
  (cls_decoder): ClsDecoder(
    (_decoder): ModuleList(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=512, out_features=512, bias=True)
      (4): ReLU()
      (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (out_layer): Linear(in_features=512, out_features=1, bias=True)
  )
  (sim): Similarity(
    (cos): CosineSimilarity()
  )
  (creterion_cce): CrossEntropyLoss()
)

1.2 Load dataset of interest

The Immune Human dataset can be downloaded via this link.

[7]:
# Specify data path; here we load the Immune Human dataset
data_dir = Path("../data")
adata = sc.read(
    str(data_dir / "Immune_ALL_human.h5ad"), cache=True
)  # 33506 × 12303
ori_batch_col = "batch"
adata.obs["celltype"] = adata.obs["final_annotation"].astype(str)
data_is_raw = False
[8]:
# Preprocess the data following the scGPT data pre-processing pipeline
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="batch")
scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...

Step 2: Retrieve scGPT’s gene embeddings

Note that technically scGPT’s gene embeddings are data independent. Overall, the pre-trained foundation model contains 30+K genes. Here for simplicity, we focus on a subset of HVGs specific to the data at hand.

[9]:
# Retrieve the data-independent gene embeddings from scGPT
gene_ids = np.array([id for id in gene2idx.values()])
gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()
[10]:
# Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if gene in adata.var.index.tolist()}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))
Retrieved gene embeddings for 1173 genes.
[11]:
# Construct gene embedding network
embed = GeneEmbedding(gene_embeddings)
100%|██████████| 1173/1173 [00:00<00:00, 1365127.25it/s]

Step 3: Extract gene programs from gene embedding network

3.1 Perform Louvain clustering on the gene embedding network

[12]:
# Perform Louvain clustering with desired resolution; here we specify resolution=40
gdata = embed.get_adata(resolution=40)
# Retrieve the gene clusters
metagenes = embed.get_metagenes(gdata)

3.2 Filter on clusters with 5 or more genes

[13]:
# Obtain the set of gene programs from clusters with #genes >= 5
mgs = dict()
for mg, genes in metagenes.items():
    if len(genes) > 4:
        mgs[mg] = genes
[14]:
# Here are the gene programs identified
mgs
[14]:
{'46': ['ZNF683', 'JAKMIP1', 'LAG3', 'FKBP11', 'ZBP1'],
 '7': ['ZFP36', 'TNFAIP3', 'DUSP1', 'FOS', 'CD69', 'KLF6', 'PPP1R15A', 'JUN'],
 '34': ['YPEL4', 'TMEM86B', 'TMEM63B', 'CYB5R1', 'RUNDC3A', 'DNAJB2'],
 '25': ['YPEL3', 'TLE1', 'CXXC5', 'RBM38', 'ODC1', 'TFDP2'],
 '43': ['YOD1', 'FLCN', 'CPEB4', 'FBXO30', 'MTMR3'],
 '14': ['VSIG4', 'C1QA', 'C1QC', 'C1QB', 'ACP5', 'CD163', 'ABCG2'],
 '71': ['VPREB3', 'VPREB1', 'SOCS2', 'IGLL1', 'DNTT'],
 '17': ['UROD', 'CMAS', 'CDC27', 'MINPP1', 'CPOX', 'RFESD'],
 '58': ['UBE2O', 'FN3K', 'PIM1', 'PNP', 'TRIM58'],
 '47': ['TXK', 'SPOCK2', 'SKAP1', 'CD96', 'RNF125'],
 '22': ['TTN', 'ICA1L', 'KCNQ1OT1', 'BCL9L', 'ABCA5', 'GPR82'],
 '27': ['TSC22D1', 'MEG3', 'PTCRA', 'PGRMC1', 'MLH3', 'GRAP2'],
 '21': ['TRAT1', 'LEF1', 'RCAN3', 'BCL11B', 'TCF7', 'OXNAD1'],
 '0': ['TPX2',
  'STMN1',
  'UBE2C',
  'MKI67',
  'TOP2A',
  'CCNB1',
  'CDK1',
  'TYMS',
  'TK1',
  'NUSAP1',
  'BIRC5',
  'PTTG1',
  'CENPF',
  'RRM2'],
 '8': ['TPM2', 'TPM1', 'TGFB1I1', 'CALD1', 'MYL9', 'MYH11', 'MYLK'],
 '15': ['TNRC6C', 'DHRS3', 'FBXO32', 'APBA2', 'CEP290', 'PLEKHA1'],
 '42': ['TNFSF10', 'DPYSL2', 'AP1S2', 'GLIPR1', 'FGL2'],
 '68': ['TNFRSF17', 'MZB1', 'DERL3', 'CD38', 'POU2AF1'],
 '4': ['TNFRSF13B',
  'SNRPN',
  'SLC4A10',
  'FCRL2',
  'SPINK2',
  'CAPN12',
  'CELA1',
  'CPA5',
  'LILRA4'],
 '12': ['TMOD1', 'SOX6', 'ART4', 'SPTB', 'AK1', 'OSBP2', 'KANK2'],
 '37': ['TMEM154', 'HVCN1', 'NCF1', 'CKAP4', 'POU2F2'],
 '29': ['TLR10', 'FGD2', 'HLA-DOB', 'FCRL5', 'CD180', 'MACROD2'],
 '26': ['TIMP1', 'S100A11', 'LGALS1', 'ANXA2', 'LGALS3', 'IGFBP7'],
 '45': ['TCL1A', 'FCRLA', 'CD22', 'CD72', 'CD19'],
 '67': ['TAGAP', 'PHACTR1', 'PMAIP1', 'NFKBID', 'CD83'],
 '59': ['SYTL2', 'SMAGP', 'LAIR2', 'DTHD1', 'NCALD'],
 '13': ['SYNE2', 'SYNE1', 'PIK3R1', 'OPTN', 'CBLB', 'RORA', 'PARP8'],
 '54': ['STYK1', 'SP140', 'LPAR5', 'FCRL3', 'PRKD2'],
 '11': ['STXBP2', 'SPI1', 'TKT', 'AIF1', 'PYCARD', 'COTL1', 'LST1'],
 '9': ['STRBP', 'TRAF5', 'GNB5', 'BACH2', 'CDCA7L', 'CHD7', 'AFF3'],
 '57': ['STRADB', 'ISCA1', 'CCS', 'NPRL3', 'PITHD1'],
 '1': ['STAT1',
  'ISG15',
  'CMPK2',
  'IFIT2',
  'IFIT1',
  'IFI44L',
  'RSAD2',
  'IFI6',
  'SAMD9',
  'IFIT3',
  'MX1',
  'IRF7'],
 '55': ['SPON2', 'MYOM2', 'CLIC3', 'CHST2', 'SH2D1B'],
 '16': ['SPEF2', 'ECHDC2', 'DUOX1', 'CLUAP1', 'CCDC146', 'CLIC6'],
 '52': ['SLC38A5', 'KCNH2', 'POLE2', 'CERS3', 'DHRS13'],
 '51': ['SLC38A11', 'IL6', 'CA3', 'FAM177B', 'PTPRS'],
 '49': ['SH3BGRL2', 'PDZK1IP1', 'CTDSPL', 'PDGFA', 'CTTN'],
 '32': ['SEMA4F', 'CARNS1', 'ARHGAP12', 'XYLT2', 'RECK', 'CBFA2T2'],
 '61': ['FCGR3A', 'DOK2', 'CX3CR1', 'ABI3', 'RHOC'],
 '19': ['CCR10', 'SLAMF1', 'IL2RA', 'CCR4', 'CCR6', 'CTLA4'],
 '72': ['ZFP36L1', 'EGR1', 'NFKBIZ', 'RHOB', 'NR4A1'],
 '50': ['VMP1', 'GPCPD1', 'NAMPT', 'CD55', 'SOD2'],
 '5': ['PRSS57',
  'AZU1',
  'CLEC11A',
  'CTSG',
  'ELANE',
  'MS4A3',
  'MPO',
  'RETN',
  'RNASE2'],
 '65': ['LILRA5', 'LILRB2', 'GTPBP2', 'HCK', 'LRRC25'],
 '3': ['LCK', 'CD2', 'LTB', 'CD3E', 'CD3G', 'CD8B', 'CD8A', 'IL7R', 'CD3D'],
 '70': ['LBH', 'MGAT4A', 'RHOH', 'SLC38A1', 'STK17A'],
 '36': ['PICALM', 'FOXO3', 'QKI', 'CTNNB1', 'RAPGEF2'],
 '41': ['IL6ST', 'ZNF37A', 'ITGA6', 'ACTN1', 'PDK1'],
 '35': ['P2RX5', 'PNOC', 'PKIG', 'BLK', 'OSBPL10', 'RRAS2'],
 '30': ['ID3', 'GPM6B', 'BEX5', 'CRIP2', 'APOLD1', 'TSHZ2'],
 '2': ['HLA-DRB5',
  'HLA-DQA1',
  'HLA-DPB1',
  'HLA-DRA',
  'HLA-DMA',
  'HLA-DQB1',
  'CD74',
  'HLA-DQA2',
  'HLA-DRB1',
  'HLA-DPA1'],
 '24': ['CLEC4A', 'NRG1', 'TMEM107', 'CLEC4D', 'PTCH2', 'S1PR3'],
 '10': ['GNG11', 'CLEC1B', 'DNASE1L3', 'ESAM', 'CLDN5', 'MYCT1', 'EGFL7'],
 '18': ['GLUL', 'SLC40A1', 'APOE', 'FABP5', 'CTSB', 'CTSD'],
 '31': ['GIMAP4', 'GIMAP7', 'ETS1', 'GIMAP6', 'SLFN5', 'DENND2D'],
 '40': ['FGR', 'TNFRSF1B', 'AOAH', 'AGTRAP', 'MYO1F'],
 '56': ['ZSCAN18', 'MAP9', 'ZCWPW1', 'PLXNA3', 'ZNF274'],
 '6': ['VMO1', 'LYPD2', 'LGALS3BP', 'AQP3', 'HES4', 'SCGB3A1', 'IRF9', 'CD59'],
 '64': ['PASK', 'GATA3', 'CD40LG', 'AP3M2', 'P2RY10'],
 '44': ['CYP4F3', 'ARG1', 'LCN2', 'HP', 'S100P'],
 '28': ['PTPN7', 'LYAR', 'TCOF1', 'GTF3C1', 'CMC1', 'MATK'],
 '33': ['HDC', 'CYTL1', 'CPA3', 'CLC', 'LGALS4', 'GATA2'],
 '66': ['DAB2', 'F13A1', 'MAF', 'PRDM1', 'LPAR6'],
 '39': ['IGSF6', 'CPVL', 'CFP', 'LY86', 'RAB32'],
 '69': ['APOBEC3A', 'IFITM3', 'MT1F', 'IFI27', 'MT2A'],
 '48': ['NELL2', 'CAMK4', 'ABLIM1', 'PDE3B', 'SATB1'],
 '62': ['PLBD1', 'GCA', 'NUP214', 'PGD', 'LTA4H'],
 '38': ['GP1BA', 'ARHGAP6', 'ITGA2B', 'SELP', 'RAB27B'],
 '23': ['CCL20', 'CCL8', 'CXCL1', 'CXCL3', 'OSM', 'CCL2'],
 '53': ['BTN3A1', 'CALCOCO1', 'RAB2B', 'OPA1', 'ATG4A'],
 '20': ['NFIL3', 'EGR3', 'NLRP3', 'OTUD1', 'PTGS2', 'TRIB1'],
 '60': ['RAB3IP', 'MXI1', 'PGM2L1', 'FHIT', 'UBE2H'],
 '63': ['MAL', 'MYC', 'PIM2', 'CISH', 'S1PR1']}

Step 4: Visualize gene program activation on the Immune Human dataset

[15]:
sns.set(font_scale=0.35)
embed.score_metagenes(adata, metagenes)
embed.plot_metagenes_scores(adata, mgs, "celltype")
<Figure size 360x936 with 0 Axes>
<Figure size 432x288 with 0 Axes>
_images/tutorial_grn_21_2.png

Step 5: Visualize network connectivity within desired gene program

We can further visualize the connectivity between genes within any gene program of interest from Step 4. Here is an example of gene program 3 consisting of the CD3 cluster, CD8 cluster and other genes. In the visualization, we see strong connections highlighted in blue (by cosine similarity) between CD3D, E, and G, as well as CD8A and B.

[16]:
# Retrieve gene program 3 which contains the CD3 gene set
CD_genes = mgs['3']
print(CD_genes)
# Compute cosine similarities among genes in this gene program
df_CD = pd.DataFrame(columns=['Gene', 'Similarity', 'Gene1'])
for i in tqdm.tqdm(CD_genes):
    df = embed.compute_similarities(i, CD_genes)
    df['Gene1'] = i
    df_CD = df_CD.append(df)
df_CD_sub = df_CD[df_CD['Similarity']<0.99].sort_values(by='Gene') # Filter out edges from each gene to itself
['LCK', 'CD2', 'LTB', 'CD3E', 'CD3G', 'CD8B', 'CD8A', 'IL7R', 'CD3D']
100%|██████████| 9/9 [00:00<00:00, 254.81it/s]
[17]:
# Creates a graph from the cosine similarity network
input_node_weights = [(row['Gene'], row['Gene1'], round(row['Similarity'], 2)) for i, row in df_CD_sub.iterrows()]
G = nx.Graph()
G.add_weighted_edges_from(input_node_weights)
[18]:
# Plot the cosine similarity network; strong edges (> select threshold) are highlighted
thresh = 0.4
plt.figure(figsize=(20, 20))
widths = nx.get_edge_attributes(G, 'weight')

elarge = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] > thresh]
esmall = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] <= thresh]

pos = nx.spring_layout(G, k=0.4, iterations=15, seed=3)

width_large = {}
width_small = {}
for i, v in enumerate(list(widths.values())):
    if v > thresh:
        width_large[list(widths.keys())[i]] = v*10
    else:
        width_small[list(widths.keys())[i]] = max(v, 0)*10

nx.draw_networkx_edges(G, pos,
                       edgelist = width_small.keys(),
                       width=list(width_small.values()),
                       edge_color='lightblue',
                       alpha=0.8)
nx.draw_networkx_edges(G, pos,
                       edgelist = width_large.keys(),
                       width = list(width_large.values()),
                       alpha = 0.5,
                       edge_color = "blue",
                      )
# node labels
nx.draw_networkx_labels(G, pos, font_size=25, font_family="sans-serif")
# edge weight labels
d = nx.get_edge_attributes(G, "weight")
edge_labels = {k: d[k] for k in elarge}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=15)

ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
plt.show()
_images/tutorial_grn_25_0.png

Step 6: Reactome pathway analysis

Again with gene program 3 as an example, users may perform pathway enrichment analysis to identify related pathways. In the paper, we used the Bonferroni correction to adjust the p-value threshold by accounting for the total number of tests performed.

[19]:
# Meta info about the number of terms (tests) in the databases
df_database = pd.DataFrame(
data = [['GO_Biological_Process_2021', 6036],
['GO_Molecular_Function_2021', 1274],
['Reactome_2022', 1818]],
columns = ['dataset', 'term'])
[20]:
# Select desired database for query; here use Reactome as an example
databases = ['Reactome_2022']
m = df_database[df_database['dataset'].isin(databases)]['term'].sum()
# p-value correction for total number of tests done
p_thresh = 0.05/m
[21]:
# Perform pathway enrichment analysis using the gseapy package in the Reactome database
df = pd.DataFrame()
enr_Reactome = gp.enrichr(gene_list=CD_genes,
                          gene_sets=databases,
                          organism='Human',
                          outdir='test/enr_Reactome',
                          cutoff=0.5)
out = enr_Reactome.results
out = out[out['P-value'] < p_thresh]
df = df.append(out, ignore_index=True)
df
[21]:
Gene_set Term Overlap P-value Adjusted P-value Old P-value Old Adjusted P-value Odds Ratio Combined Score Genes
0 Reactome_2022 Translocation Of ZAP-70 To Immunological Synap... 3/17 4.270574e-08 0.000002 0 0 713.464286 12106.727222 LCK;CD3E;CD3D
1 Reactome_2022 Phosphorylation Of CD3 And TCR Zeta Chains R-H... 3/20 7.154717e-08 0.000002 0 0 587.470588 9665.600052 LCK;CD3E;CD3D
2 Reactome_2022 PD-1 Signaling R-HSA-389948 3/21 8.345311e-08 0.000002 0 0 554.805556 9042.765180 LCK;CD3E;CD3D
3 Reactome_2022 Immunoregulatory Interactions Between A Lympho... 4/123 1.675722e-07 0.000003 0 0 133.593277 2084.302453 CD8B;CD8A;CD3E;CD3D
4 Reactome_2022 Generation Of Second Messenger Molecules R-HSA... 3/32 3.104597e-07 0.000004 0 0 344.172414 5157.496482 LCK;CD3E;CD3D
5 Reactome_2022 Immune System R-HSA-168256 7/1943 2.439382e-06 0.000024 0 0 32.640754 421.841460 CD8B;LCK;CD8A;LTB;CD3E;IL7R;CD3D
6 Reactome_2022 Costimulation By CD28 Family R-HSA-388841 3/68 3.111647e-06 0.000027 0 0 153.276923 1943.606325 LCK;CD3E;CD3D
7 Reactome_2022 Adaptive Immune System R-HSA-1280218 5/733 7.270042e-06 0.000055 0 0 33.075206 391.337520 CD8B;LCK;CD8A;CD3E;CD3D
8 Reactome_2022 Downstream TCR Signaling R-HSA-202424 3/94 8.274085e-06 0.000055 0 0 109.340659 1279.546191 LCK;CD3E;CD3D
9 Reactome_2022 TCR Signaling R-HSA-202403 3/116 1.556787e-05 0.000093 0 0 87.955752 973.696687 LCK;CD3E;CD3D