fix(memory): allow internal chains to use memory (#6769)

Fixed #6768.

This is a workaround only. I think a better longer-term solution is for
chains to declare how many input variables they *actually* need (as
opposed to ones that are in the prompt, where some may be satisfied by
the memory). Then, a wrapping chain can check the input match against
the actual input variables.

@hwchase17
pull/7643/head
Nir Gazit 1 year ago committed by GitHub
parent 488d2d5da9
commit f307ca094b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,6 +62,9 @@ class SequentialChain(Chain):
for chain in chains: for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables) missing_vars = set(chain.input_keys).difference(known_variables)
if chain.memory:
missing_vars = missing_vars.difference(chain.memory.memory_variables)
if missing_vars: if missing_vars:
raise ValueError( raise ValueError(
f"Missing required input keys: {missing_vars}, " f"Missing required input keys: {missing_vars}, "

@ -6,6 +6,7 @@ import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.simple import SimpleMemory from langchain.memory.simple import SimpleMemory
@ -81,6 +82,21 @@ def test_sequential_usage_memory() -> None:
) )
def test_sequential_internal_chain_use_memory() -> None:
"""Test sequential usage with memory for one of the internal chains."""
memory = ConversationBufferMemory(memory_key="bla")
memory.save_context({"input": "yo"}, {"output": "ya"})
chain_1 = FakeChain(
input_variables=["foo", "bla"], output_variables=["bar"], memory=memory
)
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
output = chain({"foo": "123"})
print("HEYYY OUTPUT", output)
expected_output = {"foo": "123", "baz": "123 Human: yo\nAI: yafoofoo"}
assert output == expected_output
def test_sequential_usage_multiple_outputs() -> None: def test_sequential_usage_multiple_outputs() -> None:
"""Test sequential usage on multiple output chains.""" """Test sequential usage on multiple output chains."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])

Loading…
Cancel
Save