mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
unbreak everything
This commit is contained in:
parent
3bffcde0fe
commit
721f7d2db3
@ -43,7 +43,7 @@ class _ServerInferenceSession:
|
||||
**metadata,
|
||||
):
|
||||
self.config = config
|
||||
self.span, self.span_uids = span, span_uids
|
||||
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
||||
self.num_blocks = len(span_uids)
|
||||
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
||||
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
||||
@ -283,7 +283,6 @@ class InferenceSession:
|
||||
inputs: torch.Tensor,
|
||||
prompts: Optional[torch.Tensor] = None,
|
||||
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert not self._closed
|
||||
if torch.is_grad_enabled():
|
||||
@ -328,7 +327,6 @@ class InferenceSession:
|
||||
prompts[server_session.span.start : server_session.span.end],
|
||||
*block_kwargs[server_session.span.start : server_session.span.end],
|
||||
step_id=step_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
server_idx += 1
|
||||
|
@ -53,7 +53,7 @@ class RemoteSequential(nn.Module):
|
||||
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
|
||||
if self.active_session is None:
|
||||
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
|
||||
return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs)
|
||||
return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args)
|
||||
else:
|
||||
return self.active_session.step(inputs, prompts, *args, **kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user