Harrison/base agent without docs (#2166)

doc
Harrison Chase 1 year ago committed by GitHub
parent 1b7cfd7222
commit 5c907d9998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,13 +19,183 @@ from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseMessage from langchain.schema import AgentAction, AgentFinish, BaseMessage, BaseOutputParser
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
logger = logging.getLogger() logger = logging.getLogger()
class Agent(BaseModel): class BaseSingleActionAgent(BaseModel):
"""Base Agent class."""
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
return None
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
else:
raise ValueError(
f"Got unsupported early_stopping_method `{early_stopping_method}`"
)
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = self._agent_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the agent.
Args:
file_path: Path to file to save the agent to.
Example:
.. code-block:: python
# If working with agent executor
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(agent_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(agent_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
return {}
class AgentOutputParser(BaseOutputParser):
@abstractmethod
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
"""Parse text into agent action/finish."""
class LLMSingleActionAgent(BaseSingleActionAgent):
llm_chain: LLMChain
output_parser: AgentOutputParser
stop: List[str]
@property
def input_keys(self) -> List[str]:
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = self.llm_chain.run(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
)
return self.output_parser.parse(output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = await self.llm_chain.arun(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
)
return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict:
return {
"llm_prefix": "",
"observation_prefix": "" if len(self.stop) == 0 else self.stop[0],
}
class Agent(BaseSingleActionAgent):
"""Class responsible for calling the language model and deciding the action. """Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include This is driven by an LLMChain. The prompt in the LLMChain MUST include
@ -35,7 +205,13 @@ class Agent(BaseModel):
llm_chain: LLMChain llm_chain: LLMChain
allowed_tools: Optional[List[str]] = None allowed_tools: Optional[List[str]] = None
return_values: List[str] = ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools
@property
def return_values(self) -> List[str]:
return ["output"]
@abstractmethod @abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
@ -248,55 +424,17 @@ class Agent(BaseModel):
f"got {early_stopping_method}" f"got {early_stopping_method}"
) )
@property def tool_run_logging_kwargs(self) -> Dict:
@abstractmethod return {
def _agent_type(self) -> str: "llm_prefix": self.llm_prefix,
"""Return Identifier of agent type.""" "observation_prefix": self.observation_prefix,
}
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = self._agent_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the agent.
Args:
file_path: Path to file to save the agent to.
Example:
.. code-block:: python
# If working with agent executor
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(agent_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(agent_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class AgentExecutor(Chain, BaseModel): class AgentExecutor(Chain, BaseModel):
"""Consists of an agent using tools.""" """Consists of an agent using tools."""
agent: Agent agent: BaseSingleActionAgent
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
@ -305,7 +443,7 @@ class AgentExecutor(Chain, BaseModel):
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
cls, cls,
agent: Agent, agent: BaseSingleActionAgent,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any, **kwargs: Any,
@ -320,10 +458,11 @@ class AgentExecutor(Chain, BaseModel):
"""Validate that tools are compatible with agent.""" """Validate that tools are compatible with agent."""
agent = values["agent"] agent = values["agent"]
tools = values["tools"] tools = values["tools"]
if agent.allowed_tools is not None: allowed_tools = agent.get_allowed_tools()
if set(agent.allowed_tools) != set([tool.name for tool in tools]): if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
raise ValueError( raise ValueError(
f"Allowed tools ({agent.allowed_tools}) different than " f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})" f"provided tools ({[tool.name for tool in tools]})"
) )
return values return values
@ -418,22 +557,17 @@ class AgentExecutor(Chain, BaseModel):
tool = name_to_tool_map[output.tool] tool = name_to_tool_map[output.tool]
return_direct = tool.return_direct return_direct = tool.return_direct
color = color_mapping[output.tool] color = color_mapping[output.tool]
llm_prefix = "" if return_direct else self.agent.llm_prefix tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = tool.run( observation = tool.run(
output.tool_input, output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
) )
else: else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run( observation = InvalidTool().run(
output.tool, output.tool, verbose=self.verbose, color=None, **tool_run_kwargs
verbose=self.verbose,
color=None,
llm_prefix="",
observation_prefix=self.agent.observation_prefix,
) )
return output, observation return output, observation
@ -467,22 +601,17 @@ class AgentExecutor(Chain, BaseModel):
tool = name_to_tool_map[output.tool] tool = name_to_tool_map[output.tool]
return_direct = tool.return_direct return_direct = tool.return_direct
color = color_mapping[output.tool] color = color_mapping[output.tool]
llm_prefix = "" if return_direct else self.agent.llm_prefix tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = await tool.arun( observation = await tool.arun(
output.tool_input, output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
) )
else: else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun( observation = await InvalidTool().arun(
output.tool, output.tool, verbose=self.verbose, color=None, **tool_run_kwargs
verbose=self.verbose,
color=None,
llm_prefix="",
observation_prefix=self.agent.observation_prefix,
) )
return_direct = False return_direct = False
return output, observation return output, observation

@ -74,9 +74,11 @@ class StdOutCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
print_text(f"\n{observation_prefix}") if observation_prefix:
print_text(f"\n{observation_prefix}")
print_text(output, color=color if color else self.color) print_text(output, color=color if color else self.color)
print_text(f"\n{llm_prefix}") if llm_prefix:
print_text(f"\n{llm_prefix}")
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

@ -199,7 +199,7 @@ class Chain(BaseModel, ABC):
"""Call the chain on all inputs in the list.""" """Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list] return [self(inputs) for inputs in input_list]
def run(self, *args: str, **kwargs: str) -> str: def run(self, *args: Any, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out.""" """Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1: if len(self.output_keys) != 1:
raise ValueError( raise ValueError(
@ -220,7 +220,7 @@ class Chain(BaseModel, ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}." f" but not both. Got args: {args} and kwargs: {kwargs}."
) )
async def arun(self, *args: str, **kwargs: str) -> str: async def arun(self, *args: Any, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out.""" """Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1: if len(self.output_keys) != 1:
raise ValueError( raise ValueError(

@ -144,9 +144,9 @@ class BasePromptTemplate(BaseModel, ABC):
""" """
@property @property
@abstractmethod
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
"""Return the prompt type key.""" """Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt.""" """Return dictionary representation of prompt."""

@ -149,7 +149,7 @@ def test_agent_with_callbacks_local() -> None:
callback_manager=manager, callback_manager=manager,
) )
agent.agent.llm_chain.verbose = True agent.agent.llm_chain.verbose = True # type: ignore
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "curses foiled again" assert output == "curses foiled again"
@ -285,8 +285,8 @@ def test_agent_with_new_prefix_suffix() -> None:
) )
# avoids "BasePromptTemplate" has no attribute "template" error # avoids "BasePromptTemplate" has no attribute "template" error
assert hasattr(agent.agent.llm_chain.prompt, "template") assert hasattr(agent.agent.llm_chain.prompt, "template") # type: ignore
prompt_str = agent.agent.llm_chain.prompt.template prompt_str = agent.agent.llm_chain.prompt.template # type: ignore
assert prompt_str.startswith(prefix), "Prompt does not start with prefix" assert prompt_str.startswith(prefix), "Prompt does not start with prefix"
assert prompt_str.endswith(suffix), "Prompt does not end with suffix" assert prompt_str.endswith(suffix), "Prompt does not end with suffix"

Loading…
Cancel
Save