Add base Chain docstrings (#7114)

pull/6990/head^2
Bagatur 1 year ago committed by GitHub
parent 284d40b7af
commit 87f75cb322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
@ -22,13 +23,35 @@ from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
return langchain.verbose
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
The main methods exposed by chains are:
- `__call__`: Chains are callable. The `__call__` method is the primary way to
execute a Chain. This takes inputs as a dictionary and returns a
dictionary output.
- `run`: A convenience method that takes inputs as args/kwargs and returns the
output as a string. This method can only be used for a subset of chains and
cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
@ -71,7 +94,7 @@ class Chain(Serializable, ABC):
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
@ -83,9 +106,9 @@ class Chain(Serializable, ABC):
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
"""Set the chain verbosity.
This allows users to pass in None as verbose to access the global setting.
Defaults to the global setting if not specified by the user.
"""
if verbose is None:
return _get_verbosity()
@ -95,12 +118,12 @@ class Chain(Serializable, ABC):
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
"""Return the keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
"""Return the keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
@ -119,14 +142,44 @@ class Chain(Serializable, ABC):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
"""Execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.__call__`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
"""Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.acall`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__(
@ -139,21 +192,30 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
"""Execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
tags: Optional list of tags associated with the chain. Defaults to None
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
@ -197,21 +259,30 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
"""Asynchronously execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
tags: Optional list of tags associated with the chain. Defaults to None
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
@ -251,7 +322,18 @@ class Chain(Serializable, ABC):
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
inputs: Dictionary of chain inputs, including any inputs added by chain
memory.
outputs: Dictionary of initial chain outputs.
return_only_outputs: Whether to only return the chain outputs. If False,
inputs are also added to the final outputs.
Returns:
A dict of the final chain outputs.
"""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
@ -261,7 +343,17 @@ class Chain(Serializable, ABC):
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Validate and prep inputs."""
"""Validate and prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
@ -282,12 +374,6 @@ class Chain(Serializable, ABC):
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
@ -305,7 +391,46 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
"""Convenience method for executing chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
has more outputs, a non-string output, or you want to return the inputs/run
info along with the outputs, use `Chain.__call__`.
The other difference is that this method expects inputs to be passed directly in
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
a single input dictionary with all the inputs.
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output as a string.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
chain.run("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
chain.run(question=question, context=context)
# -> "The temperature in Boise is..."
"""
# Run at start to make sure this is possible/defined
_output_key = self._run_output_key
@ -326,11 +451,11 @@ class Chain(Serializable, ABC):
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
else:
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(
self,
@ -340,14 +465,52 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
"""Convenience method for executing chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
has more outputs, a non-string output, or you want to return the inputs/run
info along with the outputs, use `Chain.__call__`.
The other difference is that this method expects inputs to be passed directly in
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
a single input dictionary with all the inputs.
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output as a string.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
await chain.arun("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
await chain.arun(question=question, context=context)
# -> "The temperature in Boise is..."
"""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (
@ -369,23 +532,43 @@ class Chain(Serializable, ABC):
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
"""Return dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
method.
Returns:
A dictionary representation of the chain.
Example:
..code-block:: python
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
.. code-block:: python
chain.save(file_path="path/chain.yaml")
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
@ -407,3 +590,9 @@ class Chain(Serializable, ABC):
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]

Loading…
Cancel
Save