mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: OpenAPI agent. Add support for PUT, DELETE and PATCH (#22962)
**Description**: Add PUT, DELETE and PATCH tools to tool list for OpenAPI agent if dangerous requests are allowed. **Issue**: https://github.com/langchain-ai/langchain/issues/20469
This commit is contained in:
parent
3c42bf8d97
commit
3b7b933aa2
@ -3,7 +3,7 @@
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
"""Maximum length of the response to be returned."""
|
||||
|
||||
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
||||
|
||||
|
||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -254,25 +256,56 @@ def _create_api_controller_agent(
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
allow_dangerous_requests: bool,
|
||||
allowed_operations: Sequence[Operation],
|
||||
) -> Any:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
tools: List[BaseTool] = []
|
||||
if "GET" in allowed_operations:
|
||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools: List[BaseTool] = [
|
||||
tools.append(
|
||||
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=get_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "POST" in allowed_operations:
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools.append(
|
||||
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=post_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
if "PUT" in allowed_operations:
|
||||
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
|
||||
tools.append(
|
||||
RequestsPutToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=put_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
)
|
||||
)
|
||||
if "DELETE" in allowed_operations:
|
||||
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
|
||||
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=delete_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
)
|
||||
if "PATCH" in allowed_operations:
|
||||
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
|
||||
RequestsPatchToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=patch_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
)
|
||||
if not tools:
|
||||
raise ValueError("Tools not found")
|
||||
prompt = PromptTemplate(
|
||||
template=API_CONTROLLER_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
@ -297,6 +330,7 @@ def _create_api_controller_tool(
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
allow_dangerous_requests: bool,
|
||||
allowed_operations: Sequence[Operation],
|
||||
) -> Tool:
|
||||
"""Expose controller as a tool.
|
||||
|
||||
@ -308,7 +342,7 @@ def _create_api_controller_tool(
|
||||
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
||||
|
||||
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
||||
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
||||
pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
|
||||
matches = re.findall(pattern, plan_str)
|
||||
endpoint_names = [
|
||||
"{method} {route}".format(method=method, route=route.split("?")[0])
|
||||
@ -326,7 +360,12 @@ def _create_api_controller_tool(
|
||||
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
||||
|
||||
agent = _create_api_controller_agent(
|
||||
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests
|
||||
base_url,
|
||||
docs_str,
|
||||
requests_wrapper,
|
||||
llm,
|
||||
allow_dangerous_requests,
|
||||
allowed_operations,
|
||||
)
|
||||
return agent.run(plan_str)
|
||||
|
||||
@ -346,6 +385,7 @@ def create_openapi_agent(
|
||||
verbose: bool = True,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
allow_dangerous_requests: bool = False,
|
||||
allowed_operations: Sequence[Operation] = ("GET", "POST"),
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Construct an OpenAI API planner and controller for a given spec.
|
||||
@ -371,7 +411,11 @@ def create_openapi_agent(
|
||||
tools = [
|
||||
_create_api_planner_tool(api_spec, llm),
|
||||
_create_api_controller_tool(
|
||||
api_spec, requests_wrapper, llm, allow_dangerous_requests
|
||||
api_spec,
|
||||
requests_wrapper,
|
||||
llm,
|
||||
allow_dangerous_requests,
|
||||
allowed_operations,
|
||||
),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
|
Loading…
Reference in New Issue
Block a user