diff --git a/src/client/remote_block.py b/src/client/remote_block.py index 12ba83c..c02f0a2 100644 --- a/src/client/remote_block.py +++ b/src/client/remote_block.py @@ -29,6 +29,11 @@ class RemoteTransformerBlock(RemoteExpert): def stub(self) -> StubBase: return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id) + def forward(self, inputs: torch.Tensor, **kwargs): + for k, v in kwargs.items(): + assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})" + return super().forward(inputs) + def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession: """Initialize a new inference session with the specified remote server""" _ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 8cb49dd..69d58be 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -16,7 +16,9 @@ logger = get_logger(__file__) class RemoteSequential(nn.Sequential): - """A sequence of transformer blocks hosted by the swarm""" + """ + A sequence of transformer blocks hosted by the swarm. + """ def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3): logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")