From f307ca094b0d175d71ac424eba3d9f7ef5fc44f1 Mon Sep 17 00:00:00 2001 From: Nir Gazit Date: Thu, 13 Jul 2023 09:47:44 +0300 Subject: [PATCH] fix(memory): allow internal chains to use memory (#6769) Fixed #6768. This is a workaround only. I think a better longer-term solution is for chains to declare how many input variables they *actually* need (as opposed to ones that are in the prompt, where some may be satisfied by the memory). Then, a wrapping chain can check the input match against the actual input variables. @hwchase17 --- langchain/chains/sequential.py | 3 +++ tests/unit_tests/chains/test_sequential.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 6877df2b7e..078a0d7fff 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -62,6 +62,9 @@ class SequentialChain(Chain): for chain in chains: missing_vars = set(chain.input_keys).difference(known_variables) + if chain.memory: + missing_vars = missing_vars.difference(chain.memory.memory_variables) + if missing_vars: raise ValueError( f"Missing required input keys: {missing_vars}, " diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index 19e7df1066..12f72c6fe7 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -6,6 +6,7 @@ import pytest from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain +from langchain.memory import ConversationBufferMemory from langchain.memory.simple import SimpleMemory @@ -81,6 +82,21 @@ def test_sequential_usage_memory() -> None: ) +def test_sequential_internal_chain_use_memory() -> None: + """Test sequential usage with memory for one of the internal chains.""" + memory = ConversationBufferMemory(memory_key="bla") + memory.save_context({"input": "yo"}, {"output": "ya"}) + chain_1 = FakeChain( + input_variables=["foo", "bla"], output_variables=["bar"], memory=memory + ) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) + output = chain({"foo": "123"}) + print("HEYYY OUTPUT", output) + expected_output = {"foo": "123", "baz": "123 Human: yo\nAI: yafoofoo"} + assert output == expected_output + + def test_sequential_usage_multiple_outputs() -> None: """Test sequential usage on multiple output chains.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])