forked from Archives/langchain
Hf emb device (#3266)
Make it possible to control the HuggingFaceEmbeddings and HuggingFaceInstructEmbeddings client model kwargs. Additionally, the cache folder was added for HuggingFaceInstructEmbedding as the client inherits from SentenceTransformer (client of HuggingFaceEmbeddings). It can be useful, especially to control the client device, as it will be defaulted to GPU by sentence_transformers if there is any. --------- Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com>
This commit is contained in:
parent
d7942a9f19
commit
36720cb57f
@ -1,7 +1,7 @@
|
||||
"""Wrapper around HuggingFace embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
hf = HuggingFaceEmbeddings(model_name=model_name)
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
import sentence_transformers
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, self.cache_folder
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||
|
||||
model_name = "hkunlp/instructor-large"
|
||||
hf = HuggingFaceInstructEmbeddings(model_name=model_name)
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
hf = HuggingFaceInstructEmbeddings(
|
||||
model_name=model_name, model_kwargs=model_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_INSTRUCT_MODEL
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to pass to the model."""
|
||||
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
||||
"""Instruction to use for embedding documents."""
|
||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||
@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
|
||||
self.client = INSTRUCTOR(self.model_name)
|
||||
self.client = INSTRUCTOR(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ValueError("Dependencies for InstructorEmbedding not found.") from e
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user