Use slightly less memory in .generate() (#177)

pull/180/head
Alexander Borzunov 1 year ago committed by GitHub
parent 55698381d0
commit e27706358c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,7 +50,7 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --
Check out more examples and tutorials:
- Chatbot web app: [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat)
- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat)
- Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
- Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)

@ -40,7 +40,7 @@ class RemoteGenerationMixin:
return self.transformer.h.inference_session(**kwargs)
@torch.no_grad()
@torch.inference_mode()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
@ -171,13 +171,15 @@ class RemoteGenerationMixin:
seq_idx = outputs[0].size(1)
hypo_ids = torch.arange(outputs[0].size(0))
while True:
embs = self.transformer.word_embeddings(outputs[-1])
hidden_state = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
embs = torch.cat([prompts, embs], dim=1)
embs = self.transformer.word_embeddings_layernorm(embs)
hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
hidden_state = torch.cat([prompts, hidden_state], dim=1)
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)

Loading…
Cancel
Save