mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Add base Chain docstrings (#7114)
This commit is contained in:
parent
284d40b7af
commit
87f75cb322
@ -1,6 +1,7 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@ -22,13 +23,35 @@ from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
models, document retrievers, other chains, etc., and provide a simple interface
|
||||
to this sequence.
|
||||
|
||||
The Chain interface makes it easy to create apps that are:
|
||||
- Stateful: add Memory to any Chain to give it state,
|
||||
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||
like logging, outside the main sequence of component calls,
|
||||
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||
Chains with other components, including other Chains.
|
||||
|
||||
The main methods exposed by chains are:
|
||||
- `__call__`: Chains are callable. The `__call__` method is the primary way to
|
||||
execute a Chain. This takes inputs as a dictionary and returns a
|
||||
dictionary output.
|
||||
- `run`: A convenience method that takes inputs as args/kwargs and returns the
|
||||
output as a string. This method can only be used for a subset of chains and
|
||||
cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
@ -71,7 +94,7 @@ class Chain(Serializable, ABC):
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
@ -83,9 +106,9 @@ class Chain(Serializable, ABC):
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
"""Set the chain verbosity.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
Defaults to the global setting if not specified by the user.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
@ -95,12 +118,12 @@ class Chain(Serializable, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
"""Return the keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
"""Return the keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
@ -119,14 +142,44 @@ class Chain(Serializable, ABC):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
"""Execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.__call__`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.acall`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
def __call__(
|
||||
@ -139,21 +192,30 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
"""Execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Optional list of tags associated with the chain. Defaults to None
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
@ -197,21 +259,30 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Optional list of tags associated with the chain. Defaults to None
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
@ -251,7 +322,18 @@ class Chain(Serializable, ABC):
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of chain inputs, including any inputs added by chain
|
||||
memory.
|
||||
outputs: Dictionary of initial chain outputs.
|
||||
return_only_outputs: Whether to only return the chain outputs. If False,
|
||||
inputs are also added to the final outputs.
|
||||
|
||||
Returns:
|
||||
A dict of the final chain outputs.
|
||||
"""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
@ -261,7 +343,17 @@ class Chain(Serializable, ABC):
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
"""Validate and prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
@ -282,12 +374,6 @@ class Chain(Serializable, ABC):
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
@ -305,7 +391,46 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
has more outputs, a non-string output, or you want to return the inputs/run
|
||||
info along with the outputs, use `Chain.__call__`.
|
||||
|
||||
The other difference is that this method expects inputs to be passed directly in
|
||||
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
|
||||
a single input dictionary with all the inputs.
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output as a string.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
chain.run("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
chain.run(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
# Run at start to make sure this is possible/defined
|
||||
_output_key = self._run_output_key
|
||||
|
||||
@ -326,7 +451,7 @@ class Chain(Serializable, ABC):
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
@ -340,14 +465,52 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
has more outputs, a non-string output, or you want to return the inputs/run
|
||||
info along with the outputs, use `Chain.__call__`.
|
||||
|
||||
The other difference is that this method expects inputs to be passed directly in
|
||||
as positional arguments or keyword arguments, whereas `Chain.__call__` expects
|
||||
a single input dictionary with all the inputs.
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output as a string.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
await chain.arun("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
await chain.arun(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
elif args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (
|
||||
@ -369,16 +532,36 @@ class Chain(Serializable, ABC):
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
"""Return dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the chain.
|
||||
|
||||
Example:
|
||||
..code-block:: python
|
||||
|
||||
chain.dict(exclude_unset=True)
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict = super().dict(**kwargs)
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
@ -407,3 +590,9 @@ class Chain(Serializable, ABC):
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
Loading…
Reference in New Issue
Block a user