|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
|
|
|
|
import json
|
|
|
|
|
import re
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Callable, List, Optional
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
from pydantic import Field
|
|
|
|
|
|
|
|
|
|
from langchain.agents.agent import AgentExecutor
|
|
|
|
|
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
|
|
|
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
|
from langchain.memory import ReadOnlySharedMemory
|
|
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
from langchain.requests import RequestsWrapper
|
|
|
|
|
from langchain.schema import BaseLanguageModel
|
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
|
|
|
|
|
MAX_RESPONSE_LENGTH = 5000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
|
|
|
|
return LLMChain(
|
|
|
|
|
llm=OpenAI(),
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_default_llm_chain_factory(
|
|
|
|
|
prompt: BasePromptTemplate,
|
|
|
|
|
) -> Callable[[], LLMChain]:
|
|
|
|
|
"""Returns a default LLMChain factory."""
|
|
|
|
|
return partial(_get_default_llm_chain, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
|
|
|
|
name = "requests_get"
|
|
|
|
|
description = REQUESTS_GET_TOOL_DESCRIPTION
|
|
|
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
|
|
|
llm_chain = LLMChain(
|
|
|
|
|
llm=OpenAI(),
|
|
|
|
|
prompt=PARSING_GET_PROMPT,
|
|
|
|
|
llm_chain: LLMChain = Field(
|
|
|
|
|
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _run(self, text: str) -> str:
|
|
|
|
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|
|
|
|
description = REQUESTS_POST_TOOL_DESCRIPTION
|
|
|
|
|
|
|
|
|
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
|
|
|
|
llm_chain = LLMChain(
|
|
|
|
|
llm=OpenAI(),
|
|
|
|
|
prompt=PARSING_POST_PROMPT,
|
|
|
|
|
llm_chain: LLMChain = Field(
|
|
|
|
|
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _run(self, text: str) -> str:
|
|
|
|
@ -173,9 +188,15 @@ def _create_api_controller_agent(
|
|
|
|
|
requests_wrapper: RequestsWrapper,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
) -> AgentExecutor:
|
|
|
|
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
|
|
|
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
|
|
|
|
tools: List[BaseTool] = [
|
|
|
|
|
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
|
|
|
|
|
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
|
|
|
|
|
RequestsGetToolWithParsing(
|
|
|
|
|
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
|
|
|
|
),
|
|
|
|
|
RequestsPostToolWithParsing(
|
|
|
|
|
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
prompt = PromptTemplate(
|
|
|
|
|
template=API_CONTROLLER_PROMPT,
|
|
|
|
|