forked from Archives/langchain
14099f1b93
We shouldn't be calling a constructor for a default value - should use default_factory instead. This is especially ad in this case since it requires an optional dependency and an API key to be set. Resolves #5361
302 lines
10 KiB
Python
302 lines
10 KiB
Python
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
|
import json
|
|
import re
|
|
from functools import partial
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import yaml
|
|
from pydantic import Field
|
|
|
|
from langchain.agents.agent import AgentExecutor
|
|
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
|
API_CONTROLLER_PROMPT,
|
|
API_CONTROLLER_TOOL_DESCRIPTION,
|
|
API_CONTROLLER_TOOL_NAME,
|
|
API_ORCHESTRATOR_PROMPT,
|
|
API_PLANNER_PROMPT,
|
|
API_PLANNER_TOOL_DESCRIPTION,
|
|
API_PLANNER_TOOL_NAME,
|
|
PARSING_DELETE_PROMPT,
|
|
PARSING_GET_PROMPT,
|
|
PARSING_PATCH_PROMPT,
|
|
PARSING_POST_PROMPT,
|
|
REQUESTS_DELETE_TOOL_DESCRIPTION,
|
|
REQUESTS_GET_TOOL_DESCRIPTION,
|
|
REQUESTS_PATCH_TOOL_DESCRIPTION,
|
|
REQUESTS_POST_TOOL_DESCRIPTION,
|
|
)
|
|
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
from langchain.agents.tools import Tool
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.llms.openai import OpenAI
|
|
from langchain.memory import ReadOnlySharedMemory
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.requests import RequestsWrapper
|
|
from langchain.tools.base import BaseTool
|
|
from langchain.tools.requests.tool import BaseRequestsTool
|
|
|
|
#
|
|
# Requests tools with LLM-instructed extraction of truncated responses.
|
|
#
|
|
# Of course, truncating so bluntly may lose a lot of valuable
|
|
# information in the response.
|
|
# However, the goal for now is to have only a single inference step.
|
|
MAX_RESPONSE_LENGTH = 5000
|
|
|
|
|
|
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
|
return LLMChain(
|
|
llm=OpenAI(),
|
|
prompt=prompt,
|
|
)
|
|
|
|
|
|
def _get_default_llm_chain_factory(
|
|
prompt: BasePromptTemplate,
|
|
) -> Callable[[], LLMChain]:
|
|
"""Returns a default LLMChain factory."""
|
|
return partial(_get_default_llm_chain, prompt)
|
|
|
|
|
|
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
|
name = "requests_get"
|
|
description = REQUESTS_GET_TOOL_DESCRIPTION
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
llm_chain: LLMChain = Field(
|
|
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
|
)
|
|
|
|
def _run(self, text: str) -> str:
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError as e:
|
|
raise e
|
|
data_params = data.get("params")
|
|
response = self.requests_wrapper.get(data["url"], params=data_params)
|
|
response = response[: self.response_length]
|
|
return self.llm_chain.predict(
|
|
response=response, instructions=data["output_instructions"]
|
|
).strip()
|
|
|
|
async def _arun(self, text: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|
name = "requests_post"
|
|
description = REQUESTS_POST_TOOL_DESCRIPTION
|
|
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
llm_chain: LLMChain = Field(
|
|
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
|
)
|
|
|
|
def _run(self, text: str) -> str:
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError as e:
|
|
raise e
|
|
response = self.requests_wrapper.post(data["url"], data["data"])
|
|
response = response[: self.response_length]
|
|
return self.llm_chain.predict(
|
|
response=response, instructions=data["output_instructions"]
|
|
).strip()
|
|
|
|
async def _arun(self, text: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
|
name = "requests_patch"
|
|
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
|
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
llm_chain: LLMChain = Field(
|
|
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
|
|
)
|
|
|
|
def _run(self, text: str) -> str:
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError as e:
|
|
raise e
|
|
response = self.requests_wrapper.patch(data["url"], data["data"])
|
|
response = response[: self.response_length]
|
|
return self.llm_chain.predict(
|
|
response=response, instructions=data["output_instructions"]
|
|
).strip()
|
|
|
|
async def _arun(self, text: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
|
name = "requests_delete"
|
|
description = REQUESTS_DELETE_TOOL_DESCRIPTION
|
|
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
llm_chain: LLMChain = Field(
|
|
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
|
|
)
|
|
|
|
def _run(self, text: str) -> str:
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError as e:
|
|
raise e
|
|
response = self.requests_wrapper.delete(data["url"])
|
|
response = response[: self.response_length]
|
|
return self.llm_chain.predict(
|
|
response=response, instructions=data["output_instructions"]
|
|
).strip()
|
|
|
|
async def _arun(self, text: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
#
|
|
# Orchestrator, planner, controller.
|
|
#
|
|
def _create_api_planner_tool(
|
|
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
|
|
) -> Tool:
|
|
endpoint_descriptions = [
|
|
f"{name} {description}" for name, description, _ in api_spec.endpoints
|
|
]
|
|
prompt = PromptTemplate(
|
|
template=API_PLANNER_PROMPT,
|
|
input_variables=["query"],
|
|
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
|
|
)
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
tool = Tool(
|
|
name=API_PLANNER_TOOL_NAME,
|
|
description=API_PLANNER_TOOL_DESCRIPTION,
|
|
func=chain.run,
|
|
)
|
|
return tool
|
|
|
|
|
|
def _create_api_controller_agent(
|
|
api_url: str,
|
|
api_docs: str,
|
|
requests_wrapper: RequestsWrapper,
|
|
llm: BaseLanguageModel,
|
|
) -> AgentExecutor:
|
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
|
tools: List[BaseTool] = [
|
|
RequestsGetToolWithParsing(
|
|
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
|
),
|
|
RequestsPostToolWithParsing(
|
|
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
|
),
|
|
]
|
|
prompt = PromptTemplate(
|
|
template=API_CONTROLLER_PROMPT,
|
|
input_variables=["input", "agent_scratchpad"],
|
|
partial_variables={
|
|
"api_url": api_url,
|
|
"api_docs": api_docs,
|
|
"tool_names": ", ".join([tool.name for tool in tools]),
|
|
"tool_descriptions": "\n".join(
|
|
[f"{tool.name}: {tool.description}" for tool in tools]
|
|
),
|
|
},
|
|
)
|
|
agent = ZeroShotAgent(
|
|
llm_chain=LLMChain(llm=llm, prompt=prompt),
|
|
allowed_tools=[tool.name for tool in tools],
|
|
)
|
|
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
|
|
|
|
|
def _create_api_controller_tool(
|
|
api_spec: ReducedOpenAPISpec,
|
|
requests_wrapper: RequestsWrapper,
|
|
llm: BaseLanguageModel,
|
|
) -> Tool:
|
|
"""Expose controller as a tool.
|
|
|
|
The tool is invoked with a plan from the planner, and dynamically
|
|
creates a controller agent with relevant documentation only to
|
|
constrain the context.
|
|
"""
|
|
|
|
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
|
|
|
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
|
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
|
matches = re.findall(pattern, plan_str)
|
|
endpoint_names = [
|
|
"{method} {route}".format(method=method, route=route.split("?")[0])
|
|
for method, route in matches
|
|
]
|
|
endpoint_docs_by_name = {name: docs for name, _, docs in api_spec.endpoints}
|
|
docs_str = ""
|
|
for endpoint_name in endpoint_names:
|
|
docs = endpoint_docs_by_name.get(endpoint_name)
|
|
if not docs:
|
|
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
|
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
|
|
|
|
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
|
|
return agent.run(plan_str)
|
|
|
|
return Tool(
|
|
name=API_CONTROLLER_TOOL_NAME,
|
|
func=_create_and_run_api_controller_agent,
|
|
description=API_CONTROLLER_TOOL_DESCRIPTION,
|
|
)
|
|
|
|
|
|
def create_openapi_agent(
|
|
api_spec: ReducedOpenAPISpec,
|
|
requests_wrapper: RequestsWrapper,
|
|
llm: BaseLanguageModel,
|
|
shared_memory: Optional[ReadOnlySharedMemory] = None,
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
verbose: bool = True,
|
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Dict[str, Any],
|
|
) -> AgentExecutor:
|
|
"""Instantiate API planner and controller for a given spec.
|
|
|
|
Inject credentials via requests_wrapper.
|
|
|
|
We use a top-level "orchestrator" agent to invoke the planner and controller,
|
|
rather than a top-level planner
|
|
that invokes a controller with its plan. This is to keep the planner simple.
|
|
"""
|
|
tools = [
|
|
_create_api_planner_tool(api_spec, llm),
|
|
_create_api_controller_tool(api_spec, requests_wrapper, llm),
|
|
]
|
|
prompt = PromptTemplate(
|
|
template=API_ORCHESTRATOR_PROMPT,
|
|
input_variables=["input", "agent_scratchpad"],
|
|
partial_variables={
|
|
"tool_names": ", ".join([tool.name for tool in tools]),
|
|
"tool_descriptions": "\n".join(
|
|
[f"{tool.name}: {tool.description}" for tool in tools]
|
|
),
|
|
},
|
|
)
|
|
agent = ZeroShotAgent(
|
|
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
|
|
allowed_tools=[tool.name for tool in tools],
|
|
**kwargs,
|
|
)
|
|
return AgentExecutor.from_agent_and_tools(
|
|
agent=agent,
|
|
tools=tools,
|
|
callback_manager=callback_manager,
|
|
verbose=verbose,
|
|
**(agent_executor_kwargs or {}),
|
|
)
|