|
|
|
@ -3,7 +3,7 @@
|
|
|
|
|
import json
|
|
|
|
|
import re
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, cast
|
|
|
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
from langchain_core.callbacks import BaseCallbackManager
|
|
|
|
@ -45,6 +45,8 @@ from langchain_community.utilities.requests import RequestsWrapper
|
|
|
|
|
MAX_RESPONSE_LENGTH = 5000
|
|
|
|
|
"""Maximum length of the response to be returned."""
|
|
|
|
|
|
|
|
|
|
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
@ -254,25 +256,56 @@ def _create_api_controller_agent(
|
|
|
|
|
requests_wrapper: RequestsWrapper,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
allow_dangerous_requests: bool,
|
|
|
|
|
allowed_operations: Sequence[Operation],
|
|
|
|
|
) -> Any:
|
|
|
|
|
from langchain.agents.agent import AgentExecutor
|
|
|
|
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
|
|
|
|
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
|
|
|
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
|
|
|
|
tools: List[BaseTool] = [
|
|
|
|
|
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
tools: List[BaseTool] = []
|
|
|
|
|
if "GET" in allowed_operations:
|
|
|
|
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
|
|
|
|
tools.append(
|
|
|
|
|
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
llm_chain=get_llm_chain,
|
|
|
|
|
allow_dangerous_requests=allow_dangerous_requests,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if "POST" in allowed_operations:
|
|
|
|
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
|
|
|
|
tools.append(
|
|
|
|
|
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
llm_chain=post_llm_chain,
|
|
|
|
|
allow_dangerous_requests=allow_dangerous_requests,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if "PUT" in allowed_operations:
|
|
|
|
|
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
|
|
|
|
|
tools.append(
|
|
|
|
|
RequestsPutToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
llm_chain=put_llm_chain,
|
|
|
|
|
allow_dangerous_requests=allow_dangerous_requests,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if "DELETE" in allowed_operations:
|
|
|
|
|
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
|
|
|
|
|
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
llm_chain=get_llm_chain,
|
|
|
|
|
llm_chain=delete_llm_chain,
|
|
|
|
|
allow_dangerous_requests=allow_dangerous_requests,
|
|
|
|
|
),
|
|
|
|
|
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
)
|
|
|
|
|
if "PATCH" in allowed_operations:
|
|
|
|
|
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
|
|
|
|
|
RequestsPatchToolWithParsing( # type: ignore[call-arg]
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
llm_chain=post_llm_chain,
|
|
|
|
|
llm_chain=patch_llm_chain,
|
|
|
|
|
allow_dangerous_requests=allow_dangerous_requests,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
if not tools:
|
|
|
|
|
raise ValueError("Tools not found")
|
|
|
|
|
prompt = PromptTemplate(
|
|
|
|
|
template=API_CONTROLLER_PROMPT,
|
|
|
|
|
input_variables=["input", "agent_scratchpad"],
|
|
|
|
@ -297,6 +330,7 @@ def _create_api_controller_tool(
|
|
|
|
|
requests_wrapper: RequestsWrapper,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
allow_dangerous_requests: bool,
|
|
|
|
|
allowed_operations: Sequence[Operation],
|
|
|
|
|
) -> Tool:
|
|
|
|
|
"""Expose controller as a tool.
|
|
|
|
|
|
|
|
|
@ -308,7 +342,7 @@ def _create_api_controller_tool(
|
|
|
|
|
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
|
|
|
|
|
|
|
|
|
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
|
|
|
|
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
|
|
|
|
|
pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
|
|
|
|
|
matches = re.findall(pattern, plan_str)
|
|
|
|
|
endpoint_names = [
|
|
|
|
|
"{method} {route}".format(method=method, route=route.split("?")[0])
|
|
|
|
@ -326,7 +360,12 @@ def _create_api_controller_tool(
|
|
|
|
|
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
|
|
|
|
|
|
|
|
|
agent = _create_api_controller_agent(
|
|
|
|
|
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests
|
|
|
|
|
base_url,
|
|
|
|
|
docs_str,
|
|
|
|
|
requests_wrapper,
|
|
|
|
|
llm,
|
|
|
|
|
allow_dangerous_requests,
|
|
|
|
|
allowed_operations,
|
|
|
|
|
)
|
|
|
|
|
return agent.run(plan_str)
|
|
|
|
|
|
|
|
|
@ -346,6 +385,7 @@ def create_openapi_agent(
|
|
|
|
|
verbose: bool = True,
|
|
|
|
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
allow_dangerous_requests: bool = False,
|
|
|
|
|
allowed_operations: Sequence[Operation] = ("GET", "POST"),
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Construct an OpenAI API planner and controller for a given spec.
|
|
|
|
@ -371,7 +411,11 @@ def create_openapi_agent(
|
|
|
|
|
tools = [
|
|
|
|
|
_create_api_planner_tool(api_spec, llm),
|
|
|
|
|
_create_api_controller_tool(
|
|
|
|
|
api_spec, requests_wrapper, llm, allow_dangerous_requests
|
|
|
|
|
api_spec,
|
|
|
|
|
requests_wrapper,
|
|
|
|
|
llm,
|
|
|
|
|
allow_dangerous_requests,
|
|
|
|
|
allowed_operations,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
prompt = PromptTemplate(
|
|
|
|
|