You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/agent_toolkits/openapi/planner.py

379 lines
13 KiB
Python

"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast
import yaml
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool, Tool
from langchain_community.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
)
from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain_community.llms import OpenAI
from langchain_community.tools.requests.tool import BaseRequestsTool
from langchain_community.utilities.requests import RequestsWrapper
#
# Requests tools with LLM-instructed extraction of truncated responses.
#
# Of course, truncating so bluntly may lose a lot of valuable
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
from langchain.chains.llm import LLMChain
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], Any]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_get"
"""Tool name."""
description = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
data_params = data.get("params")
response: str = cast(
str, self.requests_wrapper.get(data["url"], params=data_params)
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_post"
"""Tool name."""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_patch"
"""Tool name."""
description = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(
str, self.requests_wrapper.patch(data["url"], data["data"])
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_put"
"""Tool name."""
description = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool that sends a DELETE request and parses the response."""
name: str = "requests_delete"
"""The name of the tool."""
description = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""The maximum length of the response."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
)
"""The LLM chain used to parse the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.delete(data["url"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
#
# Orchestrator, planner, controller.
#
def _create_api_planner_tool(
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
from langchain.chains.llm import LLMChain
endpoint_descriptions = [
f"{name} {description}" for name, description, _ in api_spec.endpoints
]
prompt = PromptTemplate(
template=API_PLANNER_PROMPT,
input_variables=["query"],
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
)
chain = LLMChain(llm=llm, prompt=prompt)
tool = Tool(
name=API_PLANNER_TOOL_NAME,
description=API_PLANNER_TOOL_DESCRIPTION,
func=chain.run,
)
return tool
def _create_api_controller_agent(
api_url: str,
api_docs: str,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> Any:
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"api_url": api_url,
"api_docs": api_docs,
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt),
allowed_tools=[tool.name for tool in tools],
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def _create_api_controller_tool(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> Tool:
"""Expose controller as a tool.
The tool is invoked with a plan from the planner, and dynamically
creates a controller agent with relevant documentation only to
constrain the context.
"""
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
for method, route in matches
]
docs_str = ""
for endpoint_name in endpoint_names:
found_match = False
for name, _, docs in api_spec.endpoints:
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
if regex_name.match(endpoint_name):
found_match = True
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
if not found_match:
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
return agent.run(plan_str)
return Tool(
name=API_CONTROLLER_TOOL_NAME,
func=_create_and_run_api_controller_agent,
description=API_CONTROLLER_TOOL_DESCRIPTION,
)
def create_openapi_agent(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
shared_memory: Optional[Any] = None,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Instantiate OpenAI API planner and controller for a given spec.
Inject credentials via requests_wrapper.
We use a top-level "orchestrator" agent to invoke the planner and controller,
rather than a top-level planner
that invokes a controller with its plan. This is to keep the planner simple.
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(api_spec, requests_wrapper, llm),
]
prompt = PromptTemplate(
template=API_ORCHESTRATOR_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)