From f4ee3c8a223ea4b3d7bd9e08fcd3769c70e3446c Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 19 Jul 2024 15:03:19 -0700 Subject: [PATCH] infra: add min version testing to pr test flow (#24358) xfailing some sql tests that do not currently work on sqlalchemy v1 #22207 was very much not sqlalchemy v1 compatible. Moving forward, implementations should be compatible with both to pass CI --- .github/scripts/get_min_versions.py | 7 +- .github/workflows/_test.yml | 18 +++++ .../chat_message_histories/sql.py | 7 +- .../langchain_community/storage/sql.py | 65 ++++++++++++++----- .../utilities/sql_database.py | 2 +- .../tests/unit_tests/storage/test_sql.py | 7 ++ .../tests/unit_tests/test_sql_database.py | 3 + .../unit_tests/test_sql_database_schema.py | 9 +++ 8 files changed, 97 insertions(+), 21 deletions(-) diff --git a/.github/scripts/get_min_versions.py b/.github/scripts/get_min_versions.py index 35740a43d4..ab12186f3c 100644 --- a/.github/scripts/get_min_versions.py +++ b/.github/scripts/get_min_versions.py @@ -1,6 +1,11 @@ import sys -import tomllib +if sys.version_info >= (3, 11): + import tomllib +else: + # for python 3.10 and below, which doesnt have stdlib tomllib + import tomli as tomllib + from packaging.version import parse as parse_version import re diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 841f3796f3..23fbd3e5d2 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -65,3 +65,21 @@ jobs: # grep will exit non-zero if the target message isn't found, # and `set -e` above will cause the step to fail. echo "$STATUS" | grep 'nothing to commit, working tree clean' + + - name: Get minimum versions + working-directory: ${{ inputs.working-directory }} + id: min-version + run: | + poetry run pip install packaging tomli + min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml)" + echo "min-versions=$min_versions" >> "$GITHUB_OUTPUT" + echo "min-versions=$min_versions" + + - name: Run unit tests with minimum dependency versions + if: ${{ steps.min-version.outputs.min-versions != '' }} + env: + MIN_VERSIONS: ${{ steps.min-version.outputs.min-versions }} + run: | + poetry run pip install --force-reinstall $MIN_VERSIONS --editable . + make tests + working-directory: ${{ inputs.working-directory }} diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index cb839b4c3a..32e56c8d3c 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -32,7 +32,6 @@ from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, - async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import ( @@ -44,6 +43,12 @@ from sqlalchemy.orm import ( sessionmaker, ) +try: + from sqlalchemy.ext.asyncio import async_sessionmaker +except ImportError: + # dummy for sqlalchemy < 2 + async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore + logger = logging.getLogger(__name__) diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py index 170372f4e7..aae7e1149e 100644 --- a/libs/community/langchain_community/storage/sql.py +++ b/libs/community/langchain_community/storage/sql.py @@ -17,48 +17,74 @@ from typing import ( from langchain_core.stores import BaseStore from sqlalchemy import ( - Engine, LargeBinary, + Text, and_, create_engine, delete, select, ) +from sqlalchemy.engine.base import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, - async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import ( Mapped, Session, declarative_base, - mapped_column, sessionmaker, ) +try: + from sqlalchemy.ext.asyncio import async_sessionmaker +except ImportError: + # dummy for sqlalchemy < 2 + async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore + Base = declarative_base() +try: + from sqlalchemy.orm import mapped_column + + class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "langchain_key_value_stores" + + namespace: Mapped[str] = mapped_column( + primary_key=True, index=True, nullable=False + ) + key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + value = mapped_column(LargeBinary, index=False, nullable=False) + +except ImportError: + # dummy for sqlalchemy < 2 + from sqlalchemy import Column + + class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc,no-redef] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "langchain_key_value_stores" + + namespace = Column(Text(), primary_key=True, index=True, nullable=False) + key = Column(Text(), primary_key=True, index=True, nullable=False) + value = Column(LargeBinary, index=False, nullable=False) + def items_equal(x: Any, y: Any) -> bool: return x == y -class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] - """Table used to save values.""" - - # ATTENTION: - # Prior to modifying this table, please determine whether - # we should create migrations for this table to make sure - # users do not experience data loss. - __tablename__ = "langchain_key_value_stores" - - namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) - key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) - value = mapped_column(LargeBinary, index=False, nullable=False) - - # This is a fix of original SQLStore. # This can will be removed when a PR will be merged. class SQLStore(BaseStore[str, bytes]): @@ -135,7 +161,10 @@ class SQLStore(BaseStore[str, bytes]): self.namespace = namespace def create_schema(self) -> None: - Base.metadata.create_all(self.engine) + Base.metadata.create_all(self.engine) # problem in sqlalchemy v1 + # sqlalchemy.exc.CompileError: (in table 'langchain_key_value_stores', + # column 'namespace'): Can't generate DDL for NullType(); did you forget + # to specify a type on this Column? async def acreate_schema(self) -> None: assert isinstance(self.engine, AsyncEngine) diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index de2ad3fd44..fb3e4c39f7 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -338,7 +338,7 @@ class SQLDatabase: continue # Ignore JSON datatyped columns - for k, v in table.columns.items(): + for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1 if type(v.type) is NullType: table._columns.remove(v) diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py index 084f0e2d19..1f4163224b 100644 --- a/libs/community/tests/unit_tests/storage/test_sql.py +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -1,12 +1,16 @@ from typing import AsyncGenerator, Generator, cast import pytest +import sqlalchemy as sa from langchain.storage._lc_store import create_kv_docstore, create_lc_store from langchain_core.documents import Document from langchain_core.stores import BaseStore +from packaging import version from langchain_community.storage.sql import SQLStore +is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1 + @pytest.fixture def sql_store() -> Generator[SQLStore, None, None]: @@ -22,6 +26,7 @@ async def async_sql_store() -> AsyncGenerator[SQLStore, None]: yield store +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_create_lc_store(sql_store: SQLStore) -> None: """Test that a docstore is created from a base store.""" docstore: BaseStore[str, Document] = cast( @@ -34,6 +39,7 @@ def test_create_lc_store(sql_store: SQLStore) -> None: assert fetched_doc.metadata == {"key": "value"} +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_create_kv_store(sql_store: SQLStore) -> None: """Test that a docstore is created from a base store.""" docstore = create_kv_docstore(sql_store) @@ -57,6 +63,7 @@ async def test_async_create_kv_store(async_sql_store: SQLStore) -> None: assert fetched_doc.metadata == {"key": "value"} +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_sample_sql_docstore(sql_store: SQLStore) -> None: # Set values for keys sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 6bd37d4052..6acb734a54 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -55,6 +55,7 @@ def db_lazy_reflection(engine: Engine) -> SQLDatabase: return SQLDatabase(engine, lazy_table_reflection=True) +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_table_info(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" output = db.table_info @@ -85,6 +86,7 @@ def test_table_info(db: SQLDatabase) -> None: assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_table_info_lazy_reflection(db_lazy_reflection: SQLDatabase) -> None: """Test that table info with lazy reflection""" assert len(db_lazy_reflection._metadata.sorted_tables) == 0 @@ -111,6 +113,7 @@ def test_table_info_lazy_reflection(db_lazy_reflection: SQLDatabase) -> None: assert db_lazy_reflection._metadata.sorted_tables[1].name == "user" +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") def test_table_info_w_sample_rows(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" diff --git a/libs/community/tests/unit_tests/test_sql_database_schema.py b/libs/community/tests/unit_tests/test_sql_database_schema.py index 6809b8f329..2b1e815a05 100644 --- a/libs/community/tests/unit_tests/test_sql_database_schema.py +++ b/libs/community/tests/unit_tests/test_sql_database_schema.py @@ -18,6 +18,9 @@ from sqlalchemy import ( insert, schema, ) +import sqlalchemy as sa + +from packaging import version from langchain_community.utilities.sql_database import SQLDatabase @@ -43,6 +46,9 @@ company = Table( ) +@pytest.mark.xfail( + version.parse(sa.__version__).major == 1, reason="SQLAlchemy 1.x issues" +) def test_table_info() -> None: """Test that table info is constructed properly.""" engine = create_engine("duckdb:///:memory:") @@ -65,6 +71,9 @@ def test_table_info() -> None: assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) +@pytest.mark.xfail( + version.parse(sa.__version__).major == 1, reason="SQLAlchemy 1.x issues" +) def test_sql_database_run() -> None: """Test that commands can be run successfully and returned in correct format.""" engine = create_engine("duckdb:///:memory:")