forked from Archives/langchain
Harrison/structured chat mem (#4652)
Co-authored-by: d 3 n 7 <29033313+d3n7@users.noreply.github.com>
This commit is contained in:
parent
9ba3a798c4
commit
873b0c7eb6
@ -16,7 +16,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "ccc8ff98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -98,7 +98,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -111,9 +111,17 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Hello Erica, how can I assist you today?\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"Hi Erica! How can I assist you today?\n"
|
||||
"Hello Erica, how can I assist you today?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -274,10 +282,119 @@
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42473442",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Adding in memory\n",
|
||||
"\n",
|
||||
"Here is how you add in memory to this agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "b5a0dd2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import MessagesPlaceholder\n",
|
||||
"from langchain.memory import ConversationBufferMemory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "91b9288f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_history = MessagesPlaceholder(variable_name=\"chat_history\")\n",
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "dba9e0d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_chain = initialize_agent(\n",
|
||||
" tools, \n",
|
||||
" llm, \n",
|
||||
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, \n",
|
||||
" verbose=True, \n",
|
||||
" memory=memory, \n",
|
||||
" agent_kwargs = {\n",
|
||||
" \"memory_prompts\": [chat_history],\n",
|
||||
" \"input_variables\": [\"input\", \"agent_scratchpad\", \"chat_history\"]\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a9509461",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Hi Erica! How can I assist you today?\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"Hi Erica! How can I assist you today?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response = await agent_chain.arun(input=\"Hi I'm Erica.\")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "412cedd2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mYour name is Erica.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"Your name is Erica.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response = await agent_chain.arun(input=\"whats my name?\")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ebd7ae33-f67d-4378-ac79-9d91e0c8f53a",
|
||||
"id": "9af1a713",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -299,7 +416,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -76,6 +76,7 @@ class StructuredChatAgent(Agent):
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@ -85,12 +86,14 @@ class StructuredChatAgent(Agent):
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
@ -105,6 +108,7 @@ class StructuredChatAgent(Agent):
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
@ -116,6 +120,7 @@ class StructuredChatAgent(Agent):
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
|
@ -88,7 +88,7 @@ class LangChainTracer(BaseTracer):
|
||||
name=serialized.get("name"),
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"messages": messages_to_dict(batch) for batch in messages},
|
||||
inputs={"messages": [messages_to_dict(batch) for batch in messages]},
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
|
Loading…
Reference in New Issue
Block a user