diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py index cd4a83ff92..87f2530348 100644 --- a/libs/community/langchain_community/agent_toolkits/openapi/planner.py +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -3,7 +3,7 @@ import json import re from functools import partial -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast import yaml from langchain_core.callbacks import BaseCallbackManager @@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper MAX_RESPONSE_LENGTH = 5000 """Maximum length of the response to be returned.""" +Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"] + def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any: from langchain.chains.llm import LLMChain @@ -254,25 +256,56 @@ def _create_api_controller_agent( requests_wrapper: RequestsWrapper, llm: BaseLanguageModel, allow_dangerous_requests: bool, + allowed_operations: Sequence[Operation], ) -> Any: from langchain.agents.agent import AgentExecutor from langchain.agents.mrkl.base import ZeroShotAgent from langchain.chains.llm import LLMChain - get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) - post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) - tools: List[BaseTool] = [ - RequestsGetToolWithParsing( # type: ignore[call-arg] + tools: List[BaseTool] = [] + if "GET" in allowed_operations: + get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) + tools.append( + RequestsGetToolWithParsing( # type: ignore[call-arg] + requests_wrapper=requests_wrapper, + llm_chain=get_llm_chain, + allow_dangerous_requests=allow_dangerous_requests, + ) + ) + if "POST" in allowed_operations: + post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) + tools.append( + RequestsPostToolWithParsing( # type: ignore[call-arg] + requests_wrapper=requests_wrapper, + llm_chain=post_llm_chain, + allow_dangerous_requests=allow_dangerous_requests, + ) + ) + if "PUT" in allowed_operations: + put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT) + tools.append( + RequestsPutToolWithParsing( # type: ignore[call-arg] + requests_wrapper=requests_wrapper, + llm_chain=put_llm_chain, + allow_dangerous_requests=allow_dangerous_requests, + ) + ) + if "DELETE" in allowed_operations: + delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT) + RequestsDeleteToolWithParsing( # type: ignore[call-arg] requests_wrapper=requests_wrapper, - llm_chain=get_llm_chain, + llm_chain=delete_llm_chain, allow_dangerous_requests=allow_dangerous_requests, - ), - RequestsPostToolWithParsing( # type: ignore[call-arg] + ) + if "PATCH" in allowed_operations: + patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT) + RequestsPatchToolWithParsing( # type: ignore[call-arg] requests_wrapper=requests_wrapper, - llm_chain=post_llm_chain, + llm_chain=patch_llm_chain, allow_dangerous_requests=allow_dangerous_requests, - ), - ] + ) + if not tools: + raise ValueError("Tools not found") prompt = PromptTemplate( template=API_CONTROLLER_PROMPT, input_variables=["input", "agent_scratchpad"], @@ -297,6 +330,7 @@ def _create_api_controller_tool( requests_wrapper: RequestsWrapper, llm: BaseLanguageModel, allow_dangerous_requests: bool, + allowed_operations: Sequence[Operation], ) -> Tool: """Expose controller as a tool. @@ -308,7 +342,7 @@ def _create_api_controller_tool( base_url = api_spec.servers[0]["url"] # TODO: do better. def _create_and_run_api_controller_agent(plan_str: str) -> str: - pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*" + pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*" matches = re.findall(pattern, plan_str) endpoint_names = [ "{method} {route}".format(method=method, route=route.split("?")[0]) @@ -326,7 +360,12 @@ def _create_api_controller_tool( raise ValueError(f"{endpoint_name} endpoint does not exist.") agent = _create_api_controller_agent( - base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests + base_url, + docs_str, + requests_wrapper, + llm, + allow_dangerous_requests, + allowed_operations, ) return agent.run(plan_str) @@ -346,6 +385,7 @@ def create_openapi_agent( verbose: bool = True, agent_executor_kwargs: Optional[Dict[str, Any]] = None, allow_dangerous_requests: bool = False, + allowed_operations: Sequence[Operation] = ("GET", "POST"), **kwargs: Any, ) -> Any: """Construct an OpenAI API planner and controller for a given spec. @@ -371,7 +411,11 @@ def create_openapi_agent( tools = [ _create_api_planner_tool(api_spec, llm), _create_api_controller_tool( - api_spec, requests_wrapper, llm, allow_dangerous_requests + api_spec, + requests_wrapper, + llm, + allow_dangerous_requests, + allowed_operations, ), ] prompt = PromptTemplate(