community[minor]: Add native async support to SQLChatMessageHistory (#22065)

# package community: Fix SQLChatMessageHistory

## Description
Here is a rewrite of `SQLChatMessageHistory` to properly implement the
asynchronous approach. The code circumvents [issue
22021](https://github.com/langchain-ai/langchain/issues/22021) by
accepting a synchronous call to `def add_messages()` in an asynchronous
scenario. This bypasses the bug.

For the same reasons as in [PR
22](https://github.com/langchain-ai/langchain-postgres/pull/32) of
`langchain-postgres`, we use a lazy strategy for table creation. Indeed,
the promise of the constructor cannot be fulfilled without this. It is
not possible to invoke a synchronous call in a constructor. We
compensate for this by waiting for the next asynchronous method call to
create the table.

The goal of the `PostgresChatMessageHistory` class (in
`langchain-postgres`) is, among other things, to be able to recycle
database connections. The implementation of the class is problematic, as
we have demonstrated in [issue
22021](https://github.com/langchain-ai/langchain/issues/22021).

Our new implementation of `SQLChatMessageHistory` achieves this by using
a singleton of type (`Async`)`Engine` for the database connection. The
connection pool is managed by this singleton, and the code is then
reentrant.

We also accept the type `str` (optionally complemented by `async_mode`.
I know you don't like this much, but it's the only way to allow an
asynchronous connection string).

In order to unify the different classes handling database connections,
we have renamed `connection_string` to `connection`, and `Session` to
`session_maker`.

Now, a single transaction is used to add a list of messages. Thus, a
crash during this write operation will not leave the database in an
unstable state with a partially added message list. This makes the code
resilient.

We believe that the `PostgresChatMessageHistory` class is no longer
necessary and can be replaced by:
```
PostgresChatMessageHistory = SQLChatMessageHistory
```
This also fixes the bug.


## Issue
- [issue 22021](https://github.com/langchain-ai/langchain/issues/22021)
  - Bug in _exit_history()
  - Bugs in PostgresChatMessageHistory and sync usage
  - Bugs in PostgresChatMessageHistory and async usage
- [issue
36](https://github.com/langchain-ai/langchain-postgres/issues/36)
 ## Twitter handle:
pprados

## Tests
- libs/community/tests/unit_tests/chat_message_histories/test_sql.py
(add async test)

@baskaryan, @eyurtsev or @hwchase17 can you check this PR ?
And, I've been waiting a long time for validation from other PRs. Can
you take a look?
- [PR 32](https://github.com/langchain-ai/langchain-postgres/pull/32)
- [PR 15575](https://github.com/langchain-ai/langchain/pull/15575)
- [PR 13200](https://github.com/langchain-ai/langchain/pull/13200)

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/22482/head
Philippe PRADOS 3 months ago committed by GitHub
parent 59bef31997
commit 8250c177de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,9 +1,22 @@
import asyncio
import contextlib
import json import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Union,
cast,
)
from sqlalchemy import Column, Integer, Text, create_engine from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import Column, Integer, Text, delete, select
try: try:
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
@ -15,7 +28,22 @@ from langchain_core.messages import (
message_to_dict, message_to_dict,
messages_from_dict, messages_from_dict,
) )
from sqlalchemy.orm import sessionmaker from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Session as SQLSession,
)
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,36 +108,98 @@ class DefaultMessageConverter(BaseMessageConverter):
return self.model_class return self.model_class
DBConnection = Union[AsyncEngine, Engine, str]
_warned_once_already = False
class SQLChatMessageHistory(BaseChatMessageHistory): class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database.""" """Chat message history stored in an SQL database."""
@property
@deprecated("0.2.2", removal="0.3.0", alternative="session_maker")
def Session(self) -> Union[scoped_session, async_sessionmaker]:
return self.session_maker
def __init__( def __init__(
self, self,
session_id: str, session_id: str,
connection_string: str, connection_string: Optional[str] = None,
table_name: str = "message_store", table_name: str = "message_store",
session_id_field_name: str = "session_id", session_id_field_name: str = "session_id",
custom_message_converter: Optional[BaseMessageConverter] = None, custom_message_converter: Optional[BaseMessageConverter] = None,
connection: Union[None, DBConnection] = None,
engine_args: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None, # Use only if connection is a string
): ):
self.connection_string = connection_string assert not (
self.engine = create_engine(connection_string, echo=False) connection_string and connection
), "connection_string and connection are mutually exclusive"
if connection_string:
global _warned_once_already
if not _warned_once_already:
warn_deprecated(
since="0.2.2",
removal="0.3.0",
name="connection_string",
alternative="Use connection instead",
)
_warned_once_already = True
connection = connection_string
self.connection_string = connection_string
if isinstance(connection, str):
self.async_mode = async_mode
if async_mode:
self.async_engine = create_async_engine(
connection, **(engine_args or {})
)
else:
self.engine = create_engine(url=connection, **(engine_args or {}))
elif isinstance(connection, Engine):
self.async_mode = False
self.engine = connection
elif isinstance(connection, AsyncEngine):
self.async_mode = True
self.async_engine = connection
else:
raise ValueError(
"connection should be a connection string or an instance of "
"sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
)
# To be consistent with others SQL implementations, rename to session_maker
self.session_maker: Union[scoped_session, async_sessionmaker]
if self.async_mode:
self.session_maker = async_sessionmaker(bind=self.async_engine)
else:
self.session_maker = scoped_session(sessionmaker(bind=self.engine))
self.session_id_field_name = session_id_field_name self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name) self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.sql_model_class = self.converter.get_sql_model_class() self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name): if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column") raise ValueError("SQL model class must have session_id column")
self._create_table_if_not_exists() self._table_created = False
if not self.async_mode:
self._create_table_if_not_exists()
self.session_id = session_id self.session_id = session_id
self.Session = sessionmaker(self.engine)
def _create_table_if_not_exists(self) -> None: def _create_table_if_not_exists(self) -> None:
self.sql_model_class.metadata.create_all(self.engine) self.sql_model_class.metadata.create_all(self.engine)
self._table_created = True
async def _acreate_table_if_not_exists(self) -> None:
if not self._table_created:
assert self.async_mode, "This method must be called with async_mode"
async with self.async_engine.begin() as conn:
await conn.run_sync(self.sql_model_class.metadata.create_all)
self._table_created = True
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db""" """Retrieve all messages from db"""
with self.Session() as session: with self._make_sync_session() as session:
result = ( result = (
session.query(self.sql_model_class) session.query(self.sql_model_class)
.where( .where(
@ -123,18 +213,105 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
messages.append(self.converter.from_sql_model(record)) messages.append(self.converter.from_sql_model(record))
return messages return messages
def get_messages(self) -> List[BaseMessage]:
return self.messages
async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve all messages from db"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = (
select(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)
result = await session.execute(stmt)
messages = []
for record in result.scalars():
messages.append(self.converter.from_sql_model(record))
return messages
def add_message(self, message: BaseMessage) -> None: def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db""" """Append the message to the record in db"""
with self.Session() as session: with self._make_sync_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id)) session.add(self.converter.to_sql_model(message, self.session_id))
session.commit() session.commit()
async def aadd_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
# Add all messages in one transaction
await self._acreate_table_if_not_exists()
async with self.session_maker() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()
def clear(self) -> None: def clear(self) -> None:
"""Clear session memory from db""" """Clear session memory from db"""
with self.Session() as session: with self._make_sync_session() as session:
session.query(self.sql_model_class).filter( session.query(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name) getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id == self.session_id
).delete() ).delete()
session.commit() session.commit()
async def aclear(self) -> None:
"""Clear session memory from db"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = delete(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
await session.execute(stmt)
await session.commit()
@contextlib.contextmanager
def _make_sync_session(self) -> Generator[SQLSession, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with self.session_maker() as session:
yield cast(SQLSession, session)
@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with self.session_maker() as session:
yield cast(AsyncSession, session)

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]] [[package]]
name = "aenum" name = "aenum"
@ -3475,6 +3475,7 @@ files = [
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
@ -3985,7 +3986,7 @@ files = [
[[package]] [[package]]
name = "langchain" name = "langchain"
version = "0.2.1" version = "0.2.2"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -4026,7 +4027,7 @@ url = "../langchain"
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.3" version = "0.2.4"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -4035,7 +4036,7 @@ develop = true
[package.dependencies] [package.dependencies]
jsonpatch = "^1.33" jsonpatch = "^1.33"
langsmith = "^0.1.65" langsmith = "^0.1.66"
packaging = "^23.2" packaging = "^23.2"
pydantic = ">=1,<3" pydantic = ">=1,<3"
PyYAML = ">=5.3" PyYAML = ">=5.3"
@ -4050,7 +4051,7 @@ url = "../core"
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.2.0" version = "0.2.1"
description = "LangChain text splitting utilities" description = "LangChain text splitting utilities"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -6123,8 +6124,6 @@ files = [
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
@ -6167,7 +6166,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@ -6176,8 +6174,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@ -7175,7 +7171,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -10217,9 +10212,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras] [extras]
cli = ["typer"] cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"] extended-testing = ["aiosqlite", "aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "77ccc0105fabe1735497289125bb276822101a6a9b1c2b596bf49b8f30b8068d" content-hash = "22bdadbd8a34235ba0cd923d9b380d362caa64f000053a5f91f9d163e8b41aad"

@ -291,6 +291,7 @@ extended_testing = [
"pyjwt", "pyjwt",
"oracledb", "oracledb",
"simsimd", "simsimd",
"aiosqlite"
] ]
[tool.ruff] [tool.ruff]

@ -1,8 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import Any, Generator, Tuple from typing import Any, AsyncGenerator, Generator, List, Tuple
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from sqlalchemy import Column, Integer, Text from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
@ -17,16 +17,23 @@ def con_str(tmp_path: Path) -> str:
return con_str return con_str
@pytest.fixture()
def acon_str(tmp_path: Path) -> str:
file_path = tmp_path / "adb.sqlite3"
con_str = f"sqlite+aiosqlite:///{file_path}"
return con_str
@pytest.fixture() @pytest.fixture()
def sql_histories( def sql_histories(
con_str: str, con_str: str,
) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]: ) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]:
message_history = SQLChatMessageHistory( message_history = SQLChatMessageHistory(
session_id="123", connection_string=con_str, table_name="test_table" session_id="123", connection=con_str, table_name="test_table"
) )
# Create history for other session # Create history for other session
other_history = SQLChatMessageHistory( other_history = SQLChatMessageHistory(
session_id="456", connection_string=con_str, table_name="test_table" session_id="456", connection=con_str, table_name="test_table"
) )
yield message_history, other_history yield message_history, other_history
@ -34,12 +41,38 @@ def sql_histories(
other_history.clear() other_history.clear()
@pytest.fixture()
async def asql_histories(
acon_str: str,
) -> AsyncGenerator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None]:
message_history = SQLChatMessageHistory(
session_id="123",
connection=acon_str,
table_name="test_table",
async_mode=True,
engine_args={"echo": False},
)
# Create history for other session
other_history = SQLChatMessageHistory(
session_id="456",
connection=acon_str,
table_name="test_table",
async_mode=True,
engine_args={"echo": False},
)
yield message_history, other_history
await message_history.aclear()
await other_history.aclear()
def test_add_messages( def test_add_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None: ) -> None:
sql_history, other_history = sql_histories sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!") sql_history.add_messages(
sql_history.add_ai_message("Hi there!") [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
messages = sql_history.messages messages = sql_history.messages
assert len(messages) == 2 assert len(messages) == 2
@ -49,39 +82,94 @@ def test_add_messages(
assert messages[1].content == "Hi there!" assert messages[1].content == "Hi there!"
@pytest.mark.requires("aiosqlite")
async def test_async_add_messages(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
messages = await sql_history.aget_messages()
assert len(messages) == 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
def test_multiple_sessions( def test_multiple_sessions(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None: ) -> None:
sql_history, other_history = sql_histories sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!") sql_history.add_messages(
sql_history.add_ai_message("Hi there!") [
sql_history.add_user_message("Whats cracking?") HumanMessage(content="Hello!"),
AIMessage(content="Hi there!"),
HumanMessage(content="Whats cracking?"),
]
)
# Ensure the messages are added correctly in the first session # Ensure the messages are added correctly in the first session
assert len(sql_history.messages) == 3, "waat" messages = sql_history.messages
assert sql_history.messages[0].content == "Hello!" assert len(messages) == 3, "waat"
assert sql_history.messages[1].content == "Hi there!" assert messages[0].content == "Hello!"
assert sql_history.messages[2].content == "Whats cracking?" assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
# second session # second session
other_history.add_user_message("Hellox") other_history.add_messages([HumanMessage(content="Hellox")])
assert len(other_history.messages) == 1 assert len(other_history.messages) == 1
assert len(sql_history.messages) == 3 messages = sql_history.messages
assert len(messages) == 3
assert other_history.messages[0].content == "Hellox" assert other_history.messages[0].content == "Hellox"
assert sql_history.messages[0].content == "Hello!" assert messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!" assert messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?" assert messages[2].content == "Whats cracking?"
@pytest.mark.requires("aiosqlite")
async def test_async_multiple_sessions(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[
HumanMessage(content="Hello!"),
AIMessage(content="Hi there!"),
HumanMessage(content="Whats cracking?"),
]
)
# Ensure the messages are added correctly in the first session
messages: List[BaseMessage] = await sql_history.aget_messages()
assert len(messages) == 3, "waat"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
# second session
await other_history.aadd_messages([HumanMessage(content="Hellox")])
messages = await sql_history.aget_messages()
assert len(await other_history.aget_messages()) == 1
assert len(messages) == 3
assert (await other_history.aget_messages())[0].content == "Hellox"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
def test_clear_messages( def test_clear_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory], sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None: ) -> None:
sql_history, other_history = sql_histories sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!") sql_history.add_messages(
sql_history.add_ai_message("Hi there!") [HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
assert len(sql_history.messages) == 2 assert len(sql_history.messages) == 2
# Now create another history with different session id # Now create another history with different session id
other_history.add_user_message("Hellox") other_history.add_messages([HumanMessage(content="Hellox")])
assert len(other_history.messages) == 1 assert len(other_history.messages) == 1
assert len(sql_history.messages) == 2 assert len(sql_history.messages) == 2
# Now clear the first history # Now clear the first history
@ -90,6 +178,25 @@ def test_clear_messages(
assert len(other_history.messages) == 1 assert len(other_history.messages) == 1
@pytest.mark.requires("aiosqlite")
async def test_async_clear_messages(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
assert len(await sql_history.aget_messages()) == 2
# Now create another history with different session id
await other_history.aadd_messages([HumanMessage(content="Hellox")])
assert len(await other_history.aget_messages()) == 1
assert len(await sql_history.aget_messages()) == 2
# Now clear the first history
await sql_history.aclear()
assert len(await sql_history.aget_messages()) == 0
assert len(await other_history.aget_messages()) == 1
def test_model_no_session_id_field_error(con_str: str) -> None: def test_model_no_session_id_field_error(con_str: str) -> None:
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass

Loading…
Cancel
Save