Pass kwargs from initialize_agent into agent classmethod (#799)

# Problem
I noticed that in order to change the prefix of the prompt in the
`zero-shot-react-description` agent
we had to dig around to subset strings deep into the agent's attributes.
It requires the user to inspect a long chain of attributes and classes.

`initialize_agent -> AgentExecutor -> Agent -> LLMChain -> Prompt from
Agent.create_prompt`

``` python
agent = initialize_agent(
    tools=tools,
    llm=fake_llm,
    agent="zero-shot-react-description"
)
prompt_str = agent.agent.llm_chain.prompt.template
new_prompt_str = change_prefix(prompt_str)
agent.agent.llm_chain.prompt.template = new_prompt_str
```

# Implemented Solution

`initialize_agent` accepts `**kwargs` but passes it to `AgentExecutor`
but not `ZeroShotAgent`, by simply giving the kwargs to the agent class
methods we can support changing the prefix and suffix for one agent
while allowing future agents to take advantage of `initialize_agent`.


```
agent = initialize_agent(
    tools=tools,
    llm=fake_llm,
    agent="zero-shot-react-description",
    agent_kwargs={"prefix": prefix, "suffix": suffix}
)
```

To be fair, this was before finding docs around custom agents here:
https://langchain.readthedocs.io/en/latest/modules/agents/examples/custom_agent.html?highlight=custom%20#custom-llmchain
but i find that my use case just needed to change the prefix a little.


# Changes

* Pass kwargs to Agent class method
* Added a test to check suffix and prefix

---------

Co-authored-by: Jason Liu <jason@jxnl.coA>
ankush/async-llmchain
Jason Liu 1 year ago committed by GitHub
parent c331009440
commit 54f9e4287f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@ def initialize_agent(
agent: Optional[str] = None,
callback_manager: Optional[BaseCallbackManager] = None,
agent_path: Optional[str] = None,
agent_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Load agent given tools and LLM.
@ -50,8 +51,9 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
agent_cls = AGENT_TO_CLASS[agent]
agent_kwargs = agent_kwargs or {}
agent_obj = agent_cls.from_llm_and_tools(
llm, tools, callback_manager=callback_manager
llm, tools, callback_manager=callback_manager, **agent_kwargs
)
elif agent_path is not None:
agent_obj = load_agent(

@ -196,3 +196,29 @@ def test_agent_tool_return_direct() -> None:
output = agent.run("when was langchain made")
assert output == "misalignment"
def test_agent_with_new_prefix_suffix() -> None:
"""Test agent initilization kwargs with new prefix and suffix."""
fake_llm = FakeListLLM(
responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"]
)
tools = [
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
]
prefix = "FooBarBaz"
suffix = "Begin now!\nInput: {input}\nThought: {agent_scratchpad}"
agent = initialize_agent(
tools=tools,
llm=fake_llm,
agent="zero-shot-react-description",
agent_kwargs={"prefix": prefix, "suffix": suffix},
)
# avoids "BasePromptTemplate" has no attribute "template" error
assert hasattr(agent.agent.llm_chain.prompt, "template")
prompt_str = agent.agent.llm_chain.prompt.template
assert prompt_str.startswith(prefix), "Prompt does not start with prefix"
assert prompt_str.endswith(suffix), "Prompt does not end with suffix"

Loading…
Cancel
Save