From 50556f3b35f155a93656ea7baf7caa4e0b33ec72 Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Mon, 19 Jun 2023 10:00:40 +0800 Subject: [PATCH] support memory for functions (#6165) #### Before submitting Add memory support for `OpenAIFunctionsAgent` like `StructuredChatAgent`. #### Who can review? @hwchase17 --------- Co-authored-by: Harrison Chase --- .../how_to/add_memory_openai_functions.ipynb | 233 ++++++++++++++++++ .../use_toolkits_with_openai_functions.ipynb | 173 +++++++++++++ .../agents/openai_functions_agent/base.py | 74 +++++- langchain/prompts/chat.py | 20 +- tests/unit_tests/prompts/test_chat.py | 20 ++ 5 files changed, 507 insertions(+), 13 deletions(-) create mode 100644 docs/extras/modules/agents/how_to/add_memory_openai_functions.ipynb create mode 100644 docs/extras/modules/agents/how_to/use_toolkits_with_openai_functions.ipynb diff --git a/docs/extras/modules/agents/how_to/add_memory_openai_functions.ipynb b/docs/extras/modules/agents/how_to/add_memory_openai_functions.ipynb new file mode 100644 index 00000000..7da668e2 --- /dev/null +++ b/docs/extras/modules/agents/how_to/add_memory_openai_functions.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0c9954e9", + "metadata": {}, + "source": [ + "# Add Memory to OpenAI Functions Agent\n", + "\n", + "This notebook goes over how to add memory to OpenAI Functions agent." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ac594f26", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from langchain import (\n", + " LLMMathChain,\n", + " OpenAI,\n", + " SerpAPIWrapper,\n", + " SQLDatabase,\n", + " SQLDatabaseChain,\n", + ")\n", + "from langchain.agents import initialize_agent, Tool\n", + "from langchain.agents import AgentType\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1e7844e7", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n", + "search = SerpAPIWrapper()\n", + "llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)\n", + "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", + "tools = [\n", + " Tool(\n", + " name=\"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events. You should ask targeted questions\",\n", + " ),\n", + " Tool(\n", + " name=\"Calculator\",\n", + " func=llm_math_chain.run,\n", + " description=\"useful for when you need to answer questions about math\",\n", + " ),\n", + " Tool(\n", + " name=\"FooBar-DB\",\n", + " func=db_chain.run,\n", + " description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\",\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "54ca3b82", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import MessagesPlaceholder\n", + "from langchain.memory import ConversationBufferMemory\n", + "agent_kwargs = {\n", + " \"extra_prompt_messages\": [MessagesPlaceholder(variable_name=\"memory\")],\n", + "}\n", + "memory = ConversationBufferMemory(memory_key=\"memory\", return_messages=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "81af5658", + "metadata": {}, + "outputs": [], + "source": [ + "agent = initialize_agent(\n", + " tools, \n", + " llm, \n", + " agent=AgentType.OPENAI_FUNCTIONS, \n", + " verbose=True, \n", + " agent_kwargs=agent_kwargs, \n", + " memory=memory\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8ab08f43", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mHello! How can I assist you today?\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Hello! How can I assist you today?'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"hi\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "520a81f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mNice to meet you, Bob! How can I help you today?\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Nice to meet you, Bob! How can I help you today?'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"my name is bob\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bc4a69f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mYour name is Bob.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Your name is Bob.'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"whats my name\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40def1b7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/extras/modules/agents/how_to/use_toolkits_with_openai_functions.ipynb b/docs/extras/modules/agents/how_to/use_toolkits_with_openai_functions.ipynb new file mode 100644 index 00000000..f17aff18 --- /dev/null +++ b/docs/extras/modules/agents/how_to/use_toolkits_with_openai_functions.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af49b410", + "metadata": {}, + "source": [ + "# Use ToolKits with OpenAI Functions\n", + "\n", + "This notebook shows how to use the OpenAI functions agent with arbitrary toolkits." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "af6496bd", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import (\n", + " LLMMathChain,\n", + " OpenAI,\n", + " SerpAPIWrapper,\n", + " SQLDatabase,\n", + " SQLDatabaseChain,\n", + ")\n", + "from langchain.agents import initialize_agent, Tool\n", + "from langchain.agents import AgentType\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n", + "from langchain.schema import SystemMessage" + ] + }, + { + "cell_type": "markdown", + "id": "1b7ee35f", + "metadata": {}, + "source": [ + "Load the toolkit" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0423c32c", + "metadata": {}, + "outputs": [], + "source": [ + "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", + "toolkit = SQLDatabaseToolkit(llm=ChatOpenAI(), db=db)" + ] + }, + { + "cell_type": "markdown", + "id": "203fa80a", + "metadata": {}, + "source": [ + "Set a system message specific to that toolkit" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e4edb101", + "metadata": {}, + "outputs": [], + "source": [ + "agent_kwargs = {\n", + " \"system_message\": SystemMessage(content=\"You are an expert SQL data analyst.\")\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e0c67b60", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n", + "agent = initialize_agent(\n", + " toolkit.get_tools(), \n", + " llm, \n", + " agent=AgentType.OPENAI_FUNCTIONS, \n", + " verbose=True, \n", + " agent_kwargs=agent_kwargs,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "93619811", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT artist_name) AS num_artists FROM artists'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such table: artists\n", + "[SQL: SELECT COUNT(DISTINCT artist_name) AS num_artists FROM artists]\n", + "(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_list_tables` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mMediaType, Track, Playlist, sales_table, Customer, Genre, PlaylistTrack, Artist, Invoice, Album, InvoiceLine, Employee\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT artist_id) AS num_artists FROM Artist'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such column: artist_id\n", + "[SQL: SELECT COUNT(DISTINCT artist_id) AS num_artists FROM Artist]\n", + "(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT Name) AS num_artists FROM Artist'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 different artists in the database.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'There are 275 different artists in the database.'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"how many different artists are there?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34415bad", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/openai_functions_agent/base.py b/langchain/agents/openai_functions_agent/base.py index b8407d49..87108e5f 100644 --- a/langchain/agents/openai_functions_agent/base.py +++ b/langchain/agents/openai_functions_agent/base.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from json import JSONDecodeError from typing import Any, List, Optional, Sequence, Tuple, Union +from pydantic import root_validator + from langchain.agents import BaseSingleActionAgent from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -138,7 +140,16 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: class OpenAIFunctionsAgent(BaseSingleActionAgent): - """An Agent driven by OpenAIs function powered API.""" + """An Agent driven by OpenAIs function powered API. + + Args: + llm: This should be an instance of ChatOpenAI, specifically a model + that supports using `functions`. + tools: The tools this agent has access to. + prompt: The prompt for this agent, should support agent_scratchpad as one + of the variables. For an easy way to construct this prompt, use + `OpenAIFunctionsAgent.create_prompt(...)` + """ llm: BaseLanguageModel tools: Sequence[BaseTool] @@ -148,6 +159,22 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): """Get allowed tools.""" return list([t.name for t in self.tools]) + @root_validator + def validate_llm(cls, values: dict) -> dict: + if not isinstance(values["llm"], ChatOpenAI): + raise ValueError("Only supported with ChatOpenAI models.") + return values + + @root_validator + def validate_prompt(cls, values: dict) -> dict: + prompt: BasePromptTemplate = values["prompt"] + if "agent_scratchpad" not in prompt.input_variables: + raise ValueError( + "`agent_scratchpad` should be one of the variables in the prompt, " + f"got {prompt.input_variables}" + ) + return values + @property def input_keys(self) -> List[str]: """Get input keys. Input refers to user input here.""" @@ -164,17 +191,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. + Args: intermediate_steps: Steps the LLM has taken to date, along with observations **kwargs: User inputs. + Returns: Action specifying what tool to use. """ - user_input = kwargs["input"] agent_scratchpad = _format_intermediate_steps(intermediate_steps) - prompt = self.prompt.format_prompt( - input=user_input, agent_scratchpad=agent_scratchpad - ) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = self.llm.predict_messages( messages, functions=self.functions, callbacks=callbacks @@ -189,18 +219,21 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. + Args: intermediate_steps: Steps the LLM has taken to date, along with observations **kwargs: User inputs. + Returns: Action specifying what tool to use. """ - user_input = kwargs["input"] agent_scratchpad = _format_intermediate_steps(intermediate_steps) - prompt = self.prompt.format_prompt( - input=user_input, agent_scratchpad=agent_scratchpad - ) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = await self.llm.apredict_messages( messages, functions=self.functions @@ -214,7 +247,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), + extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, ) -> BasePromptTemplate: + """Create prompt for this agent. + + Args: + system_message: Message to use as the system message that will be the + first in the prompt. + extra_prompt_messages: Prompt messages that will be placed between the + system message and the new human input. + + Returns: + A prompt template to pass into this agent. + """ + _prompts = extra_prompt_messages or [] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] if system_message: messages = [system_message] @@ -223,12 +269,12 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): messages.extend( [ + *_prompts, HumanMessagePromptTemplate.from_template("{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), ] ) - input_variables = ["input", "agent_scratchpad"] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) + return ChatPromptTemplate(messages=messages) @classmethod def from_llm_and_tools( @@ -236,6 +282,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): llm: BaseLanguageModel, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, + extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), @@ -244,7 +291,10 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): """Construct an agent from an LLM and tools.""" if not isinstance(llm, ChatOpenAI): raise ValueError("Only supported with OpenAI models.") - prompt = cls.create_prompt(system_message=system_message) + prompt = cls.create_prompt( + extra_prompt_messages=extra_prompt_messages, + system_message=system_message, + ) return cls( llm=llm, prompt=prompt, diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 1edc8f6c..bc814158 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union -from pydantic import Field +from pydantic import Field, root_validator from langchain.load.serializable import Serializable from langchain.memory.buffer import get_buffer_string @@ -161,6 +161,24 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): input_variables: List[str] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] + @root_validator(pre=True) + def validate_input_variables(cls, values: dict) -> dict: + messages = values["messages"] + input_vars = set() + for message in messages: + if isinstance(message, BaseMessagePromptTemplate): + input_vars.update(message.input_variables) + if "input_variables" in values: + if input_vars != set(values["input_variables"]): + raise ValueError( + "Got mismatched input_variables. " + f"Expected: {input_vars}. " + f"Got: {values['input_variables']}" + ) + else: + values["input_variables"] = list(input_vars) + return values + @classmethod def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: prompt_template = PromptTemplate.from_template(template, **kwargs) diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index a9844a78..d1faac58 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import List +import pytest + from langchain.prompts import PromptTemplate from langchain.prompts.chat import ( AIMessagePromptTemplate, @@ -142,3 +144,21 @@ def test_chat_prompt_template_with_messages() -> None: ) prompt_value_messages = prompt_value.to_messages() assert prompt_value_messages[-1] == HumanMessage(content="foo") + + +def test_chat_invalid_input_variables_extra() -> None: + messages = [HumanMessage(content="foo")] + with pytest.raises(ValueError): + ChatPromptTemplate(messages=messages, input_variables=["foo"]) + + +def test_chat_invalid_input_variables_missing() -> None: + messages = [HumanMessagePromptTemplate.from_template("{foo}")] + with pytest.raises(ValueError): + ChatPromptTemplate(messages=messages, input_variables=[]) + + +def test_infer_variables() -> None: + messages = [HumanMessagePromptTemplate.from_template("{foo}")] + prompt = ChatPromptTemplate(messages=messages) + assert prompt.input_variables == ["foo"]