add functions agent (#6113)

This commit is contained in:
Harrison Chase 2023-06-13 18:51:01 -07:00 committed by GitHub
parent b3b155d488
commit e41f0b341c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 426 additions and 2 deletions

View File

@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9502d5b0",
"metadata": {},
"source": [
"# OpenAI Functions Agent\n",
"\n",
"This notebook showcases using an agent that uses the OpenAI functions ability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c0a83623",
"metadata": {},
"outputs": [],
"source": [
"from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6fefaba2",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n",
" ),\n",
" Tool(\n",
" name=\"FooBar-DB\",\n",
" func=db_chain.run,\n",
" description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9ff6cee9",
"metadata": {},
"outputs": [],
"source": [
"mrkl = initialize_agent(tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ba8e4cbe",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Search` with `{'query': 'Leo DiCaprio girlfriend'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mAmidst his casual romance with Gigi, Leo allegedly entered a relationship with 19-year old model, Eden Polani, in February 2023.\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Error in on_chain_start callback: 'name'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Calculator` with `{'expression': '19^0.43'}`\n",
"\n",
"\n",
"\u001b[0m19^0.43\u001b[32;1m\u001b[1;3m```text\n",
"19**0.43\n",
"```\n",
"...numexpr.evaluate(\"19**0.43\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.547023357958959\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 3.547023357958959\u001b[0m\u001b[32;1m\u001b[1;3mLeo DiCaprio's girlfriend is reportedly Eden Polani. Her current age raised to the power of 0.43 is approximately 3.55.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"Leo DiCaprio's girlfriend is reportedly Eden Polani. Her current age raised to the power of 0.43 is approximately 3.55.\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f5f6743",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -11,3 +11,4 @@ class AgentType(str, Enum):
STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION = ( STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION = (
"structured-chat-zero-shot-react-description" "structured-chat-zero-shot-react-description"
) )
OPENAI_FUNCTIONS = "openai-functions"

View File

@ -0,0 +1,232 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
import json
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
Callbacks,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
FunctionMessage,
SystemMessage,
)
from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
@dataclass
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if not isinstance(agent_action, _FunctionsAgentAction):
raise ValueError("This agent type only works with _FunctionsAgentAction")
return agent_action.message_log
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
content = json.dumps(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action))
messages.append(_create_function_message(agent_action, observation))
return messages
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise ValueError(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n"
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
message_log=[message],
)
return AgentFinish(return_values={"output": message.content}, log=message.content)
class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""An Agent driven by OpenAIs function powered API."""
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
return list([t.name for t in self.tools])
@property
def input_keys(self) -> List[str]:
"""Get input keys. Input refers to user input here."""
return ["input"]
@property
def functions(self) -> List[dict]:
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
user_input = kwargs["input"]
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
prompt = self.prompt.format_prompt(
input=user_input, agent_scratchpad=agent_scratchpad
)
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
user_input = kwargs["input"]
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
prompt = self.prompt.format_prompt(
input=user_input, agent_scratchpad=agent_scratchpad
)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def create_prompt(cls) -> BasePromptTemplate:
messages = [
SystemMessage(content="You are a helpful AI assistant."),
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools."""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with OpenAI models.")
prompt = cls.create_prompt()
return cls(
llm=llm,
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@ -6,6 +6,7 @@ from langchain.agents.chat.base import ChatAgent
from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.conversational.base import ConversationalAgent
from langchain.agents.conversational_chat.base import ConversationalChatAgent from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.structured_chat.base import StructuredChatAgent from langchain.agents.structured_chat.base import StructuredChatAgent
@ -18,4 +19,5 @@ AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = {
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent,
AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent,
AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION: StructuredChatAgent, AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION: StructuredChatAgent,
AgentType.OPENAI_FUNCTIONS: OpenAIFunctionsAgent,
} }

View File

@ -35,6 +35,7 @@ from langchain.schema import (
ChatGeneration, ChatGeneration,
ChatMessage, ChatMessage,
ChatResult, ChatResult,
FunctionMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
@ -120,6 +121,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["function_call"] = message.additional_kwargs["function_call"] message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs: if "name" in message.additional_kwargs:

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
Dict, Dict,
@ -34,15 +35,22 @@ def get_buffer_string(
role = ai_prefix role = ai_prefix
elif isinstance(m, SystemMessage): elif isinstance(m, SystemMessage):
role = "System" role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage): elif isinstance(m, ChatMessage):
role = m.role role = m.role
else: else:
raise ValueError(f"Got unsupported message type: {m}") raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}") message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages) return "\n".join(string_messages)
class AgentAction(NamedTuple): @dataclass
class AgentAction:
"""Agent's action to take.""" """Agent's action to take."""
tool: str tool: str
@ -112,6 +120,15 @@ class SystemMessage(BaseMessage):
return "system" return "system"
class FunctionMessage(BaseMessage):
name: str
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class ChatMessage(BaseMessage): class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker.""" """Type of message with arbitrary speaker."""