From 51b15448cc16f670d9d74f233451cbd98a4d3d54 Mon Sep 17 00:00:00 2001 From: Anush Date: Tue, 30 Jul 2024 22:12:46 +0530 Subject: [PATCH] community: Fix FastEmbedEmbeddings (#24462) ## Description This PR: - Fixes the validation error in `FastEmbedEmbeddings`. - Adds support for `batch_size`, `parallel` params. - Removes support for very old FastEmbed versions. - Updates the FastEmbed doc with the new params. Associated Issues: - Resolves #24039 - Resolves #https://github.com/qdrant/fastembed/issues/296 --- .../text_embedding/fastembed.ipynb | 17 ++++- .../integrations/vectorstores/qdrant.ipynb | 4 +- .../embeddings/fastembed.py | 74 ++++++++++++------- .../embeddings/test_fastembed.py | 13 +++- 4 files changed, 71 insertions(+), 37 deletions(-) diff --git a/docs/docs/integrations/text_embedding/fastembed.ipynb b/docs/docs/integrations/text_embedding/fastembed.ipynb index 311de53a54..2bb56a7005 100644 --- a/docs/docs/integrations/text_embedding/fastembed.ipynb +++ b/docs/docs/integrations/text_embedding/fastembed.ipynb @@ -73,16 +73,25 @@ "- `max_length: int` (default: 512)\n", " > The maximum number of tokens. Unknown behavior for values > 512.\n", "\n", - "- `cache_dir: Optional[str]`\n", + "- `cache_dir: Optional[str]` (default: None)\n", " > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n", "\n", - "- `threads: Optional[int]`\n", - " > The number of threads a single onnxruntime session can use. Defaults to None.\n", + "- `threads: Optional[int]` (default: None)\n", + " > The number of threads a single onnxruntime session can use.\n", "\n", "- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n", " > \"default\": Uses FastEmbed's default embedding method.\n", " \n", - " > \"passage\": Prefixes the text with \"passage\" before embedding." + " > \"passage\": Prefixes the text with \"passage\" before embedding.\n", + "\n", + "- `batch_size: int` (default: 256)\n", + " > Batch size for encoding. Higher values will use more memory, but be faster.\n", + "\n", + "- `parallel: Optional[int]` (default: None)\n", + "\n", + " > If `>1`, data-parallel encoding will be used, recommended for offline encoding of large datasets.\n", + " > If `0`, use all available cores.\n", + " > If `None`, don't use data-parallel processing, use default onnxruntime threading instead." ] }, { diff --git a/docs/docs/integrations/vectorstores/qdrant.ipynb b/docs/docs/integrations/vectorstores/qdrant.ipynb index bf9eed3266..e6d7ba00ba 100644 --- a/docs/docs/integrations/vectorstores/qdrant.ipynb +++ b/docs/docs/integrations/vectorstores/qdrant.ipynb @@ -317,7 +317,7 @@ "To search with only dense vectors,\n", "\n", "- The `retrieval_mode` parameter should be set to `RetrievalMode.DENSE`(default).\n", - "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter." + "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter." ] }, { @@ -407,7 +407,7 @@ "To perform a hybrid search using dense and sparse vectors with score fusion,\n", "\n", "- The `retrieval_mode` parameter should be set to `RetrievalMode.HYBRID`.\n", - "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter.\n", + "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter.\n", "- An implementation of the [`SparseEmbeddings`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/sparse_embeddings.py) interface using any sparse embeddings provider has to be provided as value to the `sparse_embedding` parameter.\n", "\n", "Note that if you've added documents with the `HYBRID` mode, you can switch to any retrieval mode when searching. Since both the dense and sparse vectors are available in the collection." diff --git a/libs/community/langchain_community/embeddings/fastembed.py b/libs/community/langchain_community/embeddings/fastembed.py index a5ebfdebda..dd4a97922e 100644 --- a/libs/community/langchain_community/embeddings/fastembed.py +++ b/libs/community/langchain_community/embeddings/fastembed.py @@ -1,3 +1,5 @@ +import importlib +import importlib.metadata from typing import Any, Dict, List, Literal, Optional import numpy as np @@ -5,6 +7,8 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.utils import pre_init +MIN_VERSION = "0.2.0" + class FastEmbedEmbeddings(BaseModel, Embeddings): """Qdrant FastEmbedding models. @@ -48,12 +52,24 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): The available options are: "default" and "passage" """ + batch_size: int = 256 + """Batch size for encoding. Higher values will use more memory, but be faster. + Defaults to 256. + """ + + parallel: Optional[int] = None + """If `>1`, parallel encoding is used, recommended for encoding of large datasets. + If `0`, use all available cores. + If `None`, don't use data-parallel processing, use default onnxruntime threading. + Defaults to `None`. + """ + _model: Any # : :meta private: class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid + extra = Extra.allow @pre_init def validate_environment(cls, values: Dict) -> Dict: @@ -64,31 +80,25 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): threads = values.get("threads") try: - # >= v0.2.0 - from fastembed import TextEmbedding - - values["_model"] = TextEmbedding( - model_name=model_name, - max_length=max_length, - cache_dir=cache_dir, - threads=threads, + fastembed = importlib.import_module("fastembed") + + except ModuleNotFoundError: + raise ImportError( + "Could not import 'fastembed' Python package. " + "Please install it with `pip install fastembed`." ) - except ImportError as ie: - try: - # < v0.2.0 - from fastembed.embedding import FlagEmbedding - - values["_model"] = FlagEmbedding( - model_name=model_name, - max_length=max_length, - cache_dir=cache_dir, - threads=threads, - ) - except ImportError: - raise ImportError( - "Could not import 'fastembed' Python package. " - "Please install it with `pip install fastembed`." - ) from ie + + if importlib.metadata.version("fastembed") < MIN_VERSION: + raise ImportError( + 'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.' + ) + + values["_model"] = fastembed.TextEmbedding( + model_name=model_name, + max_length=max_length, + cache_dir=cache_dir, + threads=threads, + ) return values def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -102,9 +112,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): """ embeddings: List[np.ndarray] if self.doc_embed_type == "passage": - embeddings = self._model.passage_embed(texts) + embeddings = self._model.passage_embed( + texts, batch_size=self.batch_size, parallel=self.parallel + ) else: - embeddings = self._model.embed(texts) + embeddings = self._model.embed( + texts, batch_size=self.batch_size, parallel=self.parallel + ) return [e.tolist() for e in embeddings] def embed_query(self, text: str) -> List[float]: @@ -116,5 +130,9 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - query_embeddings: np.ndarray = next(self._model.query_embed(text)) + query_embeddings: np.ndarray = next( + self._model.query_embed( + text, batch_size=self.batch_size, parallel=self.parallel + ) + ) return query_embeddings.tolist() diff --git a/libs/community/tests/integration_tests/embeddings/test_fastembed.py b/libs/community/tests/integration_tests/embeddings/test_fastembed.py index 9aa2027ca6..09cede659b 100644 --- a/libs/community/tests/integration_tests/embeddings/test_fastembed.py +++ b/libs/community/tests/integration_tests/embeddings/test_fastembed.py @@ -11,8 +11,9 @@ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings @pytest.mark.parametrize("max_length", [50, 512]) @pytest.mark.parametrize("doc_embed_type", ["default", "passage"]) @pytest.mark.parametrize("threads", [0, 10]) +@pytest.mark.parametrize("batch_size", [1, 10]) def test_fastembed_embedding_documents( - model_name: str, max_length: int, doc_embed_type: str, threads: int + model_name: str, max_length: int, doc_embed_type: str, threads: int, batch_size: int ) -> None: """Test fastembed embeddings for documents.""" documents = ["foo bar", "bar foo"] @@ -21,6 +22,7 @@ def test_fastembed_embedding_documents( max_length=max_length, doc_embed_type=doc_embed_type, # type: ignore[arg-type] threads=threads, + batch_size=batch_size, ) output = embedding.embed_documents(documents) assert len(output) == 2 @@ -31,10 +33,15 @@ def test_fastembed_embedding_documents( "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] ) @pytest.mark.parametrize("max_length", [50, 512]) -def test_fastembed_embedding_query(model_name: str, max_length: int) -> None: +@pytest.mark.parametrize("batch_size", [1, 10]) +def test_fastembed_embedding_query( + model_name: str, max_length: int, batch_size: int +) -> None: """Test fastembed embeddings for query.""" document = "foo bar" - embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg] + embedding = FastEmbedEmbeddings( + model_name=model_name, max_length=max_length, batch_size=batch_size + ) # type: ignore[call-arg] output = embedding.embed_query(document) assert len(output) == 384