|
|
|
@ -7,15 +7,13 @@ import uuid
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
|
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
|
from hivemind.p2p import P2P
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
|
|
from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
|
|
|
|
|
|
|
|
|
|
from petals.client.config import ClientConfig
|
|
|
|
|
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
|
|
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
|
|
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
|
|
|
|
|
from petals.utils.packaging import pack_args_kwargs
|
|
|
|
@ -32,23 +30,21 @@ class _ServerInferenceSession:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
config: ClientConfig,
|
|
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
|
|
span: RemoteSpanInfo,
|
|
|
|
|
span_uids: Sequence[ModuleUID],
|
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
|
|
outputs_aiter: AsyncIterator,
|
|
|
|
|
*,
|
|
|
|
|
outputs_stream: AsyncIterator,
|
|
|
|
|
*block_kwargs,
|
|
|
|
|
max_length: int,
|
|
|
|
|
**metadata,
|
|
|
|
|
):
|
|
|
|
|
self.config = config
|
|
|
|
|
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
|
|
|
|
self.sequence_manager = sequence_manager
|
|
|
|
|
self.span, self.span_uids = span, span_uids
|
|
|
|
|
self.num_blocks = len(span_uids)
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
|
|
|
|
|
self.session_id = str(uuid.uuid4())
|
|
|
|
|
self.session_metadata = dict(max_length=max_length, **metadata)
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.stepped = False
|
|
|
|
|
self.closed = False
|
|
|
|
|
|
|
|
|
@ -56,24 +52,26 @@ class _ServerInferenceSession:
|
|
|
|
|
self.history = None # Used in case of server failures to regenerate attention caches on new servers
|
|
|
|
|
self.next_session = None
|
|
|
|
|
|
|
|
|
|
self.block_kwargs = block_kwargs
|
|
|
|
|
assert len(self.block_kwargs) in (0, self.num_blocks)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def create(
|
|
|
|
|
cls,
|
|
|
|
|
config: ClientConfig,
|
|
|
|
|
p2p: P2P,
|
|
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
|
|
span: RemoteSpanInfo,
|
|
|
|
|
span_uids: Sequence[RemoteSpanInfo],
|
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
|
**metadata,
|
|
|
|
|
span_uids: Sequence[ModuleUID],
|
|
|
|
|
*block_kwargs: Dict[str, Any],
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> _ServerInferenceSession:
|
|
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
|
|
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
|
|
config.connect_timeout,
|
|
|
|
|
sequence_manager.config.connect_timeout,
|
|
|
|
|
)
|
|
|
|
|
return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
|
|
return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
|
|
@ -87,7 +85,7 @@ class _ServerInferenceSession:
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Dict[str, Any],
|
|
|
|
|
*,
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
step_id: str,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
@ -96,7 +94,6 @@ class _ServerInferenceSession:
|
|
|
|
|
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
|
|
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
|
|
|
|
"""
|
|
|
|
|
# TODO record previous kwargs in case of server failure!!!
|
|
|
|
|
if self.closed:
|
|
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
|
|
|
|
@ -112,10 +109,11 @@ class _ServerInferenceSession:
|
|
|
|
|
|
|
|
|
|
if not self.stepped:
|
|
|
|
|
inputs = self.history # Pass full inputs including prefix
|
|
|
|
|
block_kwargs = self.block_kwargs
|
|
|
|
|
else:
|
|
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
block_kwargs = []
|
|
|
|
|
|
|
|
|
|
assert len(block_kwargs) in (0, self.span.length)
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
@ -131,39 +129,50 @@ class _ServerInferenceSession:
|
|
|
|
|
assert len(hypo_ids) == len(inputs)
|
|
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
|
|
|
|
|
|
|
|
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
|
|
|
if not self.stepped:
|
|
|
|
|
request_metadata.update(self.session_metadata)
|
|
|
|
|
elif self.config.use_server_to_server:
|
|
|
|
|
metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
|
|
|
|
|
metadata.update(
|
|
|
|
|
self.sequence_manager.get_request_metadata(
|
|
|
|
|
self.span.peer_id,
|
|
|
|
|
"rpc_inference",
|
|
|
|
|
self.span_uids,
|
|
|
|
|
inputs,
|
|
|
|
|
prompts,
|
|
|
|
|
*block_kwargs,
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
session_id=self.session_id,
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if self.stepped and self.sequence_manager.config.use_server_to_server:
|
|
|
|
|
next_servers = self._collect_next_servers()
|
|
|
|
|
if next_servers:
|
|
|
|
|
request_metadata["next_servers"] = next_servers
|
|
|
|
|
metadata["next_servers"] = next_servers
|
|
|
|
|
|
|
|
|
|
args_structure = request_metadata.setdefault("args_structure", args_structure)
|
|
|
|
|
codecs = self.sequence_manager.get_compression_codecs(
|
|
|
|
|
self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
|
|
|
|
|
# TODO: make possible to use different compression method for different tensors
|
|
|
|
|
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
|
|
|
|
|
compression = server_side_inference_schema[0].compression
|
|
|
|
|
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
|
|
|
|
args_structure = metadata.setdefault("args_structure", args_structure)
|
|
|
|
|
|
|
|
|
|
# TODO: create more explicit way to check servers schema and client's structure
|
|
|
|
|
assert len(input_tensors) >= len(
|
|
|
|
|
server_side_inference_schema
|
|
|
|
|
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
|
|
|
|
|
if codecs is None:
|
|
|
|
|
codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
|
|
|
|
|
else:
|
|
|
|
|
codecs = list(nested_flatten(codecs))
|
|
|
|
|
assert len(codecs) == len(
|
|
|
|
|
input_tensors
|
|
|
|
|
), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
|
|
|
|
|
|
|
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
self._step(
|
|
|
|
|
runtime_pb2.ExpertRequest(
|
|
|
|
|
uid=CHAIN_DELIMITER.join(self.span_uids),
|
|
|
|
|
tensors=[
|
|
|
|
|
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
|
|
|
|
for tensor, proto in zip(input_tensors, inference_schema)
|
|
|
|
|
serialize_torch_tensor(tensor, compression)
|
|
|
|
|
for tensor, compression in zip(input_tensors, codecs)
|
|
|
|
|
],
|
|
|
|
|
metadata=MSGPackSerializer.dumps(request_metadata),
|
|
|
|
|
metadata=MSGPackSerializer.dumps(metadata),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
@ -190,7 +199,7 @@ class _ServerInferenceSession:
|
|
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
|
self.stepped = True
|
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
|
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
@ -227,7 +236,7 @@ class InferenceSession:
|
|
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
|
|
|
|
|
self._sequence_manager = sequence_manager
|
|
|
|
|
self._closed = False
|
|
|
|
|
self._server_sessions = []
|
|
|
|
@ -235,6 +244,12 @@ class InferenceSession:
|
|
|
|
|
self._max_length = max_length
|
|
|
|
|
self.output_ids = None
|
|
|
|
|
|
|
|
|
|
num_blocks = len(self._sequence_manager)
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * num_blocks
|
|
|
|
|
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
|
|
|
|
self.block_kwargs = block_kwargs
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def num_blocks(self) -> int:
|
|
|
|
|
return len(self._sequence_manager)
|
|
|
|
@ -247,17 +262,13 @@ class InferenceSession:
|
|
|
|
|
server_sessions = []
|
|
|
|
|
try:
|
|
|
|
|
for span in chosen_spans:
|
|
|
|
|
span_uids = self._sequence_manager.block_uids[span.start : span.end]
|
|
|
|
|
metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
|
|
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
_ServerInferenceSession.create(
|
|
|
|
|
self._sequence_manager.config,
|
|
|
|
|
self._sequence_manager.state.p2p,
|
|
|
|
|
self._sequence_manager,
|
|
|
|
|
span,
|
|
|
|
|
span_uids,
|
|
|
|
|
rpc_info=self._sequence_manager.rpc_info, # TODO not actually needed
|
|
|
|
|
self._sequence_manager.block_uids[span.start : span.end],
|
|
|
|
|
*self.block_kwargs[span.start : span.end],
|
|
|
|
|
max_length=self._max_length,
|
|
|
|
|
**metadata,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
server_sessions.append(session)
|
|
|
|
@ -282,18 +293,13 @@ class InferenceSession:
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
|
|
|
|
|
|
num_blocks = len(self._sequence_manager)
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * num_blocks
|
|
|
|
|
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
|
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
@ -326,9 +332,8 @@ class InferenceSession:
|
|
|
|
|
inputs = server_session.step(
|
|
|
|
|
inputs,
|
|
|
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
|
|
|
*block_kwargs[server_session.span.start : server_session.span.end],
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
hypo_ids=hypo_ids,
|
|
|
|
|
step_id=step_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
@ -354,7 +359,7 @@ class InferenceSession:
|
|
|
|
|
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
|
|
|
|
|
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
|
|
|
|
|
# If there is a failed server session, this code closes it
|
|
|
|
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|
|
|
|
|
|
|