forked from Archives/langchain
Add HuggingFace Hub Embeddings (#125)
Add support for calling HuggingFace embedding models using the HuggingFaceHub Inference API. New class mirrors the existing HuggingFaceHub LLM implementation. Currently only supports 'sentence-transformers' models. Closes #86harrison/logging_to_file
parent
d0415952f7
commit
b90e25f786
@ -1,6 +1,12 @@
|
||||
"""Wrappers around embedding modules."""
|
||||
from langchain.embeddings.cohere import CohereEmbeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
__all__ = ["OpenAIEmbeddings", "HuggingFaceEmbeddings", "CohereEmbeddings"]
|
||||
__all__ = [
|
||||
"OpenAIEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
]
|
||||
|
@ -0,0 +1,105 @@
|
||||
"""Wrapper around HuggingFace Hub embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2"
|
||||
VALID_TASKS = ("feature-extraction",)
|
||||
|
||||
|
||||
class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around HuggingFaceHub embedding models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceHubEmbeddings
|
||||
repo_id = "sentence-transformers/all-mpnet-base-v2"
|
||||
hf = HuggingFaceHubEmbeddings(
|
||||
repo_id=repo_id,
|
||||
task="feature-extraction",
|
||||
huggingfacehub_api_token="my-api-key",
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
repo_id: str = DEFAULT_REPO_ID
|
||||
"""Model name to use."""
|
||||
task: Optional[str] = "feature-extraction"
|
||||
"""Task to call the model with."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
try:
|
||||
from huggingface_hub.inference_api import InferenceApi
|
||||
|
||||
repo_id = values["repo_id"]
|
||||
if not repo_id.startswith("sentence-transformers"):
|
||||
raise ValueError(
|
||||
"Currently only 'sentence-transformers' embedding models "
|
||||
f"are supported. Got invalid 'repo_id' {repo_id}."
|
||||
)
|
||||
client = InferenceApi(
|
||||
repo_id=repo_id,
|
||||
token=huggingfacehub_api_token,
|
||||
task=values.get("task"),
|
||||
)
|
||||
if client.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {client.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
values["client"] = client
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please it install it with `pip install huggingface_hub`."
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
responses = self.client(inputs=texts, params=_model_kwargs)
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
@ -0,0 +1,19 @@
|
||||
"""Test HuggingFaceHub embeddings."""
|
||||
from langchain.embeddings import HuggingFaceHubEmbeddings
|
||||
|
||||
|
||||
def test_huggingfacehub_embedding_documents() -> None:
|
||||
"""Test huggingfacehub embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = HuggingFaceHubEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
def test_huggingfacehub_embedding_query() -> None:
|
||||
"""Test huggingfacehub embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = HuggingFaceHubEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 768
|
@ -0,0 +1 @@
|
||||
"""All unit tests for Embeddings objects."""
|
@ -0,0 +1,11 @@
|
||||
"""Test HuggingFaceHub embeddings."""
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings import HuggingFaceHubEmbeddings
|
||||
|
||||
|
||||
def test_huggingfacehub_embedding_invalid_repo() -> None:
|
||||
"""Test huggingfacehub embedding repo id validation."""
|
||||
# Only sentence-transformers models are currently supported.
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceHubEmbeddings(repo_id="allenai/specter")
|
Loading…
Reference in New Issue