Extend DynamoDBChatMessageHistory to support composite keys (#9896)

- Description: Adds two optional parameters to the
DynamoDBChatMessageHistory class to enable users to pass in a name for
their PrimaryKey, or a Key object itself to enable the use of composite
keys, a common DynamoDB paradigm.
  
[AWS DynamoDB Key
docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/)
  
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Co-authored-by: Josh White <josh@ctrlstack.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Josh White 2023-09-03 17:05:16 -05:00 committed by GitHub
parent 872d829201
commit bc8cceebf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 126 additions and 141 deletions

View File

@ -28,7 +28,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 10,
"id": "93ce1811", "id": "93ce1811",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 11,
"id": "d15e3302", "id": "d15e3302",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -87,18 +87,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 12,
"id": "64fc465e", "id": "64fc465e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False),\n HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
]
}, },
"execution_count": 3, "execution_count": 12,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -119,7 +116,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"id": "225713c8", "id": "225713c8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -133,6 +130,81 @@
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"## DynamoDBChatMessageHistory With Different Keys Composite Keys\n",
"The default key for DynamoDBChatMessageHistory is ```{\"SessionId\": self.session_id}```, but you can modify this to match your table design.\n",
"\n",
"### Primary Key Name\n",
"You may modify the primary key by passing in a primary_key_name value in the constructor, resulting in the following:\n",
"```{self.primary_key_name: self.session_id}```\n",
"\n",
"### Composite Keys\n",
"When using an existing DynamoDB table, you may need to modify the key structure from the default of to something including a Sort Key. To do this you may use the ```key``` parameter.\n",
"\n",
"Passing a value for key will override the primary_key parameter, and the resulting key structure will be the passed value.\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
},
{
"data": {
"text/plain": "[HumanMessage(content='hello, composite dynamodb table!', additional_kwargs={}, example=False)]"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
"\n",
"composite_table = dynamodb.create_table(\n",
" TableName=\"CompositeTable\",\n",
" KeySchema=[{\"AttributeName\": \"PK\", \"KeyType\": \"HASH\"}, {\"AttributeName\": \"SK\", \"KeyType\": \"RANGE\"}],\n",
" AttributeDefinitions=[{\"AttributeName\": \"PK\", \"AttributeType\": \"S\"}, {\"AttributeName\": \"SK\", \"AttributeType\": \"S\"}],\n",
" BillingMode=\"PAY_PER_REQUEST\",\n",
")\n",
"\n",
"# Wait until the table exists.\n",
"composite_table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"CompositeTable\")\n",
"\n",
"# Print out some data about the table.\n",
"print(composite_table.item_count)\n",
"\n",
"my_key = {\n",
" \"PK\": \"session_id::0\",\n",
" \"SK\": \"langchain_history\",\n",
"}\n",
"\n",
"composite_key_history = DynamoDBChatMessageHistory(\n",
" table_name=\"CompositeTable\",\n",
" session_id=\"0\",\n",
" endpoint_url=\"http://localhost.localstack.cloud:4566\",\n",
" key=my_key,\n",
")\n",
"\n",
"composite_key_history.add_user_message(\"hello, composite dynamodb table!\")\n",
"\n",
"composite_key_history.messages"
],
"metadata": {
"collapsed": false
}
},
{ {
"attachments": {}, "attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
@ -144,7 +216,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 15,
"id": "f92d9499", "id": "f92d9499",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -165,7 +237,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 16,
"id": "1167eeba", "id": "1167eeba",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -184,10 +256,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 17,
"id": "fce085c5", "id": "fce085c5",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mValidationError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[17], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m llm \u001B[38;5;241m=\u001B[39m \u001B[43mChatOpenAI\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtemperature\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2\u001B[0m agent_chain \u001B[38;5;241m=\u001B[39m initialize_agent(\n\u001B[1;32m 3\u001B[0m tools,\n\u001B[1;32m 4\u001B[0m llm,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 7\u001B[0m memory\u001B[38;5;241m=\u001B[39mmemory,\n\u001B[1;32m 8\u001B[0m )\n",
"File \u001B[0;32m~/Documents/projects/langchain/libs/langchain/langchain/load/serializable.py:74\u001B[0m, in \u001B[0;36mSerializable.__init__\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_lc_kwargs \u001B[38;5;241m=\u001B[39m kwargs\n",
"File \u001B[0;32m~/Documents/projects/langchain/.venv/lib/python3.9/site-packages/pydantic/main.py:341\u001B[0m, in \u001B[0;36mpydantic.main.BaseModel.__init__\u001B[0;34m()\u001B[0m\n",
"\u001B[0;31mValidationError\u001B[0m: 1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)"
]
}
],
"source": [ "source": [
"llm = ChatOpenAI(temperature=0)\n", "llm = ChatOpenAI(temperature=0)\n",
"agent_chain = initialize_agent(\n", "agent_chain = initialize_agent(\n",
@ -201,152 +287,42 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"id": "952a3103", "id": "952a3103",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hello! How can I assist you today?\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello! How can I assist you today?'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"agent_chain.run(input=\"Hello!\")" "agent_chain.run(input=\"Hello!\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"id": "54c4aaf4", "id": "54c4aaf4",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"python_repl\",\n",
" \"action_input\": \"import requests\\nfrom bs4 import BeautifulSoup\\n\\nurl = 'https://en.wikipedia.org/wiki/Twitter'\\nresponse = requests.get(url)\\nsoup = BeautifulSoup(response.content, 'html.parser')\\nowner = soup.find('th', text='Owner').find_next_sibling('td').text.strip()\\nprint(owner)\"\n",
"}\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mX Corp. (2023present)Twitter, Inc. (20062023)\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"X Corp. (2023present)Twitter, Inc. (20062023)\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'X Corp. (2023present)Twitter, Inc. (20062023)'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"agent_chain.run(input=\"Who owns Twitter?\")" "agent_chain.run(input=\"Who owns Twitter?\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"id": "f9013118", "id": "f9013118",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hello Bob! How can I assist you today?\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello Bob! How can I assist you today?'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"agent_chain.run(input=\"My name is Bob.\")" "agent_chain.run(input=\"My name is Bob.\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": null,
"id": "405e5315", "id": "405e5315",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Your name is Bob.\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Your name is Bob.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"agent_chain.run(input=\"Who am I?\")" "agent_chain.run(input=\"Who am I?\")\n"
] ]
} }
], ],

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import List, Optional from typing import Dict, List, Optional
from langchain.schema import ( from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
class DynamoDBChatMessageHistory(BaseChatMessageHistory): class DynamoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in AWS DynamoDB. """Chat message history that stores history in AWS DynamoDB.
This class expects that a DynamoDB table with name `table_name` This class expects that a DynamoDB table exists with name `table_name`
and a partition Key of `SessionId` is present.
Args: Args:
table_name: name of the DynamoDB table table_name: name of the DynamoDB table
@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
is optional and useful for test purposes, like using Localstack. is optional and useful for test purposes, like using Localstack.
If you plan to use AWS cloud service, you normally don't have to If you plan to use AWS cloud service, you normally don't have to
worry about setting the endpoint_url. worry about setting the endpoint_url.
primary_key_name: name of the primary key of the DynamoDB table. This argument
is optional, defaulting to "SessionId".
key: an optional dictionary with a custom primary and secondary key.
This argument is optional, but useful when using composite dynamodb keys, or
isolating records based off of application details such as a user id.
This may also contain global and local secondary index keys.
""" """
def __init__( def __init__(
self, table_name: str, session_id: str, endpoint_url: Optional[str] = None self,
table_name: str,
session_id: str,
endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None,
): ):
import boto3 import boto3
@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
client = boto3.resource("dynamodb") client = boto3.resource("dynamodb")
self.table = client.Table(table_name) self.table = client.Table(table_name)
self.session_id = session_id self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id}
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore
@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
response = None response = None
try: try:
response = self.table.get_item(Key={"SessionId": self.session_id}) response = self.table.get_item(Key=self.key)
except ClientError as error: except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException": if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No record found with session id: %s", self.session_id) logger.warning("No record found with session id: %s", self.session_id)
@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
messages.append(_message) messages.append(_message)
try: try:
self.table.put_item( self.table.put_item(Item={**self.key, "History": messages})
Item={"SessionId": self.session_id, "History": messages}
)
except ClientError as err: except ClientError as err:
logger.error(err) logger.error(err)
@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
try: try:
self.table.delete_item(Key={"SessionId": self.session_id}) self.table.delete_item(self.key)
except ClientError as err: except ClientError as err:
logger.error(err) logger.error(err)