diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 217b6a84..b0bd03e9 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -1,7 +1,7 @@ """Wrapper around HuggingFace embedding models.""" -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field from langchain.embeddings.base import Embeddings @@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain.embeddings import HuggingFaceEmbeddings + model_name = "sentence-transformers/all-mpnet-base-v2" - hf = HuggingFaceEmbeddings(model_name=model_name) + model_kwargs = {'device': 'cpu'} + hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) """ client: Any #: :meta private: @@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): cache_folder: Optional[str] = None """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass to the model.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): import sentence_transformers self.client = sentence_transformers.SentenceTransformer( - self.model_name, self.cache_folder + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) except ImportError: raise ValueError( @@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain.embeddings import HuggingFaceInstructEmbeddings + model_name = "hkunlp/instructor-large" - hf = HuggingFaceInstructEmbeddings(model_name=model_name) + model_kwargs = {'device': 'cpu'} + hf = HuggingFaceInstructEmbeddings( + model_name=model_name, model_kwargs=model_kwargs + ) """ client: Any #: :meta private: model_name: str = DEFAULT_INSTRUCT_MODEL """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass to the model.""" embed_instruction: str = DEFAULT_EMBED_INSTRUCTION """Instruction to use for embedding documents.""" query_instruction: str = DEFAULT_QUERY_INSTRUCTION @@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): try: from InstructorEmbedding import INSTRUCTOR - self.client = INSTRUCTOR(self.model_name) + self.client = INSTRUCTOR( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) except ImportError as e: raise ValueError("Dependencies for InstructorEmbedding not found.") from e