From d0b5af34cd9fdeaccdecd5692dabdb030fd38c56 Mon Sep 17 00:00:00 2001 From: Vadim Peretokin Date: Sun, 6 Aug 2023 14:47:21 +0200 Subject: [PATCH 1/8] Fix typo and make blocks message more informative (#437) The message really doesn't tell me much as a user, since I never touched update_period to begin with: ``` Aug 06 09:43:07.287 [WARN] [petals.server.server.run:701] Declaring blocs to DHT takes more than --update_period, consider increasing it ``` Made it better and more informative. --- src/petals/server/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5c47270..7772fa6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -698,7 +698,9 @@ class ModuleAnnouncerThread(threading.Thread): delay = self.update_period - (time.perf_counter() - start_time) if delay < 0: - logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it") + logger.warning( + f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})" + ) self.trigger.wait(max(delay, 0)) self.trigger.clear() From 679397df0c8fd374084eaa4b55b63e7279600053 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 6 Aug 2023 17:11:49 +0400 Subject: [PATCH 2/8] Update Discord links from channels to forums (#440) As our Discord community growths, we found it difficult to look for open and resolved issues in **#running-a-client** and **#running-a-server** channels, as well as navigate through interleaving conversations happening there. That's why we recreated these channels as Discord forums, where different discussions are separated into different posts. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 982a1b5..f042051 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ๐Ÿ” **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. -๐Ÿ’ฌ **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)! +๐Ÿ’ฌ **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! ### Connect your GPU and increase Petals capacity @@ -68,7 +68,7 @@ These commands will host a part of [Stable Beluga 2](https://huggingface.co/stab python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE ``` -๐Ÿ’ฌ **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +๐Ÿ’ฌ **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)! ๐Ÿ”’ **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). From b58141ef667534a4db4d3aa6905164484361d438 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 6 Aug 2023 18:55:22 +0400 Subject: [PATCH 3/8] Remove distracting links from readme (#441) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f042051..ea7919a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@

-Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fineโ€‘tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fineโ€‘tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer @@ -96,8 +96,8 @@ Learning more: ## How does it work? -- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** โ€” you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** โ€” you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. +- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs โ€” you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

From 32fbab5192da5ee26ee2a0a7eade6d101690a70d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 02:22:21 +0400 Subject: [PATCH 4/8] Remove deprecated comment in fine-tuning notebook (#443) --- examples/prompt-tuning-sst2.ipynb | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 9123c1a..b6f2d8a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -92,9 +92,6 @@ }, "outputs": [], "source": [ - "# Choose a model you'd like to prompt-tune. We recommend starting with\n", - "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n", - "# The code below uses LLaMA-65B.\n", "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", From 593d980ad8676fa0b43d8cc266b0dbb2052c31ec Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 02:33:42 +0400 Subject: [PATCH 5/8] Use bitsandbytes 0.41.1 (#442) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7c04686..1560f2b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.8 install_requires = torch>=1.12 - bitsandbytes==0.40.1.post1 + bitsandbytes==0.41.1 accelerate>=0.20.3,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 From ac9b5467067735a885f7a4dfad689a2a2dc7f594 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 7 Aug 2023 14:32:51 +0300 Subject: [PATCH 6/8] [Refactor] extract block forward, backward and inference into a separate file (#435) This PR does not change any functionality. It merely moves stuff around. List of changes: handler.py/_rpc_forward became block_methods/rpc_forward handler.py/_rpc_backward became block_methods/rpc_backward the math bits of rpc_inference were extracted into block_methods/iterate_rpc_inference --------- Co-authored-by: Your Name Co-authored-by: artek0chumak Co-authored-by: Aleksandr Borzunov --- src/petals/server/block_functions.py | 195 +++++++++++++++++++++++++++ src/petals/server/handler.py | 192 +++----------------------- 2 files changed, 214 insertions(+), 173 deletions(-) create mode 100644 src/petals/server/block_functions.py diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py new file mode 100644 index 0000000..9208deb --- /dev/null +++ b/src/petals/server/block_functions.py @@ -0,0 +1,195 @@ +""" +This module implements server-side computations on served blocks: forward, backward and inference; used by handler +""" +from __future__ import annotations + +from typing import AsyncIterator, Optional, Sequence, Tuple, Union + +import torch +from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor +from hivemind.moe.expert_uid import ExpertUID +from hivemind.proto import runtime_pb2 +from hivemind.utils.nested import nested_flatten + +from petals.data_structures import InferenceMetadata +from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle +from petals.server.task_pool import PrioritizedTaskPool +from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.misc import DUMMY, is_dummy + + +async def run_rpc_forward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> torch.Tensor: + """ + Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream + + :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors + :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :returns: hidden states after the last layer [batch_size, seq_length, hid_size] + """ + hidden_states, prompts = flat_tensors + dtype = requested_backends[0].dtype + # check parse input tensors and cast dtypes + hidden_states = hidden_states.to(dtype) + assert hidden_states.ndim == 3 + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a chain of requested backends + for backend, prompt in zip(requested_backends, prompts): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + hidden_states, points=points / len(requested_backends), backend=backend, type="forward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, + active_adapter, + priority=priority, + ) + assert isinstance(hidden_states, torch.Tensor) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + + return hidden_states + + +async def run_rpc_backward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + inputs, grad_outputs, prompts = flat_tensors + # Cast inputs & grad outputs to backend dtype + inputs = inputs.to(requested_backends[0].dtype) + grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [] + for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): + assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + if not is_dummy(prompt): + inputs[:, : prompt.shape[1]] += prompt + inter_inputs.append(inputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) + + assert isinstance(inputs, torch.Tensor) + + if not is_dummy(prompts[-1]): + inputs[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(inputs) + + assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" + grad_prompts_reversed = [] + # Run a chain of requested backends + for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) + + assert isinstance(grad_outputs, torch.Tensor) + if not is_dummy(prompt): + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY + return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + + +async def iterate_rpc_inference( + requested_uids: Sequence[ExpertUID], + requested_backends: Sequence[TransformerBackend], + active_adapter: Optional[str], + input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], + cache_handles: Sequence[Sequence[Handle]], + max_length: int, + prioritizer: TaskPrioritizerBase, + points: int, +) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: + assert len(cache_handles) == len(requested_backends) + + prefix_length = 0 + point_per_piece = points / max_length if max_length > 0 else 0.0 + + async for request, step_metadata in input_iterator: + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: + prompts = [None] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) + + priority = prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + + # serialize and send last layer outputs + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) + ] + can_push = not has_prompts + yield output_tensors, can_push + + # prepare for next step + prefix_length += length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d3776de..b9be294 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -6,7 +6,7 @@ import multiprocessing as mp import sys from enum import Enum from itertools import chain -from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple import torch from async_timeout import timeout @@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle -from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase -from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) @@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") if not requested_uids: @@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler): f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) - point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - prefix_length = 0 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: - assert len(cache_handles) == len(requested_backends) - first_request = request background_tasks = set() - async for request, metadata in self._iterate_inference_steps( - first_request, requests, session_id, requested_uids, context + async for output_tensors, can_push in iterate_rpc_inference( + requested_uids=requested_uids, + requested_backends=requested_backends, + active_adapter=self._get_active_adapter(metadata), + input_iterator=self._iterate_inference_steps( + request, requests, session_id, requested_uids, context + ), + cache_handles=cache_handles, + max_length=max_length, + prioritizer=self._prioritizer, + points=points, ): - hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) - - # Cast inputs to backend dtype - hidden_states = hidden_states.to(requested_backends[0].dtype) - assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" - - # parse deep prompts (optional argument) - has_prompts = prompts is not None and not is_dummy(prompts) - if not has_prompts: - prompts = [None] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] - - if not (len(requested_backends) == len(prompts)): - raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) - if prefix_length + length_increment > max_length: - raise ValueError( - f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" - f" exceeds pre-allocated maximum {max_length}" - ) - - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece, - requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) - ) - - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: - assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) - - # serialize and send last layer outputs - output_tensors = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - if not has_prompts: + if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) - # prepare for next step - prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) @@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler): result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) - - -async def _rpc_forward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> torch.Tensor: - """ - Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream - - :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) - :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass - :returns: hidden states after the last layer [batch_size, seq_length, hid_size] - """ - hidden_states, prompts = flat_tensors - dtype = requested_backends[0].dtype - # check parse input tensors and cast dtypes - hidden_states = hidden_states.to(dtype) - assert hidden_states.ndim == 3 - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - hidden_states, points=points / len(requested_backends), backend=backend, type="forward" - ) - (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, - active_adapter, - priority=priority, - ) - assert isinstance(hidden_states, torch.Tensor) - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - return hidden_states - - -async def _rpc_backward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, prompts = flat_tensors - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its output - inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" - ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - - assert isinstance(inputs, torch.Tensor) - - if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) - - assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" - grad_prompts_reversed = [] - # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" - ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) - - assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) - - grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape From 00d48dcbe12ec26a8fd8aaab37030f994ed84555 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 19:47:22 +0400 Subject: [PATCH 7/8] Override float32 in config to bfloat16 (#431) --- README.md | 4 ++-- src/petals/server/block_utils.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ea7919a..aa93a43 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 +python -m petals.cli.run_server stabilityai/StableBeluga2 ``` ๐ŸชŸ **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). @@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16 + python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 ``` These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from ๐Ÿค— [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eb5300e..effce82 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype] """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype - if config.torch_dtype not in ("auto", None): + if config.torch_dtype not in ("auto", None, torch.float32): + # If config specifies float32, we override it to the default dtype below return config.torch_dtype return torch.bfloat16 From 2a150770a47d2ee8fd9f9bf129a54d1afdf29e64 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 21:43:21 +0400 Subject: [PATCH 8/8] Prefer longer servers for fine-tuning, exclude unreachable (#448) We choose longer servers to minimize the number of hops but leave some randomization to distribute the load. We also exclude servers known to be unreachable. --- src/petals/client/routing/sequence_manager.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index b19d468..7328cdc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -50,7 +50,7 @@ class SequenceManagerConfig: ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) - max_pinged: int = 5 # max servers to ping from each sequence side, per update + max_pinged: int = 3 # max servers to ping from each sequence side, per update ping_timeout: float = 2 # max time to wait for pings, per update @@ -293,6 +293,8 @@ class RemoteSequenceManager: return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + client_server_rtts = self.ping_aggregator.to_dict() + span_sequence = [] current_index = start_index while current_index < end_index: @@ -300,7 +302,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + # We choose longer servers to minimize the number of hops but leave some randomization + # to distribute the load. We also exclude servers known to be unreachable. + eps = 1e-6 + span_weights = np.array( + [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end @@ -361,9 +369,13 @@ class RemoteSequenceManager: self.state.sequence_info.update_(new_block_infos) first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + middle_servers = [ + span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans + ] last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged)) pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)