|
|
|
@ -87,8 +87,8 @@ class _ServerInferenceSession:
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Dict[str, Any],
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
step_id: str,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
@ -283,6 +283,7 @@ class InferenceSession:
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
@ -327,6 +328,7 @@ class InferenceSession:
|
|
|
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
|
|
|
*block_kwargs[server_session.span.start : server_session.span.end],
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
hypo_ids=hypo_ids,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
|