Harrison/base agent without docs (#2166)

This commit is contained in:
Harrison Chase 2023-03-29 22:11:25 -07:00 committed by GitHub
parent 1b7cfd7222
commit 5c907d9998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 212 additions and 81 deletions

View File

@ -19,13 +19,183 @@ from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseMessage
from langchain.schema import AgentAction, AgentFinish, BaseMessage, BaseOutputParser
from langchain.tools.base import BaseTool
logger = logging.getLogger()
class Agent(BaseModel):
class BaseSingleActionAgent(BaseModel):
"""Base Agent class."""
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
return None
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force":
# `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
else:
raise ValueError(
f"Got unsupported early_stopping_method `{early_stopping_method}`"
)
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = self._agent_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the agent.
Args:
file_path: Path to file to save the agent to.
Example:
.. code-block:: python
# If working with agent executor
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(agent_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(agent_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
return {}
class AgentOutputParser(BaseOutputParser):
@abstractmethod
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
"""Parse text into agent action/finish."""
class LLMSingleActionAgent(BaseSingleActionAgent):
llm_chain: LLMChain
output_parser: AgentOutputParser
stop: List[str]
@property
def input_keys(self) -> List[str]:
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = self.llm_chain.run(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
)
return self.output_parser.parse(output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = await self.llm_chain.arun(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
)
return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict:
return {
"llm_prefix": "",
"observation_prefix": "" if len(self.stop) == 0 else self.stop[0],
}
class Agent(BaseSingleActionAgent):
"""Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
@ -35,7 +205,13 @@ class Agent(BaseModel):
llm_chain: LLMChain
allowed_tools: Optional[List[str]] = None
return_values: List[str] = ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools
@property
def return_values(self) -> List[str]:
return ["output"]
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
@ -248,55 +424,17 @@ class Agent(BaseModel):
f"got {early_stopping_method}"
)
@property
@abstractmethod
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = self._agent_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the agent.
Args:
file_path: Path to file to save the agent to.
Example:
.. code-block:: python
# If working with agent executor
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
agent_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(agent_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(agent_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
return {
"llm_prefix": self.llm_prefix,
"observation_prefix": self.observation_prefix,
}
class AgentExecutor(Chain, BaseModel):
"""Consists of an agent using tools."""
agent: Agent
agent: BaseSingleActionAgent
tools: Sequence[BaseTool]
return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15
@ -305,7 +443,7 @@ class AgentExecutor(Chain, BaseModel):
@classmethod
def from_agent_and_tools(
cls,
agent: Agent,
agent: BaseSingleActionAgent,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
@ -320,10 +458,11 @@ class AgentExecutor(Chain, BaseModel):
"""Validate that tools are compatible with agent."""
agent = values["agent"]
tools = values["tools"]
if agent.allowed_tools is not None:
if set(agent.allowed_tools) != set([tool.name for tool in tools]):
allowed_tools = agent.get_allowed_tools()
if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({agent.allowed_tools}) different than "
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
return values
@ -418,22 +557,17 @@ class AgentExecutor(Chain, BaseModel):
tool = name_to_tool_map[output.tool]
return_direct = tool.return_direct
color = color_mapping[output.tool]
llm_prefix = "" if return_direct else self.agent.llm_prefix
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
output.tool,
verbose=self.verbose,
color=None,
llm_prefix="",
observation_prefix=self.agent.observation_prefix,
output.tool, verbose=self.verbose, color=None, **tool_run_kwargs
)
return output, observation
@ -467,22 +601,17 @@ class AgentExecutor(Chain, BaseModel):
tool = name_to_tool_map[output.tool]
return_direct = tool.return_direct
color = color_mapping[output.tool]
llm_prefix = "" if return_direct else self.agent.llm_prefix
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun(
output.tool,
verbose=self.verbose,
color=None,
llm_prefix="",
observation_prefix=self.agent.observation_prefix,
output.tool, verbose=self.verbose, color=None, **tool_run_kwargs
)
return_direct = False
return output, observation

View File

@ -74,8 +74,10 @@ class StdOutCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if observation_prefix:
print_text(f"\n{observation_prefix}")
print_text(output, color=color if color else self.color)
if llm_prefix:
print_text(f"\n{llm_prefix}")
def on_tool_error(

View File

@ -199,7 +199,7 @@ class Chain(BaseModel, ABC):
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
def run(self, *args: str, **kwargs: str) -> str:
def run(self, *args: Any, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@ -220,7 +220,7 @@ class Chain(BaseModel, ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(self, *args: str, **kwargs: str) -> str:
async def arun(self, *args: Any, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(

View File

@ -144,9 +144,9 @@ class BasePromptTemplate(BaseModel, ABC):
"""
@property
@abstractmethod
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""

View File

@ -149,7 +149,7 @@ def test_agent_with_callbacks_local() -> None:
callback_manager=manager,
)
agent.agent.llm_chain.verbose = True
agent.agent.llm_chain.verbose = True # type: ignore
output = agent.run("when was langchain made")
assert output == "curses foiled again"
@ -285,8 +285,8 @@ def test_agent_with_new_prefix_suffix() -> None:
)
# avoids "BasePromptTemplate" has no attribute "template" error
assert hasattr(agent.agent.llm_chain.prompt, "template")
prompt_str = agent.agent.llm_chain.prompt.template
assert hasattr(agent.agent.llm_chain.prompt, "template") # type: ignore
prompt_str = agent.agent.llm_chain.prompt.template # type: ignore
assert prompt_str.startswith(prefix), "Prompt does not start with prefix"
assert prompt_str.endswith(suffix), "Prompt does not end with suffix"