mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: OpenAPI agent. Add support for PUT, DELETE and PATCH (#22962)
**Description**: Add PUT, DELETE and PATCH tools to tool list for OpenAPI agent if dangerous requests are allowed. **Issue**: https://github.com/langchain-ai/langchain/issues/20469
This commit is contained in:
parent
3c42bf8d97
commit
3b7b933aa2
@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, cast
|
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper
|
|||||||
MAX_RESPONSE_LENGTH = 5000
|
MAX_RESPONSE_LENGTH = 5000
|
||||||
"""Maximum length of the response to be returned."""
|
"""Maximum length of the response to be returned."""
|
||||||
|
|
||||||
|
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
||||||
|
|
||||||
|
|
||||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
|
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -254,25 +256,56 @@ def _create_api_controller_agent(
|
|||||||
requests_wrapper: RequestsWrapper,
|
requests_wrapper: RequestsWrapper,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
allow_dangerous_requests: bool,
|
allow_dangerous_requests: bool,
|
||||||
|
allowed_operations: Sequence[Operation],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
tools: List[BaseTool] = []
|
||||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
if "GET" in allowed_operations:
|
||||||
tools: List[BaseTool] = [
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||||
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
tools.append(
|
||||||
|
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
llm_chain=get_llm_chain,
|
||||||
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "POST" in allowed_operations:
|
||||||
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||||
|
tools.append(
|
||||||
|
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
llm_chain=post_llm_chain,
|
||||||
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "PUT" in allowed_operations:
|
||||||
|
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
|
||||||
|
tools.append(
|
||||||
|
RequestsPutToolWithParsing( # type: ignore[call-arg]
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
llm_chain=put_llm_chain,
|
||||||
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "DELETE" in allowed_operations:
|
||||||
|
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
|
||||||
|
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=get_llm_chain,
|
llm_chain=delete_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
),
|
)
|
||||||
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
if "PATCH" in allowed_operations:
|
||||||
|
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
|
||||||
|
RequestsPatchToolWithParsing( # type: ignore[call-arg]
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=post_llm_chain,
|
llm_chain=patch_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
),
|
)
|
||||||
]
|
if not tools:
|
||||||
|
raise ValueError("Tools not found")
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=API_CONTROLLER_PROMPT,
|
template=API_CONTROLLER_PROMPT,
|
||||||
input_variables=["input", "agent_scratchpad"],
|
input_variables=["input", "agent_scratchpad"],
|
||||||
@ -297,6 +330,7 @@ def _create_api_controller_tool(
|
|||||||
requests_wrapper: RequestsWrapper,
|
requests_wrapper: RequestsWrapper,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
allow_dangerous_requests: bool,
|
allow_dangerous_requests: bool,
|
||||||
|
allowed_operations: Sequence[Operation],
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""Expose controller as a tool.
|
"""Expose controller as a tool.
|
||||||
|
|
||||||
@ -308,7 +342,7 @@ def _create_api_controller_tool(
|
|||||||
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
||||||
|
|
||||||
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
||||||
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
|
||||||
matches = re.findall(pattern, plan_str)
|
matches = re.findall(pattern, plan_str)
|
||||||
endpoint_names = [
|
endpoint_names = [
|
||||||
"{method} {route}".format(method=method, route=route.split("?")[0])
|
"{method} {route}".format(method=method, route=route.split("?")[0])
|
||||||
@ -326,7 +360,12 @@ def _create_api_controller_tool(
|
|||||||
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
||||||
|
|
||||||
agent = _create_api_controller_agent(
|
agent = _create_api_controller_agent(
|
||||||
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests
|
base_url,
|
||||||
|
docs_str,
|
||||||
|
requests_wrapper,
|
||||||
|
llm,
|
||||||
|
allow_dangerous_requests,
|
||||||
|
allowed_operations,
|
||||||
)
|
)
|
||||||
return agent.run(plan_str)
|
return agent.run(plan_str)
|
||||||
|
|
||||||
@ -346,6 +385,7 @@ def create_openapi_agent(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
allow_dangerous_requests: bool = False,
|
allow_dangerous_requests: bool = False,
|
||||||
|
allowed_operations: Sequence[Operation] = ("GET", "POST"),
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Construct an OpenAI API planner and controller for a given spec.
|
"""Construct an OpenAI API planner and controller for a given spec.
|
||||||
@ -371,7 +411,11 @@ def create_openapi_agent(
|
|||||||
tools = [
|
tools = [
|
||||||
_create_api_planner_tool(api_spec, llm),
|
_create_api_planner_tool(api_spec, llm),
|
||||||
_create_api_controller_tool(
|
_create_api_controller_tool(
|
||||||
api_spec, requests_wrapper, llm, allow_dangerous_requests
|
api_spec,
|
||||||
|
requests_wrapper,
|
||||||
|
llm,
|
||||||
|
allow_dangerous_requests,
|
||||||
|
allowed_operations,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
|
Loading…
Reference in New Issue
Block a user