community[minor]: Add native async support to SQLChatMessageHistory (#22065)

# package community: Fix SQLChatMessageHistory

## Description
Here is a rewrite of `SQLChatMessageHistory` to properly implement the
asynchronous approach. The code circumvents [issue
22021](https://github.com/langchain-ai/langchain/issues/22021) by
accepting a synchronous call to `def add_messages()` in an asynchronous
scenario. This bypasses the bug.

For the same reasons as in [PR
22](https://github.com/langchain-ai/langchain-postgres/pull/32) of
`langchain-postgres`, we use a lazy strategy for table creation. Indeed,
the promise of the constructor cannot be fulfilled without this. It is
not possible to invoke a synchronous call in a constructor. We
compensate for this by waiting for the next asynchronous method call to
create the table.

The goal of the `PostgresChatMessageHistory` class (in
`langchain-postgres`) is, among other things, to be able to recycle
database connections. The implementation of the class is problematic, as
we have demonstrated in [issue
22021](https://github.com/langchain-ai/langchain/issues/22021).

Our new implementation of `SQLChatMessageHistory` achieves this by using
a singleton of type (`Async`)`Engine` for the database connection. The
connection pool is managed by this singleton, and the code is then
reentrant.

We also accept the type `str` (optionally complemented by `async_mode`.
I know you don't like this much, but it's the only way to allow an
asynchronous connection string).

In order to unify the different classes handling database connections,
we have renamed `connection_string` to `connection`, and `Session` to
`session_maker`.

Now, a single transaction is used to add a list of messages. Thus, a
crash during this write operation will not leave the database in an
unstable state with a partially added message list. This makes the code
resilient.

We believe that the `PostgresChatMessageHistory` class is no longer
necessary and can be replaced by:
```
PostgresChatMessageHistory = SQLChatMessageHistory
```
This also fixes the bug.


## Issue
- [issue 22021](https://github.com/langchain-ai/langchain/issues/22021)
  - Bug in _exit_history()
  - Bugs in PostgresChatMessageHistory and sync usage
  - Bugs in PostgresChatMessageHistory and async usage
- [issue
36](https://github.com/langchain-ai/langchain-postgres/issues/36)
 ## Twitter handle:
pprados

## Tests
- libs/community/tests/unit_tests/chat_message_histories/test_sql.py
(add async test)

@baskaryan, @eyurtsev or @hwchase17 can you check this PR ?
And, I've been waiting a long time for validation from other PRs. Can
you take a look?
- [PR 32](https://github.com/langchain-ai/langchain-postgres/pull/32)
- [PR 15575](https://github.com/langchain-ai/langchain/pull/15575)
- [PR 13200](https://github.com/langchain-ai/langchain/pull/13200)

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/22482/head
Philippe PRADOS 2 months ago committed by GitHub
parent 59bef31997
commit 8250c177de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,9 +1,22 @@
import asyncio
import contextlib
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Union,
cast,
)
from sqlalchemy import Column, Integer, Text, create_engine
from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import Column, Integer, Text, delete, select
try:
from sqlalchemy.orm import declarative_base
@ -15,7 +28,22 @@ from langchain_core.messages import (
message_to_dict,
messages_from_dict,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Session as SQLSession,
)
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker,
)
logger = logging.getLogger(__name__)
@ -80,36 +108,98 @@ class DefaultMessageConverter(BaseMessageConverter):
return self.model_class
DBConnection = Union[AsyncEngine, Engine, str]
_warned_once_already = False
class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""
@property
@deprecated("0.2.2", removal="0.3.0", alternative="session_maker")
def Session(self) -> Union[scoped_session, async_sessionmaker]:
return self.session_maker
def __init__(
self,
session_id: str,
connection_string: str,
connection_string: Optional[str] = None,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: Optional[BaseMessageConverter] = None,
connection: Union[None, DBConnection] = None,
engine_args: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None, # Use only if connection is a string
):
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
assert not (
connection_string and connection
), "connection_string and connection are mutually exclusive"
if connection_string:
global _warned_once_already
if not _warned_once_already:
warn_deprecated(
since="0.2.2",
removal="0.3.0",
name="connection_string",
alternative="Use connection instead",
)
_warned_once_already = True
connection = connection_string
self.connection_string = connection_string
if isinstance(connection, str):
self.async_mode = async_mode
if async_mode:
self.async_engine = create_async_engine(
connection, **(engine_args or {})
)
else:
self.engine = create_engine(url=connection, **(engine_args or {}))
elif isinstance(connection, Engine):
self.async_mode = False
self.engine = connection
elif isinstance(connection, AsyncEngine):
self.async_mode = True
self.async_engine = connection
else:
raise ValueError(
"connection should be a connection string or an instance of "
"sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
)
# To be consistent with others SQL implementations, rename to session_maker
self.session_maker: Union[scoped_session, async_sessionmaker]
if self.async_mode:
self.session_maker = async_sessionmaker(bind=self.async_engine)
else:
self.session_maker = scoped_session(sessionmaker(bind=self.engine))
self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
self._create_table_if_not_exists()
self._table_created = False
if not self.async_mode:
self._create_table_if_not_exists()
self.session_id = session_id
self.Session = sessionmaker(self.engine)
def _create_table_if_not_exists(self) -> None:
self.sql_model_class.metadata.create_all(self.engine)
self._table_created = True
async def _acreate_table_if_not_exists(self) -> None:
if not self._table_created:
assert self.async_mode, "This method must be called with async_mode"
async with self.async_engine.begin() as conn:
await conn.run_sync(self.sql_model_class.metadata.create_all)
self._table_created = True
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
with self._make_sync_session() as session:
result = (
session.query(self.sql_model_class)
.where(
@ -123,18 +213,105 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
messages.append(self.converter.from_sql_model(record))
return messages
def get_messages(self) -> List[BaseMessage]:
return self.messages
async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve all messages from db"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = (
select(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)
result = await session.execute(stmt)
messages = []
for record in result.scalars():
messages.append(self.converter.from_sql_model(record))
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
with self._make_sync_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
async def aadd_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
# Add all messages in one transaction
await self._acreate_table_if_not_exists()
async with self.session_maker() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()
def clear(self) -> None:
"""Clear session memory from db"""
with self.Session() as session:
with self._make_sync_session() as session:
session.query(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
).delete()
session.commit()
async def aclear(self) -> None:
"""Clear session memory from db"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = delete(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
await session.execute(stmt)
await session.commit()
@contextlib.contextmanager
def _make_sync_session(self) -> Generator[SQLSession, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with self.session_maker() as session:
yield cast(SQLSession, session)
@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with self.session_maker() as session:
yield cast(AsyncSession, session)

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aenum"
@ -3475,6 +3475,7 @@ files = [
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
@ -3985,7 +3986,7 @@ files = [
[[package]]
name = "langchain"
version = "0.2.1"
version = "0.2.2"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -4026,7 +4027,7 @@ url = "../langchain"
[[package]]
name = "langchain-core"
version = "0.2.3"
version = "0.2.4"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -4035,7 +4036,7 @@ develop = true
[package.dependencies]
jsonpatch = "^1.33"
langsmith = "^0.1.65"
langsmith = "^0.1.66"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
@ -4050,7 +4051,7 @@ url = "../core"
[[package]]
name = "langchain-text-splitters"
version = "0.2.0"
version = "0.2.1"
description = "LangChain text splitting utilities"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -6123,8 +6124,6 @@ files = [
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
@ -6167,7 +6166,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@ -6176,8 +6174,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@ -7175,7 +7171,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -10217,9 +10212,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
extended-testing = ["aiosqlite", "aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "77ccc0105fabe1735497289125bb276822101a6a9b1c2b596bf49b8f30b8068d"
content-hash = "22bdadbd8a34235ba0cd923d9b380d362caa64f000053a5f91f9d163e8b41aad"

@ -291,6 +291,7 @@ extended_testing = [
"pyjwt",
"oracledb",
"simsimd",
"aiosqlite"
]
[tool.ruff]

@ -1,8 +1,8 @@
from pathlib import Path
from typing import Any, Generator, Tuple
from typing import Any, AsyncGenerator, Generator, List, Tuple
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase
@ -17,16 +17,23 @@ def con_str(tmp_path: Path) -> str:
return con_str
@pytest.fixture()
def acon_str(tmp_path: Path) -> str:
file_path = tmp_path / "adb.sqlite3"
con_str = f"sqlite+aiosqlite:///{file_path}"
return con_str
@pytest.fixture()
def sql_histories(
con_str: str,
) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]:
message_history = SQLChatMessageHistory(
session_id="123", connection_string=con_str, table_name="test_table"
session_id="123", connection=con_str, table_name="test_table"
)
# Create history for other session
other_history = SQLChatMessageHistory(
session_id="456", connection_string=con_str, table_name="test_table"
session_id="456", connection=con_str, table_name="test_table"
)
yield message_history, other_history
@ -34,12 +41,38 @@ def sql_histories(
other_history.clear()
@pytest.fixture()
async def asql_histories(
acon_str: str,
) -> AsyncGenerator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None]:
message_history = SQLChatMessageHistory(
session_id="123",
connection=acon_str,
table_name="test_table",
async_mode=True,
engine_args={"echo": False},
)
# Create history for other session
other_history = SQLChatMessageHistory(
session_id="456",
connection=acon_str,
table_name="test_table",
async_mode=True,
engine_args={"echo": False},
)
yield message_history, other_history
await message_history.aclear()
await other_history.aclear()
def test_add_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
sql_history.add_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
messages = sql_history.messages
assert len(messages) == 2
@ -49,39 +82,94 @@ def test_add_messages(
assert messages[1].content == "Hi there!"
@pytest.mark.requires("aiosqlite")
async def test_async_add_messages(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
messages = await sql_history.aget_messages()
assert len(messages) == 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
def test_multiple_sessions(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
sql_history.add_user_message("Whats cracking?")
sql_history.add_messages(
[
HumanMessage(content="Hello!"),
AIMessage(content="Hi there!"),
HumanMessage(content="Whats cracking?"),
]
)
# Ensure the messages are added correctly in the first session
assert len(sql_history.messages) == 3, "waat"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
messages = sql_history.messages
assert len(messages) == 3, "waat"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
# second session
other_history.add_user_message("Hellox")
other_history.add_messages([HumanMessage(content="Hellox")])
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 3
messages = sql_history.messages
assert len(messages) == 3
assert other_history.messages[0].content == "Hellox"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
@pytest.mark.requires("aiosqlite")
async def test_async_multiple_sessions(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[
HumanMessage(content="Hello!"),
AIMessage(content="Hi there!"),
HumanMessage(content="Whats cracking?"),
]
)
# Ensure the messages are added correctly in the first session
messages: List[BaseMessage] = await sql_history.aget_messages()
assert len(messages) == 3, "waat"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
# second session
await other_history.aadd_messages([HumanMessage(content="Hellox")])
messages = await sql_history.aget_messages()
assert len(await other_history.aget_messages()) == 1
assert len(messages) == 3
assert (await other_history.aget_messages())[0].content == "Hellox"
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
assert messages[2].content == "Whats cracking?"
def test_clear_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
sql_history.add_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
assert len(sql_history.messages) == 2
# Now create another history with different session id
other_history.add_user_message("Hellox")
other_history.add_messages([HumanMessage(content="Hellox")])
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 2
# Now clear the first history
@ -90,6 +178,25 @@ def test_clear_messages(
assert len(other_history.messages) == 1
@pytest.mark.requires("aiosqlite")
async def test_async_clear_messages(
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
) -> None:
sql_history, other_history = asql_histories
await sql_history.aadd_messages(
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
)
assert len(await sql_history.aget_messages()) == 2
# Now create another history with different session id
await other_history.aadd_messages([HumanMessage(content="Hellox")])
assert len(await other_history.aget_messages()) == 1
assert len(await sql_history.aget_messages()) == 2
# Now clear the first history
await sql_history.aclear()
assert len(await sql_history.aget_messages()) == 0
assert len(await other_history.aget_messages()) == 1
def test_model_no_session_id_field_error(con_str: str) -> None:
class Base(DeclarativeBase):
pass

Loading…
Cancel
Save