From bc8cceebf7b2d8e056b905926a6009367b6a8b14 Mon Sep 17 00:00:00 2001 From: Josh White <31215641+joshualwhite@users.noreply.github.com> Date: Sun, 3 Sep 2023 17:05:16 -0500 Subject: [PATCH] Extend DynamoDBChatMessageHistory to support composite keys (#9896) - Description: Adds two optional parameters to the DynamoDBChatMessageHistory class to enable users to pass in a name for their PrimaryKey, or a Key object itself to enable the use of composite keys, a common DynamoDB paradigm. [AWS DynamoDB Key docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/) - Issue: N/A - Dependencies: N/A - Twitter handle: N/A --------- Co-authored-by: Josh White Co-authored-by: Bagatur --- .../dynamodb_chat_message_history.ipynb | 240 ++++++++---------- .../memory/chat_message_histories/dynamodb.py | 27 +- 2 files changed, 126 insertions(+), 141 deletions(-) diff --git a/docs/extras/integrations/memory/dynamodb_chat_message_history.ipynb b/docs/extras/integrations/memory/dynamodb_chat_message_history.ipynb index a5c4dd0981..53e7230e2b 100644 --- a/docs/extras/integrations/memory/dynamodb_chat_message_history.ipynb +++ b/docs/extras/integrations/memory/dynamodb_chat_message_history.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "id": "93ce1811", "metadata": {}, "outputs": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "id": "d15e3302", "metadata": {}, "outputs": [], @@ -87,18 +87,15 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "id": "64fc465e", "metadata": {}, "outputs": [ { "data": { - "text/plain": [ - "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n", - " AIMessage(content='whats up?', additional_kwargs={}, example=False)]" - ] + "text/plain": "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False),\n HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False)]" }, - "execution_count": 3, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -119,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "225713c8", "metadata": {}, "outputs": [], @@ -133,6 +130,81 @@ ")" ] }, + { + "cell_type": "markdown", + "source": [ + "## DynamoDBChatMessageHistory With Different Keys Composite Keys\n", + "The default key for DynamoDBChatMessageHistory is ```{\"SessionId\": self.session_id}```, but you can modify this to match your table design.\n", + "\n", + "### Primary Key Name\n", + "You may modify the primary key by passing in a primary_key_name value in the constructor, resulting in the following:\n", + "```{self.primary_key_name: self.session_id}```\n", + "\n", + "### Composite Keys\n", + "When using an existing DynamoDB table, you may need to modify the key structure from the default of to something including a Sort Key. To do this you may use the ```key``` parameter.\n", + "\n", + "Passing a value for key will override the primary_key parameter, and the resulting key structure will be the passed value.\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "data": { + "text/plain": "[HumanMessage(content='hello, composite dynamodb table!', additional_kwargs={}, example=False)]" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n", + "\n", + "composite_table = dynamodb.create_table(\n", + " TableName=\"CompositeTable\",\n", + " KeySchema=[{\"AttributeName\": \"PK\", \"KeyType\": \"HASH\"}, {\"AttributeName\": \"SK\", \"KeyType\": \"RANGE\"}],\n", + " AttributeDefinitions=[{\"AttributeName\": \"PK\", \"AttributeType\": \"S\"}, {\"AttributeName\": \"SK\", \"AttributeType\": \"S\"}],\n", + " BillingMode=\"PAY_PER_REQUEST\",\n", + ")\n", + "\n", + "# Wait until the table exists.\n", + "composite_table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"CompositeTable\")\n", + "\n", + "# Print out some data about the table.\n", + "print(composite_table.item_count)\n", + "\n", + "my_key = {\n", + " \"PK\": \"session_id::0\",\n", + " \"SK\": \"langchain_history\",\n", + "}\n", + "\n", + "composite_key_history = DynamoDBChatMessageHistory(\n", + " table_name=\"CompositeTable\",\n", + " session_id=\"0\",\n", + " endpoint_url=\"http://localhost.localstack.cloud:4566\",\n", + " key=my_key,\n", + ")\n", + "\n", + "composite_key_history.add_user_message(\"hello, composite dynamodb table!\")\n", + "\n", + "composite_key_history.messages" + ], + "metadata": { + "collapsed": false + } + }, { "attachments": {}, "cell_type": "markdown", @@ -144,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "id": "f92d9499", "metadata": {}, "outputs": [], @@ -165,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "id": "1167eeba", "metadata": {}, "outputs": [], @@ -184,10 +256,24 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "id": "fce085c5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mValidationError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[17], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m llm \u001B[38;5;241m=\u001B[39m \u001B[43mChatOpenAI\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtemperature\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2\u001B[0m agent_chain \u001B[38;5;241m=\u001B[39m initialize_agent(\n\u001B[1;32m 3\u001B[0m tools,\n\u001B[1;32m 4\u001B[0m llm,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 7\u001B[0m memory\u001B[38;5;241m=\u001B[39mmemory,\n\u001B[1;32m 8\u001B[0m )\n", + "File \u001B[0;32m~/Documents/projects/langchain/libs/langchain/langchain/load/serializable.py:74\u001B[0m, in \u001B[0;36mSerializable.__init__\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_lc_kwargs \u001B[38;5;241m=\u001B[39m kwargs\n", + "File \u001B[0;32m~/Documents/projects/langchain/.venv/lib/python3.9/site-packages/pydantic/main.py:341\u001B[0m, in \u001B[0;36mpydantic.main.BaseModel.__init__\u001B[0;34m()\u001B[0m\n", + "\u001B[0;31mValidationError\u001B[0m: 1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)" + ] + } + ], "source": [ "llm = ChatOpenAI(temperature=0)\n", "agent_chain = initialize_agent(\n", @@ -201,152 +287,42 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "952a3103", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m{\n", - " \"action\": \"Final Answer\",\n", - " \"action_input\": \"Hello! How can I assist you today?\"\n", - "}\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'Hello! How can I assist you today?'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "agent_chain.run(input=\"Hello!\")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "54c4aaf4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m{\n", - " \"action\": \"python_repl\",\n", - " \"action_input\": \"import requests\\nfrom bs4 import BeautifulSoup\\n\\nurl = 'https://en.wikipedia.org/wiki/Twitter'\\nresponse = requests.get(url)\\nsoup = BeautifulSoup(response.content, 'html.parser')\\nowner = soup.find('th', text='Owner').find_next_sibling('td').text.strip()\\nprint(owner)\"\n", - "}\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mX Corp. (2023–present)Twitter, Inc. (2006–2023)\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m{\n", - " \"action\": \"Final Answer\",\n", - " \"action_input\": \"X Corp. (2023–present)Twitter, Inc. (2006–2023)\"\n", - "}\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'X Corp. (2023–present)Twitter, Inc. (2006–2023)'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "agent_chain.run(input=\"Who owns Twitter?\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "f9013118", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m{\n", - " \"action\": \"Final Answer\",\n", - " \"action_input\": \"Hello Bob! How can I assist you today?\"\n", - "}\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'Hello Bob! How can I assist you today?'" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "agent_chain.run(input=\"My name is Bob.\")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "405e5315", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m{\n", - " \"action\": \"Final Answer\",\n", - " \"action_input\": \"Your name is Bob.\"\n", - "}\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'Your name is Bob.'" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "agent_chain.run(input=\"Who am I?\")" + "agent_chain.run(input=\"Who am I?\")\n" ] } ], diff --git a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py index 3800017516..704efa9ea7 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Dict, List, Optional from langchain.schema import ( BaseChatMessageHistory, @@ -17,8 +17,7 @@ logger = logging.getLogger(__name__) class DynamoDBChatMessageHistory(BaseChatMessageHistory): """Chat message history that stores history in AWS DynamoDB. - This class expects that a DynamoDB table with name `table_name` - and a partition Key of `SessionId` is present. + This class expects that a DynamoDB table exists with name `table_name` Args: table_name: name of the DynamoDB table @@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): is optional and useful for test purposes, like using Localstack. If you plan to use AWS cloud service, you normally don't have to worry about setting the endpoint_url. + primary_key_name: name of the primary key of the DynamoDB table. This argument + is optional, defaulting to "SessionId". + key: an optional dictionary with a custom primary and secondary key. + This argument is optional, but useful when using composite dynamodb keys, or + isolating records based off of application details such as a user id. + This may also contain global and local secondary index keys. """ def __init__( - self, table_name: str, session_id: str, endpoint_url: Optional[str] = None + self, + table_name: str, + session_id: str, + endpoint_url: Optional[str] = None, + primary_key_name: str = "SessionId", + key: Optional[Dict[str, str]] = None, ): import boto3 @@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): client = boto3.resource("dynamodb") self.table = client.Table(table_name) self.session_id = session_id + self.key: Dict = key or {primary_key_name: session_id} @property def messages(self) -> List[BaseMessage]: # type: ignore @@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): response = None try: - response = self.table.get_item(Key={"SessionId": self.session_id}) + response = self.table.get_item(Key=self.key) except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": logger.warning("No record found with session id: %s", self.session_id) @@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): messages.append(_message) try: - self.table.put_item( - Item={"SessionId": self.session_id, "History": messages} - ) + self.table.put_item(Item={**self.key, "History": messages}) except ClientError as err: logger.error(err) @@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): from botocore.exceptions import ClientError try: - self.table.delete_item(Key={"SessionId": self.session_id}) + self.table.delete_item(self.key) except ClientError as err: logger.error(err)