temporary rollback: allow kwargs only at first inference step

pull/467/head
Your Name 9 months ago
parent 3048c3b3ad
commit 3f06b53b1d

@ -7,15 +7,13 @@ import uuid
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
@ -32,23 +30,21 @@ class _ServerInferenceSession:
def __init__(
self,
config: ClientConfig,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
span_uids: Sequence[ModuleUID],
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
outputs_stream: AsyncIterator,
*block_kwargs,
max_length: int,
**metadata,
):
self.config = config
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
self.sequence_manager = sequence_manager
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
self.session_id = str(uuid.uuid4())
self.session_metadata = dict(max_length=max_length, **metadata)
self.max_length = max_length
self.stepped = False
self.closed = False
@ -56,24 +52,26 @@ class _ServerInferenceSession:
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None
self.block_kwargs = block_kwargs
assert len(self.block_kwargs) in (0, self.num_blocks)
@classmethod
async def create(
cls,
config: ClientConfig,
p2p: P2P,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
span_uids: Sequence[RemoteSpanInfo],
rpc_info: RPCInfo,
**metadata,
span_uids: Sequence[ModuleUID],
*block_kwargs: Dict[str, Any],
**kwargs,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
sequence_manager.config.connect_timeout,
)
return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@ -87,7 +85,7 @@ class _ServerInferenceSession:
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
*block_kwargs: Dict[str, Any],
*,
hypo_ids: Optional[torch.Tensor] = None,
step_id: str,
) -> torch.Tensor:
@ -96,7 +94,6 @@ class _ServerInferenceSession:
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
# TODO record previous kwargs in case of server failure!!!
if self.closed:
raise Exception("Session is closed, cannot perform step")
@ -112,10 +109,11 @@ class _ServerInferenceSession:
if not self.stepped:
inputs = self.history # Pass full inputs including prefix
block_kwargs = self.block_kwargs
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
block_kwargs = []
assert len(block_kwargs) in (0, self.span.length)
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
@ -131,39 +129,50 @@ class _ServerInferenceSession:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
elif self.config.use_server_to_server:
metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
metadata.update(
self.sequence_manager.get_request_metadata(
self.span.peer_id,
"rpc_inference",
self.span_uids,
inputs,
prompts,
*block_kwargs,
max_length=self.max_length,
session_id=self.session_id,
step_id=step_id,
)
)
if self.stepped and self.sequence_manager.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
request_metadata["next_servers"] = next_servers
metadata["next_servers"] = next_servers
args_structure = request_metadata.setdefault("args_structure", args_structure)
codecs = self.sequence_manager.get_compression_codecs(
self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
)
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
args_structure = metadata.setdefault("args_structure", args_structure)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(
input_tensors
), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=CHAIN_DELIMITER.join(self.span_uids),
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
serialize_torch_tensor(tensor, compression)
for tensor, compression in zip(input_tensors, codecs)
],
metadata=MSGPackSerializer.dumps(request_metadata),
metadata=MSGPackSerializer.dumps(metadata),
)
)
)
@ -190,7 +199,7 @@ class _ServerInferenceSession:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
@ -227,7 +236,7 @@ class InferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
self._sequence_manager = sequence_manager
self._closed = False
self._server_sessions = []
@ -235,6 +244,12 @@ class InferenceSession:
self._max_length = max_length
self.output_ids = None
num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
self.block_kwargs = block_kwargs
@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
@ -247,17 +262,13 @@ class InferenceSession:
server_sessions = []
try:
for span in chosen_spans:
span_uids = self._sequence_manager.block_uids[span.start : span.end]
metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager.config,
self._sequence_manager.state.p2p,
self._sequence_manager,
span,
span_uids,
rpc_info=self._sequence_manager.rpc_info, # TODO not actually needed
self._sequence_manager.block_uids[span.start : span.end],
*self.block_kwargs[span.start : span.end],
max_length=self._max_length,
**metadata,
)
)
server_sessions.append(session)
@ -282,18 +293,13 @@ class InferenceSession:
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
@ -326,9 +332,8 @@ class InferenceSession:
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
*block_kwargs[server_session.span.start : server_session.span.end],
step_id=step_id,
hypo_ids=hypo_ids,
step_id=step_id,
)
server_idx += 1
@ -354,7 +359,7 @@ class InferenceSession:
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])

Loading…
Cancel
Save