From e2c26909f2f78f35bb45121c3a10ec28d2a7f448 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 28 Mar 2023 15:40:36 -0700 Subject: [PATCH] Harrison/memory check (#2119) Co-authored-by: JIAQIA --- langchain/chains/sequential.py | 4 ++-- tests/unit_tests/chains/test_sequential.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 9d5d66be..9b2d2dd4 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -46,8 +46,8 @@ class SequentialChain(Chain, BaseModel): if "memory" in values and values["memory"] is not None: """Validate that prompt input variables are consistent.""" memory_keys = values["memory"].memory_variables - if any(input_variables) in memory_keys: - overlapping_keys = input_variables & memory_keys + if set(input_variables).intersection(set(memory_keys)): + overlapping_keys = set(input_variables) & set(memory_keys) raise ValueError( f"The the input key(s) {''.join(overlapping_keys)} are found " f"in the Memory keys ({memory_keys}) - please use input and " diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index 74947f9f..c6021a1c 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -68,6 +68,13 @@ def test_sequential_usage_memory() -> None: output = chain({"foo": "123"}) expected_output = {"baz": "123foofoo", "foo": "123", "zab": "rab"} assert output == expected_output + memory = SimpleMemory(memories={"zab": "rab", "foo": "rab"}) + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + SequentialChain( + memory=memory, chains=[chain_1, chain_2], input_variables=["foo"] + ) def test_sequential_usage_multiple_outputs() -> None: