check for none

client
justheuristic 2 years ago
parent 471e47c0f5
commit 7903bd8f9f

@ -29,6 +29,11 @@ class RemoteTransformerBlock(RemoteExpert):
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def forward(self, inputs: torch.Tensor, **kwargs):
for k, v in kwargs.items():
assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker

@ -16,7 +16,9 @@ logger = get_logger(__file__)
class RemoteSequential(nn.Sequential):
"""A sequence of transformer blocks hosted by the swarm"""
"""
A sequence of transformer blocks hosted by the swarm.
"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")

Loading…
Cancel
Save