forked from Archives/langchain
LM Requests Wrapper (#3457)
Co-authored-by: jnmarti <88381891+jnmarti@users.noreply.github.com>
This commit is contained in:
parent
b64c86a25f
commit
d06d47bc92
@ -1,9 +1,11 @@
|
|||||||
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from functools import partial
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||||
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.memory import ReadOnlySharedMemory
|
from langchain.memory import ReadOnlySharedMemory
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.requests import RequestsWrapper
|
from langchain.requests import RequestsWrapper
|
||||||
from langchain.schema import BaseLanguageModel
|
from langchain.schema import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
|
|||||||
MAX_RESPONSE_LENGTH = 5000
|
MAX_RESPONSE_LENGTH = 5000
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
||||||
|
return LLMChain(
|
||||||
|
llm=OpenAI(),
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_llm_chain_factory(
|
||||||
|
prompt: BasePromptTemplate,
|
||||||
|
) -> Callable[[], LLMChain]:
|
||||||
|
"""Returns a default LLMChain factory."""
|
||||||
|
return partial(_get_default_llm_chain, prompt)
|
||||||
|
|
||||||
|
|
||||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
name = "requests_get"
|
name = "requests_get"
|
||||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||||
llm_chain = LLMChain(
|
llm_chain: LLMChain = Field(
|
||||||
llm=OpenAI(),
|
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||||
prompt=PARSING_GET_PROMPT,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run(self, text: str) -> str:
|
def _run(self, text: str) -> str:
|
||||||
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
|||||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||||
|
|
||||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||||
llm_chain = LLMChain(
|
llm_chain: LLMChain = Field(
|
||||||
llm=OpenAI(),
|
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||||
prompt=PARSING_POST_PROMPT,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run(self, text: str) -> str:
|
def _run(self, text: str) -> str:
|
||||||
@ -173,9 +188,15 @@ def _create_api_controller_agent(
|
|||||||
requests_wrapper: RequestsWrapper,
|
requests_wrapper: RequestsWrapper,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||||
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||||
tools: List[BaseTool] = [
|
tools: List[BaseTool] = [
|
||||||
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
|
RequestsGetToolWithParsing(
|
||||||
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
|
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
||||||
|
),
|
||||||
|
RequestsPostToolWithParsing(
|
||||||
|
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
||||||
|
),
|
||||||
]
|
]
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=API_CONTROLLER_PROMPT,
|
template=API_CONTROLLER_PROMPT,
|
||||||
|
Loading…
Reference in New Issue
Block a user