diff --git a/docs/docs/integrations/text_embedding/huggingfacehub.ipynb b/docs/docs/integrations/text_embedding/huggingfacehub.ipynb index b14b52c751..da96bb3cb5 100644 --- a/docs/docs/integrations/text_embedding/huggingfacehub.ipynb +++ b/docs/docs/integrations/text_embedding/huggingfacehub.ipynb @@ -106,7 +106,7 @@ "metadata": {}, "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ "Enter your HF Inference API Key:\n", @@ -148,6 +148,75 @@ "query_result = embeddings.embed_query(text)\n", "query_result[:3]" ] + }, + { + "cell_type": "markdown", + "id": "19ef2d31", + "metadata": {}, + "source": [ + "## Hugging Face Hub\n", + "We can also generate embeddings locally via the Hugging Face Hub package, which requires us to install ``huggingface_hub ``" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39e85945", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install huggingface_hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c78a2779", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.embeddings import HuggingFaceHubEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "116f3ce7", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = HuggingFaceHubEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6f97ee9", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb6adc67", + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f42c311", + "metadata": {}, + "outputs": [], + "source": [ + "query_result[:3]" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/embeddings/huggingface_hub.py b/libs/community/langchain_community/embeddings/huggingface_hub.py index 21d6113f05..465f8c6925 100644 --- a/libs/community/langchain_community/embeddings/huggingface_hub.py +++ b/libs/community/langchain_community/embeddings/huggingface_hub.py @@ -29,6 +29,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): """ client: Any #: :meta private: + async_client: Any #: :meta private: model: Optional[str] = None """Model name to use.""" repo_id: Optional[str] = None @@ -53,7 +54,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): ) try: - from huggingface_hub import InferenceClient + from huggingface_hub import AsyncInferenceClient, InferenceClient if values["model"]: values["repo_id"] = values["model"] @@ -67,12 +68,20 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): model=values["model"], token=huggingfacehub_api_token, ) + + async_client = AsyncInferenceClient( + model=values["model"], + token=huggingfacehub_api_token, + ) + if values["task"] not in VALID_TASKS: raise ValueError( f"Got invalid task {values['task']}, " f"currently only {VALID_TASKS} are supported" ) values["client"] = client + values["async_client"] = async_client + except ImportError: raise ImportError( "Could not import huggingface_hub python package. " @@ -97,6 +106,23 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): ) return json.loads(responses.decode()) + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Async Call to HuggingFaceHub's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + # replace newlines, which can negatively affect performance. + texts = [text.replace("\n", " ") for text in texts] + _model_kwargs = self.model_kwargs or {} + responses = await self.async_client.post( + json={"inputs": texts, "parameters": _model_kwargs}, task=self.task + ) + return json.loads(responses.decode()) + def embed_query(self, text: str) -> List[float]: """Call out to HuggingFaceHub's embedding endpoint for embedding query text. @@ -108,3 +134,15 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): """ response = self.embed_documents([text])[0] return response + + async def aembed_query(self, text: str) -> List[float]: + """Async Call to HuggingFaceHub's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + response = await self.aembed_documents([text])[0] + return response diff --git a/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py b/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py index b22bc94088..0a47d97f53 100644 --- a/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py +++ b/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py @@ -13,6 +13,15 @@ def test_huggingfacehub_embedding_documents() -> None: assert len(output[0]) == 768 +async def test_huggingfacehub_embedding_async_documents() -> None: + """Test huggingfacehub embeddings.""" + documents = ["foo bar"] + embedding = HuggingFaceHubEmbeddings() + output = await embedding.aembed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + + def test_huggingfacehub_embedding_query() -> None: """Test huggingfacehub embeddings.""" document = "foo bar" @@ -21,6 +30,14 @@ def test_huggingfacehub_embedding_query() -> None: assert len(output) == 768 +async def test_huggingfacehub_embedding_async_query() -> None: + """Test huggingfacehub embeddings.""" + document = "foo bar" + embedding = HuggingFaceHubEmbeddings() + output = await embedding.aembed_query(document) + assert len(output) == 768 + + def test_huggingfacehub_embedding_invalid_repo() -> None: """Test huggingfacehub embedding repo id validation.""" # Only sentence-transformers models are currently supported.