Merge branch 'master' of github.com:hwchase17/langchain

fix_agent_callbacks
Harrison Chase 1 year ago
commit 434d8c4c0e

@ -1,9 +1,11 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from typing import List, Optional
from functools import partial
from typing import Callable, List, Optional
import yaml
from pydantic import Field
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
@ -30,6 +32,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
@ -44,13 +47,26 @@ from langchain.tools.requests.tool import BaseRequestsTool
MAX_RESPONSE_LENGTH = 5000
def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_get"
description = REQUESTS_GET_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_GET_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
def _run(self, text: str) -> str:
@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
description = REQUESTS_POST_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_POST_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
def _run(self, text: str) -> str:
@ -173,9 +188,15 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.147"
version = "0.0.148"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

Loading…
Cancel
Save