mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
240 lines
8.3 KiB
Python
240 lines
8.3 KiB
Python
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
|
|
from langchain.agents import BaseSingleActionAgent
|
|
from langchain.agents.format_scratchpad.openai_functions import (
|
|
format_to_openai_functions,
|
|
)
|
|
from langchain.agents.output_parsers.openai_functions import (
|
|
OpenAIFunctionsAgentOutputParser,
|
|
)
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.callbacks.manager import Callbacks
|
|
from langchain.chat_models.openai import ChatOpenAI
|
|
from langchain.prompts.chat import (
|
|
BaseMessagePromptTemplate,
|
|
ChatPromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
MessagesPlaceholder,
|
|
)
|
|
from langchain.pydantic_v1 import root_validator
|
|
from langchain.schema import (
|
|
AgentAction,
|
|
AgentFinish,
|
|
BasePromptTemplate,
|
|
)
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
from langchain.schema.messages import (
|
|
BaseMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain.tools.base import BaseTool
|
|
from langchain.tools.render import format_tool_to_openai_function
|
|
|
|
|
|
class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|
"""An Agent driven by OpenAIs function powered API.
|
|
|
|
Args:
|
|
llm: This should be an instance of ChatOpenAI, specifically a model
|
|
that supports using `functions`.
|
|
tools: The tools this agent has access to.
|
|
prompt: The prompt for this agent, should support agent_scratchpad as one
|
|
of the variables. For an easy way to construct this prompt, use
|
|
`OpenAIFunctionsAgent.create_prompt(...)`
|
|
"""
|
|
|
|
llm: BaseLanguageModel
|
|
tools: Sequence[BaseTool]
|
|
prompt: BasePromptTemplate
|
|
|
|
def get_allowed_tools(self) -> List[str]:
|
|
"""Get allowed tools."""
|
|
return [t.name for t in self.tools]
|
|
|
|
@root_validator
|
|
def validate_llm(cls, values: dict) -> dict:
|
|
if not isinstance(values["llm"], ChatOpenAI):
|
|
raise ValueError("Only supported with ChatOpenAI models.")
|
|
return values
|
|
|
|
@root_validator
|
|
def validate_prompt(cls, values: dict) -> dict:
|
|
prompt: BasePromptTemplate = values["prompt"]
|
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
raise ValueError(
|
|
"`agent_scratchpad` should be one of the variables in the prompt, "
|
|
f"got {prompt.input_variables}"
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Get input keys. Input refers to user input here."""
|
|
return ["input"]
|
|
|
|
@property
|
|
def functions(self) -> List[dict]:
|
|
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
|
|
|
|
def plan(
|
|
self,
|
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
callbacks: Callbacks = None,
|
|
with_functions: bool = True,
|
|
**kwargs: Any,
|
|
) -> Union[AgentAction, AgentFinish]:
|
|
"""Given input, decided what to do.
|
|
|
|
Args:
|
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
**kwargs: User inputs.
|
|
|
|
Returns:
|
|
Action specifying what tool to use.
|
|
"""
|
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
|
selected_inputs = {
|
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
|
}
|
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
|
prompt = self.prompt.format_prompt(**full_inputs)
|
|
messages = prompt.to_messages()
|
|
if with_functions:
|
|
predicted_message = self.llm.predict_messages(
|
|
messages,
|
|
functions=self.functions,
|
|
callbacks=callbacks,
|
|
)
|
|
else:
|
|
predicted_message = self.llm.predict_messages(
|
|
messages,
|
|
callbacks=callbacks,
|
|
)
|
|
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
|
|
predicted_message
|
|
)
|
|
return agent_decision
|
|
|
|
async def aplan(
|
|
self,
|
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> Union[AgentAction, AgentFinish]:
|
|
"""Given input, decided what to do.
|
|
|
|
Args:
|
|
intermediate_steps: Steps the LLM has taken to date,
|
|
along with observations
|
|
**kwargs: User inputs.
|
|
|
|
Returns:
|
|
Action specifying what tool to use.
|
|
"""
|
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
|
selected_inputs = {
|
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
|
}
|
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
|
prompt = self.prompt.format_prompt(**full_inputs)
|
|
messages = prompt.to_messages()
|
|
predicted_message = await self.llm.apredict_messages(
|
|
messages, functions=self.functions, callbacks=callbacks
|
|
)
|
|
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
|
|
predicted_message
|
|
)
|
|
return agent_decision
|
|
|
|
def return_stopped_response(
|
|
self,
|
|
early_stopping_method: str,
|
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
**kwargs: Any,
|
|
) -> AgentFinish:
|
|
"""Return response when agent has been stopped due to max iterations."""
|
|
if early_stopping_method == "force":
|
|
# `force` just returns a constant string
|
|
return AgentFinish(
|
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
|
)
|
|
elif early_stopping_method == "generate":
|
|
# Generate does one final forward pass
|
|
agent_decision = self.plan(
|
|
intermediate_steps, with_functions=False, **kwargs
|
|
)
|
|
if type(agent_decision) == AgentFinish:
|
|
return agent_decision
|
|
else:
|
|
raise ValueError(
|
|
f"got AgentAction with no functions provided: {agent_decision}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"early_stopping_method should be one of `force` or `generate`, "
|
|
f"got {early_stopping_method}"
|
|
)
|
|
|
|
@classmethod
|
|
def create_prompt(
|
|
cls,
|
|
system_message: Optional[SystemMessage] = SystemMessage(
|
|
content="You are a helpful AI assistant."
|
|
),
|
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
|
) -> BasePromptTemplate:
|
|
"""Create prompt for this agent.
|
|
|
|
Args:
|
|
system_message: Message to use as the system message that will be the
|
|
first in the prompt.
|
|
extra_prompt_messages: Prompt messages that will be placed between the
|
|
system message and the new human input.
|
|
|
|
Returns:
|
|
A prompt template to pass into this agent.
|
|
"""
|
|
_prompts = extra_prompt_messages or []
|
|
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
|
if system_message:
|
|
messages = [system_message]
|
|
else:
|
|
messages = []
|
|
|
|
messages.extend(
|
|
[
|
|
*_prompts,
|
|
HumanMessagePromptTemplate.from_template("{input}"),
|
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
]
|
|
)
|
|
return ChatPromptTemplate(messages=messages)
|
|
|
|
@classmethod
|
|
def from_llm_and_tools(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
tools: Sequence[BaseTool],
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
|
system_message: Optional[SystemMessage] = SystemMessage(
|
|
content="You are a helpful AI assistant."
|
|
),
|
|
**kwargs: Any,
|
|
) -> BaseSingleActionAgent:
|
|
"""Construct an agent from an LLM and tools."""
|
|
if not isinstance(llm, ChatOpenAI):
|
|
raise ValueError("Only supported with ChatOpenAI models.")
|
|
prompt = cls.create_prompt(
|
|
extra_prompt_messages=extra_prompt_messages,
|
|
system_message=system_message,
|
|
)
|
|
return cls(
|
|
llm=llm,
|
|
prompt=prompt,
|
|
tools=tools,
|
|
callback_manager=callback_manager,
|
|
**kwargs,
|
|
)
|