diff --git a/README.md b/README.md index 913d223..b8cd54b 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,13 @@ ## Key features -- Run inference or fine-tune [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs. -- One inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps. -- Employ any fine-tuning and sampling methods by accessing model's hidden states and changing its control flow — something you can't do in proprietary APIs. +- Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs. +- It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning. +- This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps. +- Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.

- [Read paper] | [View website] + [Read paper] | [View website]

## How it works? @@ -25,36 +26,62 @@

-### 🚧 This project is in active development +### 🛠️ Examples -Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)). +Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library. -A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm. +This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning: -## Code examples +```python +# Initialize distributed BLOOM and connect to the swarm +model = DistributedBloomForCausalLM.from_pretrained( + "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW +) # Embeddings & prompts are on your device, BLOOM blocks are distributed -Solving a sequence classification task via soft prompt tuning of BLOOM-176B: +print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5)) -```python -# Initialize distributed BLOOM with soft prompts -model = AutoModelForPromptTuning.from_pretrained( - "bigscience/distributed-bloom") -# Define optimizer for prompts and linear head +# Training (updates only local prompts / adapters) optimizer = torch.optim.AdamW(model.parameters()) - for input_ids, labels in data_loader: - # Forward pass with local and remote layers outputs = model.forward(input_ids) loss = cross_entropy(outputs.logits, labels) - - # Distributed backward w.r.t. local params - loss.backward() # Compute model.prompts.grad - optimizer.step() # Update local params only optimizer.zero_grad() + loss.backward() + optimizer.step() ``` +### 🚧 This project is in active development + +Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)). + +A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm. + +### 🔒 Privacy and security + +If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data. + +This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf). + +## FAQ + +1. **What's the motivation for people to host model layers in the public swarm?** + + People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded). + + Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards. + +2. **Why is the platform named "Petals"?** + + "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom). + + While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future. + ## Installation +🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two). + +-------------------------------------------------------------------------------- + ```bash conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 24852df..bb1455f 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2 from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from src.server.handler import TransformerConnectionHandler +from src.utils.misc import DUMMY, is_dummy use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession: max_length: int, ): self.uid, self.rpc_info = uid, rpc_info + self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue @@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession: if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" - def step(self, new_hidden_states: torch.Tensor): - """Inference step: send a chunk of input tensors and receive a chunk of outputs""" + def step( + self, + new_hidden_states: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, + ): + """ + Inference step: send a chunk of input tesors and receive a chunk of outputs + :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size] + """ if self.closed: raise Exception("Session is closed, cannot perform step") + if prompts is None or is_dummy(prompts): + prompts = DUMMY + else: + assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]" + assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (new_hidden_states.shape[0], 1) + assert prompts.shape[2] <= new_hidden_states.shape[1] + assert prompts.shape[3] == new_hidden_states.shape[2] + + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY + else: + assert len(hypo_ids) == len(new_hidden_states) + assert hypo_ids.dtype == torch.int64 + # serialize inputs and put them into the queue - inputs = (new_hidden_states,) + inputs = (new_hidden_states, prompts, hypo_ids) outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) + for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"])) ], metadata=self._serialized_metadata if not self.stepped else None, ) @@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession: return self - def step(self, inputs: torch.Tensor): + def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs): assert not self.closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + if prompts is None or is_dummy(prompts): + prompts = DUMMY + else: + assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager) for session in self.inference_sessions: - outputs = session.step(inputs) + outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs) assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" inputs = outputs return inputs diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index e4875cc..d2be2c9 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -105,11 +105,12 @@ class RemoteGenerationMixin: hypo_ids = torch.arange(outputs[0].size(0)) while True: embs = self.transformer.word_embeddings(outputs[-1]) + intermediate_prompts = None if self.config.pre_seq_len > 0 and len(outputs) == 1: - prompts, _ = self.transformer.get_prompt(embs.size(0)) + prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0)) embs = torch.cat([prompts, embs], dim=1) embs = self.transformer.word_embeddings_layernorm(embs) - hidden_state = sess.step(embs)[:, -1] + hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] hidden_state = self.transformer.ln_f(hidden_state) lm_logits = self.lm_head(hidden_state) diff --git a/src/server/backend.py b/src/server/backend.py index 9929770..27ee1ad 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -1,15 +1,16 @@ """Code for serving bloom blocks via hivemind-server""" from queue import Empty -from typing import Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import torch -from hivemind import use_hivemind_log_handler +from hivemind import BatchTensorDescriptor, use_hivemind_log_handler from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.utils import InvalidStateError, get_logger from src.bloom.from_pretrained import BloomBlock from src.server.cache import MemoryCache +from src.utils.misc import is_dummy use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -55,18 +56,28 @@ class TransformerBackend(ModuleBackend): self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference" ) self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype + self.inference_schema = ( + ( + *self.args_schema, + BatchTensorDescriptor((), dtype=self.dtype), + BatchTensorDescriptor((), dtype=torch.int64), + ), + self.kwargs_schema, + ) def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: with torch.inference_mode(): attention_cache_handle = int(cache_metadata[0, 0].item()) prefix_length = int(cache_metadata[0, 1].item()) - hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here + (hidden_states, hypo_ids) = inputs assert ( hidden_states.ndim == 3 ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" with self.memory_cache.use_cache(attention_cache_handle) as cache: assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 + if not is_dummy(hypo_ids): + cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] print("METADATA:", cache_metadata, past_k.shape, past_v.shape) hidden_states, (new_k, new_v) = self.module.forward( @@ -85,3 +96,7 @@ class TransformerBackend(ModuleBackend): def get_pools(self) -> Sequence[TaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool + + def get_info(self) -> Dict[str, Any]: + """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" + return dict(super().get_info(), inference_schema=self.inference_schema) diff --git a/src/server/handler.py b/src/server/handler.py index 27ed562..b2e15f7 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -64,41 +64,56 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors - hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] - length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq) + hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + if prompts is None or is_dummy(prompts) or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) - # Cast inputs to backend dtype - hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] - # run request tensors through all requested modules, update caches - for backend, cache_handle in zip(requested_backends, cache_handles): + for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length + assert isinstance( + hidden_states, torch.Tensor + ), f"hidden states must be tensor, got {type(hidden_states)}" assert ( - len(hidden_states) == 1 and hidden_states[0].ndim == 3 + hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states) - assert isinstance(hidden_states, (list, tuple)) - assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 + (hidden_states,) = await backend.inference_pool.submit_task( + cache_metadata, hidden_states, hypo_ids + ) # serialize and send last layer outputs yield runtime_pb2.ExpertResponse( tensors=[ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) for result, proto in zip( - hidden_states, nested_flatten(requested_backends[-1].outputs_schema) + (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) ) ] ) # prepare for next step - prefix_length += hidden_states[0].shape[1] + prefix_length += hidden_states.shape[1] request = await (anext(requests)) finally: print("CLOSED RPC_INFERENCE") @@ -238,23 +253,20 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - hidden_states, *prompts = flat_tensors + hidden_states, prompts = flat_tensors dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 - if not prompts or is_dummy(prompts[0]): + if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) - pre_seq_len = 0 else: - prompts = [prompts[0].to(requested_backends[0].dtype)] - prompts = [p.squeeze(0) for p in prompts[0].split(1)] - pre_seq_len = prompts[0].shape[-2] + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends for backend, prompt in zip(requested_backends, prompts): if not is_dummy(prompt): - hidden_states[:, :pre_seq_len] += prompt + hidden_states[:, : prompt.shape[1]] += prompt (hidden_states,) = await backend.forward_pool.submit_task(hidden_states) assert isinstance(hidden_states, torch.Tensor) assert ( @@ -268,18 +280,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence async def _rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend] ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, *prompts = flat_tensors + inputs, grad_outputs, prompts = flat_tensors # Cast inputs & grad outputs to backend dtype inputs = inputs.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - if not prompts or is_dummy(prompts[0]): + if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) - pre_seq_len = 0 else: - prompts = [prompts[0].to(requested_backends[0].dtype)] - prompts = [p.squeeze(0) for p in prompts[0].split(1)] - pre_seq_len = prompts[0].shape[-2] + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output @@ -287,13 +296,13 @@ async def _rpc_backward( for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): - inputs[:, :pre_seq_len] += prompt + inputs[:, : prompt.shape[1]] += prompt inter_inputs.append(inputs) (inputs,) = await backend.forward_pool.submit_task(inputs) assert isinstance(inputs, torch.Tensor) if not is_dummy(prompts[-1]): - inputs[:, :pre_seq_len] += prompts[-1] + inputs[:, : prompts[-1].shape[1]] += prompts[-1] inter_inputs.append(inputs) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" @@ -303,7 +312,7 @@ async def _rpc_backward( (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0)) + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape