refactor chain

ankush/async-llmchain
Ankush Gola 1 year ago
parent a9af287297
commit 103c3046e8

@ -111,6 +111,10 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
@ -125,24 +129,7 @@ class Chain(BaseModel, ABC):
chain will be returned. Defaults to False.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
@ -154,6 +141,37 @@ class Chain(BaseModel, ABC):
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
async def async_call(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = await self._async_call(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
def prep_outputs(self, inputs, outputs, return_only_outputs):
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
@ -162,6 +180,27 @@ class Chain(BaseModel, ABC):
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs):
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]

Loading…
Cancel
Save