From 8d0869c6d3ed63b2b15d4f75ea664e089dcc569d Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Sun, 18 Dec 2022 15:54:56 -0500 Subject: [PATCH] change run to use args and kwargs (#367) Before, `run` was not able to be called with multiple arguments. This expands the functionality. --- .flake8 | 1 + langchain/chains/base.py | 25 ++++++++++------ tests/unit_tests/chains/test_base.py | 45 +++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/.flake8 b/.flake8 index 64a9cd4c..d3ac343b 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,6 @@ [flake8] exclude = + venv .venv __pycache__ notebooks diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 848664e5..ececc8f0 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -119,16 +119,23 @@ class Chain(BaseModel, ABC): """Call the chain on all inputs in the list.""" return [self(inputs) for inputs in input_list] - def run(self, text: str) -> str: - """Run text in, text out (if applicable).""" - if len(self.input_keys) != 1: - raise ValueError( - f"`run` not supported when there is not exactly " - f"one input key, got {self.input_keys}." - ) + def run(self, *args: str, **kwargs: str) -> str: + """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( f"`run` not supported when there is not exactly " - f"one output key, got {self.output_keys}." + f"one output key. Got {self.output_keys}." ) - return self({self.input_keys[0]: text})[self.output_keys[0]] + + if args and not kwargs: + if len(args) != 1: + raise ValueError("`run` supports only one positional argument.") + return self(args[0])[self.output_keys[0]] + + if kwargs and not args: + return self(kwargs)[self.output_keys[0]] + + raise ValueError( + f"`run` supported with either positional arguments or keyword arguments" + f" but not both. Got args: {args} and kwargs: {kwargs}." + ) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 8fcfa918..ade2d318 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -12,6 +12,7 @@ class FakeChain(Chain, BaseModel): be_correct: bool = True the_input_keys: List[str] = ["foo"] + the_output_keys: List[str] = ["bar"] @property def input_keys(self) -> List[str]: @@ -21,7 +22,7 @@ class FakeChain(Chain, BaseModel): @property def output_keys(self) -> List[str]: """Output key of bar.""" - return ["bar"] + return self.the_output_keys def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: if self.be_correct: @@ -63,3 +64,45 @@ def test_single_input_error() -> None: chain = FakeChain(the_input_keys=["foo", "bar"]) with pytest.raises(ValueError): chain("bar") + + +def test_run_single_arg() -> None: + """Test run method with single arg.""" + chain = FakeChain() + output = chain.run("bar") + assert output == "baz" + + +def test_run_multiple_args_error() -> None: + """Test run method with multiple args errors as expected.""" + chain = FakeChain() + with pytest.raises(ValueError): + chain.run("bar", "foo") + + +def test_run_kwargs() -> None: + """Test run method with kwargs.""" + chain = FakeChain(the_input_keys=["foo", "bar"]) + output = chain.run(foo="bar", bar="foo") + assert output == "baz" + + +def test_run_kwargs_error() -> None: + """Test run method with kwargs errors as expected.""" + chain = FakeChain(the_input_keys=["foo", "bar"]) + with pytest.raises(ValueError): + chain.run(foo="bar", baz="foo") + + +def test_run_args_and_kwargs_error() -> None: + """Test run method with args and kwargs.""" + chain = FakeChain(the_input_keys=["foo", "bar"]) + with pytest.raises(ValueError): + chain.run("bar", foo="bar") + + +def test_multiple_output_keys_error() -> None: + """Test run with multiple output keys errors as expected.""" + chain = FakeChain(the_output_keys=["foo", "bar"]) + with pytest.raises(ValueError): + chain.run("bar")