Merge branch 'main' into justheuristic-patch-3

pull/60/head
justheuristic 2 years ago committed by GitHub
commit 0a39b0ba63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,12 +11,13 @@
## Key features
- Run inference or fine-tune [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs.
- One inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps.
- Employ any fine-tuning and sampling methods by accessing model's hidden states and changing its control flow — something you can't do in proprietary APIs.
- Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs.
- It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning.
- This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps.
- Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.
<p align="center">
<b><a href="https://petals.ml/petals.pdf">[Read paper]</a></b> | <b><a href="https://petals.ml/">[View website]</a></b>
<b><a href="https://arxiv.org/pdf/2209.01188.pdf">[Read paper]</a></b> | <b><a href="https://petals.ml/">[View website]</a></b>
</p>
## How it works?
@ -25,36 +26,62 @@
<img src="https://i.imgur.com/RTYF3yW.png" width="800">
</p>
### 🚧 This project is in active development
### 🛠️ Examples
Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)).
Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm.
This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
## Code examples
```python
# Initialize distributed BLOOM and connect to the swarm
model = DistributedBloomForCausalLM.from_pretrained(
"bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
) # Embeddings & prompts are on your device, BLOOM blocks are distributed
Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
```python
# Initialize distributed BLOOM with soft prompts
model = AutoModelForPromptTuning.from_pretrained(
"bigscience/distributed-bloom")
# Define optimizer for prompts and linear head
# Training (updates only local prompts / adapters)
optimizer = torch.optim.AdamW(model.parameters())
for input_ids, labels in data_loader:
# Forward pass with local and remote layers
outputs = model.forward(input_ids)
loss = cross_entropy(outputs.logits, labels)
# Distributed backward w.r.t. local params
loss.backward() # Compute model.prompts.grad
optimizer.step() # Update local params only
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
### 🚧 This project is in active development
Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)).
A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm.
### 🔒 Privacy and security
If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data.
This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
## FAQ
1. **What's the motivation for people to host model layers in the public swarm?**
People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded).
Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards.
2. **Why is the platform named "Petals"?**
"Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model &mdash; [BLOOM](https://huggingface.co/bigscience/bloom).
While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future.
## Installation
🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two).
--------------------------------------------------------------------------------
```bash
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
max_length: int,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession:
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(self, new_hidden_states: torch.Tensor):
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
def step(
self,
new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
):
"""
Inference step: send a chunk of input tesors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
assert prompts.shape[2] <= new_hidden_states.shape[1]
assert prompts.shape[3] == new_hidden_states.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
else:
assert len(hypo_ids) == len(new_hidden_states)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
inputs = (new_hidden_states,)
inputs = (new_hidden_states, prompts, hypo_ids)
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
],
metadata=self._serialized_metadata if not self.stepped else None,
)
@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession:
return self
def step(self, inputs: torch.Tensor):
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
assert not self.closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
for session in self.inference_sessions:
outputs = session.step(inputs)
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs

@ -105,11 +105,12 @@ class RemoteGenerationMixin:
hypo_ids = torch.arange(outputs[0].size(0))
while True:
embs = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, _ = self.transformer.get_prompt(embs.size(0))
prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
embs = torch.cat([prompts, embs], dim=1)
embs = self.transformer.word_embeddings_layernorm(embs)
hidden_state = sess.step(embs)[:, -1]
hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)

@ -1,15 +1,16 @@
"""Code for serving bloom blocks via hivemind-server"""
from queue import Empty
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from hivemind import use_hivemind_log_handler
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from hivemind.utils import InvalidStateError, get_logger
from src.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache
from src.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -55,18 +56,28 @@ class TransformerBackend(ModuleBackend):
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
)
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
self.inference_schema = (
(
*self.args_schema,
BatchTensorDescriptor((), dtype=self.dtype),
BatchTensorDescriptor((), dtype=torch.int64),
),
self.kwargs_schema,
)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
with torch.inference_mode():
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
(hidden_states, hypo_ids) = inputs
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
if not is_dummy(hypo_ids):
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(
@ -85,3 +96,7 @@ class TransformerBackend(ModuleBackend):
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
def get_info(self) -> Dict[str, Any]:
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)

@ -64,41 +64,56 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
assert isinstance(hidden_states, (list, tuple))
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids
)
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
)
# prepare for next step
prefix_length += hidden_states[0].shape[1]
prefix_length += hidden_states.shape[1]
request = await (anext(requests))
finally:
print("CLOSED RPC_INFERENCE")
@ -238,23 +253,20 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, *prompts = flat_tensors
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if not prompts or is_dummy(prompts[0]):
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, :pre_seq_len] += prompt
hidden_states[:, : prompt.shape[1]] += prompt
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
assert isinstance(hidden_states, torch.Tensor)
assert (
@ -268,18 +280,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
async def _rpc_backward(
*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, *prompts = flat_tensors
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if not prompts or is_dummy(prompts[0]):
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
@ -287,13 +296,13 @@ async def _rpc_backward(
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, :pre_seq_len] += prompt
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, :pre_seq_len] += prompts[-1]
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
@ -303,7 +312,7 @@ async def _rpc_backward(
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

Loading…
Cancel
Save