make attrs public (#187)

since they are used outside of the class, should be public
harrison/output_parser
Harrison Chase 2 years ago committed by GitHub
parent ae9c6257fe
commit b913df3774
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,11 +20,11 @@ class Memory(BaseModel, ABC):
"""Input keys this memory class will load dynamically."""
@abstractmethod
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
@ -77,7 +77,7 @@ class Chain(BaseModel, ABC):
"""
if self.memory is not None:
external_context = self.memory._load_dynamic_keys(inputs)
external_context = self.memory.load_dynamic_keys(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
@ -87,7 +87,7 @@ class Chain(BaseModel, ABC):
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
if self.memory is not None:
self.memory._save_context(inputs, outputs)
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:

@ -24,11 +24,11 @@ class ConversationBufferMemory(Memory, BaseModel):
"""
return [self.dynamic_key]
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.dynamic_key: self.buffer}
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
if len(prompt_input_keys) != 1:
@ -56,7 +56,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
"""
return [self.dynamic_key]
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.dynamic_key: self.buffer}
@ -72,7 +72,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
)
return values
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
if len(prompt_input_keys) != 1:

@ -50,19 +50,19 @@ def test_conversation_memory(memory: Memory) -> None:
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because these is one variable.
good_outputs = {"bar": "foo"}
memory._save_context(good_inputs, good_outputs)
memory.save_context(good_inputs, good_outputs)
# This is a bad input because there are two variables that aren't the same as baz.
bad_inputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory._save_context(bad_inputs, good_outputs)
memory.save_context(bad_inputs, good_outputs)
# This is a bad input because the only variable is the same as baz.
bad_inputs = {"baz": "bar"}
with pytest.raises(ValueError):
memory._save_context(bad_inputs, good_outputs)
memory.save_context(bad_inputs, good_outputs)
# This is a bad output because it is empty.
with pytest.raises(ValueError):
memory._save_context(good_inputs, {})
memory.save_context(good_inputs, {})
# This is a bad output because there are two keys.
bad_outputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory._save_context(good_inputs, bad_outputs)
memory.save_context(good_inputs, bad_outputs)

Loading…
Cancel
Save