fix(embeddings): huggingface hub embeddings and TEI (#14489)

**Description:** This PR fixes `HuggingFaceHubEmbeddings` by making the
API token optional (as in the client beneath). Most models don't require
one. I also updated the notebook for TEI (text-embeddings-inference)
accordingly as requested here #14288. In addition, I fixed a mistake in
the POST call parameters.

**Tag maintainers:** @baskaryan
pull/14637/head
Massimiliano Pronesti 7 months ago committed by GitHub
parent 5da79e150b
commit 6080c98108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "579f0677-aa06-4ad8-a816-3520c8d6923c",
"metadata": {
"tags": []
@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4",
"metadata": {
"tags": []
@ -62,45 +62,19 @@
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c26fca9f-cfdb-45e5-a0bd-f677ff8b9d92",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"Enter your HF API Key:\n",
"\n",
" ········\n"
]
}
],
"source": [
"from getpass import getpass\n",
"\n",
"huggingfacehub_api_token = getpass(\"Enter your HF API Key:\\n\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"id": "f9a92970-16f4-458c-b186-2a83e9f7d840",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"embeddings = HuggingFaceHubEmbeddings(\n",
" model=\"http://localhost:8080\", huggingfacehub_api_token=huggingfacehub_api_token\n",
")"
"embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"id": "42105438-9fee-460a-9c52-b7c595722758",
"metadata": {
"tags": []
@ -112,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"id": "20167762-0988-4205-bbd4-1f20fd9dd247",
"metadata": {
"tags": []
@ -124,7 +98,7 @@
"[0.018113142, 0.00302585, -0.049911194]"
]
},
"execution_count": 8,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -136,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b",
"metadata": {
"tags": []
@ -145,6 +119,29 @@
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6fba8be9-fabf-4972-8334-aa56ed9893e1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"[0.018113142, 0.00302585, -0.049911194]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"doc_result[0][:3]"
]
}
],
"metadata": {

@ -1,9 +1,9 @@
import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
VALID_TASKS = ("feature-extraction",)
@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
"HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub import InferenceClient
@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
responses = self.client.post(
json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
)
return json.loads(responses.decode())

Loading…
Cancel
Save