mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
fix for cosmos not loading old messages (#4094)
I noticed cosmos was not loading old messages properly, fixed now.
This commit is contained in:
parent
d84df25466
commit
f4c8502e61
@ -17,7 +17,7 @@ from langchain.schema import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from azure.cosmos import ContainerProxy, CosmosClient
|
from azure.cosmos import ContainerProxy
|
||||||
|
|
||||||
|
|
||||||
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
||||||
@ -60,19 +60,10 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
|
|
||||||
self._client: Optional[CosmosClient] = None
|
|
||||||
self._container: Optional[ContainerProxy] = None
|
|
||||||
self.messages: List[BaseMessage] = []
|
self.messages: List[BaseMessage] = []
|
||||||
|
|
||||||
def prepare_cosmos(self) -> None:
|
|
||||||
"""Prepare the CosmosDB client.
|
|
||||||
|
|
||||||
Use this function or the context manager to make sure your database is ready.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||||
CosmosClient,
|
CosmosClient,
|
||||||
PartitionKey,
|
|
||||||
)
|
)
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -88,6 +79,21 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either a connection string or a credential must be set.")
|
raise ValueError("Either a connection string or a credential must be set.")
|
||||||
|
self._container: Optional[ContainerProxy] = None
|
||||||
|
|
||||||
|
def prepare_cosmos(self) -> None:
|
||||||
|
"""Prepare the CosmosDB client.
|
||||||
|
|
||||||
|
Use this function or the context manager to make sure your database is ready.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
|
||||||
|
PartitionKey,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
|
||||||
|
) from exc
|
||||||
database = self._client.create_database_if_not_exists(self.cosmos_database)
|
database = self._client.create_database_if_not_exists(self.cosmos_database)
|
||||||
self._container = database.create_container_if_not_exists(
|
self._container = database.create_container_if_not_exists(
|
||||||
self.cosmos_container,
|
self.cosmos_container,
|
||||||
@ -98,11 +104,9 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
|
|
||||||
def __enter__(self) -> "CosmosDBChatMessageHistory":
|
def __enter__(self) -> "CosmosDBChatMessageHistory":
|
||||||
"""Context manager entry point."""
|
"""Context manager entry point."""
|
||||||
if self._client:
|
self._client.__enter__()
|
||||||
self._client.__enter__()
|
self.prepare_cosmos()
|
||||||
self.prepare_cosmos()
|
return self
|
||||||
return self
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
@ -112,8 +116,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Context manager exit"""
|
"""Context manager exit"""
|
||||||
self.upsert_messages()
|
self.upsert_messages()
|
||||||
if self._client:
|
self._client.__exit__(exc_type, exc_val, traceback)
|
||||||
self._client.__exit__(exc_type, exc_val, traceback)
|
|
||||||
|
|
||||||
def load_messages(self) -> None:
|
def load_messages(self) -> None:
|
||||||
"""Retrieve the messages from Cosmos"""
|
"""Retrieve the messages from Cosmos"""
|
||||||
@ -134,11 +137,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
except CosmosHttpResponseError:
|
except CosmosHttpResponseError:
|
||||||
logger.info("no session found")
|
logger.info("no session found")
|
||||||
return
|
return
|
||||||
if (
|
if "messages" in item and len(item["messages"]) > 0:
|
||||||
"messages" in item
|
|
||||||
and len(item["messages"]) > 0
|
|
||||||
and isinstance(item["messages"][0], list)
|
|
||||||
):
|
|
||||||
self.messages = messages_from_dict(item["messages"])
|
self.messages = messages_from_dict(item["messages"])
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_user_message(self, message: str) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user