mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
fix for cosmos not loading old messages (#4094)
I noticed cosmos was not loading old messages properly, fixed now.
This commit is contained in:
parent
d84df25466
commit
f4c8502e61
@ -17,7 +17,7 @@ from langchain.schema import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.cosmos import ContainerProxy, CosmosClient
|
||||
from azure.cosmos import ContainerProxy
|
||||
|
||||
|
||||
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
@ -60,19 +60,10 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
self.user_id = user_id
|
||||
self.ttl = ttl
|
||||
|
||||
self._client: Optional[CosmosClient] = None
|
||||
self._container: Optional[ContainerProxy] = None
|
||||
self.messages: List[BaseMessage] = []
|
||||
|
||||
def prepare_cosmos(self) -> None:
|
||||
"""Prepare the CosmosDB client.
|
||||
|
||||
Use this function or the context manager to make sure your database is ready.
|
||||
"""
|
||||
try:
|
||||
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||
CosmosClient,
|
||||
PartitionKey,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
@ -88,6 +79,21 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either a connection string or a credential must be set.")
|
||||
self._container: Optional[ContainerProxy] = None
|
||||
|
||||
def prepare_cosmos(self) -> None:
|
||||
"""Prepare the CosmosDB client.
|
||||
|
||||
Use this function or the context manager to make sure your database is ready.
|
||||
"""
|
||||
try:
|
||||
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||
PartitionKey,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
||||
) from exc
|
||||
database = self._client.create_database_if_not_exists(self.cosmos_database)
|
||||
self._container = database.create_container_if_not_exists(
|
||||
self.cosmos_container,
|
||||
@ -98,11 +104,9 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
def __enter__(self) -> "CosmosDBChatMessageHistory":
|
||||
"""Context manager entry point."""
|
||||
if self._client:
|
||||
self._client.__enter__()
|
||||
self.prepare_cosmos()
|
||||
return self
|
||||
raise ValueError("Client not initialized")
|
||||
self._client.__enter__()
|
||||
self.prepare_cosmos()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
@ -112,8 +116,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
) -> None:
|
||||
"""Context manager exit"""
|
||||
self.upsert_messages()
|
||||
if self._client:
|
||||
self._client.__exit__(exc_type, exc_val, traceback)
|
||||
self._client.__exit__(exc_type, exc_val, traceback)
|
||||
|
||||
def load_messages(self) -> None:
|
||||
"""Retrieve the messages from Cosmos"""
|
||||
@ -134,11 +137,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||
except CosmosHttpResponseError:
|
||||
logger.info("no session found")
|
||||
return
|
||||
if (
|
||||
"messages" in item
|
||||
and len(item["messages"]) > 0
|
||||
and isinstance(item["messages"][0], list)
|
||||
):
|
||||
if "messages" in item and len(item["messages"]) > 0:
|
||||
self.messages = messages_from_dict(item["messages"])
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user