forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
376 lines
14 KiB
Python
376 lines
14 KiB
Python
"""Base interface that all chains should implement."""
|
|
import inspect
|
|
import json
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import yaml
|
|
from pydantic import Field, root_validator, validator
|
|
|
|
import langchain
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManager,
|
|
AsyncCallbackManagerForChainRun,
|
|
CallbackManager,
|
|
CallbackManagerForChainRun,
|
|
Callbacks,
|
|
)
|
|
from langchain.load.dump import dumpd
|
|
from langchain.load.serializable import Serializable
|
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
|
|
|
|
|
def _get_verbosity() -> bool:
|
|
return langchain.verbose
|
|
|
|
|
|
class Chain(Serializable, ABC):
|
|
"""Base interface that all chains should implement."""
|
|
|
|
memory: Optional[BaseMemory] = None
|
|
"""Optional memory object. Defaults to None.
|
|
Memory is a class that gets called at the start
|
|
and at the end of every chain. At the start, memory loads variables and passes
|
|
them along in the chain. At the end, it saves any returned variables.
|
|
There are many different types of memory - please see memory docs
|
|
for the full catalog."""
|
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
|
Callback handlers are called throughout the lifecycle of a call to a chain,
|
|
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
|
Each custom chain can optionally call additional callback methods, see Callback docs
|
|
for full details."""
|
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
|
"""Deprecated, use `callbacks` instead."""
|
|
verbose: bool = Field(default_factory=_get_verbosity)
|
|
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
|
will be printed to the console. Defaults to `langchain.verbose` value."""
|
|
tags: Optional[List[str]] = None
|
|
"""Optional list of tags associated with the chain. Defaults to None
|
|
These tags will be associated with each call to this chain,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
You can use these to eg identify a specific instance of a chain with its use case.
|
|
"""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
raise NotImplementedError("Saving not supported for this chain type.")
|
|
|
|
@root_validator()
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
"""Raise deprecation warning if callback_manager is used."""
|
|
if values.get("callback_manager") is not None:
|
|
warnings.warn(
|
|
"callback_manager is deprecated. Please use callbacks instead.",
|
|
DeprecationWarning,
|
|
)
|
|
values["callbacks"] = values.pop("callback_manager", None)
|
|
return values
|
|
|
|
@validator("verbose", pre=True, always=True)
|
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
|
"""If verbose is None, set it.
|
|
|
|
This allows users to pass in None as verbose to access the global setting.
|
|
"""
|
|
if verbose is None:
|
|
return _get_verbosity()
|
|
else:
|
|
return verbose
|
|
|
|
@property
|
|
@abstractmethod
|
|
def input_keys(self) -> List[str]:
|
|
"""Input keys this chain expects."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def output_keys(self) -> List[str]:
|
|
"""Output keys this chain expects."""
|
|
|
|
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
|
"""Check that all inputs are present."""
|
|
missing_keys = set(self.input_keys).difference(inputs)
|
|
if missing_keys:
|
|
raise ValueError(f"Missing some input keys: {missing_keys}")
|
|
|
|
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
|
missing_keys = set(self.output_keys).difference(outputs)
|
|
if missing_keys:
|
|
raise ValueError(f"Missing some output keys: {missing_keys}")
|
|
|
|
@abstractmethod
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run the logic of this chain and return the output."""
|
|
|
|
async def _acall(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run the logic of this chain and return the output."""
|
|
raise NotImplementedError("Async call not supported for this chain type.")
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: Union[Dict[str, Any], Any],
|
|
return_only_outputs: bool = False,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
include_run_info: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""Run the logic of this chain and add to output if desired.
|
|
|
|
Args:
|
|
inputs: Dictionary of inputs, or single input if chain expects
|
|
only one param.
|
|
return_only_outputs: boolean for whether to return only outputs in the
|
|
response. If True, only new keys generated by this chain will be
|
|
returned. If False, both input keys and new keys generated by this
|
|
chain will be returned. Defaults to False.
|
|
callbacks: Callbacks to use for this chain run. If not provided, will
|
|
use the callbacks provided to the chain.
|
|
include_run_info: Whether to include run info in the response. Defaults
|
|
to False.
|
|
"""
|
|
inputs = self.prep_inputs(inputs)
|
|
callback_manager = CallbackManager.configure(
|
|
callbacks, self.callbacks, self.verbose, tags, self.tags
|
|
)
|
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
|
run_manager = callback_manager.on_chain_start(
|
|
dumpd(self),
|
|
inputs,
|
|
)
|
|
try:
|
|
outputs = (
|
|
self._call(inputs, run_manager=run_manager)
|
|
if new_arg_supported
|
|
else self._call(inputs)
|
|
)
|
|
except (KeyboardInterrupt, Exception) as e:
|
|
run_manager.on_chain_error(e)
|
|
raise e
|
|
run_manager.on_chain_end(outputs)
|
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
|
inputs, outputs, return_only_outputs
|
|
)
|
|
if include_run_info:
|
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
|
return final_outputs
|
|
|
|
async def acall(
|
|
self,
|
|
inputs: Union[Dict[str, Any], Any],
|
|
return_only_outputs: bool = False,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
include_run_info: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""Run the logic of this chain and add to output if desired.
|
|
|
|
Args:
|
|
inputs: Dictionary of inputs, or single input if chain expects
|
|
only one param.
|
|
return_only_outputs: boolean for whether to return only outputs in the
|
|
response. If True, only new keys generated by this chain will be
|
|
returned. If False, both input keys and new keys generated by this
|
|
chain will be returned. Defaults to False.
|
|
callbacks: Callbacks to use for this chain run. If not provided, will
|
|
use the callbacks provided to the chain.
|
|
include_run_info: Whether to include run info in the response. Defaults
|
|
to False.
|
|
"""
|
|
inputs = self.prep_inputs(inputs)
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
callbacks, self.callbacks, self.verbose, tags, self.tags
|
|
)
|
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
|
run_manager = await callback_manager.on_chain_start(
|
|
dumpd(self),
|
|
inputs,
|
|
)
|
|
try:
|
|
outputs = (
|
|
await self._acall(inputs, run_manager=run_manager)
|
|
if new_arg_supported
|
|
else await self._acall(inputs)
|
|
)
|
|
except (KeyboardInterrupt, Exception) as e:
|
|
await run_manager.on_chain_error(e)
|
|
raise e
|
|
await run_manager.on_chain_end(outputs)
|
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
|
inputs, outputs, return_only_outputs
|
|
)
|
|
if include_run_info:
|
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
|
return final_outputs
|
|
|
|
def prep_outputs(
|
|
self,
|
|
inputs: Dict[str, str],
|
|
outputs: Dict[str, str],
|
|
return_only_outputs: bool = False,
|
|
) -> Dict[str, str]:
|
|
"""Validate and prep outputs."""
|
|
self._validate_outputs(outputs)
|
|
if self.memory is not None:
|
|
self.memory.save_context(inputs, outputs)
|
|
if return_only_outputs:
|
|
return outputs
|
|
else:
|
|
return {**inputs, **outputs}
|
|
|
|
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
|
"""Validate and prep inputs."""
|
|
if not isinstance(inputs, dict):
|
|
_input_keys = set(self.input_keys)
|
|
if self.memory is not None:
|
|
# If there are multiple input keys, but some get set by memory so that
|
|
# only one is not set, we can still figure out which key it is.
|
|
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
|
if len(_input_keys) != 1:
|
|
raise ValueError(
|
|
f"A single string input was passed in, but this chain expects "
|
|
f"multiple inputs ({_input_keys}). When a chain expects "
|
|
f"multiple inputs, please call it by passing in a dictionary, "
|
|
"eg `chain({'foo': 1, 'bar': 2})`"
|
|
)
|
|
inputs = {list(_input_keys)[0]: inputs}
|
|
if self.memory is not None:
|
|
external_context = self.memory.load_memory_variables(inputs)
|
|
inputs = dict(inputs, **external_context)
|
|
self._validate_inputs(inputs)
|
|
return inputs
|
|
|
|
def apply(
|
|
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
|
) -> List[Dict[str, str]]:
|
|
"""Call the chain on all inputs in the list."""
|
|
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
|
|
|
@property
|
|
def _run_output_key(self) -> str:
|
|
if len(self.output_keys) != 1:
|
|
raise ValueError(
|
|
f"`run` not supported when there is not exactly "
|
|
f"one output key. Got {self.output_keys}."
|
|
)
|
|
return self.output_keys[0]
|
|
|
|
def run(
|
|
self,
|
|
*args: Any,
|
|
callbacks: Callbacks = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Run the chain as text in, text out or multiple variables, text out."""
|
|
# Run at start to make sure this is possible/defined
|
|
_output_key = self._run_output_key
|
|
|
|
if args and not kwargs:
|
|
if len(args) != 1:
|
|
raise ValueError("`run` supports only one positional argument.")
|
|
return self(args[0], callbacks=callbacks, tags=tags)[_output_key]
|
|
|
|
if kwargs and not args:
|
|
return self(kwargs, callbacks=callbacks, tags=tags)[_output_key]
|
|
|
|
if not kwargs and not args:
|
|
raise ValueError(
|
|
"`run` supported with either positional arguments or keyword arguments,"
|
|
" but none were provided."
|
|
)
|
|
|
|
raise ValueError(
|
|
f"`run` supported with either positional arguments or keyword arguments"
|
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
|
)
|
|
|
|
async def arun(
|
|
self,
|
|
*args: Any,
|
|
callbacks: Callbacks = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Run the chain as text in, text out or multiple variables, text out."""
|
|
if len(self.output_keys) != 1:
|
|
raise ValueError(
|
|
f"`run` not supported when there is not exactly "
|
|
f"one output key. Got {self.output_keys}."
|
|
)
|
|
|
|
if args and not kwargs:
|
|
if len(args) != 1:
|
|
raise ValueError("`run` supports only one positional argument.")
|
|
return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
|
|
self.output_keys[0]
|
|
]
|
|
|
|
if kwargs and not args:
|
|
return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
|
|
self.output_keys[0]
|
|
]
|
|
|
|
raise ValueError(
|
|
f"`run` supported with either positional arguments or keyword arguments"
|
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
|
)
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
"""Return dictionary representation of chain."""
|
|
if self.memory is not None:
|
|
raise ValueError("Saving of memory is not yet supported.")
|
|
_dict = super().dict()
|
|
_dict["_type"] = self._chain_type
|
|
return _dict
|
|
|
|
def save(self, file_path: Union[Path, str]) -> None:
|
|
"""Save the chain.
|
|
|
|
Args:
|
|
file_path: Path to file to save the chain to.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
chain.save(file_path="path/chain.yaml")
|
|
"""
|
|
# Convert file to Path object.
|
|
if isinstance(file_path, str):
|
|
save_path = Path(file_path)
|
|
else:
|
|
save_path = file_path
|
|
|
|
directory_path = save_path.parent
|
|
directory_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Fetch dictionary to save
|
|
chain_dict = self.dict()
|
|
|
|
if save_path.suffix == ".json":
|
|
with open(file_path, "w") as f:
|
|
json.dump(chain_dict, f, indent=4)
|
|
elif save_path.suffix == ".yaml":
|
|
with open(file_path, "w") as f:
|
|
yaml.dump(chain_dict, f, default_flow_style=False)
|
|
else:
|
|
raise ValueError(f"{save_path} must be json or yaml")
|