huggingface[patch]: make HuggingFaceEndpoint serializable (#27027)

This commit is contained in:
Bagatur 2024-10-01 13:16:10 -07:00 committed by GitHub
parent 9d10151123
commit b5e28d3a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 829 additions and 791 deletions

View File

@ -4,14 +4,19 @@ import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from huggingface_hub import ( # type: ignore[import-untyped]
AsyncInferenceClient,
InferenceClient,
login,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import from_env, get_pydantic_field_names
from pydantic import ConfigDict, Field, model_validator
from langchain_core.utils import get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
logger = logging.getLogger(__name__)
@ -73,10 +78,12 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given"""
huggingfacehub_api_token: Optional[str] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
huggingfacehub_api_token: Optional[SecretStr] = Field(
default_factory=secret_from_env(
["HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN"], default=None
)
)
max_new_tokens: int = 512
max_new_tokens: int = Field(default=512, alias="max_tokens")
"""Maximum number of generated tokens"""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for
@ -116,14 +123,15 @@ class HuggingFaceEndpoint(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
@model_validator(mode="before")
@ -189,36 +197,23 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid."""
try:
from huggingface_hub import login # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
if huggingfacehub_api_token is not None:
if self.huggingfacehub_api_token is not None:
try:
login(token=huggingfacehub_api_token)
login(token=self.huggingfacehub_api_token.get_secret_value())
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient
# Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{
key: value
for key, value in self.server_kwargs.items()
@ -230,7 +225,9 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{
key: value
for key, value in self.server_kwargs.items()
@ -426,3 +423,15 @@ class HuggingFaceEndpoint(LLM):
# break if stop sequence found
if stop_seq_found:
break
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> list[str]:
return ["langchain_huggingface", "llms"]
@property
def lc_secrets(self) -> dict[str, str]:
return {"huggingfacehub_api_token": "HUGGINGFACEHUB_API_TOKEN"}

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-huggingface"
version = "0.1.0"
version = "0.1.1"
description = "An integration package connecting Hugging Face and LangChain"
authors = []
readme = "README.md"
@ -20,7 +20,7 @@ disallow_untyped_defs = "True"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = ">=0.3.0,<0.4"
langchain-core = ">=0.3.7,<0.4"
tokenizers = ">=0.19.1"
transformers = ">=4.39.0"
sentence-transformers = ">=2.6.0"