Harrison/hf pipeline (#780)

Co-authored-by: Parth Chadha <parth29@gmail.com>
ankush/async-llmchain
Harrison Chase 1 year ago committed by GitHub
parent c658f0aed3
commit 5bb2952860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,6 @@
"""Wrapper around HuggingFace Pipeline APIs."""
import importlib.util
import logging
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra
@ -10,6 +12,8 @@ DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation")
logger = logging.getLogger()
class HuggingFacePipeline(LLM, BaseModel):
"""Wrapper around HuggingFace Pipeline API.
@ -56,6 +60,7 @@ class HuggingFacePipeline(LLM, BaseModel):
cls,
model_id: str,
task: str,
device: int = -1,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
@ -68,8 +73,16 @@ class HuggingFacePipeline(LLM, BaseModel):
)
from transformers import pipeline as hf_pipeline
_model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
_model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
try:
if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation":
@ -79,25 +92,47 @@ class HuggingFacePipeline(LLM, BaseModel):
f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported"
)
pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
)
if pipeline.task not in VALID_TASKS:
except ImportError as e:
raise ValueError(
f"Could not load the {task} model due to missing dependencies."
) from e
if importlib.util.find_spec("torch") is not None:
import torch
cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count):
raise ValueError(
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})"
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=_model_kwargs,
**kwargs,
)
except ImportError:
if device < 0 and cuda_device_count > 0:
logger.warning(
"Device has %d GPUs available. "
"Provide device={deviceId} to `from_model_id` to use available"
"GPUs for execution. deviceId is -1 (default) for CPU and "
"can be a positive integer associated with CUDA device id.",
cuda_device_count,
)
pipeline = hf_pipeline(
task=task,
model=model,
tokenizer=tokenizer,
device=device,
model_kwargs=_model_kwargs,
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=_model_kwargs,
**kwargs,
)
@property
def _identifying_params(self) -> Mapping[str, Any]:

Loading…
Cancel
Save