Hf emb device (#3266)

Make it possible to control the HuggingFaceEmbeddings and HuggingFaceInstructEmbeddings client model kwargs. Additionally, the cache folder was added for HuggingFaceInstructEmbedding as the client inherits from SentenceTransformer (client of HuggingFaceEmbeddings).

It can be useful, especially to control the client device, as it will be defaulted to GPU by sentence_transformers if there is any.

---------

Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com>
This commit is contained in:
Davis Chase 2023-04-20 20:41:22 -07:00 committed by GitHub
parent d7942a9f19
commit 36720cb57f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
"""Wrapper around HuggingFace embedding models."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, Field
from langchain.embeddings.base import Embeddings
@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
.. code-block:: python
from langchain.embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
hf = HuggingFaceEmbeddings(model_name=model_name)
model_kwargs = {'device': 'cpu'}
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
"""
client: Any #: :meta private:
@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
import sentence_transformers
self.client = sentence_transformers.SentenceTransformer(
self.model_name, self.cache_folder
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
except ImportError:
raise ValueError(
@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
.. code-block:: python
from langchain.embeddings import HuggingFaceInstructEmbeddings
model_name = "hkunlp/instructor-large"
hf = HuggingFaceInstructEmbeddings(model_name=model_name)
model_kwargs = {'device': 'cpu'}
hf = HuggingFaceInstructEmbeddings(
model_name=model_name, model_kwargs=model_kwargs
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass to the model."""
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
"""Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
try:
from InstructorEmbedding import INSTRUCTOR
self.client = INSTRUCTOR(self.model_name)
self.client = INSTRUCTOR(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
except ImportError as e:
raise ValueError("Dependencies for InstructorEmbedding not found.") from e