WIP, switching to another PR

partial_rollback
Your Name 9 months ago
parent 09e9da6eb1
commit 84ebd57105

@ -112,9 +112,6 @@ class TransformerBackend(ModuleBackend):
def backward(
self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
) -> Tuple[Union[torch.Tensor, Any], ...]:
args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
# ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad
assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
(outputs,) = self.module(*args, **kwargs)
assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape

@ -92,7 +92,7 @@ async def run_rpc_backward(
requested_backends, flat_tensors, args_structure
)
# Cast inputs & grad outputs to backend dtype
assert hidden_states.ndim == 3
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
hidden_states = hidden_states.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)

@ -9,7 +9,7 @@ from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *
@pytest.mark.forked
@pytest.mark.skip
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id

Loading…
Cancel
Save