From d06d47bc9240a3c8b948961d6737139973b30bea Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Mon, 24 Apr 2023 11:12:47 -0700 Subject: [PATCH] LM Requests Wrapper (#3457) Co-authored-by: jnmarti <88381891+jnmarti@users.noreply.github.com> --- .../agents/agent_toolkits/openapi/planner.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/langchain/agents/agent_toolkits/openapi/planner.py b/langchain/agents/agent_toolkits/openapi/planner.py index 8865bc42..7fada246 100644 --- a/langchain/agents/agent_toolkits/openapi/planner.py +++ b/langchain/agents/agent_toolkits/openapi/planner.py @@ -1,9 +1,11 @@ """Agent that interacts with OpenAPI APIs via a hierarchical planning approach.""" import json import re -from typing import List, Optional +from functools import partial +from typing import Callable, List, Optional import yaml +from pydantic import Field from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.openapi.planner_prompt import ( @@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI from langchain.memory import ReadOnlySharedMemory from langchain.prompts import PromptTemplate +from langchain.prompts.base import BasePromptTemplate from langchain.requests import RequestsWrapper from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool @@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool MAX_RESPONSE_LENGTH = 5000 +def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain: + return LLMChain( + llm=OpenAI(), + prompt=prompt, + ) + + +def _get_default_llm_chain_factory( + prompt: BasePromptTemplate, +) -> Callable[[], LLMChain]: + """Returns a default LLMChain factory.""" + return partial(_get_default_llm_chain, prompt) + + class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): name = "requests_get" description = REQUESTS_GET_TOOL_DESCRIPTION response_length: Optional[int] = MAX_RESPONSE_LENGTH - llm_chain = LLMChain( - llm=OpenAI(), - prompt=PARSING_GET_PROMPT, + llm_chain: LLMChain = Field( + default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT) ) def _run(self, text: str) -> str: @@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): description = REQUESTS_POST_TOOL_DESCRIPTION response_length: Optional[int] = MAX_RESPONSE_LENGTH - llm_chain = LLMChain( - llm=OpenAI(), - prompt=PARSING_POST_PROMPT, + llm_chain: LLMChain = Field( + default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT) ) def _run(self, text: str) -> str: @@ -173,9 +188,15 @@ def _create_api_controller_agent( requests_wrapper: RequestsWrapper, llm: BaseLanguageModel, ) -> AgentExecutor: + get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT) + post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT) tools: List[BaseTool] = [ - RequestsGetToolWithParsing(requests_wrapper=requests_wrapper), - RequestsPostToolWithParsing(requests_wrapper=requests_wrapper), + RequestsGetToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=get_llm_chain + ), + RequestsPostToolWithParsing( + requests_wrapper=requests_wrapper, llm_chain=post_llm_chain + ), ] prompt = PromptTemplate( template=API_CONTROLLER_PROMPT,