agent refactor (#2801)

fix_agent_callbacks
Harrison Chase 1 year ago committed by GitHub
parent abfca72c0b
commit b2bc5ef56a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,265 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f1390152",
"metadata": {},
"source": [
"# MRKL Chat\n",
"\n",
"This notebook showcases using an agent to replicate the MRKL chain using an agent optimized for chat models."
]
},
{
"cell_type": "markdown",
"id": "39ea3638",
"metadata": {},
"source": [
"This uses the example Chinook database.\n",
"To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ac561cc4",
"metadata": {},
"outputs": [],
"source": [
"from langchain import OpenAI, LLMMathChain, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "07e96d99",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0)\n",
"llm1 = OpenAI(temperature=0)\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm1, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm1, database=db, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n",
" ),\n",
" Tool(\n",
" name=\"FooBar DB\",\n",
" func=db_chain.run,\n",
" description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a069c4b6",
"metadata": {},
"outputs": [],
"source": [
"mrkl = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e603cd7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to use the Search tool to find out who is Leo DiCaprio's girlfriend and then use the Calculator tool to raise her current age to the 0.43 power.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Search\",\n",
" \"action_input\": \"Who is Leo DiCaprio's girlfriend?\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation:\u001b[36;1m\u001b[1;3mLeo briefly dated Blake Lively when she was 23 and he was 36, then jumped to Erin Heatherton, who was 22, and was with Toni Garrin from 2013 to ...\u001b[0m\u001b[32;1m\u001b[1;3mI need to search again to find out who is Leo DiCaprio's current girlfriend. \n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Search\",\n",
" \"action_input\": \"Who is Leo DiCaprio's current girlfriend?\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation:\u001b[36;1m\u001b[1;3mLeonardo DiCaprio started dating 20-year-old model Camila Morrone in late 2017. They're still together today - and she's 22 now.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to use the Calculator tool to raise Camila Morrone's current age to the 0.43 power.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Calculator\",\n",
" \"action_input\": \"22^0.43\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"22^0.43\u001b[32;1m\u001b[1;3m\n",
"```python\n",
"import math\n",
"print(math.pow(22, 0.43))\n",
"```\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation:\u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n",
"\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"\n",
"Final Answer: 3.777824273683966\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'3.777824273683966'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a5c07010",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mQuestion: What is the full name of the artist who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n",
"Thought: I should use the Search tool to find the answer to the first part of the question and then use the FooBar DB tool to find the answer to the second part.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Search\",\n",
" \"action_input\": \"Who recently released an album called 'The Storm Before the Calm'\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"Observation:\u001b[36;1m\u001b[1;3mAlanis Morissette\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know the artist's name, I can use the FooBar DB tool to find out if they are in the database and what albums of theirs are in it.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"FooBar DB\",\n",
" \"action_input\": \"What albums does Alanis Morissette have in the database?\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What albums does Alanis Morissette have in the database? \n",
"SQLQuery:"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/sqlalchemy/sql/sqltypes.py:726: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
" util.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m SELECT Title FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Alanis Morissette') LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Alanis Morissette has the album 'Jagged Little Pill' in the database.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation:\u001b[38;5;200m\u001b[1;3m Alanis Morissette has the album 'Jagged Little Pill' in the database.\u001b[0m\u001b[32;1m\u001b[1;3mI have found the answer to both parts of the question.\n",
"Final Answer: The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette. The album 'Jagged Little Pill' is in the FooBar database.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette. The album 'Jagged Little Pill' is in the FooBar database.\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"What is the full name of the artist who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af016a70",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -99,6 +99,16 @@ class BaseSingleActionAgent(BaseModel):
f"Got unsupported early_stopping_method `{early_stopping_method}`"
)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
raise NotImplementedError
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""

@ -8,3 +8,4 @@ class AgentType(str, Enum):
CONVERSATIONAL_REACT_DESCRIPTION = "conversational-react-description"
CHAT_ZERO_SHOT_REACT_DESCRIPTION = "chat-zero-shot-react-description"
CHAT_CONVERSATIONAL_REACT_DESCRIPTION = "chat-conversational-react-description"
CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2 = "chat-zero-shot-react-description-002"

@ -0,0 +1,53 @@
from typing import Any, List, Optional, Sequence
from langchain.agents.agent import AgentOutputParser, LLMSingleActionAgent
from langchain.agents.chat_v2.prompt import (
FORMAT_INSTRUCTIONS,
PREFIX,
SUFFIX,
ChatOutputParser,
create_prompt,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema import BaseLanguageModel
from langchain.tools import BaseTool
class ChatAgentV2(LLMSingleActionAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
output_parser: Optional[AgentOutputParser] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMSingleActionAgent:
"""Construct an agent from an LLM and tools."""
_stop = stop or ["Observation:"]
_output_parser = output_parser or ChatOutputParser()
prompt = create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
return cls(
llm_chain=llm_chain, output_parser=_output_parser, stop=_stop, **kwargs
)
@property
def _agent_type(self) -> str:
raise ValueError

@ -0,0 +1,84 @@
# flake8: noqa
import json
from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.agents.schema import AgentScratchPadChatPromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction, AgentFinish
from langchain.tools.base import BaseTool
from typing import Sequence, Optional, List, Union
from langchain.agents.agent import AgentOutputParser
PREFIX = """Answer the following questions as best you can. You have access to the following tools:"""
FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
The only values that should be in the "action" field are: {tool_names}
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
ALWAYS use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action:
```
$JSON_BLOB
```
Observation: the result of the action
... (this Thought/Action/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question"""
SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding."""
def create_prompt(
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> BasePromptTemplate:
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
messages = [
SystemMessagePromptTemplate.from_template(template),
HumanMessagePromptTemplate.from_template("{input}\n\n{agent_scratchpad}"),
]
if input_variables is None:
input_variables = ["input", "intermediate_steps"]
return AgentScratchPadChatPromptTemplate(
input_variables=input_variables, messages=messages
)
class ChatOutputParser(AgentOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "Final Answer:" in text:
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": text.split("Final Answer:")[-1].strip()},
log=text,
)
try:
_, action, _ = text.split("```")
response = json.loads(action.strip())
agent_action = AgentAction(
tool=response["action"], tool_input=response["action_input"], log=text
)
return agent_action
except Exception:
raise ValueError(f"Could not parse LLM output: {text}")

@ -5,9 +5,10 @@ from typing import Any, List, Optional, Union
import yaml
from langchain.agents.agent import Agent
from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.agent_types import AgentType
from langchain.agents.chat.base import ChatAgent
from langchain.agents.chat_v2.base import ChatAgentV2
from langchain.agents.conversational.base import ConversationalAgent
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.agents.mrkl.base import ZeroShotAgent
@ -25,6 +26,7 @@ AGENT_TO_CLASS = {
AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent,
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent,
AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent,
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2: ChatAgentV2,
}
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
@ -32,7 +34,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
def _load_agent_from_tools(
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
) -> Agent:
) -> BaseSingleActionAgent:
config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS:
raise ValueError(f"Loading {config_type} agent not supported")
@ -49,7 +51,7 @@ def load_agent_from_config(
llm: Optional[BaseLLM] = None,
tools: Optional[List[Tool]] = None,
**kwargs: Any,
) -> Agent:
) -> BaseSingleActionAgent:
"""Load agent from Config Dict."""
if "_type" not in config:
raise ValueError("Must specify an agent Type in config")
@ -82,7 +84,7 @@ def load_agent_from_config(
return agent_cls(**combined_config) # type: ignore
def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
def load_agent(path: Union[str, Path], **kwargs: Any) -> BaseSingleActionAgent:
"""Unified method for loading a agent from LangChainHub or local fs."""
if hub_result := try_load_from_hub(
path, _load_agent_from_file, "agents", {"json", "yaml"}
@ -92,7 +94,9 @@ def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
return _load_agent_from_file(path, **kwargs)
def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent:
def _load_agent_from_file(
file: Union[str, Path], **kwargs: Any
) -> BaseSingleActionAgent:
"""Load agent from file."""
# Convert file to Path object.
if isinstance(file, str):

@ -0,0 +1,28 @@
from typing import Any, Dict, List, Tuple
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import AgentAction
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
def _construct_agent_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
if len(intermediate_steps) == 0:
return ""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{thoughts}"
)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
intermediate_steps = kwargs.pop("intermediate_steps")
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
intermediate_steps
)
return kwargs
Loading…
Cancel
Save