diff --git a/README.md b/README.md index 982a1b5..aa93a43 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@

-Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fineโ€‘tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fineโ€‘tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer @@ -37,7 +37,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ๐Ÿ” **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. -๐Ÿ’ฌ **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)! +๐Ÿ’ฌ **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! ### Connect your GPU and increase Petals capacity @@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 +python -m petals.cli.run_server stabilityai/StableBeluga2 ``` ๐ŸชŸ **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). @@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16 + python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 ``` These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from ๐Ÿค— [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. @@ -68,7 +68,7 @@ These commands will host a part of [Stable Beluga 2](https://huggingface.co/stab python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE ``` -๐Ÿ’ฌ **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +๐Ÿ’ฌ **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)! ๐Ÿ”’ **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). @@ -96,8 +96,8 @@ Learning more: ## How does it work? -- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** โ€” you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** โ€” you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. +- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs โ€” you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 9123c1a..b6f2d8a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -92,9 +92,6 @@ }, "outputs": [], "source": [ - "# Choose a model you'd like to prompt-tune. We recommend starting with\n", - "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n", - "# The code below uses LLaMA-65B.\n", "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index b19d468..7328cdc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -50,7 +50,7 @@ class SequenceManagerConfig: ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) - max_pinged: int = 5 # max servers to ping from each sequence side, per update + max_pinged: int = 3 # max servers to ping from each sequence side, per update ping_timeout: float = 2 # max time to wait for pings, per update @@ -293,6 +293,8 @@ class RemoteSequenceManager: return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + client_server_rtts = self.ping_aggregator.to_dict() + span_sequence = [] current_index = start_index while current_index < end_index: @@ -300,7 +302,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + # We choose longer servers to minimize the number of hops but leave some randomization + # to distribute the load. We also exclude servers known to be unreachable. + eps = 1e-6 + span_weights = np.array( + [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end @@ -361,9 +369,13 @@ class RemoteSequenceManager: self.state.sequence_info.update_(new_block_infos) first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + middle_servers = [ + span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans + ] last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged)) pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py new file mode 100644 index 0000000..9208deb --- /dev/null +++ b/src/petals/server/block_functions.py @@ -0,0 +1,195 @@ +""" +This module implements server-side computations on served blocks: forward, backward and inference; used by handler +""" +from __future__ import annotations + +from typing import AsyncIterator, Optional, Sequence, Tuple, Union + +import torch +from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor +from hivemind.moe.expert_uid import ExpertUID +from hivemind.proto import runtime_pb2 +from hivemind.utils.nested import nested_flatten + +from petals.data_structures import InferenceMetadata +from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle +from petals.server.task_pool import PrioritizedTaskPool +from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.misc import DUMMY, is_dummy + + +async def run_rpc_forward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> torch.Tensor: + """ + Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream + + :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors + :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :returns: hidden states after the last layer [batch_size, seq_length, hid_size] + """ + hidden_states, prompts = flat_tensors + dtype = requested_backends[0].dtype + # check parse input tensors and cast dtypes + hidden_states = hidden_states.to(dtype) + assert hidden_states.ndim == 3 + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a chain of requested backends + for backend, prompt in zip(requested_backends, prompts): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + hidden_states, points=points / len(requested_backends), backend=backend, type="forward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, + active_adapter, + priority=priority, + ) + assert isinstance(hidden_states, torch.Tensor) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + + return hidden_states + + +async def run_rpc_backward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + inputs, grad_outputs, prompts = flat_tensors + # Cast inputs & grad outputs to backend dtype + inputs = inputs.to(requested_backends[0].dtype) + grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [] + for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): + assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + if not is_dummy(prompt): + inputs[:, : prompt.shape[1]] += prompt + inter_inputs.append(inputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) + + assert isinstance(inputs, torch.Tensor) + + if not is_dummy(prompts[-1]): + inputs[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(inputs) + + assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" + grad_prompts_reversed = [] + # Run a chain of requested backends + for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) + + assert isinstance(grad_outputs, torch.Tensor) + if not is_dummy(prompt): + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY + return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + + +async def iterate_rpc_inference( + requested_uids: Sequence[ExpertUID], + requested_backends: Sequence[TransformerBackend], + active_adapter: Optional[str], + input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], + cache_handles: Sequence[Sequence[Handle]], + max_length: int, + prioritizer: TaskPrioritizerBase, + points: int, +) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: + assert len(cache_handles) == len(requested_backends) + + prefix_length = 0 + point_per_piece = points / max_length if max_length > 0 else 0.0 + + async for request, step_metadata in input_iterator: + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: + prompts = [None] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) + + priority = prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + + # serialize and send last layer outputs + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) + ] + can_push = not has_prompts + yield output_tensors, can_push + + # prepare for next step + prefix_length += length_increment diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eb5300e..effce82 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype] """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype - if config.torch_dtype not in ("auto", None): + if config.torch_dtype not in ("auto", None, torch.float32): + # If config specifies float32, we override it to the default dtype below return config.torch_dtype return torch.bfloat16 diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 637e2ee..4e668e4 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -6,7 +6,7 @@ import multiprocessing as mp import sys from enum import Enum from itertools import chain -from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple import torch from async_timeout import timeout @@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle -from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase -from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) @@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") if not requested_uids: @@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler): f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) - point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - prefix_length = 0 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: - assert len(cache_handles) == len(requested_backends) - first_request = request background_tasks = set() - async for request, metadata in self._iterate_inference_steps( - first_request, requests, session_id, requested_uids, context + async for output_tensors, can_push in iterate_rpc_inference( + requested_uids=requested_uids, + requested_backends=requested_backends, + active_adapter=self._get_active_adapter(metadata), + input_iterator=self._iterate_inference_steps( + request, requests, session_id, requested_uids, context + ), + cache_handles=cache_handles, + max_length=max_length, + prioritizer=self._prioritizer, + points=points, ): - hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) - - # Cast inputs to backend dtype - hidden_states = hidden_states.to(requested_backends[0].dtype) - assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" - - # parse deep prompts (optional argument) - has_prompts = prompts is not None and not is_dummy(prompts) - if not has_prompts: - prompts = [None] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] - - if not (len(requested_backends) == len(prompts)): - raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) - if prefix_length + length_increment > max_length: - raise ValueError( - f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" - f" exceeds pre-allocated maximum {max_length}" - ) - - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece, - requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) - ) - - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: - assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) - - # serialize and send last layer outputs - output_tensors = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - if not has_prompts: + if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) - # prepare for next step - prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) @@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler): result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) - - -async def _rpc_forward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> torch.Tensor: - """ - Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream - - :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) - :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass - :returns: hidden states after the last layer [batch_size, seq_length, hid_size] - """ - hidden_states, prompts = flat_tensors - dtype = requested_backends[0].dtype - # check parse input tensors and cast dtypes - hidden_states = hidden_states.to(dtype) - assert hidden_states.ndim == 3 - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - hidden_states, points=points / len(requested_backends), backend=backend, type="forward" - ) - (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, - active_adapter, - priority=priority, - ) - assert isinstance(hidden_states, torch.Tensor) - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - return hidden_states - - -async def _rpc_backward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, prompts = flat_tensors - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its output - inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" - ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - - assert isinstance(inputs, torch.Tensor) - - if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) - - assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" - grad_prompts_reversed = [] - # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" - ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) - - assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) - - grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7f0de41..53e0b5e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -690,7 +690,9 @@ class ModuleAnnouncerThread(threading.Thread): delay = self.update_period - (time.perf_counter() - start_time) if delay < 0: - logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it") + logger.warning( + f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})" + ) self.trigger.wait(max(delay, 0)) self.trigger.clear()