|
|
|
@ -43,7 +43,7 @@ class _ServerInferenceSession:
|
|
|
|
|
**metadata,
|
|
|
|
|
):
|
|
|
|
|
self.config = config
|
|
|
|
|
self.span, self.span_uids = span, span_uids
|
|
|
|
|
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
|
|
|
|
self.num_blocks = len(span_uids)
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
@ -283,7 +283,6 @@ class InferenceSession:
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
@ -328,7 +327,6 @@ class InferenceSession:
|
|
|
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
|
|
|
*block_kwargs[server_session.span.start : server_session.span.end],
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
|