LM Requests Wrapper (#3457)

Co-authored-by: jnmarti <88381891+jnmarti@users.noreply.github.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent b64c86a25f
commit d06d47bc92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,11 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from typing import List, Optional
from functools import partial
from typing import Callable, List, Optional
import yaml
from pydantic import Field
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
MAX_RESPONSE_LENGTH = 5000
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_get"
description = REQUESTS_GET_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_GET_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
def _run(self, text: str) -> str:
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
description = REQUESTS_POST_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_POST_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
def _run(self, text: str) -> str:
@ -173,9 +188,15 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,

Loading…
Cancel
Save