community[minor]: OpenAPI agent. Add support for PUT, DELETE and PATCH (#22962)

**Description**: Add PUT, DELETE and PATCH tools to tool list for
OpenAPI agent if dangerous requests are allowed.

**Issue**: https://github.com/langchain-ai/langchain/issues/20469
This commit is contained in:
Iurii Umnov 2024-06-21 23:44:23 +03:00 committed by GitHub
parent 3c42bf8d97
commit 3b7b933aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import json import json
import re import re
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
import yaml import yaml
from langchain_core.callbacks import BaseCallbackManager from langchain_core.callbacks import BaseCallbackManager
@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper
MAX_RESPONSE_LENGTH = 5000 MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned.""" """Maximum length of the response to be returned."""
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any: def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -254,25 +256,56 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper, requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel, llm: BaseLanguageModel,
allow_dangerous_requests: bool, allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Any: ) -> Any:
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) tools: List[BaseTool] = []
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) if "GET" in allowed_operations:
tools: List[BaseTool] = [ get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
RequestsGetToolWithParsing( # type: ignore[call-arg] tools.append(
RequestsGetToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=get_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "POST" in allowed_operations:
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools.append(
RequestsPostToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=post_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "PUT" in allowed_operations:
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
tools.append(
RequestsPutToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=put_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "DELETE" in allowed_operations:
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
llm_chain=get_llm_chain, llm_chain=delete_llm_chain,
allow_dangerous_requests=allow_dangerous_requests, allow_dangerous_requests=allow_dangerous_requests,
), )
RequestsPostToolWithParsing( # type: ignore[call-arg] if "PATCH" in allowed_operations:
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
RequestsPatchToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
llm_chain=post_llm_chain, llm_chain=patch_llm_chain,
allow_dangerous_requests=allow_dangerous_requests, allow_dangerous_requests=allow_dangerous_requests,
), )
] if not tools:
raise ValueError("Tools not found")
prompt = PromptTemplate( prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT, template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"], input_variables=["input", "agent_scratchpad"],
@ -297,6 +330,7 @@ def _create_api_controller_tool(
requests_wrapper: RequestsWrapper, requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel, llm: BaseLanguageModel,
allow_dangerous_requests: bool, allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Tool: ) -> Tool:
"""Expose controller as a tool. """Expose controller as a tool.
@ -308,7 +342,7 @@ def _create_api_controller_tool(
base_url = api_spec.servers[0]["url"] # TODO: do better. base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str: def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*" pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
matches = re.findall(pattern, plan_str) matches = re.findall(pattern, plan_str)
endpoint_names = [ endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0]) "{method} {route}".format(method=method, route=route.split("?")[0])
@ -326,7 +360,12 @@ def _create_api_controller_tool(
raise ValueError(f"{endpoint_name} endpoint does not exist.") raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent( agent = _create_api_controller_agent(
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests base_url,
docs_str,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
) )
return agent.run(plan_str) return agent.run(plan_str)
@ -346,6 +385,7 @@ def create_openapi_agent(
verbose: bool = True, verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None, agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_requests: bool = False, allow_dangerous_requests: bool = False,
allowed_operations: Sequence[Operation] = ("GET", "POST"),
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Construct an OpenAI API planner and controller for a given spec. """Construct an OpenAI API planner and controller for a given spec.
@ -371,7 +411,11 @@ def create_openapi_agent(
tools = [ tools = [
_create_api_planner_tool(api_spec, llm), _create_api_planner_tool(api_spec, llm),
_create_api_controller_tool( _create_api_controller_tool(
api_spec, requests_wrapper, llm, allow_dangerous_requests api_spec,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
), ),
] ]
prompt = PromptTemplate( prompt = PromptTemplate(