|
|
|
@ -28,24 +28,26 @@
|
|
|
|
|
|
|
|
|
|
### Examples
|
|
|
|
|
|
|
|
|
|
Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
|
|
|
|
|
Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
|
|
|
|
|
|
|
|
|
|
This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# Initialize distributed BLOOM with soft prompts
|
|
|
|
|
model = AutoModelForPromptTuning.from_pretrained(
|
|
|
|
|
"bigscience/distributed-bloom")
|
|
|
|
|
# Define optimizer for prompts and linear head
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters())
|
|
|
|
|
# Initialize distributed BLOOM and connect to the swarm
|
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
"bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
|
|
|
|
|
) # Embeddings & prompts are on your device, BLOOM blocks are distributed
|
|
|
|
|
|
|
|
|
|
print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
|
|
|
|
|
|
|
|
|
|
# Training (updates only local prompts / adapters)
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters())
|
|
|
|
|
for input_ids, labels in data_loader:
|
|
|
|
|
# Forward pass with local and remote layers
|
|
|
|
|
outputs = model.forward(input_ids)
|
|
|
|
|
loss = cross_entropy(outputs.logits, labels)
|
|
|
|
|
|
|
|
|
|
# Distributed backward w.r.t. local params
|
|
|
|
|
loss.backward() # Compute model.prompts.grad
|
|
|
|
|
optimizer.step() # Update local params only
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
### 🚧 This project is in active development
|
|
|
|
@ -76,6 +78,8 @@ This is important because it's technically possible for peers serving model laye
|
|
|
|
|
|
|
|
|
|
## Installation
|
|
|
|
|
|
|
|
|
|
__[To be updated soon]__
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
|
|
|
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
|
|
|