Fine-tuning on Pre-trained Model with Batch Integration

In this tutorial, we demonstrate how to fine-tune a pre-trained model on a new dataset for the batch integration task. We use the PBMC 10K dataset as an example and fine-tune on the pre-trained whole-body model.

We summarize the fine-tuning pipeline in the following steps, which can be used as a general recipe for finetuning on integration tasks and beyond:

1. Specify hyper-parameter setup for integration task

2. Load and pre-process data

3. Load the pre-trained scGPT model

4. Finetune scGPT with task-specific objectives

5. Evaluate fine-tuned scGPT
[1]:
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
import scvi
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)


sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, eval_scib_metrics, load_pretrained

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
Global seed set to 0
/h/chloexq/.cache/pypoetry/virtualenvs/scgpt--qSLVbd1-py3.9/lib/python3.7/site-packages/scanpy/_settings.py:447: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  IPython.display.set_matplotlib_formats(*ipython_format)

Step1: Specify hyper-parameter setup for integration task

Here we provide some hyper-parameter recommendations here for the integration task. Note that the PBMC 10K dataset contains multiple batches to be integrated. Therefore, in addition to the default gene modelling objectives, we also turn on ESC, DAR and DSBN objectives specifically to faciliate batch integration.

[2]:
hyperparameter_defaults = dict(
    seed=42,
    dataset_name="PBMC_10K", # Dataset name
    do_train=True, # Flag to indicate whether to do update model parameters during training
    load_model="../save/scGPT_human", # Path to pre-trained model
    GEPC=True,  # Gene expression modelling for cell objective
    ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0, # DAR objective weight for batch correction
    mask_ratio=0.4, # Default mask ratio
    epochs=15, # Default number of epochs for fine-tuning
    n_bins=51, # Default number of bins for value binning in data pre-processing
    lr=1e-4, # Default learning rate for fine-tuning
    batch_size=64, # Default batch size for fine-tuning
    layer_size=128,
    nlayers=4,
    nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2, # Default dropout rate during model fine-tuning
    schedule_ratio=0.9,  # Default rate for learning rate decay
    save_eval_interval=5, # Default model evaluation interval
    log_interval=100, # Default log interval
    fast_transformer=True, # Default setting
    pre_norm=False, # Default setting
    amp=True,  # # Default setting: Automatic Mixed Precision
)
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: chloewxq (scformer). Use `wandb login --relogin` to force relogin
wandb version 0.15.4 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.12.21
Run data is saved locally in /ssd003/home/chloexq/scGPT/tutorials/wandb/run-20230703_155312-vlmktqyd
{'seed': 42, 'dataset_name': 'PBMC_10K', 'do_train': True, 'load_model': '../save/scGPT_human', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
[3]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = -1
pad_value = -2
n_input_bins = config.n_bins

n_hvg = 1200  # number of highly variable genes
max_seq_len = n_hvg + 1
per_seq_batch_sample = True
DSBN = True  # Domain-spec batchnorm
explicit_zero_prob = True  # whether explicit bernoulli for zeros
[4]:
dataset_name = config.dataset_name
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
save to save/dev_PBMC_10K-Jul03-15-53

Step 2: Load and pre-process data

2.1 Load the PBMC 10K data

[5]:
if dataset_name == "PBMC_10K":
    adata = scvi.data.pbmc_dataset()  # 11990 × 3346
    ori_batch_col = "batch"
    adata.obs["celltype"] = adata.obs["str_labels"].astype("category")
    adata.var = adata.var.set_index("gene_symbols")
    data_is_raw = True

# make the batch category column
adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
adata.obs["batch_id"] = batch_id_labels
adata.var["gene_name"] = adata.var.index.tolist()
INFO     File data/gene_info_pbmc.csv already downloaded
INFO     File data/pbmc_metadata.pickle already downloaded
INFO     File data/pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO     Extracting tar file
INFO     Removing extracted data at data/pbmc8k/filtered_gene_bc_matrices
INFO     File data/pbmc4k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO     Extracting tar file
INFO     Removing extracted data at data/pbmc4k/filtered_gene_bc_matrices

2.2 Cross-check gene set with the pre-trained model

Note that we retain the common gene set between the data and the pre-trained model for further fine-tuning.

[6]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will be overriden by the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    embsize = config.layer_size
    nhead = config.nhead
    nlayers = config.nlayers
    d_hid = config.layer_size
scGPT - INFO - match 3256/3346 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from ../save/scGPT_human/best_model.pt, the model args will be overriden by the config ../save/scGPT_human/args.json.

2.3 Pre-process the data

We follow the standardized pipline of depth normalization, log normalization, and highly vairable gene (HVG) selection for data pre-processing. We further introduced value binning to obtain the relative expressions of each HVG.

[7]:
# set up the preprocessor, use the args to config the workflow
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="str_batch" if dataset_name != "heart_cell" else None)

scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...
[8]:
if per_seq_batch_sample:
    # sort the adata by batch_id in advance
    adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

2.4 Tokenize the input data for model fine-tuning

[9]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

celltypes_labels = adata.obs["celltype"].tolist()  # make sure count from 0
num_types = len(set(celltypes_labels))
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)

[10]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)
[11]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=True,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=True,
)
logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)
logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)


scGPT - INFO - train set number of samples: 10791,
         feature length: 1201
scGPT - INFO - valid set number of samples: 1199,
         feature length: 1201
[12]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    if sort_seq_batch:
        train_sort_ids = np.argsort(train_batch_labels)
        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
    }

    return train_data_pt, valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader

## Step 3: Load the pre-trained scGPT model

[13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=config.GEPC,
    do_dab=True,
    use_batch_labels=True,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=DSBN,
    n_input_bins=n_input_bins,
    ecs_threshold=config.ecs_thres,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=config.fast_transformer,
    pre_norm=config.pre_norm,
)
if config.load_model is not None:
    load_pretrained(model, torch.load(model_file), verbose=False)

model.to(device)
wandb.watch(model)
Use domain specific batchnorm with affine=False
[13]:
[]
[14]:
criterion = masked_mse_loss
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)
[15]:
def train(model: nn.Module, loader: DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
    total_error = 0.0
    log_interval = config.log_interval
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        target_values = batch_data["target_values"].to(device)
        batch_labels = batch_data["batch_labels"].to(device)

        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if DSBN else None,
                MVC=config.GEPC,
                ECS=config.ecs_thres > 0,
            )

            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = loss_mse = criterion(
                output_dict["mlm_output"], target_values, masked_positions
            )
            metrics_to_log = {"train/mse": loss_mse.item()}
            if explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mlm_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_zero_log_prob
                metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
            if config.GEPC:
                loss_gepc = criterion(
                    output_dict["mvc_output"], target_values, masked_positions
                )
                loss = loss + loss_gepc
                metrics_to_log.update({"train/mvc": loss_gepc.item()})
            if config.GEPC and explicit_zero_prob:
                loss_gepc_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mvc_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_gepc_zero_log_prob
                metrics_to_log.update(
                    {"train/mvc_nzlp": loss_gepc_zero_log_prob.item()}
                )
            if config.ecs_thres > 0:
                loss_ecs = 10 * output_dict["loss_ecs"]
                loss = loss + loss_ecs
                metrics_to_log.update({"train/ecs": loss_ecs.item()})
            loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
            loss = loss + config.dab_weight * loss_dab
            metrics_to_log.update({"train/dab": loss_dab.item()})

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        wandb.log(metrics_to_log)

        with torch.no_grad():
            mre = masked_relative_error(
                output_dict["mlm_output"], target_values, masked_positions
            )

        total_loss += loss.item()
        total_mse += loss_mse.item()
        total_gepc += loss_gepc.item() if config.GEPC else 0.0
        total_error += mre.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_gepc = total_gepc / log_interval if config.GEPC else 0.0
            cur_error = total_error / log_interval
            # ppl = math.exp(cur_loss)
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |"
                + (f"gepc {cur_gepc:5.2f} |" if config.GEPC else "")
            )
            total_loss = 0
            total_mse = 0
            total_gepc = 0
            total_error = 0
            start_time = time.time()


def define_wandb_metrcis():
    wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
    wandb.define_metric("valid/mre", summary="min", step_metric="epoch")
    wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
    wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
    wandb.define_metric("test/avg_bio", summary="max")


def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if DSBN else None,
                )
                output_values = output_dict["mlm_output"]

                masked_positions = input_values.eq(mask_value)
                loss = criterion(output_values, target_values, masked_positions)
                loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            total_error += masked_relative_error(
                output_values, target_values, masked_positions
            ).item() * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids)
            total_num += len(input_gene_ids)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/mre": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + config.dab_weight * total_dab)
            / total_num,
            "epoch": epoch,
        },
    )

    return total_loss / total_num, total_error / total_num


def eval_testdata(
    model: nn.Module,
    adata_t: AnnData,
    include_types: List[str] = ["cls"],
) -> Optional[Dict]:
    """evaluate the model on test dataset of adata_t"""
    model.eval()

    # copy adata_t to avoid reuse previously computed results stored in adata_t
    adata_t = adata_t.copy()

    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )

    celltypes_labels = adata_t.obs["celltype"].tolist()
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    # Evaluate cls cell embeddings
    if "cls" in include_types:
        logger.info("Evaluating cls cell embeddings")
        tokenized_all = tokenize_and_pad_batch(
            all_counts,
            gene_ids,
            max_len=max_seq_len,
            vocab=vocab,
            pad_token=pad_token,
            pad_value=pad_value,
            append_cls=True,  # append <cls> token at the beginning
            include_zero_gene=True,
        )
        all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
        src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
            cell_embeddings = model.encode_batch(
                all_gene_ids,
                all_values.float(),
                src_key_padding_mask=src_key_padding_mask,
                batch_size=config.batch_size,
                batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
                time_step=0,
                return_np=True,
            )
        cell_embeddings = cell_embeddings / np.linalg.norm(
            cell_embeddings, axis=1, keepdims=True
        )

        adata_t.obsm["X_scGPT"] = cell_embeddings

        results = {}
        try:
            results = eval_scib_metrics(adata_t)
        except Exception as e:
            traceback.print_exc()
            logger.error(e)

        sc.pp.neighbors(adata_t, use_rep="X_scGPT")
        sc.tl.umap(adata_t, min_dist=0.3)
        fig = sc.pl.umap(
            adata_t,
            color=["str_batch"],
            title=[f"batch, avg_bio = {results.get('avg_bio', 0.0):.4f}"],
            frameon=False,
            return_fig=True,
            show=False,
        )

        results["batch_umap"] = fig

        sc.pp.neighbors(adata_t, use_rep="X_scGPT")
        sc.tl.umap(adata_t, min_dist=0.3)
        fig = sc.pl.umap(
            adata_t,
            color=["celltype"],
            title=[
                f"celltype, avg_bio = {results.get('avg_bio', 0.0):.4f}",
            ],
            frameon=False,
            return_fig=True,
            show=False,
        )

        results["celltype_umap"] = fig

    if len(include_types) == 1:
        return results

Step 4: Finetune scGPT with task-specific objectives

[16]:
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None
define_wandb_metrcis()

for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        intra_domain_shuffle=True,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    if config.do_train:
        train(
            model,
            loader=train_loader,
        )
    val_loss, val_mre = evaluate(
        model,
        loader=valid_loader,
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        logger.info(f"Best model with score {best_val_loss:5.4f}")

    if epoch % config.save_eval_interval == 0 or epoch == config.epochs:
        logger.info(f"Saving model to {save_dir}")
        torch.save(best_model.state_dict(), save_dir / f"model_e{best_model_epoch}.pt")

        # eval on testdata
        results = eval_testdata(
            best_model,
            adata_t=adata_sorted if per_seq_batch_sample else adata,
            include_types=["cls"],
        )
        results["batch_umap"].savefig(
            save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png", dpi=300
        )

        results["celltype_umap"].savefig(
            save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png", dpi=300
        )
        metrics_to_log = {"test/" + k: v for k, v in results.items()}
        metrics_to_log["test/batch_umap"] = wandb.Image(
            str(save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png"),
            caption=f"celltype avg_bio epoch {best_model_epoch}",
        )

        metrics_to_log["test/celltype_umap"] = wandb.Image(
            str(save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png"),
            caption=f"celltype avg_bio epoch {best_model_epoch}",
        )
        metrics_to_log["test/best_model_epoch"] = best_model_epoch
        wandb.log(metrics_to_log)
        wandb.log({"avg_bio": results.get("avg_bio", 0.0)})

    scheduler.step()
random masking at epoch   1, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   1 | 100/170 batches | lr 0.0001 | ms/batch 757.40 | loss 173.33 | mse 89.35 | mre 2663176.96 |gepc 73.21 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   1 | time: 130.90s | valid loss/mse 75.9452 | mre 4443419.7672
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 75.9452
random masking at epoch   2, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   2 | 100/170 batches | lr 0.0001 | ms/batch 747.07 | loss 116.37 | mse 51.26 | mre 1768156.19 |gepc 53.94 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   2 | time: 129.46s | valid loss/mse 50.5663 | mre 1523694.5974
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 50.5663
random masking at epoch   3, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   3 | 100/170 batches | lr 0.0001 | ms/batch 742.33 | loss 110.88 | mse 48.25 | mre 1595595.39 |gepc 50.59 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   3 | time: 128.22s | valid loss/mse 47.3394 | mre 1329758.4306
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 47.3394
random masking at epoch   4, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   4 | 100/170 batches | lr 0.0001 | ms/batch 745.60 | loss 109.07 | mse 47.46 | mre 1555602.24 |gepc 50.33 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   4 | time: 129.36s | valid loss/mse 45.9131 | mre 1213080.1915
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 45.9131
random masking at epoch   5, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   5 | 100/170 batches | lr 0.0001 | ms/batch 748.42 | loss 105.71 | mse 47.35 | mre 1546922.04 |gepc 49.55 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   5 | time: 129.56s | valid loss/mse 46.9950 | mre 1717378.3617
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Saving model to save/dev_PBMC_10K-Jul03-15-53
scGPT - INFO - Evaluating cls cell embeddings
100%|██████████| 188/188 [00:54<00:00,  3.43it/s]
NMI...
ARI...
Silhouette score...
PC regression...
Graph connectivity...
scGPT - INFO -                                   0
NMI_cluster/label          0.814257
ARI_cluster/label          0.843118
ASW_label                  0.699138
ASW_label/batch            0.955014
PCR_batch                  0.825616
cell_cycle_conservation         NaN
isolated_label_F1               NaN
isolated_label_silhouette       NaN
graph_conn                 0.873379
kBET                            NaN
iLISI                           NaN
cLISI                           NaN
hvg_overlap                     NaN
trajectory                      NaN
scGPT - INFO - Biological Conservation Metrics:
ASW (cell-type): 0.6991, graph cLISI: nan, isolated label silhouette: nan,
Batch Effect Removal Metrics:
PCR_batch: 0.8256, ASW (batch): 0.9550, graph connectivity: 0.8734, graph iLISI: nan
random masking at epoch   6, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   6 | 100/170 batches | lr 0.0001 | ms/batch 766.20 | loss 103.49 | mse 47.02 | mre 1527726.00 |gepc 48.65 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   6 | time: 132.67s | valid loss/mse 45.6797 | mre 1421589.9451
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 45.6797
random masking at epoch   7, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   7 | 100/170 batches | lr 0.0001 | ms/batch 758.39 | loss 102.51 | mse 46.71 | mre 1518484.73 |gepc 48.05 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   7 | time: 132.07s | valid loss/mse 46.0452 | mre 1774720.6303
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch   8, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   8 | 100/170 batches | lr 0.0000 | ms/batch 756.79 | loss 102.26 | mse 46.71 | mre 1507389.60 |gepc 47.85 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   8 | time: 131.08s | valid loss/mse 46.2436 | mre 1736326.6960
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch   9, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch   9 | 100/170 batches | lr 0.0000 | ms/batch 755.87 | loss 101.76 | mse 46.54 | mre 1504583.20 |gepc 47.90 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   9 | time: 131.23s | valid loss/mse 46.6282 | mre 1929727.1859
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch  10, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  10 | 100/170 batches | lr 0.0000 | ms/batch 750.91 | loss 101.39 | mse 46.51 | mre 1526292.84 |gepc 47.80 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  10 | time: 130.18s | valid loss/mse 45.2987 | mre 1240932.0140
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 45.2987
scGPT - INFO - Saving model to save/dev_PBMC_10K-Jul03-15-53
scGPT - INFO - Evaluating cls cell embeddings
100%|██████████| 188/188 [00:53<00:00,  3.49it/s]
NMI...
ARI...
Silhouette score...
PC regression...
Graph connectivity...
scGPT - INFO -                                   0
NMI_cluster/label          0.858342
ARI_cluster/label          0.892640
ASW_label                  0.745845
ASW_label/batch            0.947944
PCR_batch                  0.636703
cell_cycle_conservation         NaN
isolated_label_F1               NaN
isolated_label_silhouette       NaN
graph_conn                 0.895191
kBET                            NaN
iLISI                           NaN
cLISI                           NaN
hvg_overlap                     NaN
trajectory                      NaN
scGPT - INFO - Biological Conservation Metrics:
ASW (cell-type): 0.7458, graph cLISI: nan, isolated label silhouette: nan,
Batch Effect Removal Metrics:
PCR_batch: 0.6367, ASW (batch): 0.9479, graph connectivity: 0.8952, graph iLISI: nan
random masking at epoch  11, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  11 | 100/170 batches | lr 0.0000 | ms/batch 755.50 | loss 100.61 | mse 46.27 | mre 1480725.00 |gepc 47.29 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  11 | time: 130.97s | valid loss/mse 45.3457 | mre 1603972.0365
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch  12, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  12 | 100/170 batches | lr 0.0000 | ms/batch 746.86 | loss 100.49 | mse 46.08 | mre 1485683.62 |gepc 47.37 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  12 | time: 130.04s | valid loss/mse 46.0835 | mre 1862575.2738
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch  13, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  13 | 100/170 batches | lr 0.0000 | ms/batch 754.29 | loss 100.68 | mse 46.31 | mre 1510157.08 |gepc 47.34 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  13 | time: 131.26s | valid loss/mse 45.9947 | mre 1604693.7741
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch  14, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  14 | 100/170 batches | lr 0.0000 | ms/batch 752.90 | loss 100.04 | mse 46.10 | mre 1473999.56 |gepc 47.03 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  14 | time: 130.39s | valid loss/mse 45.3608 | mre 1330498.1018
scGPT - INFO - -----------------------------------------------------------------------------------------
random masking at epoch  15, ratio of masked values in train:  0.3997
scGPT - INFO - | epoch  15 | 100/170 batches | lr 0.0000 | ms/batch 747.97 | loss 100.17 | mse 46.15 | mre 1473452.98 |gepc 47.09 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  15 | time: 129.55s | valid loss/mse 44.9956 | mre 1290994.5793
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 44.9956
scGPT - INFO - Saving model to save/dev_PBMC_10K-Jul03-15-53
scGPT - INFO - Evaluating cls cell embeddings
100%|██████████| 188/188 [00:54<00:00,  3.48it/s]
NMI...
ARI...
Silhouette score...
PC regression...
Graph connectivity...
scGPT - INFO -                                   0
NMI_cluster/label          0.855717
ARI_cluster/label          0.884772
ASW_label                  0.726386
ASW_label/batch            0.949340
PCR_batch                  0.525635
cell_cycle_conservation         NaN
isolated_label_F1               NaN
isolated_label_silhouette       NaN
graph_conn                 0.939571
kBET                            NaN
iLISI                           NaN
cLISI                           NaN
hvg_overlap                     NaN
trajectory                      NaN
scGPT - INFO - Biological Conservation Metrics:
ASW (cell-type): 0.7264, graph cLISI: nan, isolated label silhouette: nan,
Batch Effect Removal Metrics:
PCR_batch: 0.5256, ASW (batch): 0.9493, graph connectivity: 0.9396, graph iLISI: nan
_images/tutorial_integraion_24_7.png
_images/tutorial_integraion_24_8.png
_images/tutorial_integraion_24_9.png
_images/tutorial_integraion_24_10.png
_images/tutorial_integraion_24_11.png
_images/tutorial_integraion_24_12.png
[17]:
# save the best model
torch.save(best_model.state_dict(), save_dir / "best_model.pt")
[18]:
artifact = wandb.Artifact(f"best_model", type="model")
glob_str = os.path.join(save_dir, "best_model.pt")
artifact.add_file(glob_str)
run.log_artifact(artifact)

run.finish()
wandb.finish()
gc.collect()
Waiting for W&B process to finish... (success).

Run history:


avg_bio▁█▇
epoch▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
test/ARI_cluster/label▁█▇
test/ASW_label▁█▅
test/ASW_label/batch█▁▂
test/NMI_cluster/label▁██
test/PCR_batch█▄▁
test/avg_bio▁█▇
test/best_model_epoch▁▅█
test/graph_conn▁▃█
train/dab▁▁▁▁█▃▁▂▃▁▂▁▄▂▁▃▁▁▁▁▃▂▁▃▂▁▂▁▃▂▁▂▂▁▂▁▃▂▁▂
train/ecs████████▇▅▄▄▂▃▂▃▃▂▃▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mse█▄▃▂▃▂▂▁▂▂▂▂▂▂▂▁▂▁▁▂▂▁▂▁▂▂▁▂▂▂▂▁▂▂▁▂▁▂▂▁
train/mvc█▆▅▃▅▃▃▂▄▃▃▃▅▄▃▂▂▂▂▂▂▂▂▁▃▃▁▂▃▂▃▂▂▃▂▃▂▃▂▂
train/mvc_nzlp▄▂▃▁▂▃▇▃▇▇▂▇▃▇█▂▇▇▁▇▂▆▇▁▇█▁▇▁▇▇▁▇▇▁▇▁▇▇▁
train/nzlp█▅▄▃▃▃▃▂▃▃▂▂▂▂▂▂▂▁▁▂▁▁▂▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁
valid/dab█▄▅▂▂▂▂▁▁▁▁▁▁▁▁
valid/mre█▂▁▁▂▁▂▂▃▁▂▂▂▁▁
valid/mse█▂▂▁▁▁▁▁▁▁▁▁▁▁▁
valid/sum_mse_dab█▂▂▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


avg_bio0.82229
epoch15
test/ARI_cluster/label0.88477
test/ASW_label0.72639
test/ASW_label/batch0.94934
test/NMI_cluster/label0.85572
test/PCR_batch0.52563
test/best_model_epoch15
test/graph_conn0.93957
train/dab0.72778
train/ecs4.44531
train/mse43.72255
train/mvc44.7052
train/mvc_nzlp0.32442
train/nzlp0.2438

Synced charmed-blaze-8: https://wandb.ai/scformer/scGPT/runs/vlmktqyd
Synced 5 W&B file(s), 6 media file(s), 1 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20230703_155312-vlmktqyd/logs
[18]:
92911