probably break everyting

This commit is contained in:
Your Name 2023-09-06 03:30:26 +03:00
parent 056cd77f11
commit a23bd73f3b
2 changed files with 29 additions and 11 deletions

View File

@ -4,7 +4,7 @@ import asyncio
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
@ -43,7 +43,7 @@ class _ServerInferenceSession:
**metadata,
):
self.config = config
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
@ -67,7 +67,6 @@ class _ServerInferenceSession:
**metadata,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
# TODO YOZH you don't need rpc info here
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
@ -89,7 +88,7 @@ class _ServerInferenceSession:
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
*block_kwargs: Dict[str, Any],
step_id: str,
) -> torch.Tensor:
"""
@ -97,6 +96,7 @@ class _ServerInferenceSession:
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
# TODO record previous kwargs in case of server failure!!!
if self.closed:
raise Exception("Session is closed, cannot perform step")
@ -115,6 +115,7 @@ class _ServerInferenceSession:
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
assert len(block_kwargs) in (0, self.span.length)
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
@ -131,7 +132,7 @@ class _ServerInferenceSession:
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
@ -141,7 +142,7 @@ class _ServerInferenceSession:
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
args_structure = request_metadata.setdefault("args_structure", args_structure)
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
# TODO: make possible to use different compression method for different tensors
@ -277,11 +278,22 @@ class InferenceSession:
assert not self._closed and not self._server_sessions
return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
**kwargs,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
@ -312,7 +324,11 @@ class InferenceSession:
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
inputs,
prompts[server_session.span.start : server_session.span.end],
*block_kwargs[server_session.span.start : server_session.span.end],
step_id=step_id,
**kwargs,
)
server_idx += 1

View File

@ -52,7 +52,7 @@ async def sequential_forward(
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * (end_index - start_index)
assert (
len(block_kwargs) in (0, end_index - start_index)
not block_kwargs or len(block_kwargs) == end_index - start_index
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
@ -222,7 +222,8 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor):
# TODO add kwargs here; figure out a way to split kwargs across servers
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches)
@ -271,4 +272,5 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
grad_inputs = torch.cat(grad_input_batches, dim=0)
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
return (grad_inputs, grad_prompts, None)
# TODO return grads w.r.t. kwargs here
return (None, grad_inputs, grad_prompts)