|
|
|
@ -84,12 +84,7 @@ class _ServerInferenceSession:
|
|
|
|
|
break # this message means "done sending"
|
|
|
|
|
|
|
|
|
|
def step(
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
*,
|
|
|
|
|
step_id: str,
|
|
|
|
|
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
|
|
@ -114,21 +109,6 @@ class _ServerInferenceSession:
|
|
|
|
|
else:
|
|
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
|
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
|
|
|
|
|
assert prompts.shape[0] == self.num_blocks
|
|
|
|
|
assert prompts.shape[1] in (inputs.shape[0], 1)
|
|
|
|
|
assert prompts.shape[2] <= inputs.shape[1]
|
|
|
|
|
assert prompts.shape[3] == inputs.shape[2]
|
|
|
|
|
|
|
|
|
|
if hypo_ids is None or is_dummy(hypo_ids):
|
|
|
|
|
hypo_ids = DUMMY_INT64
|
|
|
|
|
else:
|
|
|
|
|
assert len(hypo_ids) == len(inputs)
|
|
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
|
|
|
|
|
|
|
|
|
@ -275,7 +255,9 @@ class InferenceSession:
|
|
|
|
|
assert not self._closed and not self._server_sessions
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
|
|
def step(
|
|
|
|
|
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
@ -285,11 +267,21 @@ class InferenceSession:
|
|
|
|
|
else:
|
|
|
|
|
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
|
|
|
|
|
assert prompts.shape[0] == self.num_blocks
|
|
|
|
|
assert prompts.shape[1] in (inputs.shape[0], 1)
|
|
|
|
|
assert prompts.shape[2] <= inputs.shape[1]
|
|
|
|
|
assert prompts.shape[3] == inputs.shape[2]
|
|
|
|
|
|
|
|
|
|
if hypo_ids is None or is_dummy(hypo_ids):
|
|
|
|
|
hypo_ids = DUMMY_INT64
|
|
|
|
|
else:
|
|
|
|
|
assert len(hypo_ids) == len(inputs)
|
|
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
|
|
|
|
inputs_device = inputs.device
|
|
|
|
|
inputs_dtype = inputs.dtype
|
|
|
|
|
inputs = inputs.cpu()
|
|
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
hypo_ids = hypo_ids.cpu()
|
|
|
|
|
step_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
@ -310,7 +302,7 @@ class InferenceSession:
|
|
|
|
|
|
|
|
|
|
server_session = self._server_sessions[server_idx]
|
|
|
|
|
inputs = server_session.step(
|
|
|
|
|
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
|
|
|
|
|
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
|