Fix prompt tuning after #464 (#501)

Unfortunately, running inference in models with `"ptune" in config.tuning_mode` was broken after #464.
pull/504/head
Alexander Borzunov 8 months ago committed by GitHub
parent dd4a3230bc
commit d40eb6c701
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
session_max_length = self.transformer.config.pre_seq_len
if max_length is not None:
session_max_length = max_length
session_max_length += max_length
else:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager as session:

@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state

@ -77,7 +77,8 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@ -94,7 +95,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state

@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state

Loading…
Cancel
Save