forked from Archives/langchain
harrison/agent-refactor
parent
ac208f85c8
commit
f646c94bc1
@ -1,61 +0,0 @@
|
|||||||
"""Input manager for agents."""
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
|
|
||||||
class ChainedInput:
|
|
||||||
"""Class for working with input that is the result of chains."""
|
|
||||||
|
|
||||||
def __init__(self, text: str, verbose: bool = False):
|
|
||||||
"""Initialize with verbose flag and initial text."""
|
|
||||||
self._verbose = verbose
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_start(text)
|
|
||||||
self._input = text
|
|
||||||
self._intermediate_actions: List[AgentAction] = []
|
|
||||||
self._intermediate_observations: List[str] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def intermediate_steps(self) -> List:
|
|
||||||
"""Return intermediate steps the agent took."""
|
|
||||||
steps = []
|
|
||||||
for i, action in enumerate(self._intermediate_actions):
|
|
||||||
step = {
|
|
||||||
"log": action.log,
|
|
||||||
"tool": action.tool,
|
|
||||||
"tool_input": action.tool_input,
|
|
||||||
"observation": self._intermediate_observations[i],
|
|
||||||
}
|
|
||||||
steps.append(step)
|
|
||||||
return steps
|
|
||||||
|
|
||||||
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
|
|
||||||
"""Add text to input, print if in verbose mode."""
|
|
||||||
|
|
||||||
self._input += action.log
|
|
||||||
self._intermediate_actions.append(action)
|
|
||||||
|
|
||||||
def add_observation(
|
|
||||||
self,
|
|
||||||
observation: str,
|
|
||||||
observation_prefix: str,
|
|
||||||
llm_prefix: str,
|
|
||||||
color: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Add observation to input, print if in verbose mode."""
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_observation(
|
|
||||||
observation,
|
|
||||||
color=color,
|
|
||||||
observation_prefix=observation_prefix,
|
|
||||||
llm_prefix=llm_prefix,
|
|
||||||
)
|
|
||||||
self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
|
||||||
self._intermediate_observations.append(observation)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input(self) -> str:
|
|
||||||
"""Return the accumulated input."""
|
|
||||||
return self._input
|
|
@ -1,75 +0,0 @@
|
|||||||
"""Test input manipulating logic."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
from langchain.agents.input import ChainedInput
|
|
||||||
from langchain.input import get_color_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_not_verbose() -> None:
|
|
||||||
"""Test chained input logic."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_verbose() -> None:
|
|
||||||
"""Test chained input logic, making sure verbose doesn't mess it up."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo", verbose=True)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "foo"
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n1bar\n2"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("baz", "3", "4", "blue")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2\n3baz\n4"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping() -> None:
|
|
||||||
"""Test getting of color mapping."""
|
|
||||||
# Test on few inputs.
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
expected_output = {"foo": "blue", "bar": "yellow"}
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
# Test on a lot of inputs.
|
|
||||||
items = [f"foo-{i}" for i in range(20)]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
assert len(output) == 20
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping_excluded_colors() -> None:
|
|
||||||
"""Test getting of color mapping with excluded colors."""
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items, excluded_colors=["blue"])
|
|
||||||
expected_output = {"foo": "yellow", "bar": "pink"}
|
|
||||||
assert output == expected_output
|
|
Loading…
Reference in New Issue