From 873b0c7eb6523663099f9315b28677979c632681 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 13 May 2023 21:45:42 -0700 Subject: [PATCH] Harrison/structured chat mem (#4652) Co-authored-by: d 3 n 7 <29033313+d3n7@users.noreply.github.com> --- .../agents/examples/structured_chat.ipynb | 127 +++++++++++++++++- langchain/agents/structured_chat/base.py | 9 +- langchain/callbacks/tracers/langchain.py | 2 +- 3 files changed, 130 insertions(+), 8 deletions(-) diff --git a/docs/modules/agents/agents/examples/structured_chat.ipynb b/docs/modules/agents/agents/examples/structured_chat.ipynb index 6a03c0d9..2d280c78 100644 --- a/docs/modules/agents/agents/examples/structured_chat.ipynb +++ b/docs/modules/agents/agents/examples/structured_chat.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "ccc8ff98", "metadata": {}, "outputs": [], @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6", "metadata": { "tags": [] @@ -111,9 +111,17 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mAction:\n", + "```\n", + "{\n", + " \"action\": \"Final Answer\",\n", + " \"action_input\": \"Hello Erica, how can I assist you today?\"\n", + "}\n", + "```\n", + "\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "Hi Erica! How can I assist you today?\n" + "Hello Erica, how can I assist you today?\n" ] } ], @@ -274,10 +282,119 @@ "print(response)" ] }, + { + "cell_type": "markdown", + "id": "42473442", + "metadata": {}, + "source": [ + "## Adding in memory\n", + "\n", + "Here is how you add in memory to this agent" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b5a0dd2a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import MessagesPlaceholder\n", + "from langchain.memory import ConversationBufferMemory" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "91b9288f", + "metadata": {}, + "outputs": [], + "source": [ + "chat_history = MessagesPlaceholder(variable_name=\"chat_history\")\n", + "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "dba9e0d9", + "metadata": {}, + "outputs": [], + "source": [ + "agent_chain = initialize_agent(\n", + " tools, \n", + " llm, \n", + " agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, \n", + " verbose=True, \n", + " memory=memory, \n", + " agent_kwargs = {\n", + " \"memory_prompts\": [chat_history],\n", + " \"input_variables\": [\"input\", \"agent_scratchpad\", \"chat_history\"]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a9509461", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mAction:\n", + "```\n", + "{\n", + " \"action\": \"Final Answer\",\n", + " \"action_input\": \"Hi Erica! How can I assist you today?\"\n", + "}\n", + "```\n", + "\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "Hi Erica! How can I assist you today?\n" + ] + } + ], + "source": [ + "response = await agent_chain.arun(input=\"Hi I'm Erica.\")\n", + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "412cedd2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mYour name is Erica.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "Your name is Erica.\n" + ] + } + ], + "source": [ + "response = await agent_chain.arun(input=\"whats my name?\")\n", + "print(response)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "ebd7ae33-f67d-4378-ac79-9d91e0c8f53a", + "id": "9af1a713", "metadata": {}, "outputs": [], "source": [] @@ -299,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/agents/structured_chat/base.py b/langchain/agents/structured_chat/base.py index 81bc26cc..75860183 100644 --- a/langchain/agents/structured_chat/base.py +++ b/langchain/agents/structured_chat/base.py @@ -76,6 +76,7 @@ class StructuredChatAgent(Agent): human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, + memory_prompts: Optional[List[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -85,12 +86,14 @@ class StructuredChatAgent(Agent): tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) + if input_variables is None: + input_variables = ["input", "agent_scratchpad"] + _memory_prompts = memory_prompts or [] messages = [ SystemMessagePromptTemplate.from_template(template), + *_memory_prompts, HumanMessagePromptTemplate.from_template(human_message_template), ] - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] return ChatPromptTemplate(input_variables=input_variables, messages=messages) @classmethod @@ -105,6 +108,7 @@ class StructuredChatAgent(Agent): human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, + memory_prompts: Optional[List[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" @@ -116,6 +120,7 @@ class StructuredChatAgent(Agent): human_message_template=human_message_template, format_instructions=format_instructions, input_variables=input_variables, + memory_prompts=memory_prompts, ) llm_chain = LLMChain( llm=llm, diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index e893c07b..e8603905 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -88,7 +88,7 @@ class LangChainTracer(BaseTracer): name=serialized.get("name"), parent_run_id=parent_run_id, serialized=serialized, - inputs={"messages": messages_to_dict(batch) for batch in messages}, + inputs={"messages": [messages_to_dict(batch) for batch in messages]}, extra=kwargs, start_time=datetime.utcnow(), execution_order=execution_order,