From ef22ebe431a40d23c444ecc65f4e8c8402fd4e33 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 18 Jul 2024 17:41:11 -0400 Subject: [PATCH] standard-tests[patch]: Add pytest assert rewrites (#24408) This will surface nice error messages in subclasses that fail assertions. --- .../tests/unit_tests/stores/test_in_memory.py | 29 +++++++++++++++++++ .../integration_tests/__init__.py | 18 ++++++++++++ .../integration_tests/vectorstores.py | 1 + .../unit_tests/__init__.py | 13 +++++++++ .../unit_tests/test_in_memory_base_store.py | 1 + 5 files changed, 62 insertions(+) diff --git a/libs/core/tests/unit_tests/stores/test_in_memory.py b/libs/core/tests/unit_tests/stores/test_in_memory.py index d664954f6b..6c2346e393 100644 --- a/libs/core/tests/unit_tests/stores/test_in_memory.py +++ b/libs/core/tests/unit_tests/stores/test_in_memory.py @@ -1,6 +1,35 @@ +from typing import Tuple + +import pytest +from langchain_standard_tests.integration_tests.base_store import ( + BaseStoreAsyncTests, + BaseStoreSyncTests, +) + from langchain_core.stores import InMemoryStore +# Check against standard tests +class TestSyncInMemoryStore(BaseStoreSyncTests): + @pytest.fixture + def kv_store(self) -> InMemoryStore: + return InMemoryStore() + + @pytest.fixture + def three_values(self) -> Tuple[str, str, str]: # type: ignore + return "value1", "value2", "value3" + + +class TestAsyncInMemoryStore(BaseStoreAsyncTests): + @pytest.fixture + async def kv_store(self) -> InMemoryStore: + return InMemoryStore() + + @pytest.fixture + def three_values(self) -> Tuple[str, str, str]: # type: ignore + return "value1", "value2", "value3" + + def test_mget() -> None: store = InMemoryStore() store.mset([("key1", "value1"), ("key2", "value2")]) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py b/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py index dbf12101d1..f304bff238 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py @@ -1,3 +1,21 @@ +# ruff: noqa: E402 +import pytest + +# Rewrite assert statements for test suite so that implementations can +# see the full error message from failed asserts. +# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting +modules = [ + "base_store", + "cache", + "chat_models", + "vectorstores", +] + +for module in modules: + pytest.register_assert_rewrite( + f"langchain_standard_tests.integration_tests.{module}" + ) + from langchain_standard_tests.integration_tests.chat_models import ( ChatModelIntegrationTests, ) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py b/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py index 4e9dc90505..83770099ab 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py @@ -1,4 +1,5 @@ """Test suite to test vectostores.""" + import inspect from abc import ABC, abstractmethod diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py b/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py index eabff172d4..418330b5dc 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py @@ -1,3 +1,16 @@ +# ruff: noqa: E402 +import pytest + +# Rewrite assert statements for test suite so that implementations can +# see the full error message from failed asserts. +# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting +modules = [ + "chat_models", +] + +for module in modules: + pytest.register_assert_rewrite(f"langchain_standard_tests.unit_tests.{module}") + from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests __all__ = ["ChatModelUnitTests"] diff --git a/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py b/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py index 245b096554..5171c14c16 100644 --- a/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py +++ b/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py @@ -1,4 +1,5 @@ """Tests for the InMemoryStore class.""" + from typing import Tuple import pytest