check memory variables (#411)

can have multiple input keys, if some come from memory
harrison/map-rerank
Harrison Chase 2 years ago committed by GitHub
parent f990395211
commit 20959d8c36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -963,7 +963,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.8"
}
},
"nbformat": 4,

@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "91d307ed",
"metadata": {},
"outputs": [],
@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "10a93bf9",
"metadata": {},
"outputs": [],
@ -77,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "fa0f3066",
"metadata": {},
"outputs": [],
@ -96,7 +96,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "8465b4b7",
"metadata": {},
"outputs": [],
@ -107,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "611be801",
"metadata": {},
"outputs": [
@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "b6255b02",
"metadata": {},
"outputs": [],
@ -153,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "ec4eacad",
"metadata": {},
"outputs": [],
@ -163,17 +163,17 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "59c7508d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" The president said that Ketanji Brown Jackson is one of our nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. The president also said that Ketanji Brown Jackson is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds and will continue Justice Breyer's legacy of excellence.\""
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}

@ -88,14 +88,19 @@ class Chain(BaseModel, ABC):
"""
if not isinstance(inputs, dict):
if len(self.input_keys) != 1:
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({self.input_keys}). When a chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {self.input_keys[0]: inputs}
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)

@ -1,10 +1,31 @@
"""Test logic on base chain class."""
from typing import Dict, List
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.base import Chain, Memory
class FakeMemory(Memory, BaseModel):
"""Fake memory class for testing purposes."""
@property
def memory_variables(self) -> List[str]:
"""Return baz variable."""
return ["baz"]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return baz variable."""
return {"baz": "foo"}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Pass."""
pass
def clear(self) -> None:
"""Pass."""
pass
class FakeChain(Chain, BaseModel):
@ -106,3 +127,9 @@ def test_multiple_output_keys_error() -> None:
chain = FakeChain(the_output_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain.run("bar")
def test_run_arg_with_memory() -> None:
"""Test run method works when arg is passed."""
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
chain.run("bar")

Loading…
Cancel
Save