|
|
|
@ -112,9 +112,6 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
def backward(
|
|
|
|
|
self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
|
|
|
|
|
) -> Tuple[Union[torch.Tensor, Any], ...]:
|
|
|
|
|
args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
|
|
|
|
|
# ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad
|
|
|
|
|
assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
|
|
|
|
|
with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
|
|
|
|
|
(outputs,) = self.module(*args, **kwargs)
|
|
|
|
|
assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
|
|
|
|
|