community[minor]: Add async methods to AstraDBChatMessageHistory (#17572)

pull/17569/head^2
Christophe Bornet 4 months ago committed by GitHub
parent ff1f985a2a
commit 387cacb881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,9 +3,12 @@ from __future__ import annotations
import json
import time
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Sequence
from langchain_community.utilities.astradb import _AstraDBEnvironment
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
if TYPE_CHECKING:
from astrapy.db import AstraDB
@ -45,24 +48,30 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""Create an Astra DB chat message history."""
astra_env = _AstraDBEnvironment(
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
self.astra_db = astra_env.astra_db
self.collection = self.astra_db.create_collection(collection_name)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
self.session_id = session_id
self.collection_name = collection_name
@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
@ -82,16 +91,63 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")
async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Write a message to the table"""
self.collection.insert_one(
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
)
for message in messages
]
self.collection.chunked_insert_many(docs)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Write a message to the table"""
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)
def clear(self) -> None:
"""Clear session memory from DB"""
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})
async def aclear(self) -> None:
"""Clear session memory from DB"""
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})

@ -1,10 +1,11 @@
import os
from typing import Iterable
from typing import AsyncIterable, Iterable
import pytest
from langchain_community.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain_community.utilities.astradb import SetupMode
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
@ -29,7 +30,7 @@ def history1() -> Iterable[AstraDBChatMessageHistory]:
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history1
history1.astra_db.delete_collection("langchain_cmh_test")
history1.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
@ -42,7 +43,35 @@ def history2() -> Iterable[AstraDBChatMessageHistory]:
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history2
history2.astra_db.delete_collection("langchain_cmh_test")
history2.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="async-session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history1
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="async-session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history2
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.mark.requires("astrapy")
@ -58,8 +87,12 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
memory.chat_memory.add_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = memory.chat_memory.messages
expected = [
@ -74,6 +107,41 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
assert memory.chat_memory.messages == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_with_message_store_async(
async_history1: AstraDBChatMessageHistory,
) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=async_history1,
return_messages=True,
)
assert await memory.chat_memory.aget_messages() == []
# add some messages
await memory.chat_memory.aadd_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = await memory.chat_memory.aget_messages()
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
await memory.chat_memory.aclear()
assert await memory.chat_memory.aget_messages() == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_separate_session_ids(
@ -91,7 +159,7 @@ def test_memory_separate_session_ids(
return_messages=True,
)
memory1.chat_memory.add_ai_message("Just saying.")
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
assert memory2.chat_memory.messages == []
@ -102,3 +170,33 @@ def test_memory_separate_session_ids(
memory1.chat_memory.clear()
assert memory1.chat_memory.messages == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_separate_session_ids_async(
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=async_history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=async_history2,
return_messages=True,
)
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
assert await memory2.chat_memory.aget_messages() == []
await memory2.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() != []
await memory1.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() == []

Loading…
Cancel
Save