langchain/libs/community/tests/integration_tests/embeddings/test_fastembed.py
Eugene Yurtsev 25fbe356b4
community[patch]: upgrade to recent version of mypy (#21616)
This PR upgrades community to a recent version of mypy. It inserts type:
ignore on all existing failures.
2024-05-13 14:55:07 -04:00

75 lines
2.7 KiB
Python

"""Test FastEmbed embeddings."""
import pytest
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
def test_fastembed_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings( # type: ignore[call-arg]
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
threads=threads,
)
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
def test_fastembed_embedding_query(model_name: str, max_length: int) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
output = embedding.embed_query(document)
assert len(output) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
async def test_fastembed_async_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings( # type: ignore[call-arg]
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
threads=threads,
)
output = await embedding.aembed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
async def test_fastembed_async_embedding_query(
model_name: str, max_length: int
) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
output = await embedding.aembed_query(document)
assert len(output) == 384