community[minor]: OpenAPI agent. Add support for PUT, DELETE and PATCH (#22962)

**Description**: Add PUT, DELETE and PATCH tools to tool list for
OpenAPI agent if dangerous requests are allowed.

**Issue**: https://github.com/langchain-ai/langchain/issues/20469
This commit is contained in:
Iurii Umnov 2024-06-21 23:44:23 +03:00 committed by GitHub
parent 3c42bf8d97
commit 3b7b933aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
import yaml
from langchain_core.callbacks import BaseCallbackManager
@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
from langchain.chains.llm import LLMChain
@ -254,25 +256,56 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Any:
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools: List[BaseTool] = []
if "GET" in allowed_operations:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
tools.append(
RequestsGetToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=get_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
),
)
)
if "POST" in allowed_operations:
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools.append(
RequestsPostToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=post_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
),
]
)
)
if "PUT" in allowed_operations:
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
tools.append(
RequestsPutToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=put_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "DELETE" in allowed_operations:
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=delete_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
if "PATCH" in allowed_operations:
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
RequestsPatchToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=patch_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
if not tools:
raise ValueError("Tools not found")
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
@ -297,6 +330,7 @@ def _create_api_controller_tool(
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Tool:
"""Expose controller as a tool.
@ -308,7 +342,7 @@ def _create_api_controller_tool(
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
@ -326,7 +360,12 @@ def _create_api_controller_tool(
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests
base_url,
docs_str,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
)
return agent.run(plan_str)
@ -346,6 +385,7 @@ def create_openapi_agent(
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_requests: bool = False,
allowed_operations: Sequence[Operation] = ("GET", "POST"),
**kwargs: Any,
) -> Any:
"""Construct an OpenAI API planner and controller for a given spec.
@ -371,7 +411,11 @@ def create_openapi_agent(
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(
api_spec, requests_wrapper, llm, allow_dangerous_requests
api_spec,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
),
]
prompt = PromptTemplate(