fix for cosmos not loading old messages (#4094)

I noticed cosmos was not loading old messages properly, fixed now.
This commit is contained in:
Eduard van Valkenburg 2023-05-08 21:48:15 +02:00 committed by GitHub
parent d84df25466
commit f4c8502e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,7 @@ from langchain.schema import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from azure.cosmos import ContainerProxy, CosmosClient
from azure.cosmos import ContainerProxy
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
@ -60,19 +60,10 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
self.user_id = user_id
self.ttl = ttl
self._client: Optional[CosmosClient] = None
self._container: Optional[ContainerProxy] = None
self.messages: List[BaseMessage] = []
def prepare_cosmos(self) -> None:
"""Prepare the CosmosDB client.
Use this function or the context manager to make sure your database is ready.
"""
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosClient,
PartitionKey,
)
except ImportError as exc:
raise ImportError(
@ -88,6 +79,21 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
)
else:
raise ValueError("Either a connection string or a credential must be set.")
self._container: Optional[ContainerProxy] = None
def prepare_cosmos(self) -> None:
"""Prepare the CosmosDB client.
Use this function or the context manager to make sure your database is ready.
"""
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
PartitionKey,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
) from exc
database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists(
self.cosmos_container,
@ -98,11 +104,9 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
def __enter__(self) -> "CosmosDBChatMessageHistory":
"""Context manager entry point."""
if self._client:
self._client.__enter__()
self.prepare_cosmos()
return self
raise ValueError("Client not initialized")
self._client.__enter__()
self.prepare_cosmos()
return self
def __exit__(
self,
@ -112,8 +116,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
) -> None:
"""Context manager exit"""
self.upsert_messages()
if self._client:
self._client.__exit__(exc_type, exc_val, traceback)
self._client.__exit__(exc_type, exc_val, traceback)
def load_messages(self) -> None:
"""Retrieve the messages from Cosmos"""
@ -134,11 +137,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
except CosmosHttpResponseError:
logger.info("no session found")
return
if (
"messages" in item
and len(item["messages"]) > 0
and isinstance(item["messages"][0], list)
):
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])
def add_user_message(self, message: str) -> None: