Harrison/memory check (#2119)

Co-authored-by: JIAQIA <jqq1716@gmail.com>
This commit is contained in:
Harrison Chase 2023-03-28 15:40:36 -07:00 committed by GitHub
parent 3e879b47c1
commit e2c26909f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 2 deletions

View File

@ -46,8 +46,8 @@ class SequentialChain(Chain, BaseModel):
if "memory" in values and values["memory"] is not None:
"""Validate that prompt input variables are consistent."""
memory_keys = values["memory"].memory_variables
if any(input_variables) in memory_keys:
overlapping_keys = input_variables & memory_keys
if set(input_variables).intersection(set(memory_keys)):
overlapping_keys = set(input_variables) & set(memory_keys)
raise ValueError(
f"The the input key(s) {''.join(overlapping_keys)} are found "
f"in the Memory keys ({memory_keys}) - please use input and "

View File

@ -68,6 +68,13 @@ def test_sequential_usage_memory() -> None:
output = chain({"foo": "123"})
expected_output = {"baz": "123foofoo", "foo": "123", "zab": "rab"}
assert output == expected_output
memory = SimpleMemory(memories={"zab": "rab", "foo": "rab"})
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
SequentialChain(
memory=memory, chains=[chain_1, chain_2], input_variables=["foo"]
)
def test_sequential_usage_multiple_outputs() -> None: