mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
Fix max token length huggingface models
This commit is contained in:
parent
d610e3c800
commit
b094cecb4f
@ -47,6 +47,7 @@ class Pipeline:
|
||||
):
|
||||
"""Initialize."""
|
||||
self.model = model
|
||||
self.max_length = model.config.max_position_embeddings
|
||||
self.tokenizer = tokenizer
|
||||
self.device = (
|
||||
torch.device("cpu")
|
||||
@ -63,7 +64,12 @@ class Pipeline:
|
||||
Returns:
|
||||
generated text.
|
||||
"""
|
||||
encoded_prompt = self.tokenizer.encode(text, return_tensors="pt")
|
||||
# If text is longer than max model length, we reduce max input length to ensure
|
||||
# the user indicated generation tokens is preserved.
|
||||
max_input_length = kwargs.get("max_input_length")
|
||||
encoded_prompt = self.tokenizer.encode(
|
||||
text, max_length=max_input_length, return_tensors="pt"
|
||||
)
|
||||
encoded_prompt = encoded_prompt.to(self.device)
|
||||
output_sequences = self.model.generate( # type: ignore
|
||||
encoded_prompt,
|
||||
@ -219,14 +225,18 @@ class HuggingFaceModel(Model):
|
||||
list of generated text (list of length 1 for 1 generation).
|
||||
"""
|
||||
num_return = kwargs.get("n")
|
||||
max_input_len = self.pipeline.max_length - kwargs.get("max_tokens")
|
||||
# Add tokens for length
|
||||
encoded_prompt_with_special = self.pipeline.tokenizer.encode(prompt)
|
||||
encoded_prompt_with_special = self.pipeline.tokenizer.encode(
|
||||
prompt, max_length=max_input_len
|
||||
)
|
||||
# Remove tokens as the pipeline removes special tokens upon return
|
||||
encoded_prompt_without_special = self.pipeline.tokenizer.encode(
|
||||
prompt, add_special_tokens=False
|
||||
prompt, max_length=max_input_len, add_special_tokens=False
|
||||
)
|
||||
result = self.pipeline(
|
||||
prompt,
|
||||
max_input_length=max_input_len,
|
||||
max_length=kwargs.get("max_tokens") + len(encoded_prompt_with_special),
|
||||
temperature=kwargs.get("temperature"),
|
||||
repetition_penalty=kwargs.get("repetition_penalty"),
|
||||
|
Loading…
Reference in New Issue
Block a user