Harrison/single input (#347)

allow passing of single input into chain

Co-authored-by: thepok <richterthepok@yahoo.de>
harrison/agent_multi_inputs^2
Harrison Chase 2 years ago committed by GitHub
parent 5161ae7e08
commit 8cf62ce06e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
"""Base interface that all chains should implement."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field
@ -74,18 +74,28 @@ class Chain(BaseModel, ABC):
"""Run the logic of this chain and return the output."""
def __call__(
self, inputs: Dict[str, Any], return_only_outputs: bool = False
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, str]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs.
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
if not isinstance(inputs, dict):
if len(self.input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({self.input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {self.input_keys[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)

@ -11,11 +11,12 @@ class FakeChain(Chain, BaseModel):
"""Fake chain class for testing purposes."""
be_correct: bool = True
the_input_keys: List[str] = ["foo"]
@property
def input_keys(self) -> List[str]:
"""Input key of foo."""
return ["foo"]
"""Input keys."""
return self.the_input_keys
@property
def output_keys(self) -> List[str]:
@ -48,3 +49,17 @@ def test_correct_call() -> None:
chain = FakeChain()
output = chain({"foo": "bar"})
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_correct() -> None:
"""Test passing single input works."""
chain = FakeChain()
output = chain("bar")
assert output == {"foo": "bar", "bar": "baz"}
def test_single_input_error() -> None:
"""Test passing single input errors as expected."""
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain("bar")

Loading…
Cancel
Save