Fix max token length huggingface models

This commit is contained in:
Laurel Orr 2022-07-31 22:14:45 +00:00
parent d610e3c800
commit b094cecb4f

View File

@ -47,6 +47,7 @@ class Pipeline:
):
"""Initialize."""
self.model = model
self.max_length = model.config.max_position_embeddings
self.tokenizer = tokenizer
self.device = (
torch.device("cpu")
@ -63,7 +64,12 @@ class Pipeline:
Returns:
generated text.
"""
encoded_prompt = self.tokenizer.encode(text, return_tensors="pt")
# If text is longer than max model length, we reduce max input length to ensure
# the user indicated generation tokens is preserved.
max_input_length = kwargs.get("max_input_length")
encoded_prompt = self.tokenizer.encode(
text, max_length=max_input_length, return_tensors="pt"
)
encoded_prompt = encoded_prompt.to(self.device)
output_sequences = self.model.generate( # type: ignore
encoded_prompt,
@ -219,14 +225,18 @@ class HuggingFaceModel(Model):
list of generated text (list of length 1 for 1 generation).
"""
num_return = kwargs.get("n")
max_input_len = self.pipeline.max_length - kwargs.get("max_tokens")
# Add tokens for length
encoded_prompt_with_special = self.pipeline.tokenizer.encode(prompt)
encoded_prompt_with_special = self.pipeline.tokenizer.encode(
prompt, max_length=max_input_len
)
# Remove tokens as the pipeline removes special tokens upon return
encoded_prompt_without_special = self.pipeline.tokenizer.encode(
prompt, add_special_tokens=False
prompt, max_length=max_input_len, add_special_tokens=False
)
result = self.pipeline(
prompt,
max_input_length=max_input_len,
max_length=kwargs.get("max_tokens") + len(encoded_prompt_with_special),
temperature=kwargs.get("temperature"),
repetition_penalty=kwargs.get("repetition_penalty"),