Extend DynamoDBChatMessageHistory to support composite keys (#9896)

- Description: Adds two optional parameters to the
DynamoDBChatMessageHistory class to enable users to pass in a name for
their PrimaryKey, or a Key object itself to enable the use of composite
keys, a common DynamoDB paradigm.
  
[AWS DynamoDB Key
docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/)
  
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Co-authored-by: Josh White <josh@ctrlstack.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10155/head
Josh White 1 year ago committed by GitHub
parent 872d829201
commit bc8cceebf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 10,
"id": "93ce1811",
"metadata": {},
"outputs": [
@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 11,
"id": "d15e3302",
"metadata": {},
"outputs": [],
@ -87,18 +87,15 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"id": "64fc465e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
]
"text/plain": "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False),\n HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
},
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -119,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "225713c8",
"metadata": {},
"outputs": [],
@ -133,6 +130,81 @@
")"
]
},
{
"cell_type": "markdown",
"source": [
"## DynamoDBChatMessageHistory With Different Keys Composite Keys\n",
"The default key for DynamoDBChatMessageHistory is ```{\"SessionId\": self.session_id}```, but you can modify this to match your table design.\n",
"\n",
"### Primary Key Name\n",
"You may modify the primary key by passing in a primary_key_name value in the constructor, resulting in the following:\n",
"```{self.primary_key_name: self.session_id}```\n",
"\n",
"### Composite Keys\n",
"When using an existing DynamoDB table, you may need to modify the key structure from the default of to something including a Sort Key. To do this you may use the ```key``` parameter.\n",
"\n",
"Passing a value for key will override the primary_key parameter, and the resulting key structure will be the passed value.\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
},
{
"data": {
"text/plain": "[HumanMessage(content='hello, composite dynamodb table!', additional_kwargs={}, example=False)]"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
"\n",
"composite_table = dynamodb.create_table(\n",
" TableName=\"CompositeTable\",\n",
" KeySchema=[{\"AttributeName\": \"PK\", \"KeyType\": \"HASH\"}, {\"AttributeName\": \"SK\", \"KeyType\": \"RANGE\"}],\n",
" AttributeDefinitions=[{\"AttributeName\": \"PK\", \"AttributeType\": \"S\"}, {\"AttributeName\": \"SK\", \"AttributeType\": \"S\"}],\n",
" BillingMode=\"PAY_PER_REQUEST\",\n",
")\n",
"\n",
"# Wait until the table exists.\n",
"composite_table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"CompositeTable\")\n",
"\n",
"# Print out some data about the table.\n",
"print(composite_table.item_count)\n",
"\n",
"my_key = {\n",
" \"PK\": \"session_id::0\",\n",
" \"SK\": \"langchain_history\",\n",
"}\n",
"\n",
"composite_key_history = DynamoDBChatMessageHistory(\n",
" table_name=\"CompositeTable\",\n",
" session_id=\"0\",\n",
" endpoint_url=\"http://localhost.localstack.cloud:4566\",\n",
" key=my_key,\n",
")\n",
"\n",
"composite_key_history.add_user_message(\"hello, composite dynamodb table!\")\n",
"\n",
"composite_key_history.messages"
],
"metadata": {
"collapsed": false
}
},
{
"attachments": {},
"cell_type": "markdown",
@ -144,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 15,
"id": "f92d9499",
"metadata": {},
"outputs": [],
@ -165,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 16,
"id": "1167eeba",
"metadata": {},
"outputs": [],
@ -184,10 +256,24 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"id": "fce085c5",
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mValidationError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[17], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m llm \u001B[38;5;241m=\u001B[39m \u001B[43mChatOpenAI\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtemperature\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2\u001B[0m agent_chain \u001B[38;5;241m=\u001B[39m initialize_agent(\n\u001B[1;32m 3\u001B[0m tools,\n\u001B[1;32m 4\u001B[0m llm,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 7\u001B[0m memory\u001B[38;5;241m=\u001B[39mmemory,\n\u001B[1;32m 8\u001B[0m )\n",
"File \u001B[0;32m~/Documents/projects/langchain/libs/langchain/langchain/load/serializable.py:74\u001B[0m, in \u001B[0;36mSerializable.__init__\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_lc_kwargs \u001B[38;5;241m=\u001B[39m kwargs\n",
"File \u001B[0;32m~/Documents/projects/langchain/.venv/lib/python3.9/site-packages/pydantic/main.py:341\u001B[0m, in \u001B[0;36mpydantic.main.BaseModel.__init__\u001B[0;34m()\u001B[0m\n",
"\u001B[0;31mValidationError\u001B[0m: 1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)"
]
}
],
"source": [
"llm = ChatOpenAI(temperature=0)\n",
"agent_chain = initialize_agent(\n",
@ -201,152 +287,42 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "952a3103",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hello! How can I assist you today?\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello! How can I assist you today?'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"agent_chain.run(input=\"Hello!\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "54c4aaf4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"python_repl\",\n",
" \"action_input\": \"import requests\\nfrom bs4 import BeautifulSoup\\n\\nurl = 'https://en.wikipedia.org/wiki/Twitter'\\nresponse = requests.get(url)\\nsoup = BeautifulSoup(response.content, 'html.parser')\\nowner = soup.find('th', text='Owner').find_next_sibling('td').text.strip()\\nprint(owner)\"\n",
"}\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mX Corp. (2023present)Twitter, Inc. (20062023)\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"X Corp. (2023present)Twitter, Inc. (20062023)\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'X Corp. (2023present)Twitter, Inc. (20062023)'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"agent_chain.run(input=\"Who owns Twitter?\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "f9013118",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hello Bob! How can I assist you today?\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello Bob! How can I assist you today?'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"agent_chain.run(input=\"My name is Bob.\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "405e5315",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Your name is Bob.\"\n",
"}\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Your name is Bob.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"agent_chain.run(input=\"Who am I?\")"
"agent_chain.run(input=\"Who am I?\")\n"
]
}
],

@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Dict, List, Optional
from langchain.schema import (
BaseChatMessageHistory,
@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in AWS DynamoDB.
This class expects that a DynamoDB table with name `table_name`
and a partition Key of `SessionId` is present.
This class expects that a DynamoDB table exists with name `table_name`
Args:
table_name: name of the DynamoDB table
@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
is optional and useful for test purposes, like using Localstack.
If you plan to use AWS cloud service, you normally don't have to
worry about setting the endpoint_url.
primary_key_name: name of the primary key of the DynamoDB table. This argument
is optional, defaulting to "SessionId".
key: an optional dictionary with a custom primary and secondary key.
This argument is optional, but useful when using composite dynamodb keys, or
isolating records based off of application details such as a user id.
This may also contain global and local secondary index keys.
"""
def __init__(
self, table_name: str, session_id: str, endpoint_url: Optional[str] = None
self,
table_name: str,
session_id: str,
endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None,
):
import boto3
@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
client = boto3.resource("dynamodb")
self.table = client.Table(table_name)
self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id}
@property
def messages(self) -> List[BaseMessage]: # type: ignore
@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
response = None
try:
response = self.table.get_item(Key={"SessionId": self.session_id})
response = self.table.get_item(Key=self.key)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No record found with session id: %s", self.session_id)
@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
messages.append(_message)
try:
self.table.put_item(
Item={"SessionId": self.session_id, "History": messages}
)
self.table.put_item(Item={**self.key, "History": messages})
except ClientError as err:
logger.error(err)
@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError
try:
self.table.delete_item(Key={"SessionId": self.session_id})
self.table.delete_item(self.key)
except ClientError as err:
logger.error(err)

Loading…
Cancel
Save