diff --git a/docs/docs/integrations/text_embedding/text_embeddings_inference.ipynb b/docs/docs/integrations/text_embedding/text_embeddings_inference.ipynb index dafedba496..29c5b67ee2 100644 --- a/docs/docs/integrations/text_embedding/text_embeddings_inference.ipynb +++ b/docs/docs/integrations/text_embedding/text_embeddings_inference.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "579f0677-aa06-4ad8-a816-3520c8d6923c", "metadata": { "tags": [] @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4", "metadata": { "tags": [] @@ -62,45 +62,19 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "c26fca9f-cfdb-45e5-a0bd-f677ff8b9d92", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Enter your HF API Key:\n", - "\n", - " ········\n" - ] - } - ], - "source": [ - "from getpass import getpass\n", - "\n", - "huggingfacehub_api_token = getpass(\"Enter your HF API Key:\\n\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "id": "f9a92970-16f4-458c-b186-2a83e9f7d840", "metadata": { "tags": [] }, "outputs": [], "source": [ - "embeddings = HuggingFaceHubEmbeddings(\n", - " model=\"http://localhost:8080\", huggingfacehub_api_token=huggingfacehub_api_token\n", - ")" + "embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "id": "42105438-9fee-460a-9c52-b7c595722758", "metadata": { "tags": [] @@ -112,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "20167762-0988-4205-bbd4-1f20fd9dd247", "metadata": { "tags": [] @@ -124,7 +98,7 @@ "[0.018113142, 0.00302585, -0.049911194]" ] }, - "execution_count": 8, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -136,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b", "metadata": { "tags": [] @@ -145,6 +119,29 @@ "source": [ "doc_result = embeddings.embed_documents([text])" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6fba8be9-fabf-4972-8334-aa56ed9893e1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.018113142, 0.00302585, -0.049911194]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "doc_result[0][:3]" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/embeddings/huggingface_hub.py b/libs/community/langchain_community/embeddings/huggingface_hub.py index 773dddad7d..21d6113f05 100644 --- a/libs/community/langchain_community/embeddings/huggingface_hub.py +++ b/libs/community/langchain_community/embeddings/huggingface_hub.py @@ -1,9 +1,9 @@ import json +import os from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain_core.utils import get_from_dict_or_env DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" VALID_TASKS = ("feature-extraction",) @@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - huggingfacehub_api_token = get_from_dict_or_env( - values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv( + "HUGGINGFACEHUB_API_TOKEN" ) + try: from huggingface_hub import InferenceClient @@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): texts = [text.replace("\n", " ") for text in texts] _model_kwargs = self.model_kwargs or {} responses = self.client.post( - json={"inputs": texts, "parameters": _model_kwargs, "task": self.task} + json={"inputs": texts, "parameters": _model_kwargs}, task=self.task ) return json.loads(responses.decode())