Hf emb device (#3266)

Make it possible to control the HuggingFaceEmbeddings and HuggingFaceInstructEmbeddings client model kwargs. Additionally, the cache folder was added for HuggingFaceInstructEmbedding as the client inherits from SentenceTransformer (client of HuggingFaceEmbeddings).

It can be useful, especially to control the client device, as it will be defaulted to GPU by sentence_transformers if there is any.

---------

Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com>
This commit is contained in:
Davis Chase 2023-04-20 20:41:22 -07:00 committed by GitHub
parent d7942a9f19
commit 36720cb57f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
"""Wrapper around HuggingFace embedding models.""" """Wrapper around HuggingFace embedding models."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra, Field
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
.. code-block:: python .. code-block:: python
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2" model_name = "sentence-transformers/all-mpnet-base-v2"
hf = HuggingFaceEmbeddings(model_name=model_name) model_kwargs = {'device': 'cpu'}
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
""" """
client: Any #: :meta private: client: Any #: :meta private:
@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder: Optional[str] = None cache_folder: Optional[str] = None
"""Path to store models. """Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass to the model."""
def __init__(self, **kwargs: Any): def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer.""" """Initialize the sentence_transformer."""
@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
import sentence_transformers import sentence_transformers
self.client = sentence_transformers.SentenceTransformer( self.client = sentence_transformers.SentenceTransformer(
self.model_name, self.cache_folder self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
) )
except ImportError: except ImportError:
raise ValueError( raise ValueError(
@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
.. code-block:: python .. code-block:: python
from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.embeddings import HuggingFaceInstructEmbeddings
model_name = "hkunlp/instructor-large" model_name = "hkunlp/instructor-large"
hf = HuggingFaceInstructEmbeddings(model_name=model_name) model_kwargs = {'device': 'cpu'}
hf = HuggingFaceInstructEmbeddings(
model_name=model_name, model_kwargs=model_kwargs
)
""" """
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = DEFAULT_INSTRUCT_MODEL model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use.""" """Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass to the model."""
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
"""Instruction to use for embedding documents.""" """Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION query_instruction: str = DEFAULT_QUERY_INSTRUCTION
@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
try: try:
from InstructorEmbedding import INSTRUCTOR from InstructorEmbedding import INSTRUCTOR
self.client = INSTRUCTOR(self.model_name) self.client = INSTRUCTOR(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
except ImportError as e: except ImportError as e:
raise ValueError("Dependencies for InstructorEmbedding not found.") from e raise ValueError("Dependencies for InstructorEmbedding not found.") from e