diff --git a/langchain/memory/chat_message_histories/cosmos_db.py b/langchain/memory/chat_message_histories/cosmos_db.py index 8030907c..3c021928 100644 --- a/langchain/memory/chat_message_histories/cosmos_db.py +++ b/langchain/memory/chat_message_histories/cosmos_db.py @@ -33,6 +33,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): credential: Any = None, connection_string: Optional[str] = None, ttl: Optional[int] = None, + cosmos_client_kwargs: Optional[dict] = None, ): """ Initializes a new instance of the CosmosDBChatMessageHistory class. @@ -50,6 +51,7 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): :param credential: The credential to use to authenticate to Azure Cosmos DB. :param connection_string: The connection string to use to authenticate. :param ttl: The time to live (in seconds) to use for documents in the container. + :param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient. """ self.cosmos_endpoint = cosmos_endpoint self.cosmos_database = cosmos_database @@ -71,11 +73,14 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): ) from exc if self.credential: self._client = CosmosClient( - url=self.cosmos_endpoint, credential=self.credential + url=self.cosmos_endpoint, + credential=self.credential, + **cosmos_client_kwargs or {}, ) elif self.conn_string: self._client = CosmosClient.from_connection_string( - conn_str=self.conn_string + conn_str=self.conn_string, + **cosmos_client_kwargs or {}, ) else: raise ValueError("Either a connection string or a credential must be set.")