From c38cafd6c2a92a0d698982831ef9a29cbbd5cc29 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Tue, 2 May 2023 05:21:46 +0200 Subject: [PATCH] Add connection string auth to cosmos (#3867) Adds a connection string option for the cosmos memory, in case AAD auth is not enabled on the cosmos instance. --- .../chat_message_histories/cosmos_db.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/langchain/memory/chat_message_histories/cosmos_db.py b/langchain/memory/chat_message_histories/cosmos_db.py index b4f8a264..6da81570 100644 --- a/langchain/memory/chat_message_histories/cosmos_db.py +++ b/langchain/memory/chat_message_histories/cosmos_db.py @@ -28,26 +28,34 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): cosmos_endpoint: str, cosmos_database: str, cosmos_container: str, - credential: Any, session_id: str, user_id: str, + credential: Any = None, + connection_string: Optional[str] = None, ttl: Optional[int] = None, ): """ Initializes a new instance of the CosmosDBChatMessageHistory class. + Make sure to call prepare_cosmos or use the context manager to make + sure your database is ready. + + Either a credential or a connection string must be provided. + :param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account. :param cosmos_database: The name of the database to use. :param cosmos_container: The name of the container to use. - :param credential: The credential to use to authenticate to Azure Cosmos DB. :param session_id: The session ID to use, can be overwritten while loading. :param user_id: The user ID to use, can be overwritten while loading. + :param credential: The credential to use to authenticate to Azure Cosmos DB. + :param connection_string: The connection string to use to authenticate. :param ttl: The time to live (in seconds) to use for documents in the container. """ self.cosmos_endpoint = cosmos_endpoint self.cosmos_database = cosmos_database self.cosmos_container = cosmos_container self.credential = credential + self.conn_string = connection_string self.session_id = session_id self.user_id = user_id self.ttl = ttl @@ -70,9 +78,16 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): raise ImportError( "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 ) from exc - self._client = CosmosClient( - url=self.cosmos_endpoint, credential=self.credential - ) + if self.credential: + self._client = CosmosClient( + url=self.cosmos_endpoint, credential=self.credential + ) + elif self.conn_string: + self._client = CosmosClient.from_connection_string( + conn_str=self.conn_string + ) + else: + raise ValueError("Either a connection string or a credential must be set.") database = self._client.create_database_if_not_exists(self.cosmos_database) self._container = database.create_container_if_not_exists( self.cosmos_container,