@ -1,9 +1,9 @@
import json
import json
import os
from typing import Any , Dict , List , Optional
from typing import Any , Dict , List , Optional
from langchain_core . embeddings import Embeddings
from langchain_core . embeddings import Embeddings
from langchain_core . pydantic_v1 import BaseModel , Extra , root_validator
from langchain_core . pydantic_v1 import BaseModel , Extra , root_validator
from langchain_core . utils import get_from_dict_or_env
DEFAULT_MODEL = " sentence-transformers/all-mpnet-base-v2 "
DEFAULT_MODEL = " sentence-transformers/all-mpnet-base-v2 "
VALID_TASKS = ( " feature-extraction " , )
VALID_TASKS = ( " feature-extraction " , )
@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
@root_validator ( )
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
""" Validate that api key and python package exists in environment. """
huggingfacehub_api_token = get_from_dict_or_ env(
huggingfacehub_api_token = values[ " huggingfacehub_api_token " ] or os . get env(
values , " huggingfacehub_api_token " , " HUGGINGFACEHUB_API_TOKEN "
" HUGGINGFACEHUB_API_TOKEN "
)
)
try :
try :
from huggingface_hub import InferenceClient
from huggingface_hub import InferenceClient
@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
texts = [ text . replace ( " \n " , " " ) for text in texts ]
texts = [ text . replace ( " \n " , " " ) for text in texts ]
_model_kwargs = self . model_kwargs or { }
_model_kwargs = self . model_kwargs or { }
responses = self . client . post (
responses = self . client . post (
json = { " inputs " : texts , " parameters " : _model_kwargs , " task " : self . task }
json = { " inputs " : texts , " parameters " : _model_kwargs }, task = self . task
)
)
return json . loads ( responses . decode ( ) )
return json . loads ( responses . decode ( ) )