mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Harrison/structured chat mem (#4652)
Co-authored-by: d 3 n 7 <29033313+d3n7@users.noreply.github.com>
This commit is contained in:
parent
9ba3a798c4
commit
873b0c7eb6
@ -16,7 +16,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "ccc8ff98",
|
"id": "ccc8ff98",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -98,7 +98,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6",
|
"id": "4f4aa234-9746-47d8-bec7-d76081ac3ef6",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -111,9 +111,17 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||||
|
"```\n",
|
||||||
|
"{\n",
|
||||||
|
" \"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"Hello Erica, how can I assist you today?\"\n",
|
||||||
|
"}\n",
|
||||||
|
"```\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
"Hi Erica! How can I assist you today?\n"
|
"Hello Erica, how can I assist you today?\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -274,10 +282,119 @@
|
|||||||
"print(response)"
|
"print(response)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "42473442",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Adding in memory\n",
|
||||||
|
"\n",
|
||||||
|
"Here is how you add in memory to this agent"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "b5a0dd2a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts import MessagesPlaceholder\n",
|
||||||
|
"from langchain.memory import ConversationBufferMemory"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "91b9288f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat_history = MessagesPlaceholder(variable_name=\"chat_history\")\n",
|
||||||
|
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "dba9e0d9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent_chain = initialize_agent(\n",
|
||||||
|
" tools, \n",
|
||||||
|
" llm, \n",
|
||||||
|
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, \n",
|
||||||
|
" verbose=True, \n",
|
||||||
|
" memory=memory, \n",
|
||||||
|
" agent_kwargs = {\n",
|
||||||
|
" \"memory_prompts\": [chat_history],\n",
|
||||||
|
" \"input_variables\": [\"input\", \"agent_scratchpad\", \"chat_history\"]\n",
|
||||||
|
" }\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "a9509461",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||||
|
"```\n",
|
||||||
|
"{\n",
|
||||||
|
" \"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"Hi Erica! How can I assist you today?\"\n",
|
||||||
|
"}\n",
|
||||||
|
"```\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"Hi Erica! How can I assist you today?\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response = await agent_chain.arun(input=\"Hi I'm Erica.\")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "412cedd2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mYour name is Erica.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"Your name is Erica.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response = await agent_chain.arun(input=\"whats my name?\")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "ebd7ae33-f67d-4378-ac79-9d91e0c8f53a",
|
"id": "9af1a713",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@ -299,7 +416,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.2"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -76,6 +76,7 @@ class StructuredChatAgent(Agent):
|
|||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
tool_strings = []
|
tool_strings = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@ -85,12 +86,14 @@ class StructuredChatAgent(Agent):
|
|||||||
tool_names = ", ".join([tool.name for tool in tools])
|
tool_names = ", ".join([tool.name for tool in tools])
|
||||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||||
messages = [
|
|
||||||
SystemMessagePromptTemplate.from_template(template),
|
|
||||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
|
||||||
]
|
|
||||||
if input_variables is None:
|
if input_variables is None:
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
_memory_prompts = memory_prompts or []
|
||||||
|
messages = [
|
||||||
|
SystemMessagePromptTemplate.from_template(template),
|
||||||
|
*_memory_prompts,
|
||||||
|
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||||
|
]
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -105,6 +108,7 @@ class StructuredChatAgent(Agent):
|
|||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
@ -116,6 +120,7 @@ class StructuredChatAgent(Agent):
|
|||||||
human_message_template=human_message_template,
|
human_message_template=human_message_template,
|
||||||
format_instructions=format_instructions,
|
format_instructions=format_instructions,
|
||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
|
memory_prompts=memory_prompts,
|
||||||
)
|
)
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@ -88,7 +88,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
name=serialized.get("name"),
|
name=serialized.get("name"),
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs={"messages": messages_to_dict(batch) for batch in messages},
|
inputs={"messages": [messages_to_dict(batch) for batch in messages]},
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
start_time=datetime.utcnow(),
|
start_time=datetime.utcnow(),
|
||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
|
Loading…
Reference in New Issue
Block a user