@ -1,7 +1,7 @@
""" Wrapper around HuggingFace embedding models. """
from typing import Any , List, Optional
from typing import Any , Dict, List, Optional
from pydantic import BaseModel , Extra
from pydantic import BaseModel , Extra , Field
from langchain . embeddings . base import Embeddings
@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
. . code - block : : python
from langchain . embeddings import HuggingFaceEmbeddings
model_name = " sentence-transformers/all-mpnet-base-v2 "
hf = HuggingFaceEmbeddings ( model_name = model_name )
model_kwargs = { ' device ' : ' cpu ' }
hf = HuggingFaceEmbeddings ( model_name = model_name , model_kwargs = model_kwargs )
"""
client : Any #: :meta private:
@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder : Optional [ str ] = None
""" Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable . """
model_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" Key word arguments to pass to the model. """
def __init__ ( self , * * kwargs : Any ) :
""" Initialize the sentence_transformer. """
@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
import sentence_transformers
self . client = sentence_transformers . SentenceTransformer (
self . model_name , self . cache_folder
self . model_name , cache_folder = self . cache_folder , * * self . model_kwargs
)
except ImportError :
raise ValueError (
@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
. . code - block : : python
from langchain . embeddings import HuggingFaceInstructEmbeddings
model_name = " hkunlp/instructor-large "
hf = HuggingFaceInstructEmbeddings ( model_name = model_name )
model_kwargs = { ' device ' : ' cpu ' }
hf = HuggingFaceInstructEmbeddings (
model_name = model_name , model_kwargs = model_kwargs
)
"""
client : Any #: :meta private:
model_name : str = DEFAULT_INSTRUCT_MODEL
""" Model name to use. """
cache_folder : Optional [ str ] = None
""" Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable . """
model_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" Key word arguments to pass to the model. """
embed_instruction : str = DEFAULT_EMBED_INSTRUCTION
""" Instruction to use for embedding documents. """
query_instruction : str = DEFAULT_QUERY_INSTRUCTION
@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
try :
from InstructorEmbedding import INSTRUCTOR
self . client = INSTRUCTOR ( self . model_name )
self . client = INSTRUCTOR (
self . model_name , cache_folder = self . cache_folder , * * self . model_kwargs
)
except ImportError as e :
raise ValueError ( " Dependencies for InstructorEmbedding not found. " ) from e