mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: Fix FastEmbedEmbeddings (#24462)
## Description This PR: - Fixes the validation error in `FastEmbedEmbeddings`. - Adds support for `batch_size`, `parallel` params. - Removes support for very old FastEmbed versions. - Updates the FastEmbed doc with the new params. Associated Issues: - Resolves #24039 - Resolves #https://github.com/qdrant/fastembed/issues/296
This commit is contained in:
parent
73ec24fc56
commit
51b15448cc
@ -73,16 +73,25 @@
|
|||||||
"- `max_length: int` (default: 512)\n",
|
"- `max_length: int` (default: 512)\n",
|
||||||
" > The maximum number of tokens. Unknown behavior for values > 512.\n",
|
" > The maximum number of tokens. Unknown behavior for values > 512.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- `cache_dir: Optional[str]`\n",
|
"- `cache_dir: Optional[str]` (default: None)\n",
|
||||||
" > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n",
|
" > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- `threads: Optional[int]`\n",
|
"- `threads: Optional[int]` (default: None)\n",
|
||||||
" > The number of threads a single onnxruntime session can use. Defaults to None.\n",
|
" > The number of threads a single onnxruntime session can use.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n",
|
"- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n",
|
||||||
" > \"default\": Uses FastEmbed's default embedding method.\n",
|
" > \"default\": Uses FastEmbed's default embedding method.\n",
|
||||||
" \n",
|
" \n",
|
||||||
" > \"passage\": Prefixes the text with \"passage\" before embedding."
|
" > \"passage\": Prefixes the text with \"passage\" before embedding.\n",
|
||||||
|
"\n",
|
||||||
|
"- `batch_size: int` (default: 256)\n",
|
||||||
|
" > Batch size for encoding. Higher values will use more memory, but be faster.\n",
|
||||||
|
"\n",
|
||||||
|
"- `parallel: Optional[int]` (default: None)\n",
|
||||||
|
"\n",
|
||||||
|
" > If `>1`, data-parallel encoding will be used, recommended for offline encoding of large datasets.\n",
|
||||||
|
" > If `0`, use all available cores.\n",
|
||||||
|
" > If `None`, don't use data-parallel processing, use default onnxruntime threading instead."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -317,7 +317,7 @@
|
|||||||
"To search with only dense vectors,\n",
|
"To search with only dense vectors,\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- The `retrieval_mode` parameter should be set to `RetrievalMode.DENSE`(default).\n",
|
"- The `retrieval_mode` parameter should be set to `RetrievalMode.DENSE`(default).\n",
|
||||||
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter."
|
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -407,7 +407,7 @@
|
|||||||
"To perform a hybrid search using dense and sparse vectors with score fusion,\n",
|
"To perform a hybrid search using dense and sparse vectors with score fusion,\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- The `retrieval_mode` parameter should be set to `RetrievalMode.HYBRID`.\n",
|
"- The `retrieval_mode` parameter should be set to `RetrievalMode.HYBRID`.\n",
|
||||||
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided for the `embedding` parameter.\n",
|
"- A [dense embeddings](https://python.langchain.com/v0.2/docs/integrations/text_embedding/) value should be provided to the `embedding` parameter.\n",
|
||||||
"- An implementation of the [`SparseEmbeddings`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/sparse_embeddings.py) interface using any sparse embeddings provider has to be provided as value to the `sparse_embedding` parameter.\n",
|
"- An implementation of the [`SparseEmbeddings`](https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/sparse_embeddings.py) interface using any sparse embeddings provider has to be provided as value to the `sparse_embedding` parameter.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Note that if you've added documents with the `HYBRID` mode, you can switch to any retrieval mode when searching. Since both the dense and sparse vectors are available in the collection."
|
"Note that if you've added documents with the `HYBRID` mode, you can switch to any retrieval mode when searching. Since both the dense and sparse vectors are available in the collection."
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -5,6 +7,8 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
|
MIN_VERSION = "0.2.0"
|
||||||
|
|
||||||
|
|
||||||
class FastEmbedEmbeddings(BaseModel, Embeddings):
|
class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||||
"""Qdrant FastEmbedding models.
|
"""Qdrant FastEmbedding models.
|
||||||
@ -48,12 +52,24 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
The available options are: "default" and "passage"
|
The available options are: "default" and "passage"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
batch_size: int = 256
|
||||||
|
"""Batch size for encoding. Higher values will use more memory, but be faster.
|
||||||
|
Defaults to 256.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parallel: Optional[int] = None
|
||||||
|
"""If `>1`, parallel encoding is used, recommended for encoding of large datasets.
|
||||||
|
If `0`, use all available cores.
|
||||||
|
If `None`, don't use data-parallel processing, use default onnxruntime threading.
|
||||||
|
Defaults to `None`.
|
||||||
|
"""
|
||||||
|
|
||||||
_model: Any # : :meta private:
|
_model: Any # : :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.allow
|
||||||
|
|
||||||
@pre_init
|
@pre_init
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -64,31 +80,25 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
threads = values.get("threads")
|
threads = values.get("threads")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# >= v0.2.0
|
fastembed = importlib.import_module("fastembed")
|
||||||
from fastembed import TextEmbedding
|
|
||||||
|
|
||||||
values["_model"] = TextEmbedding(
|
except ModuleNotFoundError:
|
||||||
model_name=model_name,
|
raise ImportError(
|
||||||
max_length=max_length,
|
"Could not import 'fastembed' Python package. "
|
||||||
cache_dir=cache_dir,
|
"Please install it with `pip install fastembed`."
|
||||||
threads=threads,
|
|
||||||
)
|
)
|
||||||
except ImportError as ie:
|
|
||||||
try:
|
|
||||||
# < v0.2.0
|
|
||||||
from fastembed.embedding import FlagEmbedding
|
|
||||||
|
|
||||||
values["_model"] = FlagEmbedding(
|
if importlib.metadata.version("fastembed") < MIN_VERSION:
|
||||||
model_name=model_name,
|
raise ImportError(
|
||||||
max_length=max_length,
|
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
||||||
cache_dir=cache_dir,
|
)
|
||||||
threads=threads,
|
|
||||||
)
|
values["_model"] = fastembed.TextEmbedding(
|
||||||
except ImportError:
|
model_name=model_name,
|
||||||
raise ImportError(
|
max_length=max_length,
|
||||||
"Could not import 'fastembed' Python package. "
|
cache_dir=cache_dir,
|
||||||
"Please install it with `pip install fastembed`."
|
threads=threads,
|
||||||
) from ie
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
@ -102,9 +112,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
embeddings: List[np.ndarray]
|
embeddings: List[np.ndarray]
|
||||||
if self.doc_embed_type == "passage":
|
if self.doc_embed_type == "passage":
|
||||||
embeddings = self._model.passage_embed(texts)
|
embeddings = self._model.passage_embed(
|
||||||
|
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = self._model.embed(texts)
|
embeddings = self._model.embed(
|
||||||
|
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||||
|
)
|
||||||
return [e.tolist() for e in embeddings]
|
return [e.tolist() for e in embeddings]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -116,5 +130,9 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
query_embeddings: np.ndarray = next(self._model.query_embed(text))
|
query_embeddings: np.ndarray = next(
|
||||||
|
self._model.query_embed(
|
||||||
|
text, batch_size=self.batch_size, parallel=self.parallel
|
||||||
|
)
|
||||||
|
)
|
||||||
return query_embeddings.tolist()
|
return query_embeddings.tolist()
|
||||||
|
@ -11,8 +11,9 @@ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
|
|||||||
@pytest.mark.parametrize("max_length", [50, 512])
|
@pytest.mark.parametrize("max_length", [50, 512])
|
||||||
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
|
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
|
||||||
@pytest.mark.parametrize("threads", [0, 10])
|
@pytest.mark.parametrize("threads", [0, 10])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 10])
|
||||||
def test_fastembed_embedding_documents(
|
def test_fastembed_embedding_documents(
|
||||||
model_name: str, max_length: int, doc_embed_type: str, threads: int
|
model_name: str, max_length: int, doc_embed_type: str, threads: int, batch_size: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test fastembed embeddings for documents."""
|
"""Test fastembed embeddings for documents."""
|
||||||
documents = ["foo bar", "bar foo"]
|
documents = ["foo bar", "bar foo"]
|
||||||
@ -21,6 +22,7 @@ def test_fastembed_embedding_documents(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
|
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
|
||||||
threads=threads,
|
threads=threads,
|
||||||
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
@ -31,10 +33,15 @@ def test_fastembed_embedding_documents(
|
|||||||
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
|
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_length", [50, 512])
|
@pytest.mark.parametrize("max_length", [50, 512])
|
||||||
def test_fastembed_embedding_query(model_name: str, max_length: int) -> None:
|
@pytest.mark.parametrize("batch_size", [1, 10])
|
||||||
|
def test_fastembed_embedding_query(
|
||||||
|
model_name: str, max_length: int, batch_size: int
|
||||||
|
) -> None:
|
||||||
"""Test fastembed embeddings for query."""
|
"""Test fastembed embeddings for query."""
|
||||||
document = "foo bar"
|
document = "foo bar"
|
||||||
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
|
embedding = FastEmbedEmbeddings(
|
||||||
|
model_name=model_name, max_length=max_length, batch_size=batch_size
|
||||||
|
) # type: ignore[call-arg]
|
||||||
output = embedding.embed_query(document)
|
output = embedding.embed_query(document)
|
||||||
assert len(output) == 384
|
assert len(output) == 384
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user