Add Retrieval Example for AI Plugins (#2737)

This PR proposes
- An NLAToolkit method to instantiate from an AI Plugin URL
- A notebook that shows how to use that alongside an example of using a
Retriever object to lookup specs and route queries to them on the fly

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
fix_agent_callbacks
vowelparrot 1 year ago committed by GitHub
parent b5bbe601fb
commit 94a92abf24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,538 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ba5f8741",
"metadata": {},
"source": [
"# Custom Agent with PlugIn Retrieval\n",
"\n",
"This notebook combines two concepts in order to build a custom agent that can interact with AI Plugins:\n",
"\n",
"1. [Custom Agent with Retrieval](../../modules/agents/agents/custom_agent_with_plugin_retrieval.html): This introduces the concept of retrieving many tools, which is useful when trying to work with arbitrarily many plugins.\n",
"2. [Natural Language API Chains](../../modules/chains/examples/openapi.ipynb): This creates Natural Language wrappers around OpenAPI endpoints. This is useful because (1) plugins use OpenAPI endpoints under the hood, (2) wrapping them in an NLAChain allows the router agent to call it more easily.\n",
"\n",
"The novel idea introduced in this notebook is the idea of using retrieval to select not the tools explicitly, but the set of OpenAPI specs to use. We can then generate tools from those OpenAPI specs. The use case for this is when trying to get agents to use plugins. It may be more efficient to choose plugins first, then the endpoints, rather than the endpoints directly. This is because the plugins may contain more useful information for selection."
]
},
{
"cell_type": "markdown",
"id": "fea4812c",
"metadata": {},
"source": [
"## Set up environment\n",
"\n",
"Do necessary imports, etc."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9af9734e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser\n",
"from langchain.prompts import StringPromptTemplate\n",
"from langchain import OpenAI, SerpAPIWrapper, LLMChain\n",
"from typing import List, Union\n",
"from langchain.schema import AgentAction, AgentFinish\n",
"from langchain.agents.agent_toolkits import NLAToolkit\n",
"from langchain.tools.plugin import AIPlugin\n",
"import re"
]
},
{
"cell_type": "markdown",
"id": "2f91d8b4",
"metadata": {},
"source": [
"## Setup LLM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a1a3b59c",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "markdown",
"id": "6df0253f",
"metadata": {},
"source": [
"## Set up plugins\n",
"\n",
"Load and index plugins"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "becda2a1",
"metadata": {},
"outputs": [],
"source": [
"urls = [\n",
" \"https://datasette.io/.well-known/ai-plugin.json\",\n",
" \"https://api.speak.com/.well-known/ai-plugin.json\",\n",
" \"https://www.wolframalpha.com/.well-known/ai-plugin.json\",\n",
" \"https://www.zapier.com/.well-known/ai-plugin.json\",\n",
" \"https://www.klarna.com/.well-known/ai-plugin.json\",\n",
" \"https://www.joinmilo.com/.well-known/ai-plugin.json\",\n",
" \"https://slack.com/.well-known/ai-plugin.json\",\n",
" \"https://schooldigger.com/.well-known/ai-plugin.json\",\n",
"]\n",
"\n",
"AI_PLUGINS = [AIPlugin.from_url(url) for url in urls]"
]
},
{
"cell_type": "markdown",
"id": "17362717",
"metadata": {},
"source": [
"## Tool Retriever\n",
"\n",
"We will use a vectorstore to create embeddings for each tool description. Then, for an incoming query we can create embeddings for that query and do a similarity search for relevant tools."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "77c4be4b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.vectorstores import FAISS\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.schema import Document"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9092a158",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.2 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load a Swagger 2.0 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n"
]
}
],
"source": [
"embeddings = OpenAIEmbeddings()\n",
"docs = [\n",
" Document(page_content=plugin.description_for_model, \n",
" metadata={\"plugin_name\": plugin.name_for_model}\n",
" )\n",
" for plugin in AI_PLUGINS\n",
"]\n",
"vector_store = FAISS.from_documents(docs, embeddings)\n",
"toolkits_dict = {plugin.name_for_model: \n",
" NLAToolkit.from_llm_and_ai_plugin(llm, plugin) \n",
" for plugin in AI_PLUGINS}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "735a7566",
"metadata": {},
"outputs": [],
"source": [
"retriever = vector_store.as_retriever()\n",
"\n",
"def get_tools(query):\n",
" # Get documents, which contain the Plugins to use\n",
" docs = retriever.get_relevant_documents(query)\n",
" # Get the toolkits, one for each plugin\n",
" tool_kits = [toolkits_dict[d.metadata[\"plugin_name\"]] for d in docs]\n",
" # Get the tools: a separate NLAChain for each endpoint\n",
" tools = []\n",
" for tk in tool_kits:\n",
" tools.extend(tk.nla_tools)\n",
" return tools"
]
},
{
"cell_type": "markdown",
"id": "7699afd7",
"metadata": {},
"source": [
"We can now test this retriever to see if it seems to work."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "425f2886",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Milo.askMilo',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.search_all_actions',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.preview_a_zap',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.get_configuration_link',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.list_exposed_actions',\n",
" 'SchoolDigger_API_V2.0.Autocomplete_GetSchools',\n",
" 'SchoolDigger_API_V2.0.Districts_GetAllDistricts2',\n",
" 'SchoolDigger_API_V2.0.Districts_GetDistrict2',\n",
" 'SchoolDigger_API_V2.0.Rankings_GetSchoolRank2',\n",
" 'SchoolDigger_API_V2.0.Rankings_GetRank_District',\n",
" 'SchoolDigger_API_V2.0.Schools_GetAllSchools20',\n",
" 'SchoolDigger_API_V2.0.Schools_GetSchool20',\n",
" 'Speak.translate',\n",
" 'Speak.explainPhrase',\n",
" 'Speak.explainTask']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tools = get_tools(\"What could I do today with my kiddo\")\n",
"[t.name for t in tools]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3aa88768",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Open_AI_Klarna_product_Api.productsUsingGET',\n",
" 'Milo.askMilo',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.search_all_actions',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.preview_a_zap',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.get_configuration_link',\n",
" 'Zapier_Natural_Language_Actions_(NLA)_API_(Dynamic)_-_Beta.list_exposed_actions',\n",
" 'SchoolDigger_API_V2.0.Autocomplete_GetSchools',\n",
" 'SchoolDigger_API_V2.0.Districts_GetAllDistricts2',\n",
" 'SchoolDigger_API_V2.0.Districts_GetDistrict2',\n",
" 'SchoolDigger_API_V2.0.Rankings_GetSchoolRank2',\n",
" 'SchoolDigger_API_V2.0.Rankings_GetRank_District',\n",
" 'SchoolDigger_API_V2.0.Schools_GetAllSchools20',\n",
" 'SchoolDigger_API_V2.0.Schools_GetSchool20']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tools = get_tools(\"what shirts can i buy?\")\n",
"[t.name for t in tools]"
]
},
{
"cell_type": "markdown",
"id": "2e7a075c",
"metadata": {},
"source": [
"## Prompt Template\n",
"\n",
"The prompt template is pretty standard, because we're not actually changing that much logic in the actual prompt template, but rather we are just changing how retrieval is done."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "339b1bb8",
"metadata": {},
"outputs": [],
"source": [
"# Set up the base template\n",
"template = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\n",
"\n",
"{tools}\n",
"\n",
"Use the following format:\n",
"\n",
"Question: the input question you must answer\n",
"Thought: you should always think about what to do\n",
"Action: the action to take, should be one of [{tool_names}]\n",
"Action Input: the input to the action\n",
"Observation: the result of the action\n",
"... (this Thought/Action/Action Input/Observation can repeat N times)\n",
"Thought: I now know the final answer\n",
"Final Answer: the final answer to the original input question\n",
"\n",
"Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s\n",
"\n",
"Question: {input}\n",
"{agent_scratchpad}\"\"\""
]
},
{
"cell_type": "markdown",
"id": "1583acdc",
"metadata": {},
"source": [
"The custom prompt template now has the concept of a tools_getter, which we call on the input to select the tools to use"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fd969d31",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"# Set up a prompt template\n",
"class CustomPromptTemplate(StringPromptTemplate):\n",
" # The template to use\n",
" template: str\n",
" ############## NEW ######################\n",
" # The list of tools available\n",
" tools_getter: Callable\n",
" \n",
" def format(self, **kwargs) -> str:\n",
" # Get the intermediate steps (AgentAction, Observation tuples)\n",
" # Format them in a particular way\n",
" intermediate_steps = kwargs.pop(\"intermediate_steps\")\n",
" thoughts = \"\"\n",
" for action, observation in intermediate_steps:\n",
" thoughts += action.log\n",
" thoughts += f\"\\nObservation: {observation}\\nThought: \"\n",
" # Set the agent_scratchpad variable to that value\n",
" kwargs[\"agent_scratchpad\"] = thoughts\n",
" ############## NEW ######################\n",
" tools = self.tools_getter(kwargs[\"input\"])\n",
" # Create a tools variable from the list of tools provided\n",
" kwargs[\"tools\"] = \"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in tools])\n",
" # Create a list of tool names for the tools provided\n",
" kwargs[\"tool_names\"] = \", \".join([tool.name for tool in tools])\n",
" return self.template.format(**kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "798ef9fb",
"metadata": {},
"outputs": [],
"source": [
"prompt = CustomPromptTemplate(\n",
" template=template,\n",
" tools_getter=get_tools,\n",
" # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n",
" # This includes the `intermediate_steps` variable because that is needed\n",
" input_variables=[\"input\", \"intermediate_steps\"]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ef3a1af3",
"metadata": {},
"source": [
"## Output Parser\n",
"\n",
"The output parser is unchanged from the previous notebook, since we are not changing anything about the output format."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7c6fe0d3",
"metadata": {},
"outputs": [],
"source": [
"class CustomOutputParser(AgentOutputParser):\n",
" \n",
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
" # Check if agent should finish\n",
" if \"Final Answer:\" in llm_output:\n",
" return AgentFinish(\n",
" # Return values is generally always a dictionary with a single `output` key\n",
" # It is not recommended to try anything else at the moment :)\n",
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
" log=llm_output,\n",
" )\n",
" # Parse out the action and action input\n",
" regex = r\"Action: (.*?)[\\n]*Action Input:[\\s]*(.*)\"\n",
" match = re.search(regex, llm_output, re.DOTALL)\n",
" if not match:\n",
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
" action = match.group(1).strip()\n",
" action_input = match.group(2)\n",
" # Return the action and action input\n",
" return AgentAction(tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d278706a",
"metadata": {},
"outputs": [],
"source": [
"output_parser = CustomOutputParser()"
]
},
{
"cell_type": "markdown",
"id": "170587b1",
"metadata": {},
"source": [
"## Set up LLM, stop sequence, and the agent\n",
"\n",
"Also the same as the previous notebook"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f9d4c374",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9b1cc2a2",
"metadata": {},
"outputs": [],
"source": [
"# LLM chain consisting of the LLM and a prompt\n",
"llm_chain = LLMChain(llm=llm, prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e4f5092f",
"metadata": {},
"outputs": [],
"source": [
"tool_names = [tool.name for tool in tools]\n",
"agent = LLMSingleActionAgent(\n",
" llm_chain=llm_chain, \n",
" output_parser=output_parser,\n",
" stop=[\"\\nObservation:\"], \n",
" allowed_tools=tool_names\n",
")"
]
},
{
"cell_type": "markdown",
"id": "aa8a5326",
"metadata": {},
"source": [
"## Use the Agent\n",
"\n",
"Now we can use it!"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "490604e9",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "653b1617",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to find a product API\n",
"Action: Open_AI_Klarna_product_Api.productsUsingGET\n",
"Action Input: shirts\u001b[0m\n",
"\n",
"Observation:\u001b[36;1m\u001b[1;3mI found 10 shirts from the API response. They range in price from $9.99 to $450.00 and come in a variety of materials, colors, and patterns.\u001b[0m\u001b[32;1m\u001b[1;3m I now know what shirts I can buy\n",
"Final Answer: Arg, I found 10 shirts from the API response. They range in price from $9.99 to $450.00 and come in a variety of materials, colors, and patterns.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Arg, I found 10 shirts from the API response. They range in price from $9.99 to $450.00 and come in a variety of materials, colors, and patterns.'"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"what shirts can i buy?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2481ee76",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,5 +1,5 @@
"""Toolkit for interacting with API's using natural language."""
from __future__ import annotations
from typing import Any, List, Optional, Sequence
@ -11,6 +11,7 @@ from langchain.llms.base import BaseLLM
from langchain.requests import Requests
from langchain.tools.base import BaseTool
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain.tools.plugin import AIPlugin
class NLAToolkit(BaseToolkit):
@ -23,19 +24,18 @@ class NLAToolkit(BaseToolkit):
"""Get the tools for all the API operations."""
return list(self.nla_tools)
@classmethod
def from_llm_and_spec(
cls,
@staticmethod
def _get_http_operation_tools(
llm: BaseLLM,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any
) -> "NLAToolkit":
"""Instantiate the toolkit by creating tools for each operation."""
http_operation_tools: List[NLATool] = []
**kwargs: Any,
) -> List[NLATool]:
"""Get the tools for all the API operations."""
if not spec.paths:
return cls(nla_tools=http_operation_tools)
return []
http_operation_tools = []
for path in spec.paths:
for method in spec.get_methods_for_path(path):
endpoint_tool = NLATool.from_llm_and_method(
@ -45,9 +45,24 @@ class NLAToolkit(BaseToolkit):
spec=spec,
requests=requests,
verbose=verbose,
**kwargs
**kwargs,
)
http_operation_tools.append(endpoint_tool)
return http_operation_tools
@classmethod
def from_llm_and_spec(
cls,
llm: BaseLLM,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit by creating tools for each operation."""
http_operation_tools = cls._get_http_operation_tools(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
return cls(nla_tools=http_operation_tools)
@classmethod
@ -57,10 +72,45 @@ class NLAToolkit(BaseToolkit):
open_api_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any
) -> "NLAToolkit":
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(open_api_url)
return cls.from_llm_and_spec(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
@classmethod
def from_llm_and_ai_plugin(
cls,
llm: BaseLLM,
ai_plugin: AIPlugin,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(ai_plugin.api.url)
# TODO: Merge optional Auth information with the `requests` argument
return cls.from_llm_and_spec(
llm=llm,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
@classmethod
def from_llm_and_ai_plugin_url(
cls,
llm: BaseLLM,
ai_plugin_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
plugin = AIPlugin.from_url(ai_plugin_url)
return cls.from_llm_and_ai_plugin(
llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
)

@ -17,6 +17,8 @@ class ApiConfig(BaseModel):
class AIPlugin(BaseModel):
"""AI Plugin Definition."""
schema_version: str
name_for_model: str
name_for_human: str
@ -28,6 +30,12 @@ class AIPlugin(BaseModel):
contact_email: Optional[str]
legal_info_url: Optional[str]
@classmethod
def from_url(cls, url: str) -> AIPlugin:
"""Instantiate AIPlugin from a URL."""
response = requests.get(url).json()
return cls(**response)
def marshal_spec(txt: str) -> dict:
"""Convert the yaml or json serialized spec to a dict."""
@ -43,8 +51,7 @@ class AIPluginTool(BaseTool):
@classmethod
def from_plugin_url(cls, url: str) -> AIPluginTool:
response = requests.get(url).json()
plugin = AIPlugin(**response)
plugin = AIPlugin.from_url(url)
description = (
f"Call this tool to get the OpenAPI spec (and usage guide) "
f"for interacting with the {plugin.name_for_human} API. "

Loading…
Cancel
Save