From 5db6b796cf98caf4dcb168bb807ea92e70607cda Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 1 May 2023 20:27:41 -0700 Subject: [PATCH] Dev2049/hf emb encode kwargs (#3925) Thanks @amogkam for the addition! Refactored slightly --------- Co-authored-by: Amog Kamsetty --- langchain/embeddings/huggingface.py | 10 ++++++---- tests/integration_tests/embeddings/test_huggingface.py | 5 +---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 24256227..e00df753 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -36,6 +36,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Key word arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass when calling the `encode` method of the model.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -68,7 +70,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = self.client.encode(texts) + embeddings = self.client.encode(texts, **self.encode_kwargs) return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -81,7 +83,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ text = text.replace("\n", " ") - embedding = self.client.encode(text) + embedding = self.client.encode(text, **self.encode_kwargs) return embedding.tolist() @@ -89,7 +91,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): """Wrapper around sentence_transformers embedding models. To use, you should have the ``sentence_transformers`` - and ``InstructorEmbedding`` python package installed. + and ``InstructorEmbedding`` python packages installed. Example: .. code-block:: python @@ -108,7 +110,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): """Model name to use.""" cache_folder: Optional[str] = None """Path to store models. - Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Key word arguments to pass to the model.""" embed_instruction: str = DEFAULT_EMBED_INSTRUCTION diff --git a/tests/integration_tests/embeddings/test_huggingface.py b/tests/integration_tests/embeddings/test_huggingface.py index 4c941580..6928d6a2 100644 --- a/tests/integration_tests/embeddings/test_huggingface.py +++ b/tests/integration_tests/embeddings/test_huggingface.py @@ -1,5 +1,4 @@ """Test huggingface embeddings.""" -import unittest from langchain.embeddings.huggingface import ( HuggingFaceEmbeddings, @@ -7,7 +6,6 @@ from langchain.embeddings.huggingface import ( ) -@unittest.skip("This test causes a segfault.") def test_huggingface_embedding_documents() -> None: """Test huggingface embeddings.""" documents = ["foo bar"] @@ -17,11 +15,10 @@ def test_huggingface_embedding_documents() -> None: assert len(output[0]) == 768 -@unittest.skip("This test causes a segfault.") def test_huggingface_embedding_query() -> None: """Test huggingface embeddings.""" document = "foo bar" - embedding = HuggingFaceEmbeddings() + embedding = HuggingFaceEmbeddings(encode_kwargs={"batch_size": 16}) output = embedding.embed_query(document) assert len(output) == 768