|
|
|
@ -17,7 +17,7 @@ from langchain.schema import (
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from azure.cosmos import ContainerProxy, CosmosClient
|
|
|
|
|
from azure.cosmos import ContainerProxy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
@ -60,19 +60,10 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
|
self.user_id = user_id
|
|
|
|
|
self.ttl = ttl
|
|
|
|
|
|
|
|
|
|
self._client: Optional[CosmosClient] = None
|
|
|
|
|
self._container: Optional[ContainerProxy] = None
|
|
|
|
|
self.messages: List[BaseMessage] = []
|
|
|
|
|
|
|
|
|
|
def prepare_cosmos(self) -> None:
|
|
|
|
|
"""Prepare the CosmosDB client.
|
|
|
|
|
|
|
|
|
|
Use this function or the context manager to make sure your database is ready.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
|
|
|
|
CosmosClient,
|
|
|
|
|
PartitionKey,
|
|
|
|
|
)
|
|
|
|
|
except ImportError as exc:
|
|
|
|
|
raise ImportError(
|
|
|
|
@ -88,6 +79,21 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Either a connection string or a credential must be set.")
|
|
|
|
|
self._container: Optional[ContainerProxy] = None
|
|
|
|
|
|
|
|
|
|
def prepare_cosmos(self) -> None:
|
|
|
|
|
"""Prepare the CosmosDB client.
|
|
|
|
|
|
|
|
|
|
Use this function or the context manager to make sure your database is ready.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
|
|
|
|
PartitionKey,
|
|
|
|
|
)
|
|
|
|
|
except ImportError as exc:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
|
|
|
|
) from exc
|
|
|
|
|
database = self._client.create_database_if_not_exists(self.cosmos_database)
|
|
|
|
|
self._container = database.create_container_if_not_exists(
|
|
|
|
|
self.cosmos_container,
|
|
|
|
@ -98,11 +104,9 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
|
|
|
|
|
|
def __enter__(self) -> "CosmosDBChatMessageHistory":
|
|
|
|
|
"""Context manager entry point."""
|
|
|
|
|
if self._client:
|
|
|
|
|
self._client.__enter__()
|
|
|
|
|
self.prepare_cosmos()
|
|
|
|
|
return self
|
|
|
|
|
raise ValueError("Client not initialized")
|
|
|
|
|
self._client.__enter__()
|
|
|
|
|
self.prepare_cosmos()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(
|
|
|
|
|
self,
|
|
|
|
@ -112,8 +116,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Context manager exit"""
|
|
|
|
|
self.upsert_messages()
|
|
|
|
|
if self._client:
|
|
|
|
|
self._client.__exit__(exc_type, exc_val, traceback)
|
|
|
|
|
self._client.__exit__(exc_type, exc_val, traceback)
|
|
|
|
|
|
|
|
|
|
def load_messages(self) -> None:
|
|
|
|
|
"""Retrieve the messages from Cosmos"""
|
|
|
|
@ -134,11 +137,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|
|
|
|
except CosmosHttpResponseError:
|
|
|
|
|
logger.info("no session found")
|
|
|
|
|
return
|
|
|
|
|
if (
|
|
|
|
|
"messages" in item
|
|
|
|
|
and len(item["messages"]) > 0
|
|
|
|
|
and isinstance(item["messages"][0], list)
|
|
|
|
|
):
|
|
|
|
|
if "messages" in item and len(item["messages"]) > 0:
|
|
|
|
|
self.messages = messages_from_dict(item["messages"])
|
|
|
|
|
|
|
|
|
|
def add_user_message(self, message: str) -> None:
|
|
|
|
|