From d40eb6c7015c0de2914cb013601bdd47544d16ef Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 4 Sep 2023 12:25:29 +0400 Subject: [PATCH] Fix prompt tuning after #464 (#501) Unfortunately, running inference in models with `"ptune" in config.tuning_mode` was broken after #464. --- src/petals/client/remote_generation.py | 5 +++-- src/petals/models/bloom/model.py | 5 +++-- src/petals/models/falcon/model.py | 5 +++-- src/petals/models/llama/model.py | 5 +++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index e392b4f..97a115a 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin): max_new_tokens is None ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches" + session_max_length = self.transformer.config.pre_seq_len if max_length is not None: - session_max_length = max_length + session_max_length += max_length else: - session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens + session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens context_manager = self.inference_session(max_length=session_max_length) with context_manager as session: diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index cf83822..784418f 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py index 3a2a6b0..32c0b6f 100644 --- a/src/petals/models/falcon/model.py +++ b/src/petals/models/falcon/model.py @@ -77,7 +77,8 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -94,7 +95,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index a9dfcc1..3360f40 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state