community: Fix FastEmbedEmbeddings (#24462)

## Description

This PR:
- Fixes the validation error in `FastEmbedEmbeddings`.
- Adds support for `batch_size`, `parallel` params.
- Removes support for very old FastEmbed versions.
- Updates the FastEmbed doc with the new params.

Associated Issues:
- Resolves #24039
- Resolves #https://github.com/qdrant/fastembed/issues/296
This commit is contained in:
Anush 2024-07-30 22:12:46 +05:30 committed by GitHub
parent 73ec24fc56
commit 51b15448cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 35 deletions

View File

@ -73,16 +73,25 @@
"- `max_length: int` (default: 512)\n", "- `max_length: int` (default: 512)\n",
" > The maximum number of tokens. Unknown behavior for values > 512.\n", " > The maximum number of tokens. Unknown behavior for values > 512.\n",
"\n", "\n",
"- `cache_dir: Optional[str]`\n", "- `cache_dir: Optional[str]` (default: None)\n",
" > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n", " > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n",
"\n", "\n",
"- `threads: Optional[int]`\n", "- `threads: Optional[int]` (default: None)\n",
" > The number of threads a single onnxruntime session can use. Defaults to None.\n", " > The number of threads a single onnxruntime session can use.\n",
"\n", "\n",
"- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n", "- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n",
" > \"default\": Uses FastEmbed's default embedding method.\n", " > \"default\": Uses FastEmbed's default embedding method.\n",
" \n", " \n",
" > \"passage\": Prefixes the text with \"passage\" before embedding." " > \"passage\": Prefixes the text with \"passage\" before embedding.\n",
"\n",
"- `batch_size: int` (default: 256)\n",
" > Batch size for encoding. Higher values will use more memory, but be faster.\n",
"\n",
"- `parallel: Optional[int]` (default: None)\n",
"\n",
" > If `>1`, data-parallel encoding will be used, recommended for offline encoding of large datasets.\n",
" > If `0`, use all available cores.\n",
" > If `None`, don't use data-parallel processing, use default onnxruntime threading instead."
] ]
}, },
{ {

View File

@ -317,7 +317,7 @@
"To search with only dense vectors,\n", "To search with only dense vectors,\n",
"\n", "\n",
"- The `retrieval_mode` parameter should be set to `RetrievalMode.DENSE`(default).\n", "- The `retrieval_mode` parameter should be set to `RetrievalMode.DENSE`(default).\n",
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter." "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter."
] ]
}, },
{ {
@ -407,7 +407,7 @@
"To perform a hybrid search using dense and sparse vectors with score fusion,\n", "To perform a hybrid search using dense and sparse vectors with score fusion,\n",
"\n", "\n",
"- The `retrieval_mode` parameter should be set to `RetrievalMode.HYBRID`.\n", "- The `retrieval_mode` parameter should be set to `RetrievalMode.HYBRID`.\n",
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter.\n", "- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter.\n",
"- An implementation of the [`SparseEmbeddings`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/sparse_embeddings.py) interface using any sparse embeddings provider has to be provided as value to the `sparse_embedding` parameter.\n", "- An implementation of the [`SparseEmbeddings`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/sparse_embeddings.py) interface using any sparse embeddings provider has to be provided as value to the `sparse_embedding` parameter.\n",
"\n", "\n",
"Note that if you've added documents with the `HYBRID` mode, you can switch to any retrieval mode when searching. Since both the dense and sparse vectors are available in the collection." "Note that if you've added documents with the `HYBRID` mode, you can switch to any retrieval mode when searching. Since both the dense and sparse vectors are available in the collection."

View File

@ -1,3 +1,5 @@
import importlib
import importlib.metadata
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
import numpy as np import numpy as np
@ -5,6 +7,8 @@ from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
MIN_VERSION = "0.2.0"
class FastEmbedEmbeddings(BaseModel, Embeddings): class FastEmbedEmbeddings(BaseModel, Embeddings):
"""Qdrant FastEmbedding models. """Qdrant FastEmbedding models.
@ -48,12 +52,24 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
The available options are: "default" and "passage" The available options are: "default" and "passage"
""" """
batch_size: int = 256
"""Batch size for encoding. Higher values will use more memory, but be faster.
Defaults to 256.
"""
parallel: Optional[int] = None
"""If `>1`, parallel encoding is used, recommended for encoding of large datasets.
If `0`, use all available cores.
If `None`, don't use data-parallel processing, use default onnxruntime threading.
Defaults to `None`.
"""
_model: Any # : :meta private: _model: Any # : :meta private:
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.allow
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -64,31 +80,25 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
threads = values.get("threads") threads = values.get("threads")
try: try:
# >= v0.2.0 fastembed = importlib.import_module("fastembed")
from fastembed import TextEmbedding
values["_model"] = TextEmbedding( except ModuleNotFoundError:
model_name=model_name, raise ImportError(
max_length=max_length, "Could not import 'fastembed' Python package. "
cache_dir=cache_dir, "Please install it with `pip install fastembed`."
threads=threads,
) )
except ImportError as ie:
try:
# < v0.2.0
from fastembed.embedding import FlagEmbedding
values["_model"] = FlagEmbedding( if importlib.metadata.version("fastembed") < MIN_VERSION:
model_name=model_name, raise ImportError(
max_length=max_length, 'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
cache_dir=cache_dir, )
threads=threads,
) values["_model"] = fastembed.TextEmbedding(
except ImportError: model_name=model_name,
raise ImportError( max_length=max_length,
"Could not import 'fastembed' Python package. " cache_dir=cache_dir,
"Please install it with `pip install fastembed`." threads=threads,
) from ie )
return values return values
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
@ -102,9 +112,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
""" """
embeddings: List[np.ndarray] embeddings: List[np.ndarray]
if self.doc_embed_type == "passage": if self.doc_embed_type == "passage":
embeddings = self._model.passage_embed(texts) embeddings = self._model.passage_embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
else: else:
embeddings = self._model.embed(texts) embeddings = self._model.embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
return [e.tolist() for e in embeddings] return [e.tolist() for e in embeddings]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@ -116,5 +130,9 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
query_embeddings: np.ndarray = next(self._model.query_embed(text)) query_embeddings: np.ndarray = next(
self._model.query_embed(
text, batch_size=self.batch_size, parallel=self.parallel
)
)
return query_embeddings.tolist() return query_embeddings.tolist()

View File

@ -11,8 +11,9 @@ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
@pytest.mark.parametrize("max_length", [50, 512]) @pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"]) @pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10]) @pytest.mark.parametrize("threads", [0, 10])
@pytest.mark.parametrize("batch_size", [1, 10])
def test_fastembed_embedding_documents( def test_fastembed_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int model_name: str, max_length: int, doc_embed_type: str, threads: int, batch_size: int
) -> None: ) -> None:
"""Test fastembed embeddings for documents.""" """Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"] documents = ["foo bar", "bar foo"]
@ -21,6 +22,7 @@ def test_fastembed_embedding_documents(
max_length=max_length, max_length=max_length,
doc_embed_type=doc_embed_type, # type: ignore[arg-type] doc_embed_type=doc_embed_type, # type: ignore[arg-type]
threads=threads, threads=threads,
batch_size=batch_size,
) )
output = embedding.embed_documents(documents) output = embedding.embed_documents(documents)
assert len(output) == 2 assert len(output) == 2
@ -31,10 +33,15 @@ def test_fastembed_embedding_documents(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"] "model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
) )
@pytest.mark.parametrize("max_length", [50, 512]) @pytest.mark.parametrize("max_length", [50, 512])
def test_fastembed_embedding_query(model_name: str, max_length: int) -> None: @pytest.mark.parametrize("batch_size", [1, 10])
def test_fastembed_embedding_query(
model_name: str, max_length: int, batch_size: int
) -> None:
"""Test fastembed embeddings for query.""" """Test fastembed embeddings for query."""
document = "foo bar" document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg] embedding = FastEmbedEmbeddings(
model_name=model_name, max_length=max_length, batch_size=batch_size
) # type: ignore[call-arg]
output = embedding.embed_query(document) output = embedding.embed_query(document)
assert len(output) == 384 assert len(output) == 384