forked from Archives/langchain
Harrison/memory check (#2119)
Co-authored-by: JIAQIA <jqq1716@gmail.com>
This commit is contained in:
parent
3e879b47c1
commit
e2c26909f2
@ -46,8 +46,8 @@ class SequentialChain(Chain, BaseModel):
|
||||
if "memory" in values and values["memory"] is not None:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
memory_keys = values["memory"].memory_variables
|
||||
if any(input_variables) in memory_keys:
|
||||
overlapping_keys = input_variables & memory_keys
|
||||
if set(input_variables).intersection(set(memory_keys)):
|
||||
overlapping_keys = set(input_variables) & set(memory_keys)
|
||||
raise ValueError(
|
||||
f"The the input key(s) {''.join(overlapping_keys)} are found "
|
||||
f"in the Memory keys ({memory_keys}) - please use input and "
|
||||
|
@ -68,6 +68,13 @@ def test_sequential_usage_memory() -> None:
|
||||
output = chain({"foo": "123"})
|
||||
expected_output = {"baz": "123foofoo", "foo": "123", "zab": "rab"}
|
||||
assert output == expected_output
|
||||
memory = SimpleMemory(memories={"zab": "rab", "foo": "rab"})
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SequentialChain(
|
||||
memory=memory, chains=[chain_1, chain_2], input_variables=["foo"]
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_usage_multiple_outputs() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user