LM Requests Wrapper (#3457)

Co-authored-by: jnmarti <88381891+jnmarti@users.noreply.github.com>
This commit is contained in:
Zander Chase 2023-04-24 11:12:47 -07:00 committed by GitHub
parent b64c86a25f
commit d06d47bc92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,11 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach.""" """Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json import json
import re import re
from typing import List, Optional from functools import partial
from typing import Callable, List, Optional
import yaml import yaml
from pydantic import Field
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.planner_prompt import ( from langchain.agents.agent_toolkits.openapi.planner_prompt import (
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
MAX_RESPONSE_LENGTH = 5000 MAX_RESPONSE_LENGTH = 5000
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_get" name = "requests_get"
description = REQUESTS_GET_TOOL_DESCRIPTION description = REQUESTS_GET_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain( llm_chain: LLMChain = Field(
llm=OpenAI(), default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
prompt=PARSING_GET_PROMPT,
) )
def _run(self, text: str) -> str: def _run(self, text: str) -> str:
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
description = REQUESTS_POST_TOOL_DESCRIPTION description = REQUESTS_POST_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain( llm_chain: LLMChain = Field(
llm=OpenAI(), default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
prompt=PARSING_POST_PROMPT,
) )
def _run(self, text: str) -> str: def _run(self, text: str) -> str:
@ -173,9 +188,15 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper, requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel, llm: BaseLanguageModel,
) -> AgentExecutor: ) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [ tools: List[BaseTool] = [
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper), RequestsGetToolWithParsing(
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper), requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
] ]
prompt = PromptTemplate( prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT, template=API_CONTROLLER_PROMPT,