mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Extend DynamoDBChatMessageHistory to support composite keys (#9896)
- Description: Adds two optional parameters to the DynamoDBChatMessageHistory class to enable users to pass in a name for their PrimaryKey, or a Key object itself to enable the use of composite keys, a common DynamoDB paradigm. [AWS DynamoDB Key docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/) - Issue: N/A - Dependencies: N/A - Twitter handle: N/A --------- Co-authored-by: Josh White <josh@ctrlstack.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
872d829201
commit
bc8cceebf7
@ -28,7 +28,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 10,
|
||||
"id": "93ce1811",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -71,7 +71,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 11,
|
||||
"id": "d15e3302",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -87,18 +87,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 12,
|
||||
"id": "64fc465e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
|
||||
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
|
||||
]
|
||||
"text/plain": "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False),\n HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -119,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "225713c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -133,6 +130,81 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## DynamoDBChatMessageHistory With Different Keys Composite Keys\n",
|
||||
"The default key for DynamoDBChatMessageHistory is ```{\"SessionId\": self.session_id}```, but you can modify this to match your table design.\n",
|
||||
"\n",
|
||||
"### Primary Key Name\n",
|
||||
"You may modify the primary key by passing in a primary_key_name value in the constructor, resulting in the following:\n",
|
||||
"```{self.primary_key_name: self.session_id}```\n",
|
||||
"\n",
|
||||
"### Composite Keys\n",
|
||||
"When using an existing DynamoDB table, you may need to modify the key structure from the default of to something including a Sort Key. To do this you may use the ```key``` parameter.\n",
|
||||
"\n",
|
||||
"Passing a value for key will override the primary_key parameter, and the resulting key structure will be the passed value.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "[HumanMessage(content='hello, composite dynamodb table!', additional_kwargs={}, example=False)]"
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
|
||||
"\n",
|
||||
"composite_table = dynamodb.create_table(\n",
|
||||
" TableName=\"CompositeTable\",\n",
|
||||
" KeySchema=[{\"AttributeName\": \"PK\", \"KeyType\": \"HASH\"}, {\"AttributeName\": \"SK\", \"KeyType\": \"RANGE\"}],\n",
|
||||
" AttributeDefinitions=[{\"AttributeName\": \"PK\", \"AttributeType\": \"S\"}, {\"AttributeName\": \"SK\", \"AttributeType\": \"S\"}],\n",
|
||||
" BillingMode=\"PAY_PER_REQUEST\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Wait until the table exists.\n",
|
||||
"composite_table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"CompositeTable\")\n",
|
||||
"\n",
|
||||
"# Print out some data about the table.\n",
|
||||
"print(composite_table.item_count)\n",
|
||||
"\n",
|
||||
"my_key = {\n",
|
||||
" \"PK\": \"session_id::0\",\n",
|
||||
" \"SK\": \"langchain_history\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"composite_key_history = DynamoDBChatMessageHistory(\n",
|
||||
" table_name=\"CompositeTable\",\n",
|
||||
" session_id=\"0\",\n",
|
||||
" endpoint_url=\"http://localhost.localstack.cloud:4566\",\n",
|
||||
" key=my_key,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"composite_key_history.add_user_message(\"hello, composite dynamodb table!\")\n",
|
||||
"\n",
|
||||
"composite_key_history.messages"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -144,7 +216,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 15,
|
||||
"id": "f92d9499",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -165,7 +237,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 16,
|
||||
"id": "1167eeba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -184,10 +256,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 17,
|
||||
"id": "fce085c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ValidationError",
|
||||
"evalue": "1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[0;31mValidationError\u001B[0m Traceback (most recent call last)",
|
||||
"Cell \u001B[0;32mIn[17], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m llm \u001B[38;5;241m=\u001B[39m \u001B[43mChatOpenAI\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtemperature\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2\u001B[0m agent_chain \u001B[38;5;241m=\u001B[39m initialize_agent(\n\u001B[1;32m 3\u001B[0m tools,\n\u001B[1;32m 4\u001B[0m llm,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 7\u001B[0m memory\u001B[38;5;241m=\u001B[39mmemory,\n\u001B[1;32m 8\u001B[0m )\n",
|
||||
"File \u001B[0;32m~/Documents/projects/langchain/libs/langchain/langchain/load/serializable.py:74\u001B[0m, in \u001B[0;36mSerializable.__init__\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_lc_kwargs \u001B[38;5;241m=\u001B[39m kwargs\n",
|
||||
"File \u001B[0;32m~/Documents/projects/langchain/.venv/lib/python3.9/site-packages/pydantic/main.py:341\u001B[0m, in \u001B[0;36mpydantic.main.BaseModel.__init__\u001B[0;34m()\u001B[0m\n",
|
||||
"\u001B[0;31mValidationError\u001B[0m: 1 validation error for ChatOpenAI\n__root__\n Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"agent_chain = initialize_agent(\n",
|
||||
@ -201,152 +287,42 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "952a3103",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Hello! How can I assist you today?\"\n",
|
||||
"}\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Hello! How can I assist you today?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_chain.run(input=\"Hello!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "54c4aaf4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"action\": \"python_repl\",\n",
|
||||
" \"action_input\": \"import requests\\nfrom bs4 import BeautifulSoup\\n\\nurl = 'https://en.wikipedia.org/wiki/Twitter'\\nresponse = requests.get(url)\\nsoup = BeautifulSoup(response.content, 'html.parser')\\nowner = soup.find('th', text='Owner').find_next_sibling('td').text.strip()\\nprint(owner)\"\n",
|
||||
"}\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mX Corp. (2023–present)Twitter, Inc. (2006–2023)\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"X Corp. (2023–present)Twitter, Inc. (2006–2023)\"\n",
|
||||
"}\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'X Corp. (2023–present)Twitter, Inc. (2006–2023)'"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_chain.run(input=\"Who owns Twitter?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"id": "f9013118",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Hello Bob! How can I assist you today?\"\n",
|
||||
"}\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Hello Bob! How can I assist you today?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_chain.run(input=\"My name is Bob.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "405e5315",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Your name is Bob.\"\n",
|
||||
"}\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Your name is Bob.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_chain.run(input=\"Who am I?\")"
|
||||
"agent_chain.run(input=\"Who am I?\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in AWS DynamoDB.
|
||||
|
||||
This class expects that a DynamoDB table with name `table_name`
|
||||
and a partition Key of `SessionId` is present.
|
||||
This class expects that a DynamoDB table exists with name `table_name`
|
||||
|
||||
Args:
|
||||
table_name: name of the DynamoDB table
|
||||
@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
is optional and useful for test purposes, like using Localstack.
|
||||
If you plan to use AWS cloud service, you normally don't have to
|
||||
worry about setting the endpoint_url.
|
||||
primary_key_name: name of the primary key of the DynamoDB table. This argument
|
||||
is optional, defaulting to "SessionId".
|
||||
key: an optional dictionary with a custom primary and secondary key.
|
||||
This argument is optional, but useful when using composite dynamodb keys, or
|
||||
isolating records based off of application details such as a user id.
|
||||
This may also contain global and local secondary index keys.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, table_name: str, session_id: str, endpoint_url: Optional[str] = None
|
||||
self,
|
||||
table_name: str,
|
||||
session_id: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
primary_key_name: str = "SessionId",
|
||||
key: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
import boto3
|
||||
|
||||
@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
client = boto3.resource("dynamodb")
|
||||
self.table = client.Table(table_name)
|
||||
self.session_id = session_id
|
||||
self.key: Dict = key or {primary_key_name: session_id}
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = self.table.get_item(Key={"SessionId": self.session_id})
|
||||
response = self.table.get_item(Key=self.key)
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
||||
logger.warning("No record found with session id: %s", self.session_id)
|
||||
@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
self.table.put_item(
|
||||
Item={"SessionId": self.session_id, "History": messages}
|
||||
)
|
||||
self.table.put_item(Item={**self.key, "History": messages})
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
|
||||
@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self.table.delete_item(Key={"SessionId": self.session_id})
|
||||
self.table.delete_item(self.key)
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
|
Loading…
Reference in New Issue
Block a user