diff --git a/README.md b/README.md index 26927f3..f7be4c0 100644 --- a/README.md +++ b/README.md @@ -28,24 +28,26 @@ ### Examples -Solving a sequence classification task via soft prompt tuning of BLOOM-176B: +Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library. + +This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning: ```python -# Initialize distributed BLOOM with soft prompts -model = AutoModelForPromptTuning.from_pretrained( - "bigscience/distributed-bloom") -# Define optimizer for prompts and linear head -optimizer = torch.optim.AdamW(model.parameters()) +# Initialize distributed BLOOM and connect to the swarm +model = DistributedBloomForCausalLM.from_pretrained( + "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW +) # Embeddings & prompts are on your device, BLOOM blocks are distributed + +print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5)) +# Training (updates only local prompts / adapters) +optimizer = torch.optim.AdamW(model.parameters()) for input_ids, labels in data_loader: - # Forward pass with local and remote layers outputs = model.forward(input_ids) loss = cross_entropy(outputs.logits, labels) - - # Distributed backward w.r.t. local params - loss.backward() # Compute model.prompts.grad - optimizer.step() # Update local params only optimizer.zero_grad() + loss.backward() + optimizer.step() ``` ### 🚧 This project is in active development @@ -76,6 +78,8 @@ This is important because it's technically possible for peers serving model laye ## Installation +__[To be updated soon]__ + ```bash conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html