|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|
|
|
|
import itertools
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import AsyncIterator, List, Optional, Sequence, Tuple
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
|
|
|
@ -43,7 +43,7 @@ class _ServerInferenceSession:
|
|
|
|
|
**metadata,
|
|
|
|
|
):
|
|
|
|
|
self.config = config
|
|
|
|
|
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
|
|
|
|
self.span, self.span_uids = span, span_uids
|
|
|
|
|
self.num_blocks = len(span_uids)
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
@ -67,7 +67,6 @@ class _ServerInferenceSession:
|
|
|
|
|
**metadata,
|
|
|
|
|
) -> _ServerInferenceSession:
|
|
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
# TODO YOZH you don't need rpc info here
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
|
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
@ -89,7 +88,7 @@ class _ServerInferenceSession:
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
*,
|
|
|
|
|
*block_kwargs: Dict[str, Any],
|
|
|
|
|
step_id: str,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
@ -97,6 +96,7 @@ class _ServerInferenceSession:
|
|
|
|
|
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
|
|
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
|
|
|
|
"""
|
|
|
|
|
# TODO record previous kwargs in case of server failure!!!
|
|
|
|
|
if self.closed:
|
|
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
|
|
|
|
@ -115,6 +115,7 @@ class _ServerInferenceSession:
|
|
|
|
|
else:
|
|
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
|
|
|
|
assert len(block_kwargs) in (0, self.span.length)
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
@ -131,7 +132,7 @@ class _ServerInferenceSession:
|
|
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
|
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
|
|
|
|
|
|
|
|
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
|
|
|
if not self.stepped:
|
|
|
|
@ -141,7 +142,7 @@ class _ServerInferenceSession:
|
|
|
|
|
if next_servers:
|
|
|
|
|
request_metadata["next_servers"] = next_servers
|
|
|
|
|
|
|
|
|
|
request_metadata["args_structure"] = args_structure
|
|
|
|
|
args_structure = request_metadata.setdefault("args_structure", args_structure)
|
|
|
|
|
|
|
|
|
|
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
|
|
|
|
|
# TODO: make possible to use different compression method for different tensors
|
|
|
|
@ -277,11 +278,22 @@ class InferenceSession:
|
|
|
|
|
assert not self._closed and not self._server_sessions
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
|
|
def step(
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
|
|
|
|
|
|
num_blocks = len(self._sequence_manager)
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * num_blocks
|
|
|
|
|
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
|
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
@ -312,7 +324,11 @@ class InferenceSession:
|
|
|
|
|
|
|
|
|
|
server_session = self._server_sessions[server_idx]
|
|
|
|
|
inputs = server_session.step(
|
|
|
|
|
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
|
|
|
|
|
inputs,
|
|
|
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
|
|
|
*block_kwargs[server_session.span.start : server_session.span.end],
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
|