Fine-tuning Pre-trained Model for Perturbation Prediction
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings
import torch
import numpy as np
import matplotlib
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
Vocab as VocabPybind,
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id
matplotlib.rcParams["savefig.transparent"] = False
## Training Settings
# settings for data prcocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0 # for padding values
pert_pad_id = 2
n_hvg = 0 # number of highly variable genes
include_zero_gene = "all" # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False
max_seq_len = 1536
# settings for training
MLM = True # whether to use masked language modeling, currently it is always on.
CLS = False # celltype classification objective
CCE = False # Contrastive cell embedding objective
MVC = False # Masked value prediction for cell embedding
ECS = False # Elastic cell similarity objective
cell_emb_style = "cls"
mvc_decoder_style = "inner product, detach"
amp = True
load_model = "../save/scGPT_human"
load_param_prefixs = [
# settings for optimizer
lr = 1e-4 # or 1e-4
batch_size = 64
eval_batch_size = 64
epochs = 15
schedule_interval = 1
early_stop = 5
# settings for the model
embsize = 512 # embedding dimension
d_hid = 512 # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8 # number of heads in nn.MultiheadAttention
n_layers_cls = 3
dropout = 0.2 # dropout probability
use_fast_transformer = True # whether to use fast transformer
# logging
log_interval = 100
# dataset and evaluation choices
data_name = "adamson"
split = "simulation"
if data_name == "norman":
perts_to_plot = ["SAMD1+ZBTB1"]
elif data_name == "adamson":
perts_to_plot = ["KCTD16+ctrl"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = Path(f"./save/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"saving to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
# log running date and current git commit"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")
saving to save/dev_perturb_adamson-Jul12-09-57
scGPT - INFO - Running on 2023-07-12 09:57:56
pert_data = PertData("./data")
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)
Found local copy...
Local copy of pyg dataset is detected. Loading...
Local copy of split is detected. Loading...
Simulation split test composition:
Creating dataloaders....
if load_model is not None:
model_dir = Path(load_model)
model_config_file = model_dir / "args.json"
model_file = model_dir / ""
vocab_file = model_dir / "vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
if s not in vocab:
pert_data.adata.var["id_in_vocab"] = [
1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
f"in vocabulary of size {len(vocab)}."
genes = pert_data.adata.var["gene_name"].tolist()
# model
with open(model_config_file, "r") as f:
model_configs = json.load(f)
f"Resume model from {model_file}, the model args will override the "
f"config {model_config_file}."
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]
genes = pert_data.adata.var["gene_name"].tolist()
vocab = Vocab(
VocabPybind(genes + special_tokens, None)
) # bidirectional lookup [gene <-> int]
gene_ids = np.array(
[vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
n_genes = len(genes)
scGPT - INFO - match 4399/5060 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from ../save/scGPT_human/, the model args will override the config ../save/scGPT_human/args.json.
# Create and train scGpt
ntokens = len(vocab) # size of vocabulary
model = TransformerGenerator(
if load_param_prefixs is not None and load_model is not None:
# only load params that start with the prefix
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if any([k.startswith(prefix) for prefix in load_param_prefixs])
for k, v in pretrained_dict.items():"Loading params {k} with shape {v.shape}")
elif load_model is not None:
model.load_state_dict(torch.load(model_file))"Loading all model params from {model_file}")
# only load params that are in the model and match the size
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
for k, v in pretrained_dict.items():"Loading params {k} with shape {v.shape}")
Using simple batchnorm instead of domain specific batchnorm
scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.1.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.1.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.1.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.1.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.1.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.1.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.2.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.2.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.2.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.2.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.2.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.2.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.3.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.3.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.3.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.3.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.3.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.3.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.4.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.4.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.4.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.4.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.4.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.4.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.5.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.5.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.5.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.5.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.5.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.5.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.6.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.6.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.6.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.6.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.6.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.6.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.7.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.7.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.7.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.7.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.7.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.7.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.8.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.8.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.8.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.8.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.8.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.8.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.9.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.9.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.9.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.9.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.9.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.9.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.10.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.10.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.10.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.10.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.10.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.10.norm2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.11.self_attn.Wqkv.bias with shape torch.Size([1536])
scGPT - INFO - Loading params transformer_encoder.layers.11.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.11.self_attn.out_proj.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.linear1.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.11.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.11.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.norm1.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.norm1.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.norm2.weight with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.11.norm2.bias with shape torch.Size([512])
(encoder): GeneEncoder(
(embedding): Embedding(60697, 512, padding_idx=60694)
(enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(value_encoder): ContinuousValueEncoder(
(dropout): Dropout(p=0.2, inplace=False)
(linear1): Linear(in_features=1, out_features=512, bias=True)
(activation): ReLU()
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(pert_encoder): Embedding(3, 512, padding_idx=2)
(bn): BatchNorm1d(512, eps=6.1e-05, momentum=0.1, affine=True, track_running_stats=True)
(transformer_encoder): TransformerEncoder(
(layers): ModuleList(
(0): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(1): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(2): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(3): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(4): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(5): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(6): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(7): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(8): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(9): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(10): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(11): FlashTransformerEncoderLayer(
(self_attn): FlashMHA(
(Wqkv): Linear(in_features=512, out_features=1536, bias=True)
(inner_attn): FlashAttention()
(out_proj): Linear(in_features=512, out_features=512, bias=True)
(linear1): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=512, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(decoder): ExprDecoder(
(fc): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=512, out_features=1, bias=True)
(cls_decoder): ClsDecoder(
(_decoder): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): ReLU()
(2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(3): Linear(in_features=512, out_features=512, bias=True)
(4): ReLU()
(5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(out_layer): Linear(in_features=512, out_features=1, bias=True)
(sim): Similarity(
(cos): CosineSimilarity()
(creterion_cce): CrossEntropyLoss()
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)
scaler = torch.cuda.amp.GradScaler(enabled=amp)
def train(model: nn.Module, train_loader: -> None:
Train the model for one epoch.
total_loss, total_mse = 0.0, 0.0
start_time = time.time()
num_batches = len(train_loader)
for batch, batch_data in enumerate(train_loader):
batch_size = len(batch_data.y)
x: torch.Tensor = batch_data.x # (batch_size * n_genes, 2)
ori_gene_values = x[:, 0].view(batch_size, n_genes)
pert_flags = x[:, 1].long().view(batch_size, n_genes)
target_gene_values = batch_data.y # (batch_size, n_genes)
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
# sample input_gene_id
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
# src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
src_key_padding_mask = torch.zeros_like(
input_values, dtype=torch.bool, device=device
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(
input_values, dtype=torch.bool
) # Use all
loss = loss_mse = criterion(output_values, target_values, masked_positions)
with warnings.catch_warnings(record=True) as w:
error_if_nonfinite=False if scaler.is_enabled() else True,
if len(w) > 0:
f"Found infinite gradient. This may be caused by the gradient "
f"scaler. The current scale is {scaler.get_scale()}. This warning "
"can be ignored if no longer occurs after autoscaling of the scaler."
# torch.cuda.empty_cache()
total_loss += loss.item()
total_mse += loss_mse.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
cur_mse = total_mse / log_interval
# ppl = math.exp(cur_loss)
f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} |"
total_loss = 0
total_mse = 0
start_time = time.time()
def evaluate(model: nn.Module, val_loader: -> float:
Evaluate the model on the evaluation data.
total_loss = 0.0
total_error = 0.0
with torch.no_grad():
for batch, batch_data in enumerate(val_loader):
batch_size = len(batch_data.y)
x: torch.Tensor = batch_data.x # (batch_size * n_genes, 2)
ori_gene_values = x[:, 0].view(batch_size, n_genes)
pert_flags = x[:, 1].long().view(batch_size, n_genes)
target_gene_values = batch_data.y # (batch_size, n_genes)
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device)
else: # when batch-wise
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
# sample input_gene_id
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
# src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
src_key_padding_mask = torch.zeros_like(
input_values, dtype=torch.bool, device=input_values.device
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(
input_values, dtype=torch.bool, device=input_values.device
loss = criterion(output_values, target_values, masked_positions)
total_loss += loss.item()
total_error += masked_relative_error(
output_values, target_values, masked_positions
return total_loss / len(val_loader), total_error / len(val_loader)
best_val_loss = float("inf")
best_model = None
patience = 0
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]
val_loss, val_mre = evaluate(
elapsed = time.time() - epoch_start_time"-" * 89)
f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
f"valid loss/mse {val_loss:5.4f} |"
)"-" * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = copy.deepcopy(model)"Best model with score {best_val_loss:5.4f}")
patience = 0
patience += 1
if patience >= early_stop:"Early stop at epoch {epoch}")
save_dir / f"model_{epoch}.pt",
scGPT - INFO - | epoch 1 | 100/849 batches | lr 0.0001 | ms/batch 349.85 | loss 0.13 | mse 0.13 |
scGPT - INFO - | epoch 1 | 200/849 batches | lr 0.0001 | ms/batch 347.33 | loss 0.09 | mse 0.09 |
scGPT - INFO - | epoch 1 | 300/849 batches | lr 0.0001 | ms/batch 347.91 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 1 | 400/849 batches | lr 0.0001 | ms/batch 347.70 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 1 | 500/849 batches | lr 0.0001 | ms/batch 347.07 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 1 | 600/849 batches | lr 0.0001 | ms/batch 347.23 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 1 | 700/849 batches | lr 0.0001 | ms/batch 347.21 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 1 | 800/849 batches | lr 0.0001 | ms/batch 347.31 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 1 | time: 299.08s | valid loss/mse 0.1373 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1373
scGPT - INFO - | epoch 2 | 100/849 batches | lr 0.0001 | ms/batch 347.63 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 200/849 batches | lr 0.0001 | ms/batch 347.29 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 300/849 batches | lr 0.0001 | ms/batch 347.17 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 400/849 batches | lr 0.0001 | ms/batch 346.59 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 500/849 batches | lr 0.0001 | ms/batch 347.25 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 600/849 batches | lr 0.0001 | ms/batch 347.06 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 700/849 batches | lr 0.0001 | ms/batch 347.62 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 2 | 800/849 batches | lr 0.0001 | ms/batch 347.03 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 2 | time: 298.68s | valid loss/mse 0.1375 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 3 | 100/849 batches | lr 0.0001 | ms/batch 350.56 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 200/849 batches | lr 0.0001 | ms/batch 346.59 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 300/849 batches | lr 0.0001 | ms/batch 346.73 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 400/849 batches | lr 0.0001 | ms/batch 347.00 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 500/849 batches | lr 0.0001 | ms/batch 347.14 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 600/849 batches | lr 0.0001 | ms/batch 347.38 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 700/849 batches | lr 0.0001 | ms/batch 347.23 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 3 | 800/849 batches | lr 0.0001 | ms/batch 347.14 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 3 | time: 298.94s | valid loss/mse 0.1359 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1359
scGPT - INFO - | epoch 4 | 100/849 batches | lr 0.0001 | ms/batch 350.52 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 200/849 batches | lr 0.0001 | ms/batch 346.88 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 300/849 batches | lr 0.0001 | ms/batch 347.00 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 400/849 batches | lr 0.0001 | ms/batch 347.52 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 500/849 batches | lr 0.0001 | ms/batch 346.55 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 600/849 batches | lr 0.0001 | ms/batch 347.25 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 700/849 batches | lr 0.0001 | ms/batch 347.23 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 4 | 800/849 batches | lr 0.0001 | ms/batch 346.89 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 4 | time: 298.89s | valid loss/mse 0.1354 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1354
scGPT - INFO - | epoch 5 | 100/849 batches | lr 0.0001 | ms/batch 350.21 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 200/849 batches | lr 0.0001 | ms/batch 347.01 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 300/849 batches | lr 0.0001 | ms/batch 347.14 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 400/849 batches | lr 0.0001 | ms/batch 347.01 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 500/849 batches | lr 0.0001 | ms/batch 346.98 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 600/849 batches | lr 0.0001 | ms/batch 347.23 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 700/849 batches | lr 0.0001 | ms/batch 347.07 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 5 | 800/849 batches | lr 0.0001 | ms/batch 347.25 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 5 | time: 298.87s | valid loss/mse 0.1364 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 6 | 100/849 batches | lr 0.0001 | ms/batch 349.59 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 200/849 batches | lr 0.0001 | ms/batch 347.04 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 300/849 batches | lr 0.0001 | ms/batch 347.12 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 400/849 batches | lr 0.0001 | ms/batch 346.49 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 500/849 batches | lr 0.0001 | ms/batch 346.99 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 600/849 batches | lr 0.0001 | ms/batch 347.02 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 700/849 batches | lr 0.0001 | ms/batch 346.86 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 6 | 800/849 batches | lr 0.0001 | ms/batch 347.34 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 6 | time: 298.77s | valid loss/mse 0.1352 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1352
scGPT - INFO - | epoch 7 | 100/849 batches | lr 0.0001 | ms/batch 349.76 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 200/849 batches | lr 0.0001 | ms/batch 346.70 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 300/849 batches | lr 0.0001 | ms/batch 346.65 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 400/849 batches | lr 0.0001 | ms/batch 347.09 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 500/849 batches | lr 0.0001 | ms/batch 346.92 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 600/849 batches | lr 0.0001 | ms/batch 346.92 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 700/849 batches | lr 0.0001 | ms/batch 347.06 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 7 | 800/849 batches | lr 0.0001 | ms/batch 347.31 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 7 | time: 298.74s | valid loss/mse 0.1368 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 8 | 100/849 batches | lr 0.0000 | ms/batch 349.62 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 200/849 batches | lr 0.0000 | ms/batch 346.76 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 300/849 batches | lr 0.0000 | ms/batch 346.87 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 400/849 batches | lr 0.0000 | ms/batch 347.17 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 500/849 batches | lr 0.0000 | ms/batch 347.15 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 600/849 batches | lr 0.0000 | ms/batch 347.34 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 700/849 batches | lr 0.0000 | ms/batch 347.14 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 8 | 800/849 batches | lr 0.0000 | ms/batch 346.92 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 8 | time: 298.82s | valid loss/mse 0.1352 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1352
scGPT - INFO - | epoch 9 | 100/849 batches | lr 0.0000 | ms/batch 349.88 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 200/849 batches | lr 0.0000 | ms/batch 346.84 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 300/849 batches | lr 0.0000 | ms/batch 346.99 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 400/849 batches | lr 0.0000 | ms/batch 346.72 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 500/849 batches | lr 0.0000 | ms/batch 347.14 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 600/849 batches | lr 0.0000 | ms/batch 347.16 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 700/849 batches | lr 0.0000 | ms/batch 347.26 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 9 | 800/849 batches | lr 0.0000 | ms/batch 346.73 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 9 | time: 298.79s | valid loss/mse 0.1356 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 10 | 100/849 batches | lr 0.0000 | ms/batch 349.94 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 200/849 batches | lr 0.0000 | ms/batch 347.09 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 300/849 batches | lr 0.0000 | ms/batch 346.88 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 400/849 batches | lr 0.0000 | ms/batch 347.00 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 500/849 batches | lr 0.0000 | ms/batch 347.28 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 600/849 batches | lr 0.0000 | ms/batch 347.24 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 700/849 batches | lr 0.0000 | ms/batch 347.30 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 10 | 800/849 batches | lr 0.0000 | ms/batch 347.63 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 10 | time: 298.95s | valid loss/mse 0.1333 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1333
scGPT - INFO - | epoch 11 | 100/849 batches | lr 0.0000 | ms/batch 350.28 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 200/849 batches | lr 0.0000 | ms/batch 347.03 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 300/849 batches | lr 0.0000 | ms/batch 346.87 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 400/849 batches | lr 0.0000 | ms/batch 346.88 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 500/849 batches | lr 0.0000 | ms/batch 346.83 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 600/849 batches | lr 0.0000 | ms/batch 347.43 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 700/849 batches | lr 0.0000 | ms/batch 346.82 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 11 | 800/849 batches | lr 0.0000 | ms/batch 346.90 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 11 | time: 298.78s | valid loss/mse 0.1357 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 12 | 100/849 batches | lr 0.0000 | ms/batch 349.86 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 200/849 batches | lr 0.0000 | ms/batch 347.24 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 300/849 batches | lr 0.0000 | ms/batch 347.01 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 400/849 batches | lr 0.0000 | ms/batch 347.18 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 500/849 batches | lr 0.0000 | ms/batch 347.16 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 600/849 batches | lr 0.0000 | ms/batch 346.99 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 700/849 batches | lr 0.0000 | ms/batch 347.13 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 12 | 800/849 batches | lr 0.0000 | ms/batch 347.15 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 12 | time: 298.85s | valid loss/mse 0.1321 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 0.1321
scGPT - INFO - | epoch 13 | 100/849 batches | lr 0.0000 | ms/batch 349.99 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 200/849 batches | lr 0.0000 | ms/batch 347.36 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 300/849 batches | lr 0.0000 | ms/batch 347.23 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 400/849 batches | lr 0.0000 | ms/batch 347.07 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 500/849 batches | lr 0.0000 | ms/batch 346.83 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 600/849 batches | lr 0.0000 | ms/batch 347.54 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 700/849 batches | lr 0.0000 | ms/batch 347.16 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 13 | 800/849 batches | lr 0.0000 | ms/batch 346.83 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 13 | time: 298.94s | valid loss/mse 0.1378 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 14 | 100/849 batches | lr 0.0000 | ms/batch 350.03 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 200/849 batches | lr 0.0000 | ms/batch 347.42 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 300/849 batches | lr 0.0000 | ms/batch 347.32 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 400/849 batches | lr 0.0000 | ms/batch 347.37 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 500/849 batches | lr 0.0000 | ms/batch 347.06 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 600/849 batches | lr 0.0000 | ms/batch 347.28 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 700/849 batches | lr 0.0000 | ms/batch 346.85 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 14 | 800/849 batches | lr 0.0000 | ms/batch 346.77 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 14 | time: 298.97s | valid loss/mse 0.1339 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | epoch 15 | 100/849 batches | lr 0.0000 | ms/batch 350.67 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 200/849 batches | lr 0.0000 | ms/batch 346.56 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 300/849 batches | lr 0.0000 | ms/batch 347.44 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 400/849 batches | lr 0.0000 | ms/batch 346.89 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 500/849 batches | lr 0.0000 | ms/batch 347.05 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 600/849 batches | lr 0.0000 | ms/batch 346.89 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 700/849 batches | lr 0.0000 | ms/batch 347.35 | loss 0.08 | mse 0.08 |
scGPT - INFO - | epoch 15 | 800/849 batches | lr 0.0000 | ms/batch 347.19 | loss 0.08 | mse 0.08 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch 15 | time: 298.90s | valid loss/mse 0.1361 |
scGPT - INFO - -----------------------------------------------------------------------------------------
[17]:, save_dir / "")
## Evaluations
def predict(
model: TransformerGenerator, pert_list: List[str], pool_size: Optional[int] = None
) -> Dict:
Predict the gene expression values for the given perturbations.
model (:class:`torch.nn.Module`): The model to use for prediction.
pert_list (:obj:`List[str]`): The list of perturbations to predict.
pool_size (:obj:`int`, optional): For each perturbation, use this number
of cells in the control and predict their perturbation results. Report
the stats of these predictions. If `None`, use all control cells.
adata = pert_data.adata
ctrl_adata = adata[adata.obs["condition"] == "ctrl"]
if pool_size is None:
pool_size = len(ctrl_adata.obs)
gene_list = pert_data.gene_names.values.tolist()
for pert in pert_list:
for i in pert:
if i not in gene_list:
raise ValueError(
"The gene is not in the perturbation graph. Please select from GEARS.gene_list!"
device = next(model.parameters()).device
with torch.no_grad():
results_pred = {}
for pert in pert_list:
cell_graphs = create_cell_graph_dataset_for_prediction(
pert, ctrl_adata, gene_list, device, num_samples=pool_size
loader = DataLoader(cell_graphs, batch_size=eval_batch_size, shuffle=False)
preds = []
for batch_data in loader:
pred_gene_values = model.pred_perturb(
batch_data, include_zero_gene, gene_ids=gene_ids, amp=amp
preds =, dim=0)
results_pred["_".join(pert)] = np.mean(preds.detach().cpu().numpy(), axis=0)
return results_pred
def plot_perturbation(
model: nn.Module, query: str, save_file: str = None, pool_size: int = None
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
sns.set_theme(style="ticks", rc={"axes.facecolor": (0, 0, 0, 0)}, font_scale=1.5)
adata = pert_data.adata
gene2idx = pert_data.node_map
cond2name = dict(adata.obs[["condition", "condition_name"]].values)
gene_raw2id = dict(zip(adata.var.index.values, adata.var.gene_name.values))
de_idx = [
for i in adata.uns["top_non_dropout_de_20"][cond2name[query]]
genes = [
gene_raw2id[i] for i in adata.uns["top_non_dropout_de_20"][cond2name[query]]
truth = adata[adata.obs.condition == query].X.toarray()[:, de_idx]
if query.split("+")[1] == "ctrl":
pred = predict(model, [[query.split("+")[0]]], pool_size=pool_size)
pred = pred[query.split("+")[0]][de_idx]
pred = predict(model, [query.split("+")], pool_size=pool_size)
pred = pred["_".join(query.split("+"))][de_idx]
ctrl_means = adata[adata.obs["condition"] == "ctrl"].to_df().mean()[de_idx].values
pred = pred - ctrl_means
truth = truth - ctrl_means
plt.figure(figsize=[16.5, 4.5])
plt.boxplot(truth, showfliers=False, medianprops=dict(linewidth=0))
for i in range(pred.shape[0]):
_ = plt.scatter(i + 1, pred[i], color="red")
plt.axhline(0, linestyle="dashed", color="green")
ax = plt.gca()
ax.xaxis.set_ticklabels(genes, rotation=90)
plt.ylabel("Change in Gene Expression over Control", labelpad=10)
plt.tick_params(axis="x", which="major", pad=5)
plt.tick_params(axis="y", which="major", pad=5)
if save_file:
plt.savefig(save_file, bbox_inches="tight", transparent=False)
# predict(best_model, [["FEV"], ["FEV", "SAMD11"]])
for p in perts_to_plot:
plot_perturbation(best_model, p, pool_size=300, save_file=f"{save_dir}/{p}.png")

def eval_perturb(
loader: DataLoader, model: TransformerGenerator, device: torch.device
) -> Dict:
Run model in inference mode using a given data loader
pert_cat = []
pred = []
truth = []
pred_de = []
truth_de = []
results = {}
logvar = []
for itr, batch in enumerate(loader):
with torch.no_grad():
p = model.pred_perturb(batch, include_zero_gene, gene_ids=gene_ids)
t = batch.y
# Differentially expressed genes
for itr, de_idx in enumerate(batch.de_idx):
pred_de.append(p[itr, de_idx])
truth_de.append(t[itr, de_idx])
# all genes
results["pert_cat"] = np.array(pert_cat)
pred = torch.stack(pred)
truth = torch.stack(truth)
results["pred"] = pred.detach().cpu().numpy().astype(np.float)
results["truth"] = truth.detach().cpu().numpy().astype(np.float)
pred_de = torch.stack(pred_de)
truth_de = torch.stack(truth_de)
results["pred_de"] = pred_de.detach().cpu().numpy().astype(np.float)
results["truth_de"] = truth_de.detach().cpu().numpy().astype(np.float)
return results
test_loader = pert_data.dataloader["test_loader"]
test_res = eval_perturb(test_loader, best_model, device)
test_metrics, test_pert_res = compute_metrics(test_res)
# save the dicts in json
with open(f"{save_dir}/test_metrics.json", "w") as f:
json.dump(test_metrics, f)
with open(f"{save_dir}/test_pert_res.json", "w") as f:
json.dump(test_pert_res, f)
deeper_res = deeper_analysis(pert_data.adata, test_res)
non_dropout_res = non_dropout_analysis(pert_data.adata, test_res)
metrics = ["pearson_delta", "pearson_delta_de"]
metrics_non_dropout = [
subgroup_analysis = {}
for name in pert_data.subgroup["test_subgroup"].keys():
subgroup_analysis[name] = {}
for m in metrics:
subgroup_analysis[name][m] = []
for m in metrics_non_dropout:
subgroup_analysis[name][m] = []
for name, pert_list in pert_data.subgroup["test_subgroup"].items():
for pert in pert_list:
for m in metrics:
for m in metrics_non_dropout:
for name, result in subgroup_analysis.items():
for m in result.keys():
subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m])"test_" + name + "_" + m + ": " + str(subgroup_analysis[name][m]))
{'mse': 0.00633879785470548, 'mse_de': 0.12994101050959672, 'pearson': 0.9901751059128802, 'pearson_de': 0.9778788141012438}
scGPT - INFO - test_combo_seen0_pearson_delta: nan
scGPT - INFO - test_combo_seen0_pearson_delta_de: nan
scGPT - INFO - test_combo_seen0_pearson_delta_top20_de_non_dropout: nan
scGPT - INFO - test_combo_seen0_pearson_top20_de_non_dropout: nan
scGPT - INFO - test_combo_seen1_pearson_delta: nan
scGPT - INFO - test_combo_seen1_pearson_delta_de: nan
scGPT - INFO - test_combo_seen1_pearson_delta_top20_de_non_dropout: nan
scGPT - INFO - test_combo_seen1_pearson_top20_de_non_dropout: nan
scGPT - INFO - test_combo_seen2_pearson_delta: nan
scGPT - INFO - test_combo_seen2_pearson_delta_de: nan
scGPT - INFO - test_combo_seen2_pearson_delta_top20_de_non_dropout: nan
scGPT - INFO - test_combo_seen2_pearson_top20_de_non_dropout: nan
scGPT - INFO - test_unseen_single_pearson_delta: 0.6155381653776693
scGPT - INFO - test_unseen_single_pearson_delta_de: 0.7955997135623811
scGPT - INFO - test_unseen_single_pearson_delta_top20_de_non_dropout: 0.7941164267404693
scGPT - INFO - test_unseen_single_pearson_top20_de_non_dropout: 0.9749507618278279