mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
probably break everyting
This commit is contained in:
parent
056cd77f11
commit
a23bd73f3b
@ -4,7 +4,7 @@ import asyncio
|
||||
import itertools
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Sequence, Tuple
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
||||
@ -43,7 +43,7 @@ class _ServerInferenceSession:
|
||||
**metadata,
|
||||
):
|
||||
self.config = config
|
||||
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
||||
self.span, self.span_uids = span, span_uids
|
||||
self.num_blocks = len(span_uids)
|
||||
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
||||
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
||||
@ -67,7 +67,6 @@ class _ServerInferenceSession:
|
||||
**metadata,
|
||||
) -> _ServerInferenceSession:
|
||||
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
||||
# TODO YOZH you don't need rpc info here
|
||||
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
||||
inputs_queue = asyncio.Queue()
|
||||
outputs_stream = await asyncio.wait_for(
|
||||
@ -89,7 +88,7 @@ class _ServerInferenceSession:
|
||||
inputs: torch.Tensor,
|
||||
prompts: Optional[torch.Tensor] = None,
|
||||
hypo_ids: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
*block_kwargs: Dict[str, Any],
|
||||
step_id: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -97,6 +96,7 @@ class _ServerInferenceSession:
|
||||
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
||||
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
||||
"""
|
||||
# TODO record previous kwargs in case of server failure!!!
|
||||
if self.closed:
|
||||
raise Exception("Session is closed, cannot perform step")
|
||||
|
||||
@ -115,6 +115,7 @@ class _ServerInferenceSession:
|
||||
else:
|
||||
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
||||
|
||||
assert len(block_kwargs) in (0, self.span.length)
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = DUMMY
|
||||
else:
|
||||
@ -131,7 +132,7 @@ class _ServerInferenceSession:
|
||||
assert hypo_ids.dtype == torch.int64
|
||||
|
||||
# serialize inputs and put them into the queue
|
||||
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
|
||||
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
|
||||
|
||||
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
||||
if not self.stepped:
|
||||
@ -141,7 +142,7 @@ class _ServerInferenceSession:
|
||||
if next_servers:
|
||||
request_metadata["next_servers"] = next_servers
|
||||
|
||||
request_metadata["args_structure"] = args_structure
|
||||
args_structure = request_metadata.setdefault("args_structure", args_structure)
|
||||
|
||||
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
|
||||
# TODO: make possible to use different compression method for different tensors
|
||||
@ -277,11 +278,22 @@ class InferenceSession:
|
||||
assert not self._closed and not self._server_sessions
|
||||
return self
|
||||
|
||||
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||
def step(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
prompts: Optional[torch.Tensor] = None,
|
||||
*block_kwargs: Sequence[Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert not self._closed
|
||||
if torch.is_grad_enabled():
|
||||
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
||||
|
||||
num_blocks = len(self._sequence_manager)
|
||||
if len(block_kwargs) == 1:
|
||||
block_kwargs = block_kwargs * num_blocks
|
||||
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
|
||||
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = DUMMY
|
||||
else:
|
||||
@ -312,7 +324,11 @@ class InferenceSession:
|
||||
|
||||
server_session = self._server_sessions[server_idx]
|
||||
inputs = server_session.step(
|
||||
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
|
||||
inputs,
|
||||
prompts[server_session.span.start : server_session.span.end],
|
||||
*block_kwargs[server_session.span.start : server_session.span.end],
|
||||
step_id=step_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
server_idx += 1
|
||||
|
@ -52,7 +52,7 @@ async def sequential_forward(
|
||||
if len(block_kwargs) == 1:
|
||||
block_kwargs = block_kwargs * (end_index - start_index)
|
||||
assert (
|
||||
len(block_kwargs) in (0, end_index - start_index)
|
||||
not block_kwargs or len(block_kwargs) == end_index - start_index
|
||||
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
||||
assert is_dummy(prompts) or len(prompts) == len(
|
||||
@ -222,7 +222,8 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
||||
def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor):
|
||||
# TODO add kwargs here; figure out a way to split kwargs across servers
|
||||
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
||||
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
||||
input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches)
|
||||
@ -271,4 +272,5 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
||||
grad_inputs = torch.cat(grad_input_batches, dim=0)
|
||||
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
|
||||
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
|
||||
return (grad_inputs, grad_prompts, None)
|
||||
# TODO return grads w.r.t. kwargs here
|
||||
return (None, grad_inputs, grad_prompts)
|
||||
|
Loading…
Reference in New Issue
Block a user