From e41f0b341c735a8fa715b56996c47168e8065035 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 13 Jun 2023 18:51:01 -0700 Subject: [PATCH] add functions agent (#6113) --- .../examples/openai_functions_agent.ipynb | 165 +++++++++++++ langchain/agents/agent_types.py | 1 + .../agents/openai_functions_agent/__init__.py | 0 .../agents/openai_functions_agent/base.py | 232 ++++++++++++++++++ langchain/agents/types.py | 2 + langchain/chat_models/openai.py | 7 + langchain/schema.py | 21 +- 7 files changed, 426 insertions(+), 2 deletions(-) create mode 100644 docs/modules/agents/agents/examples/openai_functions_agent.ipynb create mode 100644 langchain/agents/openai_functions_agent/__init__.py create mode 100644 langchain/agents/openai_functions_agent/base.py diff --git a/docs/modules/agents/agents/examples/openai_functions_agent.ipynb b/docs/modules/agents/agents/examples/openai_functions_agent.ipynb new file mode 100644 index 0000000000..b3592e6d35 --- /dev/null +++ b/docs/modules/agents/agents/examples/openai_functions_agent.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9502d5b0", + "metadata": {}, + "source": [ + "# OpenAI Functions Agent\n", + "\n", + "This notebook showcases using an agent that uses the OpenAI functions ability" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c0a83623", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n", + "from langchain.agents import initialize_agent, Tool\n", + "from langchain.agents import AgentType\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6fefaba2", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n", + "search = SerpAPIWrapper()\n", + "llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)\n", + "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n", + " ),\n", + " Tool(\n", + " name=\"Calculator\",\n", + " func=llm_math_chain.run,\n", + " description=\"useful for when you need to answer questions about math\"\n", + " ),\n", + " Tool(\n", + " name=\"FooBar-DB\",\n", + " func=db_chain.run,\n", + " description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9ff6cee9", + "metadata": {}, + "outputs": [], + "source": [ + "mrkl = initialize_agent(tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ba8e4cbe", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Search` with `{'query': 'Leo DiCaprio girlfriend'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3mAmidst his casual romance with Gigi, Leo allegedly entered a relationship with 19-year old model, Eden Polani, in February 2023.\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in on_chain_start callback: 'name'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Calculator` with `{'expression': '19^0.43'}`\n", + "\n", + "\n", + "\u001b[0m19^0.43\u001b[32;1m\u001b[1;3m```text\n", + "19**0.43\n", + "```\n", + "...numexpr.evaluate(\"19**0.43\")...\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m3.547023357958959\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 3.547023357958959\u001b[0m\u001b[32;1m\u001b[1;3mLeo DiCaprio's girlfriend is reportedly Eden Polani. Her current age raised to the power of 0.43 is approximately 3.55.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Leo DiCaprio's girlfriend is reportedly Eden Polani. Her current age raised to the power of 0.43 is approximately 3.55.\"" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrkl.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f5f6743", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/agent_types.py b/langchain/agents/agent_types.py index c952f2a67c..6d9aff1f13 100644 --- a/langchain/agents/agent_types.py +++ b/langchain/agents/agent_types.py @@ -11,3 +11,4 @@ class AgentType(str, Enum): STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION = ( "structured-chat-zero-shot-react-description" ) + OPENAI_FUNCTIONS = "openai-functions" diff --git a/langchain/agents/openai_functions_agent/__init__.py b/langchain/agents/openai_functions_agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/langchain/agents/openai_functions_agent/base.py b/langchain/agents/openai_functions_agent/base.py new file mode 100644 index 0000000000..2276f1a754 --- /dev/null +++ b/langchain/agents/openai_functions_agent/base.py @@ -0,0 +1,232 @@ +"""Module implements an agent that uses OpenAI's APIs function enabled API.""" +import json +from dataclasses import dataclass +from json import JSONDecodeError +from typing import Any, List, Optional, Sequence, Tuple, Union + +from langchain.agents import BaseSingleActionAgent +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + Callbacks, +) +from langchain.chat_models.openai import ChatOpenAI +from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + BaseMessage, + FunctionMessage, + SystemMessage, +) +from langchain.tools import BaseTool +from langchain.tools.convert_to_openai import format_tool_to_openai_function + + +@dataclass +class _FunctionsAgentAction(AgentAction): + message_log: List[BaseMessage] + + +def _convert_agent_action_to_messages(agent_action: AgentAction) -> List[BaseMessage]: + """Convert an agent action to a message. + + This code is used to reconstruct the original AI message from the agent action. + + Args: + agent_action: Agent action to convert. + + Returns: + AIMessage that corresponds to the original tool invocation. + """ + if not isinstance(agent_action, _FunctionsAgentAction): + raise ValueError("This agent type only works with _FunctionsAgentAction") + return agent_action.message_log + + +def _create_function_message( + agent_action: AgentAction, observation: str +) -> FunctionMessage: + """Convert agent action and observation into a function message. + Args: + agent_action: the tool invocation request from the agent + observation: the result of the tool invocation + Returns: + FunctionMessage that corresponds to the original tool invocation + """ + if not isinstance(observation, str): + content = json.dumps(observation) + else: + content = observation + return FunctionMessage( + name=agent_action.tool, + content=content, + ) + + +def _format_intermediate_steps( + intermediate_steps: List[Tuple[AgentAction, str]], +) -> List[BaseMessage]: + """Format intermediate steps. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + Returns: + list of messages to send to the LLM for the next prediction + """ + messages = [] + + for intermediate_step in intermediate_steps: + agent_action, observation = intermediate_step + messages.extend(_convert_agent_action_to_messages(agent_action)) + messages.append(_create_function_message(agent_action, observation)) + + return messages + + +def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: + """Parse an AI message.""" + if not isinstance(message, AIMessage): + raise TypeError(f"Expected an AI message got {type(message)}") + + function_call = message.additional_kwargs.get("function_call", {}) + + if function_call: + function_call = message.additional_kwargs["function_call"] + function_name = function_call["name"] + try: + _tool_input = json.loads(function_call["arguments"]) + except JSONDecodeError: + raise ValueError( + f"Could not parse tool input: {function_call} because " + f"the `arguments` is not valid JSON." + ) + + # HACK HACK HACK: + # The code that encodes tool input into Open AI uses a special variable + # name called `__arg1` to handle old style tools that do not expose a + # schema and expect a single string argument as an input. + # We unpack the argument here if it exists. + # Open AI does not support passing in a JSON array as an argument. + if "__arg1" in _tool_input: + tool_input = _tool_input["__arg1"] + else: + tool_input = _tool_input + + content_msg = "responded: {content}\n" if message.content else "\n" + + return _FunctionsAgentAction( + tool=function_name, + tool_input=tool_input, + log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n", + message_log=[message], + ) + + return AgentFinish(return_values={"output": message.content}, log=message.content) + + +class OpenAIFunctionsAgent(BaseSingleActionAgent): + """An Agent driven by OpenAIs function powered API.""" + + llm: BaseLanguageModel + tools: Sequence[BaseTool] + prompt: BasePromptTemplate + + def get_allowed_tools(self) -> List[str]: + """Get allowed tools.""" + return list([t.name for t in self.tools]) + + @property + def input_keys(self) -> List[str]: + """Get input keys. Input refers to user input here.""" + return ["input"] + + @property + def functions(self) -> List[dict]: + return [dict(format_tool_to_openai_function(t)) for t in self.tools] + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + Returns: + Action specifying what tool to use. + """ + user_input = kwargs["input"] + agent_scratchpad = _format_intermediate_steps(intermediate_steps) + prompt = self.prompt.format_prompt( + input=user_input, agent_scratchpad=agent_scratchpad + ) + messages = prompt.to_messages() + predicted_message = self.llm.predict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + agent_decision = _parse_ai_message(predicted_message) + return agent_decision + + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + Returns: + Action specifying what tool to use. + """ + user_input = kwargs["input"] + agent_scratchpad = _format_intermediate_steps(intermediate_steps) + prompt = self.prompt.format_prompt( + input=user_input, agent_scratchpad=agent_scratchpad + ) + messages = prompt.to_messages() + predicted_message = await self.llm.apredict_messages( + messages, functions=self.functions + ) + agent_decision = _parse_ai_message(predicted_message) + return agent_decision + + @classmethod + def create_prompt(cls) -> BasePromptTemplate: + messages = [ + SystemMessage(content="You are a helpful AI assistant."), + HumanMessagePromptTemplate.from_template("{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + input_variables = ["input", "agent_scratchpad"] + return ChatPromptTemplate(input_variables=input_variables, messages=messages) + + @classmethod + def from_llm_and_tools( + cls, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, + ) -> BaseSingleActionAgent: + """Construct an agent from an LLM and tools.""" + if not isinstance(llm, ChatOpenAI): + raise ValueError("Only supported with OpenAI models.") + prompt = cls.create_prompt() + return cls( + llm=llm, + prompt=prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) diff --git a/langchain/agents/types.py b/langchain/agents/types.py index 20d630d8ab..b9e73c5fcb 100644 --- a/langchain/agents/types.py +++ b/langchain/agents/types.py @@ -6,6 +6,7 @@ from langchain.agents.chat.base import ChatAgent from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.conversational_chat.base import ConversationalChatAgent from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.structured_chat.base import StructuredChatAgent @@ -18,4 +19,5 @@ AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = { AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION: StructuredChatAgent, + AgentType.OPENAI_FUNCTIONS: OpenAIFunctionsAgent, } diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index db21ff8bc9..6fe9cc9048 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -35,6 +35,7 @@ from langchain.schema import ( ChatGeneration, ChatMessage, ChatResult, + FunctionMessage, HumanMessage, SystemMessage, ) @@ -120,6 +121,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["function_call"] = message.additional_kwargs["function_call"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: diff --git a/langchain/schema.py b/langchain/schema.py index b2f76e705d..ef675f6fc6 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import ( Any, Dict, @@ -34,15 +35,22 @@ def get_buffer_string( role = ai_prefix elif isinstance(m, SystemMessage): role = "System" + elif isinstance(m, FunctionMessage): + role = "Function" elif isinstance(m, ChatMessage): role = m.role else: raise ValueError(f"Got unsupported message type: {m}") - string_messages.append(f"{role}: {m.content}") + message = f"{role}: {m.content}" + if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: + message += f"{m.additional_kwargs['function_call']}" + string_messages.append(message) + return "\n".join(string_messages) -class AgentAction(NamedTuple): +@dataclass +class AgentAction: """Agent's action to take.""" tool: str @@ -112,6 +120,15 @@ class SystemMessage(BaseMessage): return "system" +class FunctionMessage(BaseMessage): + name: str + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "function" + + class ChatMessage(BaseMessage): """Type of message with arbitrary speaker."""