diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b1e3699..c625b08 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -43,7 +43,7 @@ class _ServerInferenceSession: **metadata, ): self.config = config - self.span, self.span_uids = span, span_uids + self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter @@ -283,7 +283,6 @@ class InferenceSession: inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *block_kwargs: Sequence[Dict[str, torch.Tensor]], - **kwargs, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): @@ -328,7 +327,6 @@ class InferenceSession: prompts[server_session.span.start : server_session.span.end], *block_kwargs[server_session.span.start : server_session.span.end], step_id=step_id, - **kwargs, ) server_idx += 1 diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index a6359bc..4d43f31 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -53,7 +53,7 @@ class RemoteSequential(nn.Module): assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" if self.active_session is None: assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" - return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs) + return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args) else: return self.active_session.step(inputs, prompts, *args, **kwargs)