mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
add embeddings integration tests (#25508)
This commit is contained in:
parent
a06818a654
commit
a2e90a5a43
@ -1,20 +1,17 @@
|
||||
"""Test Ollama embeddings."""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_standard_tests.integration_tests import EmbeddingsIntegrationTests
|
||||
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
|
||||
|
||||
def test_langchain_ollama_embedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OllamaEmbeddings(model="llama3")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> Type[OllamaEmbeddings]:
|
||||
return OllamaEmbeddings
|
||||
|
||||
|
||||
def test_langchain_ollama_embedding_query() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = OllamaEmbeddings(model="llama3")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
||||
@property
|
||||
def embedding_model_params(self) -> dict:
|
||||
return {"model": "llama3:latest"}
|
||||
|
@ -9,6 +9,7 @@ modules = [
|
||||
"cache",
|
||||
"chat_models",
|
||||
"vectorstores",
|
||||
"embeddings",
|
||||
]
|
||||
|
||||
for module in modules:
|
||||
@ -19,7 +20,11 @@ for module in modules:
|
||||
from langchain_standard_tests.integration_tests.chat_models import (
|
||||
ChatModelIntegrationTests,
|
||||
)
|
||||
from langchain_standard_tests.integration_tests.embeddings import (
|
||||
EmbeddingsIntegrationTests,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatModelIntegrationTests",
|
||||
"EmbeddingsIntegrationTests",
|
||||
]
|
||||
|
@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsTests
|
||||
|
||||
|
||||
class EmbeddingsIntegrationTests(EmbeddingsTests):
|
||||
def test_embed_query(self, model: Embeddings) -> None:
|
||||
embedding_1 = model.embed_query("foo")
|
||||
|
||||
assert isinstance(embedding_1, List)
|
||||
assert isinstance(embedding_1[0], float)
|
||||
|
||||
embedding_2 = model.embed_query("bar")
|
||||
|
||||
assert len(embedding_1) > 0
|
||||
assert len(embedding_1) == len(embedding_2)
|
||||
|
||||
def test_embed_documents(self, model: Embeddings) -> None:
|
||||
documents = ["foo", "bar", "baz"]
|
||||
embeddings = model.embed_documents(documents)
|
||||
|
||||
assert len(embeddings) == len(documents)
|
||||
assert all(isinstance(embedding, List) for embedding in embeddings)
|
||||
assert all(isinstance(embedding[0], float) for embedding in embeddings)
|
||||
assert len(embeddings[0]) > 0
|
||||
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
||||
|
||||
async def test_aembed_query(self, model: Embeddings) -> None:
|
||||
embedding_1 = await model.aembed_query("foo")
|
||||
|
||||
assert isinstance(embedding_1, List)
|
||||
assert isinstance(embedding_1[0], float)
|
||||
|
||||
embedding_2 = await model.aembed_query("bar")
|
||||
|
||||
assert len(embedding_1) > 0
|
||||
assert len(embedding_1) == len(embedding_2)
|
||||
|
||||
async def test_aembed_documents(self, model: Embeddings) -> None:
|
||||
documents = ["foo", "bar", "baz"]
|
||||
embeddings = await model.aembed_documents(documents)
|
||||
|
||||
assert len(embeddings) == len(documents)
|
||||
assert all(isinstance(embedding, List) for embedding in embeddings)
|
||||
assert all(isinstance(embedding[0], float) for embedding in embeddings)
|
||||
assert len(embeddings[0]) > 0
|
||||
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
@ -6,6 +6,7 @@ import pytest
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting
|
||||
modules = [
|
||||
"chat_models",
|
||||
"embeddings",
|
||||
]
|
||||
|
||||
for module in modules:
|
||||
@ -13,4 +14,4 @@ for module in modules:
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
|
||||
|
||||
__all__ = ["ChatModelUnitTests"]
|
||||
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"]
|
||||
|
@ -0,0 +1,28 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_standard_tests.base import BaseStandardTests
|
||||
|
||||
|
||||
class EmbeddingsTests(BaseStandardTests):
|
||||
@property
|
||||
@abstractmethod
|
||||
def embeddings_class(self) -> Type[Embeddings]:
|
||||
...
|
||||
|
||||
@property
|
||||
def embedding_model_params(self) -> dict:
|
||||
return {}
|
||||
|
||||
@pytest.fixture
|
||||
def model(self) -> Embeddings:
|
||||
return self.embeddings_class(**self.embedding_model_params)
|
||||
|
||||
|
||||
class EmbeddingsUnitTests(EmbeddingsTests):
|
||||
def test_init(self) -> None:
|
||||
model = self.embeddings_class(**self.embedding_model_params)
|
||||
assert model is not None
|
Loading…
Reference in New Issue
Block a user