From 54f9e4287f287aa7c29d8dd39019f1350df0d905 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Mon, 30 Jan 2023 17:54:09 -0500 Subject: [PATCH] Pass kwargs from initialize_agent into agent classmethod (#799) # Problem I noticed that in order to change the prefix of the prompt in the `zero-shot-react-description` agent we had to dig around to subset strings deep into the agent's attributes. It requires the user to inspect a long chain of attributes and classes. `initialize_agent -> AgentExecutor -> Agent -> LLMChain -> Prompt from Agent.create_prompt` ``` python agent = initialize_agent( tools=tools, llm=fake_llm, agent="zero-shot-react-description" ) prompt_str = agent.agent.llm_chain.prompt.template new_prompt_str = change_prefix(prompt_str) agent.agent.llm_chain.prompt.template = new_prompt_str ``` # Implemented Solution `initialize_agent` accepts `**kwargs` but passes it to `AgentExecutor` but not `ZeroShotAgent`, by simply giving the kwargs to the agent class methods we can support changing the prefix and suffix for one agent while allowing future agents to take advantage of `initialize_agent`. ``` agent = initialize_agent( tools=tools, llm=fake_llm, agent="zero-shot-react-description", agent_kwargs={"prefix": prefix, "suffix": suffix} ) ``` To be fair, this was before finding docs around custom agents here: https://langchain.readthedocs.io/en/latest/modules/agents/examples/custom_agent.html?highlight=custom%20#custom-llmchain but i find that my use case just needed to change the prefix a little. # Changes * Pass kwargs to Agent class method * Added a test to check suffix and prefix --------- Co-authored-by: Jason Liu --- langchain/agents/initialize.py | 4 +++- tests/unit_tests/agents/test_agent.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index b0b9a78c80..872400316b 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -14,6 +14,7 @@ def initialize_agent( agent: Optional[str] = None, callback_manager: Optional[BaseCallbackManager] = None, agent_path: Optional[str] = None, + agent_kwargs: Optional[dict] = None, **kwargs: Any, ) -> AgentExecutor: """Load agent given tools and LLM. @@ -50,8 +51,9 @@ def initialize_agent( f"Valid types are: {AGENT_TO_CLASS.keys()}." ) agent_cls = AGENT_TO_CLASS[agent] + agent_kwargs = agent_kwargs or {} agent_obj = agent_cls.from_llm_and_tools( - llm, tools, callback_manager=callback_manager + llm, tools, callback_manager=callback_manager, **agent_kwargs ) elif agent_path is not None: agent_obj = load_agent( diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index b1a70ffb85..68fa1505af 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -196,3 +196,29 @@ def test_agent_tool_return_direct() -> None: output = agent.run("when was langchain made") assert output == "misalignment" + + +def test_agent_with_new_prefix_suffix() -> None: + """Test agent initilization kwargs with new prefix and suffix.""" + fake_llm = FakeListLLM( + responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"] + ) + tools = [ + Tool("Search", lambda x: x, "Useful for searching", return_direct=True), + ] + prefix = "FooBarBaz" + + suffix = "Begin now!\nInput: {input}\nThought: {agent_scratchpad}" + + agent = initialize_agent( + tools=tools, + llm=fake_llm, + agent="zero-shot-react-description", + agent_kwargs={"prefix": prefix, "suffix": suffix}, + ) + + # avoids "BasePromptTemplate" has no attribute "template" error + assert hasattr(agent.agent.llm_chain.prompt, "template") + prompt_str = agent.agent.llm_chain.prompt.template + assert prompt_str.startswith(prefix), "Prompt does not start with prefix" + assert prompt_str.endswith(suffix), "Prompt does not end with suffix"