diff --git a/langchain/memory/chat_message_histories/cosmos_db.py b/langchain/memory/chat_message_histories/cosmos_db.py index 6da81570..8030907c 100644 --- a/langchain/memory/chat_message_histories/cosmos_db.py +++ b/langchain/memory/chat_message_histories/cosmos_db.py @@ -17,7 +17,7 @@ from langchain.schema import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: - from azure.cosmos import ContainerProxy, CosmosClient + from azure.cosmos import ContainerProxy class CosmosDBChatMessageHistory(BaseChatMessageHistory): @@ -60,19 +60,10 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): self.user_id = user_id self.ttl = ttl - self._client: Optional[CosmosClient] = None - self._container: Optional[ContainerProxy] = None self.messages: List[BaseMessage] = [] - - def prepare_cosmos(self) -> None: - """Prepare the CosmosDB client. - - Use this function or the context manager to make sure your database is ready. - """ try: from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501 CosmosClient, - PartitionKey, ) except ImportError as exc: raise ImportError( @@ -88,6 +79,21 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): ) else: raise ValueError("Either a connection string or a credential must be set.") + self._container: Optional[ContainerProxy] = None + + def prepare_cosmos(self) -> None: + """Prepare the CosmosDB client. + + Use this function or the context manager to make sure your database is ready. + """ + try: + from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501 + PartitionKey, + ) + except ImportError as exc: + raise ImportError( + "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 + ) from exc database = self._client.create_database_if_not_exists(self.cosmos_database) self._container = database.create_container_if_not_exists( self.cosmos_container, @@ -98,11 +104,9 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): def __enter__(self) -> "CosmosDBChatMessageHistory": """Context manager entry point.""" - if self._client: - self._client.__enter__() - self.prepare_cosmos() - return self - raise ValueError("Client not initialized") + self._client.__enter__() + self.prepare_cosmos() + return self def __exit__( self, @@ -112,8 +116,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): ) -> None: """Context manager exit""" self.upsert_messages() - if self._client: - self._client.__exit__(exc_type, exc_val, traceback) + self._client.__exit__(exc_type, exc_val, traceback) def load_messages(self) -> None: """Retrieve the messages from Cosmos""" @@ -134,11 +137,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): except CosmosHttpResponseError: logger.info("no session found") return - if ( - "messages" in item - and len(item["messages"]) > 0 - and isinstance(item["messages"][0], list) - ): + if "messages" in item and len(item["messages"]) > 0: self.messages = messages_from_dict(item["messages"]) def add_user_message(self, message: str) -> None: