Harrison/structured chat mem (#4652)

Co-authored-by: d 3 n 7 <29033313+d3n7@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-05-13 21:45:42 -07:00 committed by GitHub
parent 9ba3a798c4
commit 873b0c7eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 10 deletions

View File

@ -16,7 +16,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "ccc8ff98", "id": "ccc8ff98",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -98,7 +98,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6", "id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -111,9 +111,17 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction:\n",
"```\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hello Erica, how can I assist you today?\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"Hi Erica! How can I assist you today?\n" "Hello Erica, how can I assist you today?\n"
] ]
} }
], ],
@ -274,10 +282,119 @@
"print(response)" "print(response)"
] ]
}, },
{
"cell_type": "markdown",
"id": "42473442",
"metadata": {},
"source": [
"## Adding in memory\n",
"\n",
"Here is how you add in memory to this agent"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b5a0dd2a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import MessagesPlaceholder\n",
"from langchain.memory import ConversationBufferMemory"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "91b9288f",
"metadata": {},
"outputs": [],
"source": [
"chat_history = MessagesPlaceholder(variable_name=\"chat_history\")\n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dba9e0d9",
"metadata": {},
"outputs": [],
"source": [
"agent_chain = initialize_agent(\n",
" tools, \n",
" llm, \n",
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, \n",
" verbose=True, \n",
" memory=memory, \n",
" agent_kwargs = {\n",
" \"memory_prompts\": [chat_history],\n",
" \"input_variables\": [\"input\", \"agent_scratchpad\", \"chat_history\"]\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a9509461",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction:\n",
"```\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hi Erica! How can I assist you today?\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"Hi Erica! How can I assist you today?\n"
]
}
],
"source": [
"response = await agent_chain.arun(input=\"Hi I'm Erica.\")\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "412cedd2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mYour name is Erica.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"Your name is Erica.\n"
]
}
],
"source": [
"response = await agent_chain.arun(input=\"whats my name?\")\n",
"print(response)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ebd7ae33-f67d-4378-ac79-9d91e0c8f53a", "id": "9af1a713",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
@ -299,7 +416,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -76,6 +76,7 @@ class StructuredChatAgent(Agent):
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
tool_strings = [] tool_strings = []
for tool in tools: for tool in tools:
@ -85,12 +86,14 @@ class StructuredChatAgent(Agent):
tool_names = ", ".join([tool.name for tool in tools]) tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names) format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
messages = [
SystemMessagePromptTemplate.from_template(template),
HumanMessagePromptTemplate.from_template(human_message_template),
]
if input_variables is None: if input_variables is None:
input_variables = ["input", "agent_scratchpad"] input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages) return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod @classmethod
@ -105,6 +108,7 @@ class StructuredChatAgent(Agent):
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
@ -116,6 +120,7 @@ class StructuredChatAgent(Agent):
human_message_template=human_message_template, human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
memory_prompts=memory_prompts,
) )
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,

View File

@ -88,7 +88,7 @@ class LangChainTracer(BaseTracer):
name=serialized.get("name"), name=serialized.get("name"),
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"messages": messages_to_dict(batch) for batch in messages}, inputs={"messages": [messages_to_dict(batch) for batch in messages]},
extra=kwargs, extra=kwargs,
start_time=datetime.utcnow(), start_time=datetime.utcnow(),
execution_order=execution_order, execution_order=execution_order,