forked from Archives/langchain
Hf emb device (#3266)
Make it possible to control the HuggingFaceEmbeddings and HuggingFaceInstructEmbeddings client model kwargs. Additionally, the cache folder was added for HuggingFaceInstructEmbedding as the client inherits from SentenceTransformer (client of HuggingFaceEmbeddings). It can be useful, especially to control the client device, as it will be defaulted to GPU by sentence_transformers if there is any. --------- Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com>
This commit is contained in:
parent
d7942a9f19
commit
36720cb57f
@ -1,7 +1,7 @@
|
|||||||
"""Wrapper around HuggingFace embedding models."""
|
"""Wrapper around HuggingFace embedding models."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
@ -22,8 +22,10 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||||
hf = HuggingFaceEmbeddings(model_name=model_name)
|
model_kwargs = {'device': 'cpu'}
|
||||||
|
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
@ -32,6 +34,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
cache_folder: Optional[str] = None
|
cache_folder: Optional[str] = None
|
||||||
"""Path to store models.
|
"""Path to store models.
|
||||||
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
|
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
"""Initialize the sentence_transformer."""
|
"""Initialize the sentence_transformer."""
|
||||||
@ -40,7 +44,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
self.client = sentence_transformers.SentenceTransformer(
|
self.client = sentence_transformers.SentenceTransformer(
|
||||||
self.model_name, self.cache_folder
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -90,13 +94,22 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||||
|
|
||||||
model_name = "hkunlp/instructor-large"
|
model_name = "hkunlp/instructor-large"
|
||||||
hf = HuggingFaceInstructEmbeddings(model_name=model_name)
|
model_kwargs = {'device': 'cpu'}
|
||||||
|
hf = HuggingFaceInstructEmbeddings(
|
||||||
|
model_name=model_name, model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = DEFAULT_INSTRUCT_MODEL
|
model_name: str = DEFAULT_INSTRUCT_MODEL
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
cache_folder: Optional[str] = None
|
||||||
|
"""Path to store models.
|
||||||
|
Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
||||||
"""Instruction to use for embedding documents."""
|
"""Instruction to use for embedding documents."""
|
||||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||||
@ -108,7 +121,9 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
try:
|
try:
|
||||||
from InstructorEmbedding import INSTRUCTOR
|
from InstructorEmbedding import INSTRUCTOR
|
||||||
|
|
||||||
self.client = INSTRUCTOR(self.model_name)
|
self.client = INSTRUCTOR(
|
||||||
|
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ValueError("Dependencies for InstructorEmbedding not found.") from e
|
raise ValueError("Dependencies for InstructorEmbedding not found.") from e
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user