|
|
|
@ -61,24 +61,36 @@ class HuggingFacePipeline(LLM, BaseModel):
|
|
|
|
|
) -> LLM:
|
|
|
|
|
"""Construct the pipeline object from model_id and task."""
|
|
|
|
|
try:
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
from transformers import (
|
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
|
AutoModelForSeq2SeqLM,
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
)
|
|
|
|
|
from transformers import pipeline as hf_pipeline
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
|
|
|
if task == "text-generation":
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
|
|
|
elif task == "text2text-generation":
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Got invalid task {task}, "
|
|
|
|
|
f"currently only {VALID_TASKS} are supported"
|
|
|
|
|
)
|
|
|
|
|
_model_kwargs = model_kwargs or {}
|
|
|
|
|
pipeline = hf_pipeline(
|
|
|
|
|
task=task, model=model, tokenizer=tokenizer, **model_kwargs
|
|
|
|
|
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
|
|
|
|
|
)
|
|
|
|
|
if pipeline.task not in VALID_TASKS:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Got invalid task {pipeline.task}, "
|
|
|
|
|
f"currently only {VALID_TASKS} are supported"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
pipeline=pipeline,
|
|
|
|
|
model_id=model_id,
|
|
|
|
|
model_kwargs=model_kwargs,
|
|
|
|
|
model_kwargs=_model_kwargs,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
except ImportError:
|
|
|
|
@ -100,7 +112,7 @@ class HuggingFacePipeline(LLM, BaseModel):
|
|
|
|
|
return "huggingface_pipeline"
|
|
|
|
|
|
|
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|
response = self.pipeline(text_inputs=prompt)
|
|
|
|
|
response = self.pipeline(prompt)
|
|
|
|
|
if self.pipeline.task == "text-generation":
|
|
|
|
|
# Text generation return includes the starter text.
|
|
|
|
|
text = response[0]["generated_text"][len(prompt) :]
|
|
|
|
|