huggingface[patch]: make HuggingFaceEndpoint serializable (#27027)

This commit is contained in:
Bagatur 2024-10-01 13:16:10 -07:00 committed by GitHub
parent 9d10151123
commit b5e28d3a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 829 additions and 791 deletions

View File

@ -4,14 +4,19 @@ import logging
import os import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from huggingface_hub import ( # type: ignore[import-untyped]
AsyncInferenceClient,
InferenceClient,
login,
)
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
from langchain_core.utils import from_env, get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,10 +78,12 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`""" should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given""" """Repo to use. If endpoint_url is not specified then this needs to given"""
huggingfacehub_api_token: Optional[str] = Field( huggingfacehub_api_token: Optional[SecretStr] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None) default_factory=secret_from_env(
["HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN"], default=None
)
) )
max_new_tokens: int = 512 max_new_tokens: int = Field(default=512, alias="max_tokens")
"""Maximum number of generated tokens""" """Maximum number of generated tokens"""
top_k: Optional[int] = None top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for """The number of highest probability vocabulary tokens to keep for
@ -116,14 +123,15 @@ class HuggingFaceEndpoint(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified""" """Holds any model parameters valid for `call` not explicitly specified"""
model: str model: str
client: Any = None #: :meta private: client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = None #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private:
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. """Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`.""" Should be a task that returns `generated_text` or `summary_text`."""
model_config = ConfigDict( model_config = ConfigDict(
extra="forbid", extra="forbid",
populate_by_name=True,
) )
@model_validator(mode="before") @model_validator(mode="before")
@ -189,36 +197,23 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid.""" """Validate that package is installed and that the API token is valid."""
try: if self.huggingfacehub_api_token is not None:
from huggingface_hub import login # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
if huggingfacehub_api_token is not None:
try: try:
login(token=huggingfacehub_api_token) login(token=self.huggingfacehub_api_token.get_secret_value())
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
"Could not authenticate with huggingface_hub. " "Could not authenticate with huggingface_hub. "
"Please check your API token." "Please check your API token."
) from e ) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient
# Instantiate clients with supported kwargs # Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters) sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient( self.client = InferenceClient(
model=self.model, model=self.model,
timeout=self.timeout, timeout=self.timeout,
token=huggingfacehub_api_token, token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{ **{
key: value key: value
for key, value in self.server_kwargs.items() for key, value in self.server_kwargs.items()
@ -230,7 +225,9 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient( self.async_client = AsyncInferenceClient(
model=self.model, model=self.model,
timeout=self.timeout, timeout=self.timeout,
token=huggingfacehub_api_token, token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{ **{
key: value key: value
for key, value in self.server_kwargs.items() for key, value in self.server_kwargs.items()
@ -426,3 +423,15 @@ class HuggingFaceEndpoint(LLM):
# break if stop sequence found # break if stop sequence found
if stop_seq_found: if stop_seq_found:
break break
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> list[str]:
return ["langchain_huggingface", "llms"]
@property
def lc_secrets(self) -> dict[str, str]:
return {"huggingfacehub_api_token": "HUGGINGFACEHUB_API_TOKEN"}

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "langchain-huggingface" name = "langchain-huggingface"
version = "0.1.0" version = "0.1.1"
description = "An integration package connecting Hugging Face and LangChain" description = "An integration package connecting Hugging Face and LangChain"
authors = [] authors = []
readme = "README.md" readme = "README.md"
@ -20,7 +20,7 @@ disallow_untyped_defs = "True"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
langchain-core = ">=0.3.0,<0.4" langchain-core = ">=0.3.7,<0.4"
tokenizers = ">=0.19.1" tokenizers = ">=0.19.1"
transformers = ">=4.39.0" transformers = ">=4.39.0"
sentence-transformers = ">=2.6.0" sentence-transformers = ">=2.6.0"