Add connection string auth to cosmos (#3867)

Adds a connection string option for the cosmos memory, in case AAD auth
is not enabled on the cosmos instance.
This commit is contained in:
Eduard van Valkenburg 2023-05-02 05:21:46 +02:00 committed by GitHub
parent bc7e4d5cd4
commit c38cafd6c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,26 +28,34 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
cosmos_endpoint: str, cosmos_endpoint: str,
cosmos_database: str, cosmos_database: str,
cosmos_container: str, cosmos_container: str,
credential: Any,
session_id: str, session_id: str,
user_id: str, user_id: str,
credential: Any = None,
connection_string: Optional[str] = None,
ttl: Optional[int] = None, ttl: Optional[int] = None,
): ):
""" """
Initializes a new instance of the CosmosDBChatMessageHistory class. Initializes a new instance of the CosmosDBChatMessageHistory class.
Make sure to call prepare_cosmos or use the context manager to make
sure your database is ready.
Either a credential or a connection string must be provided.
:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account. :param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
:param cosmos_database: The name of the database to use. :param cosmos_database: The name of the database to use.
:param cosmos_container: The name of the container to use. :param cosmos_container: The name of the container to use.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param session_id: The session ID to use, can be overwritten while loading. :param session_id: The session ID to use, can be overwritten while loading.
:param user_id: The user ID to use, can be overwritten while loading. :param user_id: The user ID to use, can be overwritten while loading.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param connection_string: The connection string to use to authenticate.
:param ttl: The time to live (in seconds) to use for documents in the container. :param ttl: The time to live (in seconds) to use for documents in the container.
""" """
self.cosmos_endpoint = cosmos_endpoint self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container self.cosmos_container = cosmos_container
self.credential = credential self.credential = credential
self.conn_string = connection_string
self.session_id = session_id self.session_id = session_id
self.user_id = user_id self.user_id = user_id
self.ttl = ttl self.ttl = ttl
@ -70,9 +78,16 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
raise ImportError( raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
) from exc ) from exc
self._client = CosmosClient( if self.credential:
url=self.cosmos_endpoint, credential=self.credential self._client = CosmosClient(
) url=self.cosmos_endpoint, credential=self.credential
)
elif self.conn_string:
self._client = CosmosClient.from_connection_string(
conn_str=self.conn_string
)
else:
raise ValueError("Either a connection string or a credential must be set.")
database = self._client.create_database_if_not_exists(self.cosmos_database) database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists( self._container = database.create_container_if_not_exists(
self.cosmos_container, self.cosmos_container,