Source code for scgpt.model.grad_reverse

import torch
from torch.autograd import Function


[docs] class GradReverse(Function):
[docs] @staticmethod def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor: ctx.lambd = lambd return x.view_as(x)
[docs] @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return grad_output.neg() * ctx.lambd, None
[docs] def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor: return GradReverse.apply(x, lambd)