From 8250c177de67b224ed8e2f84d9b5a21354651343 Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Wed, 5 Jun 2024 17:10:38 +0200 Subject: [PATCH] community[minor]: Add native async support to SQLChatMessageHistory (#22065) # package community: Fix SQLChatMessageHistory ## Description Here is a rewrite of `SQLChatMessageHistory` to properly implement the asynchronous approach. The code circumvents [issue 22021](https://github.com/langchain-ai/langchain/issues/22021) by accepting a synchronous call to `def add_messages()` in an asynchronous scenario. This bypasses the bug. For the same reasons as in [PR 22](https://github.com/langchain-ai/langchain-postgres/pull/32) of `langchain-postgres`, we use a lazy strategy for table creation. Indeed, the promise of the constructor cannot be fulfilled without this. It is not possible to invoke a synchronous call in a constructor. We compensate for this by waiting for the next asynchronous method call to create the table. The goal of the `PostgresChatMessageHistory` class (in `langchain-postgres`) is, among other things, to be able to recycle database connections. The implementation of the class is problematic, as we have demonstrated in [issue 22021](https://github.com/langchain-ai/langchain/issues/22021). Our new implementation of `SQLChatMessageHistory` achieves this by using a singleton of type (`Async`)`Engine` for the database connection. The connection pool is managed by this singleton, and the code is then reentrant. We also accept the type `str` (optionally complemented by `async_mode`. I know you don't like this much, but it's the only way to allow an asynchronous connection string). In order to unify the different classes handling database connections, we have renamed `connection_string` to `connection`, and `Session` to `session_maker`. Now, a single transaction is used to add a list of messages. Thus, a crash during this write operation will not leave the database in an unstable state with a partially added message list. This makes the code resilient. We believe that the `PostgresChatMessageHistory` class is no longer necessary and can be replaced by: ``` PostgresChatMessageHistory = SQLChatMessageHistory ``` This also fixes the bug. ## Issue - [issue 22021](https://github.com/langchain-ai/langchain/issues/22021) - Bug in _exit_history() - Bugs in PostgresChatMessageHistory and sync usage - Bugs in PostgresChatMessageHistory and async usage - [issue 36](https://github.com/langchain-ai/langchain-postgres/issues/36) ## Twitter handle: pprados ## Tests - libs/community/tests/unit_tests/chat_message_histories/test_sql.py (add async test) @baskaryan, @eyurtsev or @hwchase17 can you check this PR ? And, I've been waiting a long time for validation from other PRs. Can you take a look? - [PR 32](https://github.com/langchain-ai/langchain-postgres/pull/32) - [PR 15575](https://github.com/langchain-ai/langchain/pull/15575) - [PR 13200](https://github.com/langchain-ai/langchain/pull/13200) --------- Co-authored-by: Eugene Yurtsev --- .../chat_message_histories/sql.py | 199 +++++++++++++++++- libs/community/poetry.lock | 21 +- libs/community/pyproject.toml | 1 + .../chat_message_histories/test_sql.py | 149 +++++++++++-- 4 files changed, 325 insertions(+), 45 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index 01cafeb9bc..9264dbeff0 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -1,9 +1,22 @@ +import asyncio +import contextlib import json import logging from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Sequence, + Union, + cast, +) -from sqlalchemy import Column, Integer, Text, create_engine +from langchain_core._api import deprecated, warn_deprecated +from sqlalchemy import Column, Integer, Text, delete, select try: from sqlalchemy.orm import declarative_base @@ -15,7 +28,22 @@ from langchain_core.messages import ( message_to_dict, messages_from_dict, ) -from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Session as SQLSession, +) +from sqlalchemy.orm import ( + declarative_base, + scoped_session, + sessionmaker, +) logger = logging.getLogger(__name__) @@ -80,36 +108,98 @@ class DefaultMessageConverter(BaseMessageConverter): return self.model_class +DBConnection = Union[AsyncEngine, Engine, str] + +_warned_once_already = False + + class SQLChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in an SQL database.""" + @property + @deprecated("0.2.2", removal="0.3.0", alternative="session_maker") + def Session(self) -> Union[scoped_session, async_sessionmaker]: + return self.session_maker + def __init__( self, session_id: str, - connection_string: str, + connection_string: Optional[str] = None, table_name: str = "message_store", session_id_field_name: str = "session_id", custom_message_converter: Optional[BaseMessageConverter] = None, + connection: Union[None, DBConnection] = None, + engine_args: Optional[Dict[str, Any]] = None, + async_mode: Optional[bool] = None, # Use only if connection is a string ): - self.connection_string = connection_string - self.engine = create_engine(connection_string, echo=False) + assert not ( + connection_string and connection + ), "connection_string and connection are mutually exclusive" + if connection_string: + global _warned_once_already + if not _warned_once_already: + warn_deprecated( + since="0.2.2", + removal="0.3.0", + name="connection_string", + alternative="Use connection instead", + ) + _warned_once_already = True + connection = connection_string + self.connection_string = connection_string + if isinstance(connection, str): + self.async_mode = async_mode + if async_mode: + self.async_engine = create_async_engine( + connection, **(engine_args or {}) + ) + else: + self.engine = create_engine(url=connection, **(engine_args or {})) + elif isinstance(connection, Engine): + self.async_mode = False + self.engine = connection + elif isinstance(connection, AsyncEngine): + self.async_mode = True + self.async_engine = connection + else: + raise ValueError( + "connection should be a connection string or an instance of " + "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" + ) + + # To be consistent with others SQL implementations, rename to session_maker + self.session_maker: Union[scoped_session, async_sessionmaker] + if self.async_mode: + self.session_maker = async_sessionmaker(bind=self.async_engine) + else: + self.session_maker = scoped_session(sessionmaker(bind=self.engine)) + self.session_id_field_name = session_id_field_name self.converter = custom_message_converter or DefaultMessageConverter(table_name) self.sql_model_class = self.converter.get_sql_model_class() if not hasattr(self.sql_model_class, session_id_field_name): raise ValueError("SQL model class must have session_id column") - self._create_table_if_not_exists() + self._table_created = False + if not self.async_mode: + self._create_table_if_not_exists() self.session_id = session_id - self.Session = sessionmaker(self.engine) def _create_table_if_not_exists(self) -> None: self.sql_model_class.metadata.create_all(self.engine) + self._table_created = True + + async def _acreate_table_if_not_exists(self) -> None: + if not self._table_created: + assert self.async_mode, "This method must be called with async_mode" + async with self.async_engine.begin() as conn: + await conn.run_sync(self.sql_model_class.metadata.create_all) + self._table_created = True @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" - with self.Session() as session: + with self._make_sync_session() as session: result = ( session.query(self.sql_model_class) .where( @@ -123,18 +213,105 @@ class SQLChatMessageHistory(BaseChatMessageHistory): messages.append(self.converter.from_sql_model(record)) return messages + def get_messages(self) -> List[BaseMessage]: + return self.messages + + async def aget_messages(self) -> List[BaseMessage]: + """Retrieve all messages from db""" + await self._acreate_table_if_not_exists() + async with self._make_async_session() as session: + stmt = ( + select(self.sql_model_class) + .where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + .order_by(self.sql_model_class.id.asc()) + ) + result = await session.execute(stmt) + messages = [] + for record in result.scalars(): + messages.append(self.converter.from_sql_model(record)) + return messages + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in db""" - with self.Session() as session: + with self._make_sync_session() as session: session.add(self.converter.to_sql_model(message, self.session_id)) session.commit() + async def aadd_message(self, message: BaseMessage) -> None: + """Add a Message object to the store. + + Args: + message: A BaseMessage object to store. + """ + await self._acreate_table_if_not_exists() + async with self._make_async_session() as session: + session.add(self.converter.to_sql_model(message, self.session_id)) + await session.commit() + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + # The method RunnableWithMessageHistory._exit_history() call + # add_message method by mistake and not aadd_message. + # See https://github.com/langchain-ai/langchain/issues/22021 + if self.async_mode: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.aadd_messages(messages)) + else: + with self._make_sync_session() as session: + for message in messages: + session.add(self.converter.to_sql_model(message, self.session_id)) + session.commit() + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + # Add all messages in one transaction + await self._acreate_table_if_not_exists() + async with self.session_maker() as session: + for message in messages: + session.add(self.converter.to_sql_model(message, self.session_id)) + await session.commit() + def clear(self) -> None: """Clear session memory from db""" - with self.Session() as session: + with self._make_sync_session() as session: session.query(self.sql_model_class).filter( getattr(self.sql_model_class, self.session_id_field_name) == self.session_id ).delete() session.commit() + + async def aclear(self) -> None: + """Clear session memory from db""" + + await self._acreate_table_if_not_exists() + async with self._make_async_session() as session: + stmt = delete(self.sql_model_class).filter( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + await session.execute(stmt) + await session.commit() + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[SQLSession, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with self.session_maker() as session: + yield cast(SQLSession, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with self.session_maker() as session: + yield cast(AsyncSession, session) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 4d3693e232..397939274f 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aenum" @@ -3475,6 +3475,7 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, + {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3985,7 +3986,7 @@ files = [ [[package]] name = "langchain" -version = "0.2.1" +version = "0.2.2" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -4026,7 +4027,7 @@ url = "../langchain" [[package]] name = "langchain-core" -version = "0.2.3" +version = "0.2.4" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -4035,7 +4036,7 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.65" +langsmith = "^0.1.66" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -4050,7 +4051,7 @@ url = "../core" [[package]] name = "langchain-text-splitters" -version = "0.2.0" +version = "0.2.1" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.8.1,<4.0" @@ -6123,8 +6124,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -6167,7 +6166,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -6176,8 +6174,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -7175,7 +7171,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -10217,9 +10212,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "77ccc0105fabe1735497289125bb276822101a6a9b1c2b596bf49b8f30b8068d" +content-hash = "22bdadbd8a34235ba0cd923d9b380d362caa64f000053a5f91f9d163e8b41aad" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index b1697ece0b..d2f27f9100 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -291,6 +291,7 @@ extended_testing = [ "pyjwt", "oracledb", "simsimd", + "aiosqlite" ] [tool.ruff] diff --git a/libs/community/tests/unit_tests/chat_message_histories/test_sql.py b/libs/community/tests/unit_tests/chat_message_histories/test_sql.py index 8f3f74b003..68e7e216ce 100644 --- a/libs/community/tests/unit_tests/chat_message_histories/test_sql.py +++ b/libs/community/tests/unit_tests/chat_message_histories/test_sql.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Any, Generator, Tuple +from typing import Any, AsyncGenerator, Generator, List, Tuple import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from sqlalchemy import Column, Integer, Text from sqlalchemy.orm import DeclarativeBase @@ -17,16 +17,23 @@ def con_str(tmp_path: Path) -> str: return con_str +@pytest.fixture() +def acon_str(tmp_path: Path) -> str: + file_path = tmp_path / "adb.sqlite3" + con_str = f"sqlite+aiosqlite:///{file_path}" + return con_str + + @pytest.fixture() def sql_histories( con_str: str, ) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]: message_history = SQLChatMessageHistory( - session_id="123", connection_string=con_str, table_name="test_table" + session_id="123", connection=con_str, table_name="test_table" ) # Create history for other session other_history = SQLChatMessageHistory( - session_id="456", connection_string=con_str, table_name="test_table" + session_id="456", connection=con_str, table_name="test_table" ) yield message_history, other_history @@ -34,12 +41,38 @@ def sql_histories( other_history.clear() +@pytest.fixture() +async def asql_histories( + acon_str: str, +) -> AsyncGenerator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None]: + message_history = SQLChatMessageHistory( + session_id="123", + connection=acon_str, + table_name="test_table", + async_mode=True, + engine_args={"echo": False}, + ) + # Create history for other session + other_history = SQLChatMessageHistory( + session_id="456", + connection=acon_str, + table_name="test_table", + async_mode=True, + engine_args={"echo": False}, + ) + + yield message_history, other_history + await message_history.aclear() + await other_history.aclear() + + def test_add_messages( sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], ) -> None: sql_history, other_history = sql_histories - sql_history.add_user_message("Hello!") - sql_history.add_ai_message("Hi there!") + sql_history.add_messages( + [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")] + ) messages = sql_history.messages assert len(messages) == 2 @@ -49,39 +82,94 @@ def test_add_messages( assert messages[1].content == "Hi there!" +@pytest.mark.requires("aiosqlite") +async def test_async_add_messages( + asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], +) -> None: + sql_history, other_history = asql_histories + await sql_history.aadd_messages( + [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")] + ) + + messages = await sql_history.aget_messages() + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + def test_multiple_sessions( sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], ) -> None: sql_history, other_history = sql_histories - sql_history.add_user_message("Hello!") - sql_history.add_ai_message("Hi there!") - sql_history.add_user_message("Whats cracking?") + sql_history.add_messages( + [ + HumanMessage(content="Hello!"), + AIMessage(content="Hi there!"), + HumanMessage(content="Whats cracking?"), + ] + ) # Ensure the messages are added correctly in the first session - assert len(sql_history.messages) == 3, "waat" - assert sql_history.messages[0].content == "Hello!" - assert sql_history.messages[1].content == "Hi there!" - assert sql_history.messages[2].content == "Whats cracking?" + messages = sql_history.messages + assert len(messages) == 3, "waat" + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + assert messages[2].content == "Whats cracking?" # second session - other_history.add_user_message("Hellox") + other_history.add_messages([HumanMessage(content="Hellox")]) assert len(other_history.messages) == 1 - assert len(sql_history.messages) == 3 + messages = sql_history.messages + assert len(messages) == 3 assert other_history.messages[0].content == "Hellox" - assert sql_history.messages[0].content == "Hello!" - assert sql_history.messages[1].content == "Hi there!" - assert sql_history.messages[2].content == "Whats cracking?" + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + assert messages[2].content == "Whats cracking?" + + +@pytest.mark.requires("aiosqlite") +async def test_async_multiple_sessions( + asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], +) -> None: + sql_history, other_history = asql_histories + await sql_history.aadd_messages( + [ + HumanMessage(content="Hello!"), + AIMessage(content="Hi there!"), + HumanMessage(content="Whats cracking?"), + ] + ) + + # Ensure the messages are added correctly in the first session + messages: List[BaseMessage] = await sql_history.aget_messages() + assert len(messages) == 3, "waat" + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + assert messages[2].content == "Whats cracking?" + + # second session + await other_history.aadd_messages([HumanMessage(content="Hellox")]) + messages = await sql_history.aget_messages() + assert len(await other_history.aget_messages()) == 1 + assert len(messages) == 3 + assert (await other_history.aget_messages())[0].content == "Hellox" + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + assert messages[2].content == "Whats cracking?" def test_clear_messages( sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], ) -> None: sql_history, other_history = sql_histories - sql_history.add_user_message("Hello!") - sql_history.add_ai_message("Hi there!") + sql_history.add_messages( + [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")] + ) assert len(sql_history.messages) == 2 # Now create another history with different session id - other_history.add_user_message("Hellox") + other_history.add_messages([HumanMessage(content="Hellox")]) assert len(other_history.messages) == 1 assert len(sql_history.messages) == 2 # Now clear the first history @@ -90,6 +178,25 @@ def test_clear_messages( assert len(other_history.messages) == 1 +@pytest.mark.requires("aiosqlite") +async def test_async_clear_messages( + asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], +) -> None: + sql_history, other_history = asql_histories + await sql_history.aadd_messages( + [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")] + ) + assert len(await sql_history.aget_messages()) == 2 + # Now create another history with different session id + await other_history.aadd_messages([HumanMessage(content="Hellox")]) + assert len(await other_history.aget_messages()) == 1 + assert len(await sql_history.aget_messages()) == 2 + # Now clear the first history + await sql_history.aclear() + assert len(await sql_history.aget_messages()) == 0 + assert len(await other_history.aget_messages()) == 1 + + def test_model_no_session_id_field_error(con_str: str) -> None: class Base(DeclarativeBase): pass