support memory for functions (#6165)

#### Before submitting
Add memory support for `OpenAIFunctionsAgent` like
`StructuredChatAgent`.


#### Who can review?
 @hwchase17

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/6321/head^2
Fei Wang 1 year ago committed by GitHub
parent b2b9ded12f
commit 50556f3b35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,233 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0c9954e9",
"metadata": {},
"source": [
"# Add Memory to OpenAI Functions Agent\n",
"\n",
"This notebook goes over how to add memory to OpenAI Functions agent."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ac594f26",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain import (\n",
" LLMMathChain,\n",
" OpenAI,\n",
" SerpAPIWrapper,\n",
" SQLDatabase,\n",
" SQLDatabaseChain,\n",
")\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1e7844e7",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name=\"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events. You should ask targeted questions\",\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\",\n",
" ),\n",
" Tool(\n",
" name=\"FooBar-DB\",\n",
" func=db_chain.run,\n",
" description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\",\n",
" ),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "54ca3b82",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import MessagesPlaceholder\n",
"from langchain.memory import ConversationBufferMemory\n",
"agent_kwargs = {\n",
" \"extra_prompt_messages\": [MessagesPlaceholder(variable_name=\"memory\")],\n",
"}\n",
"memory = ConversationBufferMemory(memory_key=\"memory\", return_messages=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "81af5658",
"metadata": {},
"outputs": [],
"source": [
"agent = initialize_agent(\n",
" tools, \n",
" llm, \n",
" agent=AgentType.OPENAI_FUNCTIONS, \n",
" verbose=True, \n",
" agent_kwargs=agent_kwargs, \n",
" memory=memory\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8ab08f43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mHello! How can I assist you today?\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello! How can I assist you today?'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"hi\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "520a81f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mNice to meet you, Bob! How can I help you today?\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Nice to meet you, Bob! How can I help you today?'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"my name is bob\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8bc4a69f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mYour name is Bob.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Your name is Bob.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"whats my name\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40def1b7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,173 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "af49b410",
"metadata": {},
"source": [
"# Use ToolKits with OpenAI Functions\n",
"\n",
"This notebook shows how to use the OpenAI functions agent with arbitrary toolkits."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af6496bd",
"metadata": {},
"outputs": [],
"source": [
"from langchain import (\n",
" LLMMathChain,\n",
" OpenAI,\n",
" SerpAPIWrapper,\n",
" SQLDatabase,\n",
" SQLDatabaseChain,\n",
")\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
"from langchain.schema import SystemMessage"
]
},
{
"cell_type": "markdown",
"id": "1b7ee35f",
"metadata": {},
"source": [
"Load the toolkit"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0423c32c",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
"toolkit = SQLDatabaseToolkit(llm=ChatOpenAI(), db=db)"
]
},
{
"cell_type": "markdown",
"id": "203fa80a",
"metadata": {},
"source": [
"Set a system message specific to that toolkit"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e4edb101",
"metadata": {},
"outputs": [],
"source": [
"agent_kwargs = {\n",
" \"system_message\": SystemMessage(content=\"You are an expert SQL data analyst.\")\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e0c67b60",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n",
"agent = initialize_agent(\n",
" toolkit.get_tools(), \n",
" llm, \n",
" agent=AgentType.OPENAI_FUNCTIONS, \n",
" verbose=True, \n",
" agent_kwargs=agent_kwargs,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "93619811",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT artist_name) AS num_artists FROM artists'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such table: artists\n",
"[SQL: SELECT COUNT(DISTINCT artist_name) AS num_artists FROM artists]\n",
"(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mMediaType, Track, Playlist, sales_table, Customer, Genre, PlaylistTrack, Artist, Invoice, Album, InvoiceLine, Employee\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT artist_id) AS num_artists FROM Artist'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such column: artist_id\n",
"[SQL: SELECT COUNT(DISTINCT artist_id) AS num_artists FROM Artist]\n",
"(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT Name) AS num_artists FROM Artist'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 different artists in the database.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'There are 275 different artists in the database.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"how many different artists are there?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34415bad",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -4,6 +4,8 @@ from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
from pydantic import root_validator
from langchain.agents import BaseSingleActionAgent from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -138,7 +140,16 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
class OpenAIFunctionsAgent(BaseSingleActionAgent): class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""An Agent driven by OpenAIs function powered API.""" """An Agent driven by OpenAIs function powered API.
Args:
llm: This should be an instance of ChatOpenAI, specifically a model
that supports using `functions`.
tools: The tools this agent has access to.
prompt: The prompt for this agent, should support agent_scratchpad as one
of the variables. For an easy way to construct this prompt, use
`OpenAIFunctionsAgent.create_prompt(...)`
"""
llm: BaseLanguageModel llm: BaseLanguageModel
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
@ -148,6 +159,22 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""Get allowed tools.""" """Get allowed tools."""
return list([t.name for t in self.tools]) return list([t.name for t in self.tools])
@root_validator
def validate_llm(cls, values: dict) -> dict:
if not isinstance(values["llm"], ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
return values
@root_validator
def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"]
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be one of the variables in the prompt, "
f"got {prompt.input_variables}"
)
return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Get input keys. Input refers to user input here.""" """Get input keys. Input refers to user input here."""
@ -164,17 +191,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, along with observations intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs. **kwargs: User inputs.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
user_input = kwargs["input"]
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = _format_intermediate_steps(intermediate_steps)
prompt = self.prompt.format_prompt( selected_inputs = {
input=user_input, agent_scratchpad=agent_scratchpad k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
) }
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = self.llm.predict_messages( predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks
@ -189,18 +219,21 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations
**kwargs: User inputs. **kwargs: User inputs.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
user_input = kwargs["input"]
agent_scratchpad = _format_intermediate_steps(intermediate_steps) agent_scratchpad = _format_intermediate_steps(intermediate_steps)
prompt = self.prompt.format_prompt( selected_inputs = {
input=user_input, agent_scratchpad=agent_scratchpad k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
) }
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages( predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions messages, functions=self.functions
@ -214,7 +247,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Create prompt for this agent.
Args:
system_message: Message to use as the system message that will be the
first in the prompt.
extra_prompt_messages: Prompt messages that will be placed between the
system message and the new human input.
Returns:
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: if system_message:
messages = [system_message] messages = [system_message]
@ -223,12 +269,12 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
messages.extend( messages.extend(
[ [
*_prompts,
HumanMessagePromptTemplate.from_template("{input}"), HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
) )
input_variables = ["input", "agent_scratchpad"] return ChatPromptTemplate(messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
@ -236,6 +282,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),
@ -244,7 +291,10 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
if not isinstance(llm, ChatOpenAI): if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with OpenAI models.") raise ValueError("Only supported with OpenAI models.")
prompt = cls.create_prompt(system_message=system_message) prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls( return cls(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import Field from pydantic import Field, root_validator
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string from langchain.memory.buffer import get_buffer_string
@ -161,6 +161,24 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
input_variables: List[str] input_variables: List[str]
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
messages = values["messages"]
input_vars = set()
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables)
if "input_variables" in values:
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
f"Expected: {input_vars}. "
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = list(input_vars)
return values
@classmethod @classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
prompt_template = PromptTemplate.from_template(template, **kwargs) prompt_template = PromptTemplate.from_template(template, **kwargs)

@ -1,6 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import pytest
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ( from langchain.prompts.chat import (
AIMessagePromptTemplate, AIMessagePromptTemplate,
@ -142,3 +144,21 @@ def test_chat_prompt_template_with_messages() -> None:
) )
prompt_value_messages = prompt_value.to_messages() prompt_value_messages = prompt_value.to_messages()
assert prompt_value_messages[-1] == HumanMessage(content="foo") assert prompt_value_messages[-1] == HumanMessage(content="foo")
def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=["foo"])
def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=[])
def test_infer_variables() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
prompt = ChatPromptTemplate(messages=messages)
assert prompt.input_variables == ["foo"]

Loading…
Cancel
Save