From 36720cb57f0c12ce27d246858b126f1d0eda2fb2 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Thu, 20 Apr 2023 20:41:22 -0700 Subject: [PATCH] Hf emb device (#3266) Make it possible to control the HuggingFaceEmbeddings and HuggingFaceInstructEmbeddings client model kwargs. Additionally, the cache folder was added for HuggingFaceInstructEmbedding as the client inherits from SentenceTransformer (client of HuggingFaceEmbeddings). It can be useful, especially to control the client device, as it will be defaulted to GPU by sentence_transformers if there is any. --------- Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> --- langchain/embeddings/huggingface.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 217b6a84..b0bd03e9 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -1,7 +1,7 @@ """Wrapper around HuggingFace embedding models.""" -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field from langchain.embeddings.base import Embeddings @@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain.embeddings import HuggingFaceEmbeddings + model_name = "sentence-transformers/all-mpnet-base-v2" - hf = HuggingFaceEmbeddings(model_name=model_name) + model_kwargs = {'device': 'cpu'} + hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) """ client: Any #: :meta private: @@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): cache_folder: Optional[str] = None """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass to the model.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): import sentence_transformers self.client = sentence_transformers.SentenceTransformer( - self.model_name, self.cache_folder + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) except ImportError: raise ValueError( @@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain.embeddings import HuggingFaceInstructEmbeddings + model_name = "hkunlp/instructor-large" - hf = HuggingFaceInstructEmbeddings(model_name=model_name) + model_kwargs = {'device': 'cpu'} + hf = HuggingFaceInstructEmbeddings( + model_name=model_name, model_kwargs=model_kwargs + ) """ client: Any #: :meta private: model_name: str = DEFAULT_INSTRUCT_MODEL """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass to the model.""" embed_instruction: str = DEFAULT_EMBED_INSTRUCTION """Instruction to use for embedding documents.""" query_instruction: str = DEFAULT_QUERY_INSTRUCTION @@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): try: from InstructorEmbedding import INSTRUCTOR - self.client = INSTRUCTOR(self.model_name) + self.client = INSTRUCTOR( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) except ImportError as e: raise ValueError("Dependencies for InstructorEmbedding not found.") from e