Update code example as in #58

update-readme-disclaimers-faq
Alexander Borzunov 2 years ago committed by GitHub
parent 640e9d67dc
commit 780780a8a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,24 +28,26 @@
### Examples
Solving a sequence classification task via soft prompt tuning of BLOOM-176B:
Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
```python
# Initialize distributed BLOOM with soft prompts
model = AutoModelForPromptTuning.from_pretrained(
"bigscience/distributed-bloom")
# Define optimizer for prompts and linear head
optimizer = torch.optim.AdamW(model.parameters())
# Initialize distributed BLOOM and connect to the swarm
model = DistributedBloomForCausalLM.from_pretrained(
"bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
) # Embeddings & prompts are on your device, BLOOM blocks are distributed
print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
# Training (updates only local prompts / adapters)
optimizer = torch.optim.AdamW(model.parameters())
for input_ids, labels in data_loader:
# Forward pass with local and remote layers
outputs = model.forward(input_ids)
loss = cross_entropy(outputs.logits, labels)
# Distributed backward w.r.t. local params
loss.backward() # Compute model.prompts.grad
optimizer.step() # Update local params only
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
### 🚧 This project is in active development
@ -76,6 +78,8 @@ This is important because it's technically possible for peers serving model laye
## Installation
__[To be updated soon]__
```bash
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

Loading…
Cancel
Save