forked from Archives/langchain
LM Requests Wrapper (#3457)
Co-authored-by: jnmarti <88381891+jnmarti@users.noreply.github.com>
This commit is contained in:
parent
b64c86a25f
commit
d06d47bc92
@ -1,9 +1,11 @@
|
||||
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
|
||||
|
||||
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
|
||||
return LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def _get_default_llm_chain_factory(
|
||||
prompt: BasePromptTemplate,
|
||||
) -> Callable[[], LLMChain]:
|
||||
"""Returns a default LLMChain factory."""
|
||||
return partial(_get_default_llm_chain, prompt)
|
||||
|
||||
|
||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
name = "requests_get"
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
llm_chain = LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=PARSING_GET_PROMPT,
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
|
||||
)
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
llm_chain = LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=PARSING_POST_PROMPT,
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
|
||||
)
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
@ -173,9 +188,15 @@ def _create_api_controller_agent(
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> AgentExecutor:
|
||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools: List[BaseTool] = [
|
||||
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
|
||||
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
|
||||
RequestsGetToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
|
||||
),
|
||||
RequestsPostToolWithParsing(
|
||||
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
|
||||
),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_CONTROLLER_PROMPT,
|
||||
|
Loading…
Reference in New Issue
Block a user