langchain: docstrings in agents root (#23561)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-06-27 12:52:18 -07:00 committed by GitHub
parent b64c4b4750
commit c0fdbaac85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 324 additions and 54 deletions

View File

@ -15,19 +15,20 @@ Agents select and use **Tools** and **Toolkits** for actions.
OpenAIFunctionsAgent
XMLAgent
Agent --> <name>Agent # Examples: ZeroShotAgent, ChatAgent
BaseMultiActionAgent --> OpenAIMultiFunctionsAgent
**Main helpers:**
.. code-block::
AgentType, AgentExecutor, AgentOutputParser, AgentExecutorIterator,
AgentAction, AgentFinish
""" # noqa: E501
from pathlib import Path
from typing import TYPE_CHECKING, Any

View File

@ -77,7 +77,7 @@ class BaseSingleActionAgent(BaseModel):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -92,11 +92,11 @@ class BaseSingleActionAgent(BaseModel):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
"""Async given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -118,7 +118,20 @@ class BaseSingleActionAgent(BaseModel):
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
"""Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
"""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish(
@ -137,15 +150,30 @@ class BaseSingleActionAgent(BaseModel):
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools.
Args:
llm: Language model to use.
tools: Tools to use.
callback_manager: Callback manager to use.
**kwargs: Additional arguments.
Returns:
BaseSingleActionAgent: Agent object.
"""
raise NotImplementedError
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
"""Return Identifier of an agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
"""Return dictionary representation of agent.
Returns:
Dict: Dictionary representation of agent.
"""
_dict = super().dict()
try:
_type = self._agent_type
@ -193,6 +221,7 @@ class BaseSingleActionAgent(BaseModel):
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {}
@ -205,6 +234,11 @@ class BaseMultiActionAgent(BaseModel):
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
"""Get allowed tools.
Returns:
Optional[List[str]]: Allowed tools.
"""
return None
@abstractmethod
@ -233,7 +267,7 @@ class BaseMultiActionAgent(BaseModel):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do.
"""Async given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
@ -259,7 +293,20 @@ class BaseMultiActionAgent(BaseModel):
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
"""Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
"""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
@ -270,7 +317,7 @@ class BaseMultiActionAgent(BaseModel):
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
"""Return Identifier of an agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
@ -288,6 +335,10 @@ class BaseMultiActionAgent(BaseModel):
Args:
file_path: Path to file to save the agent to.
Raises:
NotImplementedError: If agent does not support saving.
ValueError: If file_path is not json or yaml.
Example:
.. code-block:: python
@ -318,6 +369,8 @@ class BaseMultiActionAgent(BaseModel):
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {}
@ -332,15 +385,26 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
class MultiActionAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
"""Base class for parsing agent output into agent actions/finish."""
"""Base class for parsing agent output into agent actions/finish.
This is used for agents that can return multiple actions.
"""
@abstractmethod
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
"""Parse text into agent actions/finish."""
"""Parse text into agent actions/finish.
Args:
text: Text to parse.
Returns:
Union[List[AgentAction], AgentFinish]:
List of agent actions or agent finish.
"""
class RunnableAgent(BaseSingleActionAgent):
"""Agent powered by runnables."""
"""Agent powered by Runnables."""
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
@ -367,6 +431,7 @@ class RunnableAgent(BaseSingleActionAgent):
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return self.input_keys_arg
def plan(
@ -414,13 +479,13 @@ class RunnableAgent(BaseSingleActionAgent):
AgentAction,
AgentFinish,
]:
"""Based on past history and current inputs, decide what to do.
"""Async based on past history and current inputs, decide what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
@ -449,7 +514,7 @@ class RunnableAgent(BaseSingleActionAgent):
class RunnableMultiActionAgent(BaseMultiActionAgent):
"""Agent powered by runnables."""
"""Agent powered by Runnables."""
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions."""
@ -531,11 +596,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
List[AgentAction],
AgentFinish,
]:
"""Based on past history and current inputs, decide what to do.
"""Async based on past history and current inputs, decide what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -630,11 +695,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
"""Async given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -650,6 +715,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {
"llm_prefix": "",
"observation_prefix": "" if len(self.stop) == 0 else self.stop[0],
@ -667,14 +733,17 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
class Agent(BaseSingleActionAgent):
"""Agent that calls the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
This is driven by a LLMChain. The prompt in the LLMChain MUST include
a variable called "agent_scratchpad" where the agent can put its
intermediary work.
"""
llm_chain: LLMChain
"""LLMChain to use for agent."""
output_parser: AgentOutputParser
"""Output parser to use for agent."""
allowed_tools: Optional[List[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed."""
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
@ -683,14 +752,23 @@ class Agent(BaseSingleActionAgent):
return _dict
def get_allowed_tools(self) -> Optional[List[str]]:
"""Get allowed tools."""
return self.allowed_tools
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]
def _fix_text(self, text: str) -> str:
"""Fix the text."""
"""Fix the text.
Args:
text: Text to fix.
Returns:
str: Fixed text.
"""
raise ValueError("fix_text not implemented for this agent.")
@property
@ -720,7 +798,7 @@ class Agent(BaseSingleActionAgent):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -737,11 +815,11 @@ class Agent(BaseSingleActionAgent):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
"""Async given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observations.
callbacks: Callbacks to run.
**kwargs: User inputs.
@ -756,7 +834,16 @@ class Agent(BaseSingleActionAgent):
def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps."""
"""Create the full inputs for the LLMChain from intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
Dict[str, Any]: Full inputs for the LLMChain.
"""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
@ -772,7 +859,18 @@ class Agent(BaseSingleActionAgent):
@root_validator(pre=False, skip_on_failure=True)
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
"""Validate that prompt matches format.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If `agent_scratchpad` is not in prompt.input_variables
and prompt is not a FewShotPromptTemplate or a PromptTemplate.
"""
prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables:
logger.warning(
@ -801,11 +899,23 @@ class Agent(BaseSingleActionAgent):
@classmethod
@abstractmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Create a prompt for this class."""
"""Create a prompt for this class.
Args:
tools: Tools to use.
Returns:
BasePromptTemplate: Prompt template.
"""
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in."""
"""Validate that appropriate tools are passed in.
Args:
tools: Tools to use.
"""
pass
@classmethod
@ -822,7 +932,18 @@ class Agent(BaseSingleActionAgent):
output_parser: Optional[AgentOutputParser] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
"""Construct an agent from an LLM and tools.
Args:
llm: Language model to use.
tools: Tools to use.
callback_manager: Callback manager to use.
output_parser: Output parser to use.
**kwargs: Additional arguments.
Returns:
Agent: Agent object.
"""
cls._validate_tools(tools)
llm_chain = LLMChain(
llm=llm,
@ -844,7 +965,20 @@ class Agent(BaseSingleActionAgent):
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
"""Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not in ['force', 'generate'].
"""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish(
@ -881,6 +1015,7 @@ class Agent(BaseSingleActionAgent):
)
def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {
"llm_prefix": self.llm_prefix,
"observation_prefix": self.observation_prefix,
@ -957,6 +1092,9 @@ class AgentExecutor(Chain):
trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1
"""How to trim the intermediate steps before returning them.
Defaults to -1, which means no trimming.
"""
@classmethod
def from_agent_and_tools(
@ -966,7 +1104,17 @@ class AgentExecutor(Chain):
callbacks: Callbacks = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create from agent and tools."""
"""Create from agent and tools.
Args:
agent: Agent to use.
tools: Tools to use.
callbacks: Callbacks to use.
**kwargs: Additional arguments.
Returns:
AgentExecutor: Agent executor object.
"""
return cls(
agent=agent,
tools=tools,
@ -976,7 +1124,17 @@ class AgentExecutor(Chain):
@root_validator(pre=False, skip_on_failure=True)
def validate_tools(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
"""Validate that tools are compatible with agent.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If allowed tools are different than provided tools.
"""
agent = values["agent"]
tools = values["tools"]
allowed_tools = agent.get_allowed_tools()
@ -990,7 +1148,17 @@ class AgentExecutor(Chain):
@root_validator(pre=False, skip_on_failure=True)
def validate_return_direct_tool(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
"""Validate that tools are compatible with agent.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If tools that have `return_direct=True` are not allowed.
"""
agent = values["agent"]
tools = values["tools"]
if isinstance(agent, BaseMultiActionAgent):
@ -1004,7 +1172,14 @@ class AgentExecutor(Chain):
@root_validator(pre=True)
def validate_runnable_agent(cls, values: Dict) -> Dict:
"""Convert runnable to agent if passed in."""
"""Convert runnable to agent if passed in.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
"""
agent = values.get("agent")
if agent and isinstance(agent, Runnable):
try:
@ -1026,7 +1201,14 @@ class AgentExecutor(Chain):
return values
def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors."""
"""Raise error - saving not supported for Agent Executors.
Args:
file_path: Path to save to.
Raises:
ValueError: Saving not supported for agent executors.
"""
raise ValueError(
"Saving not supported for agent executors. "
"If you are trying to save the agent, please use the "
@ -1034,7 +1216,11 @@ class AgentExecutor(Chain):
)
def save_agent(self, file_path: Union[Path, str]) -> None:
"""Save the underlying agent."""
"""Save the underlying agent.
Args:
file_path: Path to save to.
"""
return self.agent.save(file_path)
def iter(
@ -1045,7 +1231,17 @@ class AgentExecutor(Chain):
include_run_info: bool = False,
async_: bool = False, # arg kept for backwards compat, but ignored
) -> AgentExecutorIterator:
"""Enables iteration over steps taken to reach final output."""
"""Enables iteration over steps taken to reach final output.
Args:
inputs: Inputs to the agent.
callbacks: Callbacks to run.
include_run_info: Whether to include run info.
async_: Whether to run async. (Ignored)
Returns:
AgentExecutorIterator: Agent executor iterator object.
"""
return AgentExecutorIterator(
self,
inputs,
@ -1074,7 +1270,14 @@ class AgentExecutor(Chain):
return self.agent.return_values
def lookup_tool(self, name: str) -> BaseTool:
"""Lookup tool by name."""
"""Lookup tool by name.
Args:
name: Name of tool.
Returns:
BaseTool: Tool object.
"""
return {tool.name: tool for tool in self.tools}[name]
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
@ -1463,7 +1666,7 @@ class AgentExecutor(Chain):
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Run text through and get agent response."""
"""Async run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
@ -1557,7 +1760,16 @@ class AgentExecutor(Chain):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
"""Enables streaming over steps taken to reach final output.
Args:
input: Input to the agent.
config: Config to use.
**kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
"""
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
@ -1579,7 +1791,17 @@ class AgentExecutor(Chain):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
"""Async enables streaming over steps taken to reach final output.
Args:
input: Input to the agent.
config: Config to use.
**kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
"""
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,

View File

@ -62,6 +62,22 @@ class AgentExecutorIterator:
"""
Initialize the AgentExecutorIterator with the given AgentExecutor,
inputs, and optional callbacks.
Args:
agent_executor (AgentExecutor): The AgentExecutor to iterate over.
inputs (Any): The inputs to the AgentExecutor.
callbacks (Callbacks, optional): The callbacks to use during iteration.
Defaults to None.
tags (Optional[list[str]], optional): The tags to use during iteration.
Defaults to None.
metadata (Optional[Dict[str, Any]], optional): The metadata to use
during iteration. Defaults to None.
run_name (Optional[str], optional): The name of the run. Defaults to None.
run_id (Optional[UUID], optional): The ID of the run. Defaults to None.
include_run_info (bool, optional): Whether to include run info
in the output. Defaults to False.
yield_actions (bool, optional): Whether to yield actions as they
are generated. Defaults to False.
"""
self._agent_executor = agent_executor
self.inputs = inputs
@ -85,6 +101,7 @@ class AgentExecutorIterator:
@property
def inputs(self) -> Dict[str, str]:
"""The inputs to the AgentExecutor."""
return self._inputs
@inputs.setter
@ -93,6 +110,7 @@ class AgentExecutorIterator:
@property
def agent_executor(self) -> AgentExecutor:
"""The AgentExecutor to iterate over."""
return self._agent_executor
@agent_executor.setter
@ -103,10 +121,12 @@ class AgentExecutorIterator:
@property
def name_to_tool_map(self) -> Dict[str, BaseTool]:
"""A mapping of tool names to tools."""
return {tool.name: tool for tool in self.agent_executor.tools}
@property
def color_mapping(self) -> Dict[str, str]:
"""A mapping of tool names to colors."""
return get_color_mapping(
[tool.name for tool in self.agent_executor.tools],
excluded_colors=["green", "red"],

View File

@ -1,4 +1,5 @@
"""Module definitions of agent types together with corresponding agents."""
from enum import Enum
from langchain_core._api import deprecated

View File

@ -1,4 +1,5 @@
"""Load agent."""
from typing import Any, Optional, Sequence
from langchain_core._api import deprecated
@ -35,17 +36,24 @@ def initialize_agent(
Args:
tools: List of tools this agent has access to.
llm: Language model to use as the agent.
agent: Agent type to use. If None and agent_path is also None, will default to
AgentType.ZERO_SHOT_REACT_DESCRIPTION.
agent: Agent type to use. If None and agent_path is also None, will default
to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
agent_path: Path to serialized agent to use.
agent_kwargs: Additional keyword arguments to pass to the underlying agent
tags: Tags to apply to the traced runs.
**kwargs: Additional keyword arguments passed to the agent executor
agent_path: Path to serialized agent to use. If None and agent is also None,
will default to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None.
agent_kwargs: Additional keyword arguments to pass to the underlying agent.
Defaults to None.
tags: Tags to apply to the traced runs. Defaults to None.
**kwargs: Additional keyword arguments passed to the agent executor.
Returns:
An agent executor
An agent executor.
Raises:
ValueError: If both `agent` and `agent_path` are specified.
ValueError: If `agent` is not a valid agent type.
ValueError: If both `agent` and `agent_path` are None.
"""
tags_ = list(tags) if tags else []
if agent is None and agent_path is None:

View File

@ -48,6 +48,9 @@ def load_agent_from_config(
Returns:
An agent executor.
Raises:
ValueError: If agent type is not specified in the config.
"""
if "_type" not in config:
raise ValueError("Must specify an agent Type in config")
@ -99,6 +102,10 @@ def load_agent(
Returns:
An agent executor.
Raises:
RuntimeError: If loading from the deprecated github-based
Hub is attempted.
"""
if isinstance(path, str) and path.startswith("lc://"):
raise RuntimeError(

View File

@ -1,4 +1,5 @@
"""Interface for tools."""
from typing import List, Optional
from langchain_core.callbacks import (
@ -12,7 +13,9 @@ class InvalidTool(BaseTool):
"""Tool that is run when invalid tool name is encountered by agent."""
name: str = "invalid_tool"
"""Name of the tool."""
description: str = "Called when tool name is invalid. Suggests valid tool names."
"""Description of the tool."""
def _run(
self,

View File

@ -4,7 +4,15 @@ from langchain_core.tools import BaseTool
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
"""Validate tools for single input."""
"""Validate tools for single input.
Args:
class_name: Name of the class.
tools: List of tools to validate.
Raises:
ValueError: If a multi-input tool is found in tools.
"""
for tool in tools:
if not tool.is_single_input:
raise ValueError(