|
|
@ -29,6 +29,11 @@ class RemoteTransformerBlock(RemoteExpert):
|
|
|
|
def stub(self) -> StubBase:
|
|
|
|
def stub(self) -> StubBase:
|
|
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, **kwargs):
|
|
|
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
|
|
|
|
assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})"
|
|
|
|
|
|
|
|
return super().forward(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
|
|
|
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
|
|
|
"""Initialize a new inference session with the specified remote server"""
|
|
|
|
"""Initialize a new inference session with the specified remote server"""
|
|
|
|
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
|
|
|
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
|
|
|