From a23bd73f3b0cfc0972b3b2f854f7c33314bd0b56 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:30:26 +0300 Subject: [PATCH] probably break everyting --- src/petals/client/inference_session.py | 32 ++++++++++++++++++------ src/petals/client/sequential_autograd.py | 8 +++--- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 31fa5e7..a3f1130 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -4,7 +4,7 @@ import asyncio import itertools import time import uuid -from typing import AsyncIterator, List, Optional, Sequence, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor @@ -43,7 +43,7 @@ class _ServerInferenceSession: **metadata, ): self.config = config - self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info + self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter @@ -67,7 +67,6 @@ class _ServerInferenceSession: **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" - # TODO YOZH you don't need rpc info here stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( @@ -89,7 +88,7 @@ class _ServerInferenceSession: inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, - *, + *block_kwargs: Dict[str, Any], step_id: str, ) -> torch.Tensor: """ @@ -97,6 +96,7 @@ class _ServerInferenceSession: :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ + # TODO record previous kwargs in case of server failure!!! if self.closed: raise Exception("Session is closed, cannot perform step") @@ -115,6 +115,7 @@ class _ServerInferenceSession: else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + assert len(block_kwargs) in (0, self.span.length) if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -131,7 +132,7 @@ class _ServerInferenceSession: assert hypo_ids.dtype == torch.int64 # serialize inputs and put them into the queue - input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) + input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: @@ -141,7 +142,7 @@ class _ServerInferenceSession: if next_servers: request_metadata["next_servers"] = next_servers - request_metadata["args_structure"] = args_structure + args_structure = request_metadata.setdefault("args_structure", args_structure) # TODO YOZH FIX THIS BEFORE THE END OF THIS PR # TODO: make possible to use different compression method for different tensors @@ -277,11 +278,22 @@ class InferenceSession: assert not self._closed and not self._server_sessions return self - def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def step( + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + *block_kwargs: Sequence[Dict[str, torch.Tensor]], + **kwargs, + ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + num_blocks = len(self._sequence_manager) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * num_blocks + assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" + if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -312,7 +324,11 @@ class InferenceSession: server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + inputs, + prompts[server_session.span.start : server_session.span.end], + *block_kwargs[server_session.span.start : server_session.span.end], + step_id=step_id, + **kwargs, ) server_idx += 1 diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index a86df31..1748490 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -52,7 +52,7 @@ async def sequential_forward( if len(block_kwargs) == 1: block_kwargs = block_kwargs * (end_index - start_index) assert ( - len(block_kwargs) in (0, end_index - start_index) + not block_kwargs or len(block_kwargs) == end_index - start_index ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( @@ -222,7 +222,8 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): + def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor): + # TODO add kwargs here; figure out a way to split kwargs across servers batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches) @@ -271,4 +272,5 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): grad_inputs = torch.cat(grad_input_batches, dim=0) dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches] grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None - return (grad_inputs, grad_prompts, None) + # TODO return grads w.r.t. kwargs here + return (None, grad_inputs, grad_prompts)