[bug] fix prompt truncation HF

laurel/helm
Laurel Orr 2 years ago
parent 966fe6b5d4
commit ec78ac7cbf

@ -4,6 +4,7 @@ from functools import partial
from pathlib import Path
from typing import Any, Dict, List
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
@ -27,8 +28,12 @@ class GPT2Pipeline:
"""Initialize."""
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(self.device) # type: ignore
self.device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
self.model = self.model.to(self.device) # type: ignore
def __call__(self, text: str, **kwargs: Any) -> List[Dict[str, str]]:
"""Generate from text.
@ -106,7 +111,7 @@ class HuggingFaceModel(Model):
self.model_path, cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = MODEL_PIPELINE[model_name](
self.pipeline = MODEL_PIPELINE[model_name]( # type: ignore
model=model, tokenizer=tokenizer, device=device
)
self.returns_input = "gpt" in model_name
@ -142,8 +147,10 @@ class HuggingFaceModel(Model):
do_sample=kwargs.get("do_sample"),
num_return_sequences=num_return,
)
# Correctly returns prompt without extra spaces
decoded_prompt = self.pipeline.tokenizer.decode(encoded_prompt)
if self.returns_input:
start_idx = len(prompt)
start_idx = len(decoded_prompt)
else:
start_idx = 0
if num_return == 1:

Loading…
Cancel
Save