mirror of https://github.com/hwchase17/langchain
Separate out langchain_core package (#13577)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev>pull/13649/head
parent
4eec47b191
commit
d82cbf5e76
@ -0,0 +1,52 @@
|
||||
---
|
||||
name: libs/langchain core CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/actions/poetry_setup/action.yml'
|
||||
- '.github/tools/**'
|
||||
- '.github/workflows/_lint.yml'
|
||||
- '.github/workflows/_test.yml'
|
||||
- '.github/workflows/_pydantic_compatibility.yml'
|
||||
- '.github/workflows/langchain_core_ci.yml'
|
||||
- 'libs/core/**'
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
# If another push to the same PR or branch happens while this workflow is still running,
|
||||
# cancel the earlier run in favor of the next run.
|
||||
#
|
||||
# There's no point in testing an outdated version of the code. GitHub only allows
|
||||
# a limited number of job runners to be active at the same time, so it's better to cancel
|
||||
# pointless jobs early so that more useful jobs can run sooner.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
WORKDIR: "libs/core"
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
uses:
|
||||
./.github/workflows/_lint.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
uses:
|
||||
./.github/workflows/_test.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
|
||||
pydantic-compatibility:
|
||||
uses:
|
||||
./.github/workflows/_pydantic_compatibility.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
@ -0,0 +1,13 @@
|
||||
---
|
||||
name: libs/core Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
@ -0,0 +1,54 @@
|
||||
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
|
||||
lint lint_diff:
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'test_watch - run unit tests in watch mode'
|
@ -0,0 +1 @@
|
||||
# langchain-core
|
@ -0,0 +1,7 @@
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
@ -0,0 +1,26 @@
|
||||
"""Helper functions for managing the LangChain API.
|
||||
|
||||
This module is only relevant for LangChain developers, not for users.
|
||||
|
||||
.. warning::
|
||||
|
||||
This module and its submodules are for internal use only. Do not use them
|
||||
in your own code. We may change the API at any time with no warning.
|
||||
|
||||
"""
|
||||
|
||||
from .deprecation import (
|
||||
LangChainDeprecationWarning,
|
||||
deprecated,
|
||||
suppress_langchain_deprecation_warning,
|
||||
surface_langchain_deprecation_warnings,
|
||||
warn_deprecated,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"deprecated",
|
||||
"LangChainDeprecationWarning",
|
||||
"suppress_langchain_deprecation_warning",
|
||||
"surface_langchain_deprecation_warnings",
|
||||
"warn_deprecated",
|
||||
]
|
@ -0,0 +1,341 @@
|
||||
"""Helper functions for deprecating parts of the LangChain API.
|
||||
|
||||
This module was adapted from matplotlibs _api/deprecation.py module:
|
||||
|
||||
https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py
|
||||
|
||||
.. warning::
|
||||
|
||||
This module is for internal use only. Do not use it in your own code.
|
||||
We may change the API at any time with no warning.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar
|
||||
|
||||
|
||||
class LangChainDeprecationWarning(DeprecationWarning):
|
||||
"""A class for issuing deprecation warnings for LangChain users."""
|
||||
|
||||
|
||||
class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
"""A class for issuing deprecation warnings for LangChain users."""
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", Type, Callable)
|
||||
|
||||
|
||||
def deprecated(
|
||||
since: str,
|
||||
*,
|
||||
message: str = "",
|
||||
name: str = "",
|
||||
alternative: str = "",
|
||||
pending: bool = False,
|
||||
obj_type: str = "",
|
||||
addendum: str = "",
|
||||
removal: str = "",
|
||||
) -> Callable[[T], T]:
|
||||
"""Decorator to mark a function, a class, or a property as deprecated.
|
||||
|
||||
When deprecating a classmethod, a staticmethod, or a property, the
|
||||
``@deprecated`` decorator should go *under* ``@classmethod`` and
|
||||
``@staticmethod`` (i.e., `deprecated` should directly decorate the
|
||||
underlying callable), but *over* ``@property``.
|
||||
|
||||
When deprecating a class ``C`` intended to be used as a base class in a
|
||||
multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method
|
||||
(if ``C`` instead inherited its ``__init__`` from its own base class, then
|
||||
``@deprecated`` would mess up ``__init__`` inheritance when installing its
|
||||
own (deprecation-emitting) ``C.__init__``).
|
||||
|
||||
Parameters are the same as for `warn_deprecated`, except that *obj_type*
|
||||
defaults to 'class' if decorating a class, 'attribute' if decorating a
|
||||
property, and 'function' otherwise.
|
||||
|
||||
Arguments:
|
||||
since : str
|
||||
The release at which this API became deprecated.
|
||||
message : str, optional
|
||||
Override the default deprecation message. The %(since)s,
|
||||
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
|
||||
and %(removal)s format specifiers will be replaced by the
|
||||
values of the respective arguments passed to this function.
|
||||
name : str, optional
|
||||
The name of the deprecated object.
|
||||
alternative : str, optional
|
||||
An alternative API that the user may use in place of the
|
||||
deprecated API. The deprecation warning will tell the user
|
||||
about this alternative if provided.
|
||||
pending : bool, optional
|
||||
If True, uses a PendingDeprecationWarning instead of a
|
||||
DeprecationWarning. Cannot be used together with removal.
|
||||
obj_type : str, optional
|
||||
The object type being deprecated.
|
||||
addendum : str, optional
|
||||
Additional text appended directly to the final message.
|
||||
removal : str, optional
|
||||
The expected removal version. With the default (an empty
|
||||
string), a removal version is automatically computed from
|
||||
since. Set to other Falsy values to not schedule a removal
|
||||
date. Cannot be used together with pending.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@deprecated('1.4.0')
|
||||
def the_function_to_deprecate():
|
||||
pass
|
||||
"""
|
||||
|
||||
def deprecate(
|
||||
obj: T,
|
||||
*,
|
||||
_obj_type: str = obj_type,
|
||||
_name: str = name,
|
||||
_message: str = message,
|
||||
_alternative: str = alternative,
|
||||
_pending: bool = pending,
|
||||
_addendum: str = addendum,
|
||||
) -> T:
|
||||
"""Implementation of the decorator returned by `deprecated`."""
|
||||
if isinstance(obj, type):
|
||||
if not _obj_type:
|
||||
_obj_type = "class"
|
||||
wrapped = obj.__init__ # type: ignore
|
||||
_name = _name or obj.__name__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
"""Finalize the deprecation of a class."""
|
||||
try:
|
||||
obj.__doc__ = new_doc
|
||||
except AttributeError: # Can't set on some extension objects.
|
||||
pass
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
wrapper
|
||||
)
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, property):
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or obj.fget.__name__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(type(obj)): # type: ignore
|
||||
"""A deprecated property."""
|
||||
|
||||
def __get__(self, instance, owner=None): # type: ignore
|
||||
if instance is not None or owner is not None:
|
||||
emit_warning()
|
||||
return super().__get__(instance, owner)
|
||||
|
||||
def __set__(self, instance, value): # type: ignore
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return super().__set__(instance, value)
|
||||
|
||||
def __delete__(self, instance): # type: ignore
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return super().__delete__(instance)
|
||||
|
||||
def __set_name__(self, owner, set_name): # type: ignore
|
||||
nonlocal _name
|
||||
if _name == "<lambda>":
|
||||
_name = set_name
|
||||
|
||||
def finalize(_: Any, new_doc: str) -> Any: # type: ignore
|
||||
"""Finalize the property."""
|
||||
return _deprecated_property(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
)
|
||||
|
||||
else:
|
||||
if not _obj_type:
|
||||
_obj_type = "function"
|
||||
wrapped = obj
|
||||
_name = _name or obj.__name__ # type: ignore
|
||||
old_doc = wrapped.__doc__
|
||||
|
||||
def finalize( # type: ignore
|
||||
wrapper: Callable[..., Any], new_doc: str
|
||||
) -> T:
|
||||
"""Wrap the wrapped function using the wrapper and update the docstring.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper function.
|
||||
new_doc: The new docstring.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
"""
|
||||
wrapper = functools.wraps(wrapped)(wrapper)
|
||||
wrapper.__doc__ = new_doc
|
||||
return wrapper
|
||||
|
||||
def emit_warning() -> None:
|
||||
"""Emit the warning."""
|
||||
warn_deprecated(
|
||||
since,
|
||||
message=_message,
|
||||
name=_name,
|
||||
alternative=_alternative,
|
||||
pending=_pending,
|
||||
obj_type=_obj_type,
|
||||
addendum=_addendum,
|
||||
removal=removal,
|
||||
)
|
||||
|
||||
def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the original wrapped callable that emits a warning.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments to the function.
|
||||
**kwargs: The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
The return value of the function being wrapped.
|
||||
"""
|
||||
emit_warning()
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
|
||||
|
||||
if not old_doc:
|
||||
new_doc = "[*Deprecated*]"
|
||||
else:
|
||||
new_doc = f"[*Deprecated*] {old_doc}"
|
||||
|
||||
# Modify the docstring to include a deprecation notice.
|
||||
notes_header = "\nNotes\n-----"
|
||||
components = [
|
||||
message,
|
||||
f"Use {alternative} instead." if alternative else "",
|
||||
addendum,
|
||||
]
|
||||
details = " ".join([component.strip() for component in components if component])
|
||||
new_doc += (
|
||||
f"[*Deprecated*] {old_doc}\n"
|
||||
f"{notes_header if notes_header not in old_doc else ''}\n"
|
||||
f".. deprecated:: {since}\n"
|
||||
f" {details}"
|
||||
)
|
||||
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
|
||||
return deprecate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_langchain_deprecation_warning() -> Generator[None, None, None]:
|
||||
"""Context manager to suppress LangChainDeprecationWarning."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", LangChainDeprecationWarning)
|
||||
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
|
||||
yield
|
||||
|
||||
|
||||
def warn_deprecated(
|
||||
since: str,
|
||||
*,
|
||||
message: str = "",
|
||||
name: str = "",
|
||||
alternative: str = "",
|
||||
pending: bool = False,
|
||||
obj_type: str = "",
|
||||
addendum: str = "",
|
||||
removal: str = "",
|
||||
) -> None:
|
||||
"""Display a standardized deprecation.
|
||||
|
||||
Arguments:
|
||||
since : str
|
||||
The release at which this API became deprecated.
|
||||
message : str, optional
|
||||
Override the default deprecation message. The %(since)s,
|
||||
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
|
||||
and %(removal)s format specifiers will be replaced by the
|
||||
values of the respective arguments passed to this function.
|
||||
name : str, optional
|
||||
The name of the deprecated object.
|
||||
alternative : str, optional
|
||||
An alternative API that the user may use in place of the
|
||||
deprecated API. The deprecation warning will tell the user
|
||||
about this alternative if provided.
|
||||
pending : bool, optional
|
||||
If True, uses a PendingDeprecationWarning instead of a
|
||||
DeprecationWarning. Cannot be used together with removal.
|
||||
obj_type : str, optional
|
||||
The object type being deprecated.
|
||||
addendum : str, optional
|
||||
Additional text appended directly to the final message.
|
||||
removal : str, optional
|
||||
The expected removal version. With the default (an empty
|
||||
string), a removal version is automatically computed from
|
||||
since. Set to other Falsy values to not schedule a removal
|
||||
date. Cannot be used together with pending.
|
||||
"""
|
||||
if pending and removal:
|
||||
raise ValueError("A pending deprecation cannot have a scheduled removal")
|
||||
|
||||
if not pending:
|
||||
if not removal:
|
||||
removal = f"in {removal}" if removal else "within ?? minor releases"
|
||||
raise NotImplementedError(
|
||||
f"Need to determine which default deprecation schedule to use. "
|
||||
f"{removal}"
|
||||
)
|
||||
else:
|
||||
removal = f"in {removal}"
|
||||
|
||||
if not message:
|
||||
message = ""
|
||||
|
||||
if obj_type:
|
||||
message += f"The {obj_type} `{name}`"
|
||||
else:
|
||||
message += f"`{name}`"
|
||||
|
||||
if pending:
|
||||
message += " will be deprecated in a future version"
|
||||
else:
|
||||
message += f" was deprecated in LangChain {since}"
|
||||
|
||||
if removal:
|
||||
message += f" and will be removed {removal}"
|
||||
|
||||
if alternative:
|
||||
message += f". Use {alternative} instead."
|
||||
|
||||
if addendum:
|
||||
message += f" {addendum}"
|
||||
|
||||
warning_cls = (
|
||||
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
||||
)
|
||||
warning = warning_cls(message)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
|
||||
|
||||
|
||||
def surface_langchain_deprecation_warnings() -> None:
|
||||
"""Unmute LangChain deprecation warnings."""
|
||||
warnings.filterwarnings(
|
||||
"default",
|
||||
category=LangChainPendingDeprecationWarning,
|
||||
)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"default",
|
||||
category=LangChainDeprecationWarning,
|
||||
)
|
@ -0,0 +1,36 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
# Get directory of langchain package
|
||||
PACKAGE_DIR = HERE.parent
|
||||
SEPARATOR = os.sep
|
||||
|
||||
|
||||
def get_relative_path(
|
||||
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
|
||||
) -> str:
|
||||
"""Get the path of the file as a relative path to the package directory."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
return str(file.relative_to(relative_to))
|
||||
|
||||
|
||||
def as_import_path(
|
||||
file: Union[Path, str],
|
||||
*,
|
||||
suffix: Optional[str] = None,
|
||||
relative_to: Path = PACKAGE_DIR,
|
||||
) -> str:
|
||||
"""Path of the file as a LangChain import exclude langchain top namespace."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
path = get_relative_path(file, relative_to=relative_to)
|
||||
if file.is_file():
|
||||
path = path[: -len(file.suffix)]
|
||||
import_path = path.replace(SEPARATOR, ".")
|
||||
if suffix:
|
||||
import_path += "." + suffix
|
||||
return import_path
|
@ -0,0 +1,598 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.schema.agent import AgentAction, AgentFinish
|
||||
from langchain_core.schema.document import Document
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
"""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on a retry event."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
"""Base callback handler that handles callbacks from LangChain."""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
run_inline: bool = False
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retry(self) -> bool:
|
||||
"""Whether to ignore retry callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that handles callbacks from LangChain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on a retry event."""
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseCallbackManager")
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that handles callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
self.tags = tags or []
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.remove_tags([tag])
|
||||
self.tags.extend(tags)
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
for tag in tags:
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
for key in keys:
|
||||
self.metadata.pop(key)
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,97 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain_core.utils.input import print_text
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
print_text(action.log, color=color or self.color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
if observation_prefix is not None:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color or self.color)
|
||||
if llm_prefix is not None:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text(text, color=color or self.color, end=end)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text(finish.log, color=color or self.color, end="\n")
|
@ -0,0 +1,67 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
@ -0,0 +1,537 @@
|
||||
"""Base interfaces for tracing runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.schema.document import Document
|
||||
from langchain_core.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Run] = {}
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
parent_run: Run,
|
||||
child_run: Run,
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
"""Start a trace for a run."""
|
||||
if run.parent_run_id:
|
||||
parent_run = self.run_map.get(str(run.parent_run_id))
|
||||
if parent_run:
|
||||
self._add_child_run(parent_run, run)
|
||||
parent_run.child_execution_order = max(
|
||||
parent_run.child_execution_order, run.child_execution_order
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||
self.run_map[str(run.id)] = run
|
||||
self._on_run_create(run)
|
||||
|
||||
def _end_trace(self, run: Run) -> None:
|
||||
"""End a trace for a run."""
|
||||
if not run.parent_run_id:
|
||||
self._persist_run(run)
|
||||
else:
|
||||
parent_run = self.run_map.get(str(run.parent_run_id))
|
||||
if parent_run is None:
|
||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||
elif (
|
||||
run.child_execution_order is not None
|
||||
and parent_run.child_execution_order is not None
|
||||
and run.child_execution_order > parent_run.child_execution_order
|
||||
):
|
||||
parent_run.child_execution_order = run.child_execution_order
|
||||
self.run_map.pop(str(run.id))
|
||||
self._on_run_update(run)
|
||||
|
||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
if parent_run_id is None:
|
||||
return 1
|
||||
|
||||
parent_run = self.run_map.get(parent_run_id)
|
||||
if parent_run is None:
|
||||
logger.debug(f"Parent run with UUID {parent_run_id} not found.")
|
||||
return 1
|
||||
if parent_run.child_execution_order is None:
|
||||
raise TracerException(
|
||||
f"Parent run with UUID {parent_run_id} has no child execution order."
|
||||
)
|
||||
|
||||
return parent_run.child_execution_order + 1
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
llm_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"prompts": prompts},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
run_type="llm",
|
||||
tags=tags or [],
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
self._on_llm_start(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.utcnow(),
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
)
|
||||
self._on_llm_new_token(llm_run, token, chunk)
|
||||
return llm_run
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retry callback.")
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None:
|
||||
raise TracerException("No Run found to be traced for on_retry")
|
||||
retry_d: Dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
}
|
||||
if retry_state.outcome is None:
|
||||
retry_d["outcome"] = "N/A"
|
||||
elif retry_state.outcome.failed:
|
||||
retry_d["outcome"] = "failed"
|
||||
exception = retry_state.outcome.exception()
|
||||
retry_d["exception"] = str(exception)
|
||||
retry_d["exception_type"] = exception.__class__.__name__
|
||||
else:
|
||||
retry_d["outcome"] = "success"
|
||||
retry_d["result"] = str(retry_state.outcome.result())
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "retry",
|
||||
"time": datetime.utcnow(),
|
||||
"kwargs": retry_d,
|
||||
},
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
output_generation = llm_run.outputs["generations"][i][j]
|
||||
if "message" in output_generation:
|
||||
output_generation["message"] = dumpd(
|
||||
cast(ChatGeneration, generation).message
|
||||
)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
self._on_llm_end(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
self._on_chain_error(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a chain run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chain_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=run_type or "chain",
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
self._on_chain_start(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""End a trace for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run.outputs = (
|
||||
outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
tool_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"input": input_str},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type="tool",
|
||||
tags=tags or [],
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
self._on_tool_start(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_end(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name=name or "Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
tags=tags,
|
||||
child_runs=[],
|
||||
run_type="retriever",
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
|
||||
def __copy__(self) -> BaseTracer:
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Process a run upon creation."""
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
"""Process a run upon update."""
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon start."""
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
"""Process the Chain Run."""
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
"""Process the Tool Run."""
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
@ -0,0 +1,223 @@
|
||||
"""A tracer that runs evaluators over completed runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import langsmith
|
||||
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
|
||||
|
||||
from langchain_core.callbacks import manager
|
||||
from langchain_core.callbacks.tracers import langchain as langchain_tracer
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.langchain import _get_executor
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
|
||||
|
||||
|
||||
def wait_for_all_evaluators() -> None:
|
||||
"""Wait for all tracers to finish."""
|
||||
global _TRACERS
|
||||
for tracer in list(_TRACERS):
|
||||
if tracer is not None:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evaluators : Sequence[RunEvaluator]
|
||||
The run evaluators to apply to all top level runs.
|
||||
client : LangSmith Client, optional
|
||||
The LangSmith client instance to use for evaluating the runs.
|
||||
If not specified, a new instance will be created.
|
||||
example_id : Union[UUID, str], optional
|
||||
The example ID to be associated with the runs.
|
||||
project_name : str, optional
|
||||
The LangSmith project name to be organize eval chain runs under.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
example_id : Union[UUID, None]
|
||||
The example ID associated with the runs.
|
||||
client : Client
|
||||
The LangSmith client instance used for evaluating the runs.
|
||||
evaluators : Sequence[RunEvaluator]
|
||||
The sequence of run evaluators to be executed.
|
||||
executor : ThreadPoolExecutor
|
||||
The thread pool executor used for running the evaluators.
|
||||
futures : Set[Future]
|
||||
The set of futures representing the running evaluators.
|
||||
skip_unfinished : bool
|
||||
Whether to skip runs that are not finished or raised
|
||||
an error.
|
||||
project_name : Optional[str]
|
||||
The LangSmith project name to be organize eval chain runs under.
|
||||
"""
|
||||
|
||||
name = "evaluator_callback_handler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evaluators: Sequence[langsmith.RunEvaluator],
|
||||
client: Optional[langsmith.Client] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
skip_unfinished: bool = True,
|
||||
project_name: Optional[str] = "evaluators",
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.client = client or langchain_tracer.get_client()
|
||||
self.evaluators = evaluators
|
||||
if max_concurrency is None:
|
||||
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
|
||||
elif max_concurrency > 0:
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
|
||||
weakref.finalize(
|
||||
self,
|
||||
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
|
||||
)
|
||||
else:
|
||||
self.executor = None
|
||||
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
|
||||
self.lock = threading.Lock()
|
||||
global _TRACERS
|
||||
_TRACERS.add(self)
|
||||
|
||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||
"""Evaluate the run in the project.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run : Run
|
||||
The run to be evaluated.
|
||||
evaluator : RunEvaluator
|
||||
The evaluator to use for evaluating the run.
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.project_name is None:
|
||||
eval_result = self.client.evaluate_run(run, evaluator)
|
||||
eval_results = [eval_result]
|
||||
with manager.tracing_v2_enabled(
|
||||
project_name=self.project_name, tags=["eval"], client=self.client
|
||||
) as cb:
|
||||
reference_example = (
|
||||
self.client.read_example(run.reference_example_id)
|
||||
if run.reference_example_id
|
||||
else None
|
||||
)
|
||||
evaluation_result = evaluator.evaluate_run(
|
||||
# This is subclass, but getting errors for some reason
|
||||
run, # type: ignore
|
||||
example=reference_example,
|
||||
)
|
||||
eval_results = self._log_evaluation_feedback(
|
||||
evaluation_result,
|
||||
run,
|
||||
source_run_id=cb.latest_run.id if cb.latest_run else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error evaluating run {run.id} with "
|
||||
f"{evaluator.__class__.__name__}: {repr(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
example_id = str(run.reference_example_id)
|
||||
with self.lock:
|
||||
for res in eval_results:
|
||||
run_id = (
|
||||
str(getattr(res, "target_run_id"))
|
||||
if hasattr(res, "target_run_id")
|
||||
else str(run.id)
|
||||
)
|
||||
self.logged_eval_results.setdefault((run_id, example_id), []).append(
|
||||
res
|
||||
)
|
||||
|
||||
def _select_eval_results(
|
||||
self,
|
||||
results: Union[EvaluationResult, EvaluationResults],
|
||||
) -> List[EvaluationResult]:
|
||||
if isinstance(results, EvaluationResult):
|
||||
results_ = [results]
|
||||
elif isinstance(results, dict) and "results" in results:
|
||||
results_ = cast(List[EvaluationResult], results["results"])
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid evaluation result type {type(results)}."
|
||||
" Expected EvaluationResult or EvaluationResults."
|
||||
)
|
||||
return results_
|
||||
|
||||
def _log_evaluation_feedback(
|
||||
self,
|
||||
evaluator_response: Union[EvaluationResult, EvaluationResults],
|
||||
run: Run,
|
||||
source_run_id: Optional[UUID] = None,
|
||||
) -> List[EvaluationResult]:
|
||||
results = self._select_eval_results(evaluator_response)
|
||||
for res in results:
|
||||
source_info_: Dict[str, Any] = {}
|
||||
if res.evaluator_info:
|
||||
source_info_ = {**res.evaluator_info, **source_info_}
|
||||
run_id_ = (
|
||||
getattr(res, "target_run_id")
|
||||
if hasattr(res, "target_run_id") and res.target_run_id is not None
|
||||
else run.id
|
||||
)
|
||||
self.client.create_feedback(
|
||||
run_id_,
|
||||
res.key,
|
||||
score=res.score,
|
||||
value=res.value,
|
||||
comment=res.comment,
|
||||
correction=res.correction,
|
||||
source_info=source_info_,
|
||||
source_run_id=res.source_run_id or source_run_id,
|
||||
feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL,
|
||||
)
|
||||
return results
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Run the evaluator on the run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run : Run
|
||||
The run to be evaluated.
|
||||
|
||||
"""
|
||||
if self.skip_unfinished and not run.outputs:
|
||||
logger.debug(f"Skipping unfinished run {run.id}")
|
||||
return
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
for evaluator in self.evaluators:
|
||||
if self.executor is None:
|
||||
self._evaluate_in_project(run_, evaluator)
|
||||
else:
|
||||
self.futures.add(
|
||||
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for all futures to complete."""
|
||||
wait(self.futures)
|
@ -0,0 +1,262 @@
|
||||
"""A Tracer implementation that records to LangChain endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith import utils as ls_utils
|
||||
from tenacity import (
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_LOGGED = set()
|
||||
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
|
||||
_CLIENT: Optional[Client] = None
|
||||
_EXECUTOR: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
|
||||
def log_error_once(method: str, exception: Exception) -> None:
|
||||
"""Log an error once."""
|
||||
global _LOGGED
|
||||
if (method, type(exception)) in _LOGGED:
|
||||
return
|
||||
_LOGGED.add((method, type(exception)))
|
||||
logger.error(exception)
|
||||
|
||||
|
||||
def wait_for_all_tracers() -> None:
|
||||
"""Wait for all tracers to finish."""
|
||||
global _TRACERS
|
||||
for tracer in list(_TRACERS):
|
||||
if tracer is not None:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
def get_client() -> Client:
|
||||
"""Get the client."""
|
||||
global _CLIENT
|
||||
if _CLIENT is None:
|
||||
_CLIENT = Client()
|
||||
return _CLIENT
|
||||
|
||||
|
||||
def _get_executor() -> ThreadPoolExecutor:
|
||||
"""Get the executor."""
|
||||
global _EXECUTOR
|
||||
if _EXECUTOR is None:
|
||||
_EXECUTOR = ThreadPoolExecutor()
|
||||
return _EXECUTOR
|
||||
|
||||
|
||||
def _copy(run: Run) -> Run:
|
||||
"""Copy a run."""
|
||||
try:
|
||||
return run.copy(deep=True)
|
||||
except TypeError:
|
||||
# Fallback in case the object contains a lock or other
|
||||
# non-pickleable object
|
||||
return run.copy()
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
client: Optional[Client] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
use_threading: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.project_name = project_name or ls_utils.get_tracer_project()
|
||||
self.client = client or get_client()
|
||||
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.tags = tags or []
|
||||
self.executor = _get_executor() if use_threading else None
|
||||
self.latest_run: Optional[Run] = None
|
||||
global _TRACERS
|
||||
_TRACERS.add(self)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
run_type="llm",
|
||||
tags=tags,
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
self.latest_run = run_
|
||||
|
||||
def get_run_url(self) -> str:
|
||||
"""Get the LangSmith root run URL"""
|
||||
if not self.latest_run:
|
||||
raise ValueError("No traced run found.")
|
||||
# If this is the first run in a project, the project may not yet be created.
|
||||
# This method is only really useful for debugging flows, so we will assume
|
||||
# there is some tolerace for latency.
|
||||
for attempt in Retrying(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential_jitter(),
|
||||
retry=retry_if_exception_type(ls_utils.LangSmithError),
|
||||
):
|
||||
with attempt:
|
||||
return self.client.get_run_url(
|
||||
run=self.latest_run, project_name=self.project_name
|
||||
)
|
||||
raise ValueError("Failed to get run URL.")
|
||||
|
||||
def _get_tags(self, run: Run) -> List[str]:
|
||||
"""Get combined tags for a run."""
|
||||
tags = set(run.tags or [])
|
||||
tags.update(self.tags or [])
|
||||
return list(tags)
|
||||
|
||||
def _persist_run_single(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
run_dict = run.dict(exclude={"child_runs"})
|
||||
run_dict["tags"] = self._get_tags(run)
|
||||
extra = run_dict.get("extra", {})
|
||||
extra["runtime"] = get_runtime_environment()
|
||||
run_dict["extra"] = extra
|
||||
try:
|
||||
self.client.create_run(**run_dict, project_name=self.project_name)
|
||||
except Exception as e:
|
||||
# Errors are swallowed by the thread executor so we need to log them here
|
||||
log_error_once("post", e)
|
||||
raise
|
||||
|
||||
def _update_run_single(self, run: Run) -> None:
|
||||
"""Update a run."""
|
||||
try:
|
||||
run_dict = run.dict()
|
||||
run_dict["tags"] = self._get_tags(run)
|
||||
self.client.update_run(run.id, **run_dict)
|
||||
except Exception as e:
|
||||
# Errors are swallowed by the thread executor so we need to log them here
|
||||
log_error_once("patch", e)
|
||||
raise
|
||||
|
||||
def _submit(self, function: Callable[[Run], None], run: Run) -> None:
|
||||
"""Submit a function to the executor."""
|
||||
if self.executor is None:
|
||||
function(run)
|
||||
else:
|
||||
self._futures.add(self.executor.submit(function, run))
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, _copy(run))
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, _copy(run))
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, _copy(run))
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
"""Process the Chain Run."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, _copy(run))
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
"""Process the Tool Run."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, _copy(run))
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
self._submit(self._update_run_single, _copy(run))
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for the given futures to complete."""
|
||||
wait(self._futures)
|
@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
Run,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionV1,
|
||||
TracerSessionV1Base,
|
||||
)
|
||||
from langchain_core.schema.messages import get_buffer_string
|
||||
from langchain_core.utils import raise_for_status_with_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_headers() -> Dict[str, Any]:
|
||||
"""Get the headers for the LangChain API."""
|
||||
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
return headers
|
||||
|
||||
|
||||
def _get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
|
||||
|
||||
class LangChainTracerV1(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self.session: Optional[TracerSessionV1] = None
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = get_headers()
|
||||
|
||||
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
session = self.session or self.load_default_session()
|
||||
if not isinstance(session, TracerSessionV1):
|
||||
raise ValueError(
|
||||
"LangChainTracerV1 is not compatible with"
|
||||
f" session of type {type(session)}"
|
||||
)
|
||||
|
||||
if run.run_type == "llm":
|
||||
if "prompts" in run.inputs:
|
||||
prompts = run.inputs["prompts"]
|
||||
elif "messages" in run.inputs:
|
||||
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
|
||||
else:
|
||||
raise ValueError("No prompts found in LLM run inputs")
|
||||
return LLMRun(
|
||||
uuid=str(run.id) if run.id else None,
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
extra=run.extra,
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized,
|
||||
session_id=session.id,
|
||||
error=run.error,
|
||||
prompts=prompts,
|
||||
response=run.outputs if run.outputs else None,
|
||||
)
|
||||
if run.run_type == "chain":
|
||||
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||
return ChainRun(
|
||||
uuid=str(run.id) if run.id else None,
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized,
|
||||
session_id=session.id,
|
||||
inputs=run.inputs,
|
||||
outputs=run.outputs,
|
||||
error=run.error,
|
||||
extra=run.extra,
|
||||
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||
child_chain_runs=[
|
||||
run for run in child_runs if isinstance(run, ChainRun)
|
||||
],
|
||||
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||
)
|
||||
if run.run_type == "tool":
|
||||
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
|
||||
return ToolRun(
|
||||
uuid=str(run.id) if run.id else None,
|
||||
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
execution_order=run.execution_order,
|
||||
child_execution_order=run.child_execution_order,
|
||||
serialized=run.serialized,
|
||||
session_id=session.id,
|
||||
action=str(run.serialized),
|
||||
tool_input=run.inputs.get("input", ""),
|
||||
output=None if run.outputs is None else run.outputs.get("output"),
|
||||
error=run.error,
|
||||
extra=run.extra,
|
||||
child_chain_runs=[
|
||||
run for run in child_runs if isinstance(run, ChainRun)
|
||||
],
|
||||
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
|
||||
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
|
||||
)
|
||||
raise ValueError(f"Unknown run type: {run.run_type}")
|
||||
|
||||
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
if isinstance(run, Run):
|
||||
v1_run = self._convert_to_v1_run(run)
|
||||
else:
|
||||
v1_run = run
|
||||
if isinstance(v1_run, LLMRun):
|
||||
endpoint = f"{self._endpoint}/llm-runs"
|
||||
elif isinstance(v1_run, ChainRun):
|
||||
endpoint = f"{self._endpoint}/chain-runs"
|
||||
else:
|
||||
endpoint = f"{self._endpoint}/tool-runs"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
data=v1_run.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to persist run: {e}")
|
||||
|
||||
def _persist_session(
|
||||
self, session_create: TracerSessionV1Base
|
||||
) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create session, using default session: {e}")
|
||||
session = TracerSessionV1(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
url = f"{self._endpoint}/sessions"
|
||||
if session_name:
|
||||
url += f"?name={session_name}"
|
||||
r = requests.get(url, headers=self._headers)
|
||||
|
||||
tracer_session = TracerSessionV1(**r.json()[0])
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logger.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSessionV1(id=1)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self._load_session("default")
|
@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
import jsonpatch
|
||||
from anyio import create_memory_object_stream
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.load.load import load
|
||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
class LogEntry(TypedDict):
|
||||
"""A single entry in the run log."""
|
||||
|
||||
id: str
|
||||
"""ID of the sub-run."""
|
||||
name: str
|
||||
"""Name of the object being run."""
|
||||
type: str
|
||||
"""Type of the object being run, eg. prompt, chain, llm, etc."""
|
||||
tags: List[str]
|
||||
"""List of tags for the run."""
|
||||
metadata: Dict[str, Any]
|
||||
"""Key-value pairs of metadata for the run."""
|
||||
start_time: str
|
||||
"""ISO-8601 timestamp of when the run started."""
|
||||
|
||||
streamed_output_str: List[str]
|
||||
"""List of LLM tokens streamed by this run, if applicable."""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of this run.
|
||||
Only available after the run has finished successfully."""
|
||||
end_time: Optional[str]
|
||||
"""ISO-8601 timestamp of when the run ended.
|
||||
Only available after the run has finished."""
|
||||
|
||||
|
||||
class RunState(TypedDict):
|
||||
"""State of the run."""
|
||||
|
||||
id: str
|
||||
"""ID of the run."""
|
||||
streamed_output: List[Any]
|
||||
"""List of output chunks streamed by Runnable.stream()"""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
|
||||
Only available after the run has finished successfully."""
|
||||
|
||||
logs: Dict[str, LogEntry]
|
||||
"""Map of run names to sub-runs. If filters were supplied, this list will
|
||||
contain only the runs that matched the filters."""
|
||||
|
||||
|
||||
class RunLogPatch:
|
||||
"""A patch to the run log."""
|
||||
|
||||
ops: List[Dict[str, Any]]
|
||||
"""List of jsonpatch operations, which describe how to create the run state
|
||||
from an empty dict. This is the minimal representation of the log, designed to
|
||||
be serialized as JSON and sent over the wire to reconstruct the log on the other
|
||||
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
|
||||
see https://jsonpatch.com for more information."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any]) -> None:
|
||||
self.ops = list(ops)
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(None, ops)
|
||||
return RunLog(*ops, state=state)
|
||||
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
# 1:-1 to get rid of the [] around the list
|
||||
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, RunLogPatch) and self.ops == other.ops
|
||||
|
||||
|
||||
class RunLog(RunLogPatch):
|
||||
"""A run log."""
|
||||
|
||||
state: RunState
|
||||
"""Current state of the log, obtained from applying all ops in sequence."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
|
||||
super().__init__(*ops)
|
||||
self.state = state
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(self.state, other.ops)
|
||||
return RunLog(*ops, state=state)
|
||||
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
return f"RunLog({pformat(self.state)})"
|
||||
|
||||
|
||||
class LogStreamCallbackHandler(BaseTracer):
|
||||
"""A tracer that streams run logs to a stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
auto_close: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.auto_close = auto_close
|
||||
self.include_names = include_names
|
||||
self.include_types = include_types
|
||||
self.include_tags = include_tags
|
||||
self.exclude_names = exclude_names
|
||||
self.exclude_types = exclude_types
|
||||
self.exclude_tags = exclude_tags
|
||||
|
||||
send_stream: Any
|
||||
receive_stream: Any
|
||||
send_stream, receive_stream = create_memory_object_stream(
|
||||
math.inf, item_type=RunLogPatch
|
||||
)
|
||||
self.lock = threading.Lock()
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self._key_map_by_run_id: Dict[UUID, str] = {}
|
||||
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def include_run(self, run: Run) -> bool:
|
||||
if run.id == self.root_id:
|
||||
return False
|
||||
|
||||
run_tags = run.tags or []
|
||||
|
||||
if (
|
||||
self.include_names is None
|
||||
and self.include_types is None
|
||||
and self.include_tags is None
|
||||
):
|
||||
include = True
|
||||
else:
|
||||
include = False
|
||||
|
||||
if self.include_names is not None:
|
||||
include = include or run.name in self.include_names
|
||||
if self.include_types is not None:
|
||||
include = include or run.run_type in self.include_types
|
||||
if self.include_tags is not None:
|
||||
include = include or any(tag in self.include_tags for tag in run_tags)
|
||||
|
||||
if self.exclude_names is not None:
|
||||
include = include and run.name not in self.exclude_names
|
||||
if self.exclude_types is not None:
|
||||
include = include and run.run_type not in self.exclude_types
|
||||
if self.exclude_tags is not None:
|
||||
include = include and all(tag not in self.exclude_tags for tag in run_tags)
|
||||
|
||||
return include
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
# This is a legacy method only called once for an entire run tree
|
||||
# therefore not useful here
|
||||
pass
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Start a run."""
|
||||
if self.root_id is None:
|
||||
self.root_id = run.id
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": RunState(
|
||||
id=str(run.id),
|
||||
streamed_output=[],
|
||||
final_output=None,
|
||||
logs={},
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not self.include_run(run):
|
||||
return
|
||||
|
||||
# Determine previous index, increment by 1
|
||||
with self.lock:
|
||||
self._counter_map_by_name[run.name] += 1
|
||||
count = self._counter_map_by_name[run.name]
|
||||
self._key_map_by_run_id[run.id] = (
|
||||
run.name if count == 1 else f"{run.name}:{count}"
|
||||
)
|
||||
|
||||
# Add the run to the stream
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
||||
"value": LogEntry(
|
||||
id=str(run.id),
|
||||
name=run.name,
|
||||
type=run.run_type,
|
||||
tags=run.tags or [],
|
||||
metadata=(run.extra or {}).get("metadata", {}),
|
||||
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
||||
streamed_output_str=[],
|
||||
final_output=None,
|
||||
end_time=None,
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
"""Finish a run."""
|
||||
try:
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/final_output",
|
||||
# to undo the dumpd done by some runnables / tracer / etc
|
||||
"value": load(run.outputs),
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/end_time",
|
||||
"value": run.end_time.isoformat(timespec="milliseconds")
|
||||
if run.end_time is not None
|
||||
else None,
|
||||
},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if run.id == self.root_id:
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": load(run.outputs),
|
||||
}
|
||||
)
|
||||
)
|
||||
if self.auto_close:
|
||||
self.send_stream.close()
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output_str/-",
|
||||
"value": token,
|
||||
}
|
||||
)
|
||||
)
|
@ -0,0 +1,54 @@
|
||||
from typing import Callable, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
call_func_with_variable_args,
|
||||
)
|
||||
|
||||
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
|
||||
|
||||
class RootListenersTracer(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
on_start: Optional[Listener],
|
||||
on_end: Optional[Listener],
|
||||
on_error: Optional[Listener],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self._arg_on_start = on_start
|
||||
self._arg_on_end = on_end
|
||||
self._arg_on_error = on_error
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
# This is a legacy method only called once for an entire run tree
|
||||
# therefore not useful here
|
||||
pass
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
if self.root_id is not None:
|
||||
return
|
||||
|
||||
self.root_id = run.id
|
||||
|
||||
if self._arg_on_start is not None:
|
||||
call_func_with_variable_args(self._arg_on_start, run, self.config)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
if run.id != self.root_id:
|
||||
return
|
||||
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
call_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
call_func_with_variable_args(self._arg_on_error, run, self.config)
|
@ -0,0 +1,52 @@
|
||||
"""A tracer that collects all nested runs in a list."""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class RunCollectorCallbackHandler(BaseTracer):
|
||||
"""
|
||||
A tracer that collects all nested runs in a list.
|
||||
|
||||
This tracer is useful for inspection and evaluation purposes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
"""
|
||||
|
||||
name: str = "run-collector_callback_handler"
|
||||
|
||||
def __init__(
|
||||
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RunCollectorCallbackHandler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.traced_runs: List[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""
|
||||
Persist a run by adding it to the traced_runs list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run : Run
|
||||
The run to be persisted.
|
||||
"""
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
self.traced_runs.append(run_)
|
@ -0,0 +1,140 @@
|
||||
"""Schemas for tracers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.schemas import RunBase as BaseRunV2
|
||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.schema import LLMResult
|
||||
|
||||
|
||||
def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
||||
"""RunTypeEnum."""
|
||||
warnings.warn(
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
" (e.g. 'llm', 'chain', 'tool').",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return RunTypeEnumDep
|
||||
|
||||
|
||||
class TracerSessionV1Base(BaseModel):
|
||||
"""Base class for TracerSessionV1."""
|
||||
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
name: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TracerSessionV1Create(TracerSessionV1Base):
|
||||
"""Create class for TracerSessionV1."""
|
||||
|
||||
|
||||
class TracerSessionV1(TracerSessionV1Base):
|
||||
"""TracerSessionV1 schema."""
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
class TracerSessionBase(TracerSessionV1Base):
|
||||
"""Base class for TracerSession."""
|
||||
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class TracerSession(TracerSessionBase):
|
||||
"""TracerSessionV1 schema for the V2 API."""
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
class ChainRun(BaseRun):
|
||||
"""Class for ChainRun."""
|
||||
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRun(BaseRun):
|
||||
"""Class for ToolRun."""
|
||||
|
||||
tool_input: str
|
||||
output: Optional[str] = None
|
||||
action: str
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Begin V2 API Schemas
|
||||
|
||||
|
||||
class Run(BaseRunV2):
|
||||
"""Run schema for the V2 API in the Tracer."""
|
||||
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
child_runs: List[Run] = Field(default_factory=list)
|
||||
tags: Optional[List[str]] = Field(default_factory=list)
|
||||
events: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def assign_name(cls, values: dict) -> dict:
|
||||
"""Assign name to the run."""
|
||||
if values.get("name") is None:
|
||||
if "name" in values["serialized"]:
|
||||
values["name"] = values["serialized"]["name"]
|
||||
elif "id" in values["serialized"]:
|
||||
values["name"] = values["serialized"]["id"][-1]
|
||||
if values.get("events") is None:
|
||||
values["events"] = []
|
||||
return values
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
ToolRun.update_forward_refs()
|
||||
Run.update_forward_refs()
|
||||
|
||||
__all__ = [
|
||||
"BaseRun",
|
||||
"ChainRun",
|
||||
"LLMRun",
|
||||
"Run",
|
||||
"RunTypeEnum",
|
||||
"ToolRun",
|
||||
"TracerSession",
|
||||
"TracerSessionBase",
|
||||
"TracerSessionV1",
|
||||
"TracerSessionV1Base",
|
||||
"TracerSessionV1Create",
|
||||
]
|
@ -0,0 +1,178 @@
|
||||
import json
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.utils.input import get_bolded_text, get_colored_text
|
||||
|
||||
|
||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
"""
|
||||
Try to stringify an object to JSON.
|
||||
Args:
|
||||
obj: Object to stringify.
|
||||
fallback: Fallback string to return if the object cannot be stringified.
|
||||
|
||||
Returns:
|
||||
A JSON string if the object can be stringified, otherwise the fallback string.
|
||||
|
||||
"""
|
||||
try:
|
||||
return json.dumps(obj, indent=2, ensure_ascii=False)
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
|
||||
def elapsed(run: Any) -> str:
|
||||
"""Get the elapsed time of a run.
|
||||
|
||||
Args:
|
||||
run: any object with a start_time and end_time attribute.
|
||||
|
||||
Returns:
|
||||
A string with the elapsed time in seconds or
|
||||
milliseconds if time is less than a second.
|
||||
|
||||
"""
|
||||
elapsed_time = run.end_time - run.start_time
|
||||
milliseconds = elapsed_time.total_seconds() * 1000
|
||||
if milliseconds < 1000:
|
||||
return f"{milliseconds:.0f}ms"
|
||||
return f"{(milliseconds / 1000):.2f}s"
|
||||
|
||||
|
||||
class FunctionCallbackHandler(BaseTracer):
|
||||
"""Tracer that calls a function with a single str parameter."""
|
||||
|
||||
name: str = "function_callback_handler"
|
||||
|
||||
def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.function_callback = function
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
pass
|
||||
|
||||
def get_parents(self, run: Run) -> List[Run]:
|
||||
parents = []
|
||||
current_run = run
|
||||
while current_run.parent_run_id:
|
||||
parent = self.run_map.get(str(current_run.parent_run_id))
|
||||
if parent:
|
||||
parents.append(parent)
|
||||
current_run = parent
|
||||
else:
|
||||
break
|
||||
return parents
|
||||
|
||||
def get_breadcrumbs(self, run: Run) -> str:
|
||||
parents = self.get_parents(run)[::-1]
|
||||
string = " > ".join(
|
||||
f"{parent.execution_order}:{parent.run_type}:{parent.name}"
|
||||
if i != len(parents) - 1
|
||||
else f"{parent.execution_order}:{parent.run_type}:{parent.name}"
|
||||
for i, parent in enumerate(parents + [run])
|
||||
)
|
||||
return string
|
||||
|
||||
# logging methods
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
run_type = run.run_type.capitalize()
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[chain/start]', color='green')} "
|
||||
+ get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n")
|
||||
+ f"{try_json_stringify(run.inputs, '[inputs]')}"
|
||||
)
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
run_type = run.run_type.capitalize()
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[chain/end]', color='blue')} "
|
||||
+ get_bolded_text(
|
||||
f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n"
|
||||
)
|
||||
+ f"{try_json_stringify(run.outputs, '[outputs]')}"
|
||||
)
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
run_type = run.run_type.capitalize()
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[chain/error]', color='red')} "
|
||||
+ get_bolded_text(
|
||||
f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n"
|
||||
)
|
||||
+ f"{try_json_stringify(run.error, '[error]')}"
|
||||
)
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
inputs = (
|
||||
{"prompts": [p.strip() for p in run.inputs["prompts"]]}
|
||||
if "prompts" in run.inputs
|
||||
else run.inputs
|
||||
)
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[llm/start]', color='green')} "
|
||||
+ get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n")
|
||||
+ f"{try_json_stringify(inputs, '[inputs]')}"
|
||||
)
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[llm/end]', color='blue')} "
|
||||
+ get_bolded_text(
|
||||
f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n"
|
||||
)
|
||||
+ f"{try_json_stringify(run.outputs, '[response]')}"
|
||||
)
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[llm/error]', color='red')} "
|
||||
+ get_bolded_text(
|
||||
f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n"
|
||||
)
|
||||
+ f"{try_json_stringify(run.error, '[error]')}"
|
||||
)
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
self.function_callback(
|
||||
f'{get_colored_text("[tool/start]", color="green")} '
|
||||
+ get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n")
|
||||
+ f'"{run.inputs["input"].strip()}"'
|
||||
)
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
if run.outputs:
|
||||
self.function_callback(
|
||||
f'{get_colored_text("[tool/end]", color="blue")} '
|
||||
+ get_bolded_text(
|
||||
f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n"
|
||||
)
|
||||
+ f'"{run.outputs["output"].strip()}"'
|
||||
)
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
crumbs = self.get_breadcrumbs(run)
|
||||
self.function_callback(
|
||||
f"{get_colored_text('[tool/error]', color='red')} "
|
||||
+ get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ")
|
||||
+ f"Tool run errored with error:\n"
|
||||
f"{run.error}"
|
||||
)
|
||||
|
||||
|
||||
class ConsoleCallbackHandler(FunctionCallbackHandler):
|
||||
"""Tracer that prints to the console."""
|
||||
|
||||
name: str = "console_callback_handler"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(function=print, **kwargs)
|
@ -0,0 +1,735 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.prompts.base import StringPromptValue
|
||||
from langchain_core.prompts.chat import ChatPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.schema.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.schema.output import ChatGenerationChunk
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
from langchain_core.globals import get_verbose
|
||||
|
||||
return get_verbose()
|
||||
|
||||
|
||||
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
async def _agenerate_from_stream(
|
||||
stream: AsyncIterator[ChatGenerationChunk],
|
||||
) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
"""Whether to cache the response."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Callback manager to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return AnyMessage
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
return ChatPromptValue(messages=input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
return cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_string = dumps(self)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
try:
|
||||
results.append(
|
||||
self._generate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except BaseException as e:
|
||||
if run_managers:
|
||||
run_managers[i].on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
if run_managers:
|
||||
run_infos = []
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
run_infos.append(RunInfo(run_id=manager.run_id))
|
||||
output.run = run_infos
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
for i, m in enumerate(messages)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
exceptions = []
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, BaseException):
|
||||
if run_managers:
|
||||
await run_managers[i].on_llm_error(res)
|
||||
exceptions.append(res)
|
||||
if exceptions:
|
||||
if run_managers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[res.generations], llm_output=res.llm_output
|
||||
)
|
||||
)
|
||||
for run_manager, res in zip(run_managers, results)
|
||||
if not isinstance(res, Exception)
|
||||
]
|
||||
)
|
||||
raise exceptions[0]
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return await self.agenerate(
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
).generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
result = await self.agenerate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
generation = result.generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = await self._call_async(
|
||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||
)
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
"""Simple Chat Model."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
@ -0,0 +1,17 @@
|
||||
import platform
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_runtime_environment() -> dict:
|
||||
"""Get information about the LangChain runtime environment."""
|
||||
# Lazy import to avoid circular imports
|
||||
from langchain_core import __version__
|
||||
|
||||
return {
|
||||
"library_version": __version__,
|
||||
"library": "langchain",
|
||||
"platform": platform.platform(),
|
||||
"runtime": "python",
|
||||
"runtime_version": platform.python_version(),
|
||||
}
|
@ -0,0 +1,197 @@
|
||||
# flake8: noqa
|
||||
"""Global values and configuration that apply to all of LangChain."""
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.schema import BaseCache
|
||||
|
||||
|
||||
# DO NOT USE THESE VALUES DIRECTLY!
|
||||
# Use them only via `get_<X>()` and `set_<X>()` below,
|
||||
# or else your code may behave unexpectedly with other uses of these global settings:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
_verbose: bool = False
|
||||
_debug: bool = False
|
||||
_llm_cache: Optional["BaseCache"] = None
|
||||
|
||||
|
||||
def set_verbose(value: bool) -> None:
|
||||
"""Set a new value for the `verbose` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing verbose from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.verbose` is no longer supported, and once all users
|
||||
# have migrated to using `set_verbose()` here.
|
||||
langchain.verbose = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _verbose
|
||||
_verbose = value
|
||||
|
||||
|
||||
def get_verbose() -> bool:
|
||||
"""Get the value of the `verbose` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing verbose from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.verbose` is no longer supported, and once all users
|
||||
# have migrated to using `set_verbose()` here.
|
||||
#
|
||||
# In the meantime, the `verbose` setting is considered True if either the old
|
||||
# or the new value are True. This accommodates users who haven't migrated
|
||||
# to using `set_verbose()` yet. Those users are getting deprecation warnings
|
||||
# directing them to use `set_verbose()` when they import `langhchain.verbose`.
|
||||
old_verbose = langchain.verbose
|
||||
except ImportError:
|
||||
old_verbose = False
|
||||
|
||||
global _verbose
|
||||
return _verbose or old_verbose
|
||||
|
||||
|
||||
def set_debug(value: bool) -> None:
|
||||
"""Set a new value for the `debug` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Importing debug from langchain_core root module is no longer supported",
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.debug` is no longer supported, and once all users
|
||||
# have migrated to using `set_debug()` here.
|
||||
langchain.debug = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _debug
|
||||
_debug = value
|
||||
|
||||
|
||||
def get_debug() -> bool:
|
||||
"""Get the value of the `debug` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Importing debug from langchain_core root module is no longer supported",
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.debug` is no longer supported, and once all users
|
||||
# have migrated to using `set_debug()` here.
|
||||
#
|
||||
# In the meantime, the `debug` setting is considered True if either the old
|
||||
# or the new value are True. This accommodates users who haven't migrated
|
||||
# to using `set_debug()` yet. Those users are getting deprecation warnings
|
||||
# directing them to use `set_debug()` when they import `langhchain.debug`.
|
||||
old_debug = langchain.debug
|
||||
except ImportError:
|
||||
old_debug = False
|
||||
|
||||
global _debug
|
||||
return _debug or old_debug
|
||||
|
||||
|
||||
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing llm_cache from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.llm_cache` is no longer supported, and
|
||||
# once all users have migrated to using `set_llm_cache()` here.
|
||||
langchain.llm_cache = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _llm_cache
|
||||
_llm_cache = value
|
||||
|
||||
|
||||
def get_llm_cache() -> "BaseCache":
|
||||
"""Get the value of the `llm_cache` global setting."""
|
||||
try:
|
||||
import langchain
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing llm_cache from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.llm_cache` is no longer supported, and
|
||||
# once all users have migrated to using `set_llm_cache()` here.
|
||||
#
|
||||
# In the meantime, the `llm_cache` setting returns whichever of
|
||||
# its two backing sources is truthy (not `None` and non-empty),
|
||||
# or the old value if both are falsy. This accommodates users
|
||||
# who haven't migrated to using `set_llm_cache()` yet.
|
||||
# Those users are getting deprecation warnings directing them
|
||||
# to use `set_llm_cache()` when they import `langhchain.llm_cache`.
|
||||
old_llm_cache = langchain.llm_cache
|
||||
except ImportError:
|
||||
old_llm_cache = None
|
||||
|
||||
global _llm_cache
|
||||
return _llm_cache or old_llm_cache
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,6 @@
|
||||
"""Serialization and deserialization."""
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.load.load import load, loads
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"]
|
@ -0,0 +1,26 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
"""Return a default value for a Serializable object or
|
||||
a SerializedNotImplemented object."""
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
else:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
"""Return a json string representation of an object."""
|
||||
if pretty:
|
||||
return json.dumps(obj, default=default, indent=2)
|
||||
else:
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||
"""Return a json dict representation of an object."""
|
||||
return json.loads(dumps(obj))
|
@ -0,0 +1,130 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
|
||||
|
||||
|
||||
class Reviver:
|
||||
"""Reviver for JSON objects."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.secrets_map = secrets_map or dict()
|
||||
# By default only support langchain, but user can pass in additional namespaces
|
||||
self.valid_namespaces = (
|
||||
[*DEFAULT_NAMESPACES, *valid_namespaces]
|
||||
if valid_namespaces
|
||||
else DEFAULT_NAMESPACES
|
||||
)
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
return self.secrets_map[key]
|
||||
else:
|
||||
if key in os.environ and os.environ[key]:
|
||||
return os.environ[key]
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
|
||||
if namespace[0] not in self.valid_namespaces:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", dict())
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
||||
|
||||
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
reviver = Reviver(secrets_map, valid_namespaces)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
# Need to revive leaf nodes before reviving this node
|
||||
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
@ -0,0 +1,207 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
"""Base class for serialized objects."""
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
"""Serialized constructor."""
|
||||
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
"""Serialized secret."""
|
||||
|
||||
type: Literal["secret"]
|
||||
|
||||
|
||||
class SerializedNotImplemented(BaseSerialized):
|
||||
"""Serialized not implemented."""
|
||||
|
||||
type: Literal["not_implemented"]
|
||||
repr: Optional[str]
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
"""Serializable base class."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Is this class serializable?"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||
namespace is ["langchain", "llms", "openai"]
|
||||
"""
|
||||
return cls.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
{"openai_api_key": "OPENAI_API_KEY"}
|
||||
"""
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def lc_id(cls) -> List[str]:
|
||||
"""A unique identifier for this class for serialization purposes.
|
||||
|
||||
The unique identifier is a list of strings that describes the path
|
||||
to the object.
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._lc_kwargs = kwargs
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
if not self.is_lc_serializable():
|
||||
return self.to_json_not_implemented()
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self._lc_kwargs.items()
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
}
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
# Once we get to Serializable, we're done
|
||||
if cls is Serializable:
|
||||
break
|
||||
|
||||
if cls:
|
||||
deprecated_attributes = [
|
||||
"lc_namespace",
|
||||
"lc_serializable",
|
||||
]
|
||||
|
||||
for attr in deprecated_attributes:
|
||||
if hasattr(cls, attr):
|
||||
raise ValueError(
|
||||
f"Class {self.__class__} has a deprecated "
|
||||
f"attribute {attr}. Please use the corresponding "
|
||||
f"classmethod instead."
|
||||
)
|
||||
|
||||
# Get a reference to self bound to each class in the MRO
|
||||
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||
|
||||
secrets.update(this.lc_secrets)
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
# as these secrets may be passed as an environment variable instead
|
||||
for key in secrets.keys():
|
||||
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
||||
if secret_value is not None:
|
||||
lc_kwargs.update({key: secret_value})
|
||||
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": self.lc_id(),
|
||||
"kwargs": lc_kwargs
|
||||
if not secrets
|
||||
else _replace_secrets(lc_kwargs, secrets),
|
||||
}
|
||||
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
current = result
|
||||
for part in parts:
|
||||
if part not in current:
|
||||
break
|
||||
current[part] = current[part].copy()
|
||||
current = current[part]
|
||||
if last in current:
|
||||
current[last] = {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [secret_id],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
"""Serialize a "not implemented" object.
|
||||
|
||||
Args:
|
||||
obj: object to serialize
|
||||
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
"""
|
||||
_id: List[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
elif hasattr(obj, "__class__"):
|
||||
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result: SerializedNotImplemented = {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": _id,
|
||||
"repr": None,
|
||||
}
|
||||
try:
|
||||
result["repr"] = repr(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.schema import BaseOutputParser
|
||||
|
||||
|
||||
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a list."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "list"
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a list of comma separated values, "
|
||||
"eg: `foo, bar, baz`"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text.strip().split(", ")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "comma-separated-list"
|
||||
|
||||
|
||||
class NumberedListOutputParser(ListOutputParser):
|
||||
"""Parse a numbered list."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"\d+\.\s([^\n]+)"
|
||||
|
||||
# Extract the text of each item
|
||||
matches = re.findall(pattern, text)
|
||||
return matches
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "numbered-list"
|
||||
|
||||
|
||||
class MarkdownListOutputParser(ListOutputParser):
|
||||
"""Parse a markdown list."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"-\s([^\n]+)"
|
||||
return re.findall(pattern, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "markdown-list"
|
@ -0,0 +1,75 @@
|
||||
"""**Prompt** is the input to the model.
|
||||
|
||||
Prompt is often constructed
|
||||
from multiple components. Prompt classes and functions make constructing
|
||||
and working with prompts easy.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BasePromptTemplate --> PipelinePromptTemplate
|
||||
StringPromptTemplate --> PromptTemplate
|
||||
FewShotPromptTemplate
|
||||
FewShotPromptWithTemplates
|
||||
BaseChatPromptTemplate --> AutoGPTPrompt
|
||||
ChatPromptTemplate --> AgentScratchPadChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
BaseMessagePromptTemplate --> MessagesPlaceholder
|
||||
BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate
|
||||
HumanMessagePromptTemplate
|
||||
AIMessagePromptTemplate
|
||||
SystemMessagePromptTemplate
|
||||
|
||||
PromptValue --> StringPromptValue
|
||||
ChatPromptValue
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain_core.prompts.base import StringPromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.example_selector import (
|
||||
LengthBasedExampleSelector,
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
from langchain_core.prompts.few_shot import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
from langchain_core.prompts.loading import load_prompt
|
||||
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
||||
from langchain_core.prompts.prompt import Prompt, PromptTemplate
|
||||
from langchain_core.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"AIMessagePromptTemplate",
|
||||
"BaseChatPromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"HumanMessagePromptTemplate",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
"Prompt",
|
||||
"PromptTemplate",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
]
|
@ -0,0 +1,173 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Literal, Set
|
||||
|
||||
from langchain_core.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.schema.prompt import PromptValue
|
||||
from langchain_core.schema.prompt_template import BasePromptTemplate
|
||||
from langchain_core.utils.formatting import formatter
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2.
|
||||
|
||||
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
|
||||
SandboxedEnvironment by default. However, this sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security.
|
||||
Do not accept jinja2 templates from untrusted sources as they may lead
|
||||
to arbitrary Python code execution.
|
||||
|
||||
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||
"""
|
||||
try:
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
"Please be cautious when using jinja2 templates. "
|
||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||
"inputs as that can result in arbitrary Python code execution."
|
||||
)
|
||||
|
||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||
# Please treat this sand-boxing as a best-effort approach rather than
|
||||
# a guarantee of security.
|
||||
# We recommend to never use jinja2 templates with untrusted inputs.
|
||||
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||
# approach not a guarantee of security.
|
||||
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
"""
|
||||
Validate that the input variables are valid for the template.
|
||||
Issues a warning if missing or extra variables are found.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
input_variables: The input variables.
|
||||
"""
|
||||
input_variables_set = set(input_variables)
|
||||
valid_variables = _get_jinja2_variables_from_template(template)
|
||||
missing_variables = valid_variables - input_variables_set
|
||||
extra_variables = input_variables_set - valid_variables
|
||||
|
||||
warning_message = ""
|
||||
if missing_variables:
|
||||
warning_message += f"Missing variables: {missing_variables} "
|
||||
|
||||
if extra_variables:
|
||||
warning_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if warning_message:
|
||||
warnings.warn(warning_message.strip())
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"jinja2": jinja2_formatter,
|
||||
}
|
||||
|
||||
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.validate_input_variables,
|
||||
"jinja2": validate_jinja2,
|
||||
}
|
||||
|
||||
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: List[str]
|
||||
) -> None:
|
||||
"""Check that template string is valid.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
input_variables: The input variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format not in DEFAULT_FORMATTER_MAPPING:
|
||||
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
|
||||
raise ValueError(
|
||||
f"Invalid template format. Got `{template_format}`;"
|
||||
f" should be one of {valid_formats}"
|
||||
)
|
||||
try:
|
||||
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
||||
validator_func(template, input_variables)
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||
+ str(e)
|
||||
)
|
||||
|
||||
|
||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
"""Get the variables from the template.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
|
||||
Returns:
|
||||
The variables from the template.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
elif template_format == "f-string":
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
text: str
|
||||
"""Prompt text."""
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt that exposes the format method, returning a prompt."""
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return StringPromptValue(text=self.format(**kwargs))
|
@ -0,0 +1,748 @@
|
||||
"""Chat prompt template."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.prompts.base import StringPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.schema import (
|
||||
BasePromptTemplate,
|
||||
PromptValue,
|
||||
)
|
||||
from langchain_core.schema.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""Base class for message prompt templates."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Prompt template that assumes variable is already list of messages."""
|
||||
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
def __init__(self, variable_name: str, **kwargs: Any):
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessage.
|
||||
"""
|
||||
value = kwargs[self.variable_name]
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
f"got {value}"
|
||||
)
|
||||
for v in value:
|
||||
if not isinstance(v, BaseMessage):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages,"
|
||||
f" got {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||
)
|
||||
"""Type variable for message prompt templates."""
|
||||
|
||||
|
||||
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""Base class for message prompt templates that use a string prompt template."""
|
||||
|
||||
prompt: StringPromptTemplate
|
||||
"""String prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
template: a template.
|
||||
template_format: format of the template.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(template, template_format=template_format)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: List[str],
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
template_file: path to a template file. String or Path.
|
||||
input_variables: list of input variables.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Chat message prompt template."""
|
||||
|
||||
role: str
|
||||
"""Role of the message."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return ChatMessage(
|
||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message sent from the AI."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""System message prompt template.
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value.
|
||||
|
||||
A type of a prompt value that is built from messages.
|
||||
"""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
constructor.
|
||||
"""
|
||||
return {"input_variables": self.input_variables}
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""
|
||||
Format prompt. Should return a PromptValue.
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
PromptValue.
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""A prompt template for chat models.
|
||||
|
||||
Use to create flexible templated prompts for chat models.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful AI bot. Your name is {name}."),
|
||||
("human", "Hello, how are you doing?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "{user_input}"),
|
||||
])
|
||||
|
||||
messages = template.format_messages(
|
||||
name="Bob",
|
||||
user_input="What is your name?"
|
||||
)
|
||||
"""
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables in template messages. Used for validation."""
|
||||
messages: List[MessageLike]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
# Allow for easy combining
|
||||
if isinstance(other, ChatPromptTemplate):
|
||||
return ChatPromptTemplate(messages=self.messages + other.messages)
|
||||
elif isinstance(
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, (list, tuple)):
|
||||
_other = ChatPromptTemplate.from_messages(other)
|
||||
return ChatPromptTemplate(messages=self.messages + _other.messages)
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
"""Validate input variables.
|
||||
|
||||
If input_variables is not set, it will be set to the union of
|
||||
all input variables in the messages.
|
||||
|
||||
Args:
|
||||
values: values to validate.
|
||||
|
||||
Returns:
|
||||
Validated values.
|
||||
"""
|
||||
messages = values["messages"]
|
||||
input_vars = set()
|
||||
input_types: Dict[str, Any] = values.get("input_types", {})
|
||||
for message in messages:
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
input_vars.update(message.input_variables)
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
if message.variable_name not in input_types:
|
||||
input_types[message.variable_name] = List[AnyMessage]
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
if "input_variables" in values and values.get("validate_template"):
|
||||
if input_vars != set(values["input_variables"]):
|
||||
raise ValueError(
|
||||
"Got mismatched input_variables. "
|
||||
f"Expected: {input_vars}. "
|
||||
f"Got: {values['input_variables']}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(input_vars)
|
||||
values["input_types"] = input_types
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a template string.
|
||||
|
||||
Creates a chat template consisting of a single message assumed to be from
|
||||
the human.
|
||||
|
||||
Args:
|
||||
template: template string
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt_template = PromptTemplate.from_template(template, **kwargs)
|
||||
message = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
return cls.from_messages([message])
|
||||
|
||||
@classmethod
|
||||
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
|
||||
def from_role_strings(
|
||||
cls, string_messages: List[Tuple[str, str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role, template) tuples.
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
return cls(
|
||||
messages=[
|
||||
ChatMessagePromptTemplate.from_template(template, role=role)
|
||||
for role, template in string_messages
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
|
||||
def from_strings(
|
||||
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role class, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role class, template) tuples.
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
return cls.from_messages(string_messages)
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
Examples:
|
||||
|
||||
Instantiation from a list of message templates:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "That's good to hear."),
|
||||
])
|
||||
|
||||
Instantiation from mixed message formats:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content="hello"),
|
||||
("human", "Hello, how are you?"),
|
||||
])
|
||||
|
||||
Args:
|
||||
messages: sequence of message representations.
|
||||
A message can be represented using the following formats:
|
||||
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
|
||||
(message type, template); e.g., ("human", "{user_input}"),
|
||||
(4) 2-tuple of (message class, template), (4) a string which is
|
||||
shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
_messages = [_convert_to_message(message) for message in messages]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
for _message in _messages:
|
||||
if isinstance(
|
||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(input_variables=sorted(input_vars), messages=_messages)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the chat template into a list of finalized messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
list of formatted messages
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
result = []
|
||||
for message_template in self.messages:
|
||||
if isinstance(message_template, BaseMessage):
|
||||
result.extend([message_template])
|
||||
elif isinstance(
|
||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||
):
|
||||
rel_params = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in message_template.input_variables
|
||||
}
|
||||
message = message_template.format_messages(**rel_params)
|
||||
result.extend(message)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables. Ought
|
||||
to be a subset of the input variables.
|
||||
|
||||
Returns:
|
||||
A new ChatPromptTemplate.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
("human", "Hi I'm {user}"),
|
||||
("ai", "Hi there, {user}, I'm {name}."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
template2 = template.partial(user="Lucy", name="R2D2")
|
||||
|
||||
template2.format_messages(input="hello")
|
||||
"""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def append(self, message: MessageLikeRepresentation) -> None:
|
||||
"""Append message to the end of the chat template.
|
||||
|
||||
Args:
|
||||
message: representation of a message to append.
|
||||
"""
|
||||
self.messages.append(_convert_to_message(message))
|
||||
|
||||
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
||||
"""Extend the chat template with a sequence of messages."""
|
||||
self.messages.extend([_convert_to_message(message) for message in messages])
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> MessageLike:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> ChatPromptTemplate:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, slice]
|
||||
) -> Union[MessageLike, ChatPromptTemplate]:
|
||||
"""Use to index into the chat template."""
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = index.indices(len(self.messages))
|
||||
messages = self.messages[start:stop:step]
|
||||
return ChatPromptTemplate.from_messages(messages)
|
||||
else:
|
||||
return self.messages[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the chat template."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Name of prompt type."""
|
||||
return "chat"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save prompt to file.
|
||||
|
||||
Args:
|
||||
file_path: path to file.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: str
|
||||
) -> BaseMessagePromptTemplate:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
||||
template: str the template string.
|
||||
|
||||
Returns:
|
||||
a message prompt template of the appropriate type.
|
||||
"""
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
|
||||
template
|
||||
)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessagePromptTemplate.from_template(template)
|
||||
elif message_type == "system":
|
||||
message = SystemMessagePromptTemplate.from_template(template)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system'."
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- 2-tuple of (message class, template)
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
_message: Union[
|
||||
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
|
||||
] = message
|
||||
elif isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_template_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
if isinstance(message_type_str, str):
|
||||
_message = _create_template_from_message_type(message_type_str, template)
|
||||
else:
|
||||
_message = message_type_str(prompt=PromptTemplate.from_template(template))
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
@ -0,0 +1,14 @@
|
||||
"""Logic for selecting examples to include in prompts."""
|
||||
from langchain_core.prompts.example_selector.length_based import (
|
||||
LengthBasedExampleSelector,
|
||||
)
|
||||
from langchain_core.prompts.example_selector.semantic_similarity import (
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
]
|
@ -0,0 +1,15 @@
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
"""Add new example to store for a key."""
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the inputs."""
|
@ -0,0 +1,63 @@
|
||||
"""Select examples based on length."""
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
|
||||
|
||||
def _get_length_based(text: str) -> int:
|
||||
return len(re.split("\n| ", text))
|
||||
|
||||
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""Prompt template used to format the examples."""
|
||||
|
||||
get_text_length: Callable[[str], int] = _get_length_based
|
||||
"""Function to measure prompt length. Defaults to word count."""
|
||||
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = [] #: :meta private:
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
self.examples.append(example)
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
@validator("example_text_lengths", always=True)
|
||||
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
|
||||
"""Calculate text lengths if they don't exist."""
|
||||
# Check if text lengths were passed in
|
||||
if v:
|
||||
return v
|
||||
# If they were not, calculate them
|
||||
example_prompt = values["example_prompt"]
|
||||
get_text_length = values["get_text_length"]
|
||||
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
|
||||
return [get_text_length(eg) for eg in string_examples]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the input lengths."""
|
||||
inputs = " ".join(input_variables.values())
|
||||
remaining_length = self.max_length - self.get_text_length(inputs)
|
||||
i = 0
|
||||
examples = []
|
||||
while remaining_length > 0 and i < len(self.examples):
|
||||
new_length = remaining_length - self.example_text_lengths[i]
|
||||
if new_length < 0:
|
||||
break
|
||||
else:
|
||||
examples.append(self.examples[i])
|
||||
remaining_length = new_length
|
||||
i += 1
|
||||
return examples
|
@ -0,0 +1,165 @@
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
from langchain_core.schema.embeddings import Embeddings
|
||||
from langchain_core.schema.vectorstore import VectorStore
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
"""Return a list of values in dict sorted by key."""
|
||||
return [values[val] for val in sorted(values)]
|
||||
|
||||
|
||||
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""VectorStore than contains information about examples."""
|
||||
k: int = 4
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
"""Add new example to vectorstore."""
|
||||
if self.input_keys:
|
||||
string_example = " ".join(
|
||||
sorted_values({key: example[key] for key in self.input_keys})
|
||||
)
|
||||
else:
|
||||
string_example = " ".join(sorted_values(example))
|
||||
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||
return ids[0]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in example_docs]
|
||||
# If example keys are provided, filter examples to those keys.
|
||||
if self.example_keys:
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on query similarity.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
|
||||
|
||||
|
||||
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
"""ExampleSelector that selects examples based on Max Marginal Relevance.
|
||||
|
||||
This was shown to improve performance in this paper:
|
||||
https://arxiv.org/pdf/2211.13892.pdf
|
||||
"""
|
||||
|
||||
fetch_k: int = 20
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, k=self.k, fetch_k=self.fetch_k
|
||||
)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in example_docs]
|
||||
# If example keys are provided, filter examples to those keys.
|
||||
if self.example_keys:
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on query similarity.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)
|
@ -0,0 +1,343 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to be passed to the example selector.
|
||||
|
||||
Returns:
|
||||
List of examples.
|
||||
"""
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
|
||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
suffix: str
|
||||
"""A prompt template string to put after the examples."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prompt template string to put before the examples."""
|
||||
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
example_strings = [
|
||||
self.example_prompt.format(**example) for example in examples
|
||||
]
|
||||
# Create the overall template.
|
||||
pieces = [self.prefix, *example_strings, self.suffix]
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().save(file_path)
|
||||
|
||||
|
||||
class FewShotChatMessagePromptTemplate(
|
||||
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
|
||||
):
|
||||
"""Chat prompt template that supports few-shot examples.
|
||||
|
||||
The high level structure of produced by this prompt template is a list of messages
|
||||
consisting of prefix message(s), example message(s), and suffix message(s).
|
||||
|
||||
This structure enables creating a conversation with intermediate examples like:
|
||||
|
||||
System: You are a helpful AI Assistant
|
||||
Human: What is 2+2?
|
||||
AI: 4
|
||||
Human: What is 2+3?
|
||||
AI: 5
|
||||
Human: What is 4+4?
|
||||
|
||||
This prompt template can be used to generate a fixed list of examples or else
|
||||
to dynamically select examples based on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
Prompt template with a fixed list of examples (matching the sample
|
||||
conversation above):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
ChatPromptTemplate
|
||||
)
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
|
||||
example_prompt = ChatPromptTemplate.from_messages(
|
||||
[('human', '{input}'), ('ai', '{output}')]
|
||||
)
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
examples=examples,
|
||||
# This is a prompt template used to format each individual example.
|
||||
example_prompt=example_prompt,
|
||||
)
|
||||
|
||||
final_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
('system', 'You are a helpful AI Assistant'),
|
||||
few_shot_prompt,
|
||||
('human', '{input}'),
|
||||
]
|
||||
)
|
||||
final_prompt.format(input="What is 4+4?")
|
||||
|
||||
Prompt template with dynamically selected examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import SemanticSimilarityExampleSelector
|
||||
from langchain_core.embeddings import OpenAIEmbeddings
|
||||
from langchain_core.vectorstores import Chroma
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
{"input": "2+4", "output": "6"},
|
||||
# ...
|
||||
]
|
||||
|
||||
to_vectorize = [
|
||||
" ".join(example.values())
|
||||
for example in examples
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Chroma.from_texts(
|
||||
to_vectorize, embeddings, metadatas=examples
|
||||
)
|
||||
example_selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vectorstore
|
||||
)
|
||||
|
||||
from langchain_core.schema import SystemMessage
|
||||
from langchain_core.prompts import HumanMessagePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
# Which variable(s) will be passed to the example selector.
|
||||
input_variables=["input"],
|
||||
example_selector=example_selector,
|
||||
# Define how each example will be formatted.
|
||||
# In this case, each example will become 2 messages:
|
||||
# 1 human, and 1 AI
|
||||
example_prompt=(
|
||||
HumanMessagePromptTemplate.from_template("{input}")
|
||||
+ AIMessagePromptTemplate.from_template("{output}")
|
||||
),
|
||||
)
|
||||
# Define the overall prompt.
|
||||
final_prompt = (
|
||||
SystemMessagePromptTemplate.from_template(
|
||||
"You are a helpful AI Assistant"
|
||||
)
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
# Show the prompt
|
||||
print(final_prompt.format_messages(input="What's 3+3?"))
|
||||
|
||||
# Use within an LLM
|
||||
from langchain_core.chat_models import ChatAnthropic
|
||||
chain = final_prompt | ChatAnthropic()
|
||||
chain.invoke({"input": "What's 3+3?"})
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template will use
|
||||
to pass to the example_selector, if provided."""
|
||||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
||||
"""The class to format each example."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in templates in messages.
|
||||
|
||||
Returns:
|
||||
A list of formatted messages with all template variables filled in.
|
||||
"""
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
messages = [
|
||||
message
|
||||
for example in examples
|
||||
for message in self.example_prompt.format_messages(**example)
|
||||
]
|
||||
return messages
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
|
||||
Use this method to generate a string representation of a prompt consisting
|
||||
of chat messages.
|
||||
|
||||
Useful for feeding into a string based completion language model or debugging.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
A string representation of the prompt
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return get_buffer_string(messages)
|
@ -0,0 +1,153 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
|
||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
suffix: StringPromptTemplate
|
||||
"""A PromptTemplate to put after the examples."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: Optional[StringPromptTemplate] = None
|
||||
"""A PromptTemplate to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
expected_input_variables |= set(values["partial_variables"])
|
||||
if values["prefix"] is not None:
|
||||
expected_input_variables |= set(values["prefix"].input_variables)
|
||||
missing_vars = expected_input_variables.difference(input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Got input_variables={input_variables}, but based on "
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
# Format the examples.
|
||||
example_strings = [
|
||||
self.example_prompt.format(**example) for example in examples
|
||||
]
|
||||
# Create the overall prefix.
|
||||
if self.prefix is None:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||
}
|
||||
for k in prefix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
prefix = self.prefix.format(**prefix_kwargs)
|
||||
|
||||
# Create the overall suffix
|
||||
suffix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||
}
|
||||
for k in suffix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
suffix = self.suffix.format(
|
||||
**suffix_kwargs,
|
||||
)
|
||||
|
||||
pieces = [prefix, *example_strings, suffix]
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot_with_templates"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().save(file_path)
|
@ -0,0 +1,162 @@
|
||||
"""Load prompts."""
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.schema import (
|
||||
BasePromptTemplate,
|
||||
StrOutputParser,
|
||||
)
|
||||
from langchain_core.utils.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
"""Load prompt from Config Dict."""
|
||||
if "_type" not in config:
|
||||
logger.warning("No `_type` key found, defaulting to `prompt`.")
|
||||
config_type = config.pop("_type", "prompt")
|
||||
|
||||
if config_type not in type_to_loader_dict:
|
||||
raise ValueError(f"Loading {config_type} prompt not supported")
|
||||
|
||||
prompt_loader = type_to_loader_dict[config_type]
|
||||
return prompt_loader(config)
|
||||
|
||||
|
||||
def _load_template(var_name: str, config: dict) -> dict:
|
||||
"""Load template from the path if applicable."""
|
||||
# Check if template_path exists in config.
|
||||
if f"{var_name}_path" in config:
|
||||
# If it does, make sure template variable doesn't also exist.
|
||||
if var_name in config:
|
||||
raise ValueError(
|
||||
f"Both `{var_name}_path` and `{var_name}` cannot be provided."
|
||||
)
|
||||
# Pop the template path from the config.
|
||||
template_path = Path(config.pop(f"{var_name}_path"))
|
||||
# Load the template.
|
||||
if template_path.suffix == ".txt":
|
||||
with open(template_path) as f:
|
||||
template = f.read()
|
||||
else:
|
||||
raise ValueError
|
||||
# Set the template variable to the extracted variable.
|
||||
config[var_name] = template
|
||||
return config
|
||||
|
||||
|
||||
def _load_examples(config: dict) -> dict:
|
||||
"""Load examples if necessary."""
|
||||
if isinstance(config["examples"], list):
|
||||
pass
|
||||
elif isinstance(config["examples"], str):
|
||||
with open(config["examples"]) as f:
|
||||
if config["examples"].endswith(".json"):
|
||||
examples = json.load(f)
|
||||
elif config["examples"].endswith((".yaml", ".yml")):
|
||||
examples = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid file format. Only json or yaml formats are supported."
|
||||
)
|
||||
config["examples"] = examples
|
||||
else:
|
||||
raise ValueError("Invalid examples format. Only list or string are supported.")
|
||||
return config
|
||||
|
||||
|
||||
def _load_output_parser(config: dict) -> dict:
|
||||
"""Load output parser."""
|
||||
if "output_parser" in config and config["output_parser"]:
|
||||
_config = config.pop("output_parser")
|
||||
output_parser_type = _config.pop("_type")
|
||||
if output_parser_type == "default":
|
||||
output_parser = StrOutputParser(**_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output parser {output_parser_type}")
|
||||
config["output_parser"] = output_parser
|
||||
return config
|
||||
|
||||
|
||||
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
|
||||
"""Load the "few shot" prompt from the config."""
|
||||
# Load the suffix and prefix templates.
|
||||
config = _load_template("suffix", config)
|
||||
config = _load_template("prefix", config)
|
||||
# Load the example prompt.
|
||||
if "example_prompt_path" in config:
|
||||
if "example_prompt" in config:
|
||||
raise ValueError(
|
||||
"Only one of example_prompt and example_prompt_path should "
|
||||
"be specified."
|
||||
)
|
||||
config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
|
||||
else:
|
||||
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
|
||||
# Load the examples.
|
||||
config = _load_examples(config)
|
||||
config = _load_output_parser(config)
|
||||
return FewShotPromptTemplate(**config)
|
||||
|
||||
|
||||
def _load_prompt(config: dict) -> PromptTemplate:
|
||||
"""Load the prompt template from config."""
|
||||
# Load the template from disk if necessary.
|
||||
config = _load_template("template", config)
|
||||
config = _load_output_parser(config)
|
||||
|
||||
template_format = config.get("template_format", "f-string")
|
||||
if template_format == "jinja2":
|
||||
# Disabled due to:
|
||||
# https://github.com/langchain-ai/langchain/issues/4394
|
||||
raise ValueError(
|
||||
f"Loading templates with '{template_format}' format is no longer supported "
|
||||
f"since it can lead to arbitrary code execution. Please migrate to using "
|
||||
f"the 'f-string' template format, which does not suffer from this issue."
|
||||
)
|
||||
|
||||
return PromptTemplate(**config)
|
||||
|
||||
|
||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_prompt_from_file(path)
|
||||
|
||||
|
||||
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix == ".yaml":
|
||||
with open(file_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||
# Load the prompt from the config now.
|
||||
return load_prompt_from_config(config)
|
||||
|
||||
|
||||
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
"prompt": _load_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.schema import BasePromptTemplate, PromptValue
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
return {k: inputs[k] for k in input_variables}
|
||||
|
||||
|
||||
class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""A prompt template for composing multiple prompt templates together.
|
||||
|
||||
This can be useful when you want to reuse parts of prompts.
|
||||
A PipelinePrompt consists of two main parts:
|
||||
- final_prompt: This is the final prompt that is returned
|
||||
- pipeline_prompts: This is a list of tuples, consisting
|
||||
of a string (`name`) and a Prompt Template.
|
||||
Each PromptTemplate will be formatted and then passed
|
||||
to future prompt templates as a variable with
|
||||
the same name as `name`
|
||||
"""
|
||||
|
||||
final_prompt: BasePromptTemplate
|
||||
"""The final prompt that is returned."""
|
||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
for k, prompt in values["pipeline_prompts"]:
|
||||
created_variables.add(k)
|
||||
all_variables.update(prompt.input_variables)
|
||||
values["input_variables"] = list(all_variables.difference(created_variables))
|
||||
return values
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
for k, prompt in self.pipeline_prompts:
|
||||
_inputs = _get_inputs(kwargs, prompt.input_variables)
|
||||
if isinstance(prompt, BaseChatPromptTemplate):
|
||||
kwargs[k] = prompt.format_messages(**_inputs)
|
||||
else:
|
||||
kwargs[k] = prompt.format(**_inputs)
|
||||
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||
return self.final_prompt.format_prompt(**_inputs)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
@ -0,0 +1,250 @@
|
||||
"""Prompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class PromptTemplate(StringPromptTemplate):
|
||||
"""A prompt template for a language model.
|
||||
|
||||
A prompt template consists of a string template. It accepts a set of parameters
|
||||
from the user that can be used to generate a prompt for a language model.
|
||||
|
||||
The template can be formatted using either f-strings (default) or jinja2 syntax.
|
||||
|
||||
*Security warning*: Prefer using `template_format="f-string"` instead of
|
||||
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
|
||||
from untrusted sources as they may lead to arbitrary Python code execution.
|
||||
|
||||
As of LangChain 0.0.329, Jinja2 templates will be rendered using
|
||||
Jinja2's SandboxedEnvironment by default. This sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security,
|
||||
as it is an opt-out rather than opt-in approach.
|
||||
|
||||
Despite the sand-boxing, we recommend to never use jinja2 templates
|
||||
from untrusted sources.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
# Instantiation using from_template (recommended)
|
||||
prompt = PromptTemplate.from_template("Say {foo}")
|
||||
prompt.format(foo="bar")
|
||||
|
||||
# Instantiation using initializer
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"template_format": self.template_format,
|
||||
}
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> PromptTemplate:
|
||||
"""Override the + operator to allow for combining prompt templates."""
|
||||
# Allow for easy combining
|
||||
if isinstance(other, PromptTemplate):
|
||||
if self.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
if other.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
input_variables = list(
|
||||
set(self.input_variables) | set(other.input_variables)
|
||||
)
|
||||
template = self.template + other.template
|
||||
# If any do not want to validate, then don't
|
||||
validate_template = self.validate_template and other.validate_template
|
||||
partial_variables = {k: v for k, v in self.partial_variables.items()}
|
||||
for k, v in other.partial_variables.items():
|
||||
if k in partial_variables:
|
||||
raise ValueError("Cannot have same variable partialed twice.")
|
||||
else:
|
||||
partial_variables[k] = v
|
||||
return PromptTemplate(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
partial_variables=partial_variables,
|
||||
template_format="f-string",
|
||||
validate_template=validate_template,
|
||||
)
|
||||
elif isinstance(other, str):
|
||||
prompt = PromptTemplate.from_template(other)
|
||||
return self + prompt
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[str],
|
||||
suffix: str,
|
||||
input_variables: List[str],
|
||||
example_separator: str = "\n\n",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
"""Take examples in list format with prefix and suffix to create a prompt.
|
||||
|
||||
Intended to be used as a way to dynamically create a prompt from examples.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
suffix: String to go after the list of examples. Should generally
|
||||
set up the user's input.
|
||||
input_variables: A list of variable names the final prompt template
|
||||
will expect.
|
||||
example_separator: The separator to use in between examples. Defaults
|
||||
to two new line characters.
|
||||
prefix: String that should go before any examples. Generally includes
|
||||
examples. Default to an empty string.
|
||||
|
||||
Returns:
|
||||
The final prompt generated.
|
||||
"""
|
||||
template = example_separator.join([prefix, *examples, suffix])
|
||||
return cls(input_variables=input_variables, template=template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
|
||||
) -> PromptTemplate:
|
||||
"""Load a prompt from a file.
|
||||
|
||||
Args:
|
||||
template_file: The path to the file containing the prompt template.
|
||||
input_variables: A list of variable names the final prompt template
|
||||
will expect.
|
||||
|
||||
Returns:
|
||||
The prompt loaded from the file.
|
||||
"""
|
||||
with open(str(template_file), "r") as f:
|
||||
template = f.read()
|
||||
return cls(input_variables=input_variables, template=template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls,
|
||||
template: str,
|
||||
*,
|
||||
template_format: str = "f-string",
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
"""Load a prompt template from a template.
|
||||
|
||||
*Security warning*: Prefer using `template_format="f-string"` instead of
|
||||
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
|
||||
from untrusted sources as they may lead to arbitrary Python code execution.
|
||||
|
||||
As of LangChain 0.0.329, Jinja2 templates will be rendered using
|
||||
Jinja2's SandboxedEnvironment by default. This sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security,
|
||||
as it is an opt-out rather than opt-in approach.
|
||||
|
||||
Despite the sand-boxing, we recommend to never use jinja2 templates
|
||||
from untrusted sources.
|
||||
|
||||
Args:
|
||||
template: The template to load.
|
||||
template_format: The format of the template. Use `jinja2` for jinja2,
|
||||
and `f-string` or None for f-strings.
|
||||
partial_variables: A dictionary of variables that can be used to partially
|
||||
fill in the template. For example, if the template is
|
||||
`"{variable1} {variable2}"`, and `partial_variables` is
|
||||
`{"variable1": "foo"}`, then the final prompt will be
|
||||
`"foo {variable2}"`.
|
||||
|
||||
Returns:
|
||||
The prompt template loaded from the template.
|
||||
"""
|
||||
|
||||
input_variables = get_template_variables(template, template_format)
|
||||
_partial_variables = partial_variables or {}
|
||||
|
||||
if _partial_variables:
|
||||
input_variables = [
|
||||
var for var in input_variables if var not in _partial_variables
|
||||
]
|
||||
|
||||
return cls(
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format=template_format,
|
||||
partial_variables=_partial_variables,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# For backwards compatibility.
|
||||
Prompt = PromptTemplate
|
@ -0,0 +1,23 @@
|
||||
from importlib import metadata
|
||||
|
||||
## Create namespaces for pydantic v1 and v2.
|
||||
# This code must stay at the top of the file before other modules may
|
||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
||||
#
|
||||
# This hack is done for the following reasons:
|
||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
||||
# unambiguously uses either v1 or v2 API.
|
||||
# * This change is easier to roll out and roll back.
|
||||
|
||||
try:
|
||||
from pydantic.v1 import * # noqa: F403 # type: ignore
|
||||
except ImportError:
|
||||
from pydantic import * # noqa: F403 # type: ignore
|
||||
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.dataclasses import * # noqa: F403
|
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.main import * # noqa: F403
|
@ -0,0 +1,57 @@
|
||||
"""LangChain **Runnable** and the **LangChain Expression Language (LCEL)**.
|
||||
|
||||
The LangChain Expression Language (LCEL) offers a declarative method to build
|
||||
production-grade programs that harness the power of LLMs.
|
||||
|
||||
Programs created using LCEL and LangChain Runnables inherently support
|
||||
synchronous, asynchronous, batch, and streaming operations.
|
||||
|
||||
Support for **async** allows servers hosting LCEL based programs to scale better
|
||||
for higher concurrent loads.
|
||||
|
||||
**Streaming** of intermediate outputs as they're being generated allows for
|
||||
creating more responsive UX.
|
||||
|
||||
This module contains schema and implementation of LangChain Runnables primitives.
|
||||
"""
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnableParallel,
|
||||
RunnableSequence,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain_core.runnables.branch import RunnableBranch
|
||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
"RunnableGenerator",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnablePassthrough",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,254 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableLike,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
|
||||
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
"""A Runnable that selects which branch to run based on a condition.
|
||||
|
||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||
a default branch.
|
||||
|
||||
When operating on an input, the first condition that evaluates to True is
|
||||
selected, and the corresponding runnable is run on the input.
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
|
||||
branch = RunnableBranch(
|
||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||
lambda x: "goodbye",
|
||||
)
|
||||
|
||||
branch.invoke("hello") # "HELLO"
|
||||
branch.invoke(None) # "goodbye"
|
||||
"""
|
||||
|
||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
default: Runnable[Input, Output]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*branches: Union[
|
||||
Tuple[
|
||||
Union[
|
||||
Runnable[Input, bool],
|
||||
Callable[[Input], bool],
|
||||
Callable[[Input], Awaitable[bool]],
|
||||
],
|
||||
RunnableLike,
|
||||
],
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
default = branches[-1]
|
||||
|
||||
if not isinstance(
|
||||
default,
|
||||
(Runnable, Callable, Mapping), # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||
)
|
||||
|
||||
_branches = []
|
||||
|
||||
for branch in branches[:-1]:
|
||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||
raise TypeError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists, not {type(branch)}"
|
||||
)
|
||||
|
||||
if not len(branch) == 2:
|
||||
raise ValueError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists of length 2, not {len(branch)}"
|
||||
)
|
||||
condition, runnable = branch
|
||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if runnable.get_input_schema(config).schema().get("type") is not None:
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = condition.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = self.default.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Async version of invoke."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = await condition.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = await self.default.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
@ -0,0 +1,401 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.runnables.utils import (
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
else:
|
||||
# Pydantic validates through typed dicts, but
|
||||
# the callbacks need forward refs updated
|
||||
Callbacks = Optional[Union[List, Any]]
|
||||
|
||||
|
||||
class EmptyDict(TypedDict, total=False):
|
||||
"""Empty dict type."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
|
||||
tags: List[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
run_name: str
|
||||
"""
|
||||
Name for the tracer run for this call. Defaults to the name of the class.
|
||||
"""
|
||||
|
||||
max_concurrency: Optional[int]
|
||||
"""
|
||||
Maximum number of parallel calls to make. If not provided, defaults to
|
||||
ThreadPoolExecutor's default.
|
||||
"""
|
||||
|
||||
recursion_limit: int
|
||||
"""
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 25.
|
||||
"""
|
||||
|
||||
configurable: Dict[str, Any]
|
||||
"""
|
||||
Runtime values for attributes previously made configurable on this Runnable,
|
||||
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
|
||||
Check .output_schema() for a description of the attributes that have been made
|
||||
configurable.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
"""Ensure that a config is a dict with all keys present.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig], optional): The config to ensure.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The ensured config.
|
||||
"""
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||
)
|
||||
return empty
|
||||
|
||||
|
||||
def get_config_list(
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""Get a list of configs from a single config or a list of configs.
|
||||
|
||||
It is useful for subclasses overriding batch() or abatch().
|
||||
|
||||
Args:
|
||||
config (Optional[Union[RunnableConfig, List[RunnableConfig]]]):
|
||||
The config or list of configs.
|
||||
length (int): The length of the list.
|
||||
|
||||
Returns:
|
||||
List[RunnableConfig]: The list of configs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of the list is not equal to the length of the inputs.
|
||||
|
||||
"""
|
||||
if length < 0:
|
||||
raise ValueError(f"length must be >= 0, but got {length}")
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
list(map(ensure_config, config))
|
||||
if isinstance(config, list)
|
||||
else [ensure_config(config) for _ in range(length)]
|
||||
)
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: Optional[RunnableConfig],
|
||||
*,
|
||||
callbacks: Optional[BaseCallbackManager] = None,
|
||||
recursion_limit: Optional[int] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
configurable: Optional[Dict[str, Any]] = None,
|
||||
) -> RunnableConfig:
|
||||
"""Patch a config with new values.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]): The config to patch.
|
||||
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
|
||||
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
||||
Defaults to None.
|
||||
recursion_limit (Optional[int], optional): The recursion limit to set.
|
||||
Defaults to None.
|
||||
max_concurrency (Optional[int], optional): The max concurrency to set.
|
||||
Defaults to None.
|
||||
run_name (Optional[str], optional): The run name to set. Defaults to None.
|
||||
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The patched config.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
if callbacks is not None:
|
||||
# If we're replacing callbacks, we need to unset run_name
|
||||
# As that should apply only to the same run as the original callbacks
|
||||
config["callbacks"] = callbacks
|
||||
if "run_name" in config:
|
||||
del config["run_name"]
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
if max_concurrency is not None:
|
||||
config["max_concurrency"] = max_concurrency
|
||||
if run_name is not None:
|
||||
config["run_name"] = run_name
|
||||
if configurable is not None:
|
||||
config["configurable"] = {**config.get("configurable", {}), **configurable}
|
||||
return config
|
||||
|
||||
|
||||
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
"""Merge multiple configs into one.
|
||||
|
||||
Args:
|
||||
*configs (Optional[RunnableConfig]): The configs to merge.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged config.
|
||||
"""
|
||||
base: RunnableConfig = {}
|
||||
# Even though the keys aren't literals, this is correct
|
||||
# because both dicts are the same type
|
||||
for config in (c for c in configs if c is not None):
|
||||
for key in config:
|
||||
if key == "metadata":
|
||||
base[key] = { # type: ignore
|
||||
**base.get(key, {}), # type: ignore
|
||||
**(config.get(key) or {}), # type: ignore
|
||||
}
|
||||
elif key == "tags":
|
||||
base[key] = list( # type: ignore
|
||||
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
|
||||
)
|
||||
elif key == "configurable":
|
||||
base[key] = { # type: ignore
|
||||
**base.get(key, {}), # type: ignore
|
||||
**(config.get(key) or {}), # type: ignore
|
||||
}
|
||||
elif key == "callbacks":
|
||||
base_callbacks = base.get("callbacks")
|
||||
these_callbacks = config["callbacks"]
|
||||
# callbacks can be either None, list[handler] or manager
|
||||
# so merging two callbacks values has 6 cases
|
||||
if isinstance(these_callbacks, list):
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
elif isinstance(base_callbacks, list):
|
||||
base["callbacks"] = base_callbacks + these_callbacks
|
||||
else:
|
||||
# base_callbacks is a manager
|
||||
mngr = base_callbacks.copy()
|
||||
for callback in these_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
elif these_callbacks is not None:
|
||||
# these_callbacks is a manager
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
elif isinstance(base_callbacks, list):
|
||||
mngr = these_callbacks.copy()
|
||||
for callback in base_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
base["callbacks"] = base_callbacks.__class__(
|
||||
parent_run_id=base_callbacks.parent_run_id
|
||||
or these_callbacks.parent_run_id,
|
||||
handlers=base_callbacks.handlers + these_callbacks.handlers,
|
||||
inheritable_handlers=base_callbacks.inheritable_handlers
|
||||
+ these_callbacks.inheritable_handlers,
|
||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
||||
inheritable_tags=list(
|
||||
set(
|
||||
base_callbacks.inheritable_tags
|
||||
+ these_callbacks.inheritable_tags
|
||||
)
|
||||
),
|
||||
metadata={
|
||||
**base_callbacks.metadata,
|
||||
**these_callbacks.metadata,
|
||||
},
|
||||
)
|
||||
else:
|
||||
base[key] = config[key] or base.get(key) # type: ignore
|
||||
return base
|
||||
|
||||
|
||||
def call_func_with_variable_args(
|
||||
func: Union[
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input, RunnableConfig], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
if accepts_config(func):
|
||||
if run_manager is not None:
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
else:
|
||||
kwargs["config"] = config
|
||||
if run_manager is not None and accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_manager
|
||||
return func(input, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
async def acall_func_with_variable_args(
|
||||
func: Union[
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Callable[[Input, RunnableConfig], Awaitable[Output]],
|
||||
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||
Callable[
|
||||
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
||||
AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input,
|
||||
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
if accepts_config(func):
|
||||
if run_manager is not None:
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
else:
|
||||
kwargs["config"] = config
|
||||
if run_manager is not None and accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_manager
|
||||
return await func(input, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
"""Get a callback manager for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def get_async_callback_manager_for_config(
|
||||
config: RunnableConfig,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Get an async callback manager for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
||||
"""Get an executor for a config.
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config.
|
||||
|
||||
Yields:
|
||||
Generator[Executor, None, None]: The executor.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||
yield executor
|
@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
AnyConfigurableField,
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
gather_with_concurrency,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
"""A Serializable Runnable that can be dynamically configured."""
|
||||
|
||||
default: RunnableSerializable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.default.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.default.OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_output_schema(config)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
...
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return self._prepare(config).invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return await self._prepare(config).ainvoke(input, config, **kwargs)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.default for p in prepared):
|
||||
return self.default.batch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(
|
||||
bound: Runnable[Input, Output],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output], list(executor.map(invoke, prepared, inputs, configs))
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.default for p in prepared):
|
||||
return await self.default.abatch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
async def ainvoke(
|
||||
bound: Runnable[Input, Output],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, prepared, inputs, configs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).stream(input, config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).astream(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).transform(input, config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
"""A Runnable that can be dynamically configured."""
|
||||
|
||||
fields: Dict[str, AnyConfigurableField]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.default.__fields__[field_name].field_info.description,
|
||||
annotation=spec.annotation
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
spec, self.default.__fields__[field_name].field_info.description
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
+ list(self.default.config_specs)
|
||||
)
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable_fields = {
|
||||
specs_by_id[k][0]: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField)
|
||||
}
|
||||
configurable_single_options = {
|
||||
k: v.options[(config.get("configurable", {}).get(v.id) or v.default)]
|
||||
for k, v in self.fields.items()
|
||||
if isinstance(v, ConfigurableFieldSingleOption)
|
||||
}
|
||||
configurable_multi_options = {
|
||||
k: [
|
||||
v.options[o]
|
||||
for o in config.get("configurable", {}).get(v.id, v.default)
|
||||
]
|
||||
for k, v in self.fields.items()
|
||||
if isinstance(v, ConfigurableFieldMultiOption)
|
||||
}
|
||||
configurable = {
|
||||
**configurable_fields,
|
||||
**configurable_single_options,
|
||||
**configurable_multi_options,
|
||||
}
|
||||
|
||||
if configurable:
|
||||
return self.default.__class__(**{**self.default.__dict__, **configurable})
|
||||
else:
|
||||
return self.default
|
||||
|
||||
|
||||
# Before Python 3.11 native StrEnum is not available
|
||||
class StrEnum(str, enum.Enum):
|
||||
"""A string enum."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_enums_for_spec: WeakValueDictionary[
|
||||
Union[
|
||||
ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField
|
||||
],
|
||||
Type[StrEnum],
|
||||
] = WeakValueDictionary()
|
||||
|
||||
_enums_for_spec_lock = threading.Lock()
|
||||
|
||||
|
||||
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
"""A Runnable that can be dynamically configured."""
|
||||
|
||||
which: ConfigurableField
|
||||
|
||||
alternatives: Dict[
|
||||
str,
|
||||
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||
]
|
||||
|
||||
default_key: str = "default"
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
with _enums_for_spec_lock:
|
||||
if which_enum := _enums_for_spec.get(self.which):
|
||||
pass
|
||||
else:
|
||||
which_enum = StrEnum( # type: ignore[call-overload]
|
||||
self.which.name or self.which.id,
|
||||
(
|
||||
(v, v)
|
||||
for v in list(self.alternatives.keys()) + [self.default_key]
|
||||
),
|
||||
)
|
||||
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id=self.which.id,
|
||||
name=self.which.name,
|
||||
description=self.which.description,
|
||||
annotation=which_enum,
|
||||
default=self.default_key,
|
||||
),
|
||||
*self.default.config_specs,
|
||||
] + [
|
||||
s
|
||||
for alt in self.alternatives.values()
|
||||
if isinstance(alt, RunnableSerializable)
|
||||
for s in alt.config_specs
|
||||
]
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
return self.__class__(
|
||||
which=self.which,
|
||||
default=self.default.configurable_fields(**kwargs),
|
||||
alternatives=self.alternatives,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||
if which == self.default_key:
|
||||
return self.default
|
||||
elif which in self.alternatives:
|
||||
alt = self.alternatives[which]
|
||||
if isinstance(alt, Runnable):
|
||||
return alt
|
||||
else:
|
||||
return alt()
|
||||
else:
|
||||
raise ValueError(f"Unknown alternative: {which}")
|
||||
|
||||
|
||||
def make_options_spec(
|
||||
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
|
||||
description: Optional[str],
|
||||
) -> ConfigurableFieldSpec:
|
||||
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
|
||||
ConfigurableFieldMultiOption."""
|
||||
with _enums_for_spec_lock:
|
||||
if enum := _enums_for_spec.get(spec):
|
||||
pass
|
||||
else:
|
||||
enum = StrEnum( # type: ignore[call-overload]
|
||||
spec.name or spec.id,
|
||||
((v, v) for v in list(spec.options.keys())),
|
||||
)
|
||||
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
|
||||
if isinstance(spec, ConfigurableFieldSingleOption):
|
||||
return ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description or description,
|
||||
annotation=enum,
|
||||
default=spec.default,
|
||||
)
|
||||
else:
|
||||
return ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description or description,
|
||||
annotation=Sequence[enum], # type: ignore[valid-type]
|
||||
default=spec.default,
|
||||
)
|
@ -0,0 +1,344 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
get_config_list,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
|
||||
|
||||
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
"""A Runnable that can fallback to other Runnables if it fails.
|
||||
|
||||
External APIs (e.g., APIs for a language model) may at times experience
|
||||
degraded performance or even downtime.
|
||||
|
||||
In these cases, it can be useful to have a fallback runnable that can be
|
||||
used in place of the original runnable (e.g., fallback to another LLM provider).
|
||||
|
||||
Fallbacks can be defined at the level of a single runnable, or at the level
|
||||
of a chain of runnables. Fallbacks are tried in order until one succeeds or
|
||||
all fail.
|
||||
|
||||
While you can instantiate a ``RunnableWithFallbacks`` directly, it is usually
|
||||
more convenient to use the ``with_fallbacks`` method on a runnable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.chat_models.openai import ChatOpenAI
|
||||
from langchain_core.chat_models.anthropic import ChatAnthropic
|
||||
|
||||
model = ChatAnthropic().with_fallbacks([ChatOpenAI()])
|
||||
# Will usually use ChatAnthropic, but fallback to ChatOpenAI
|
||||
# if ChatAnthropic fails.
|
||||
model.invoke('hello')
|
||||
|
||||
# And you can also use fallbacks at the level of a chain.
|
||||
# Here if both LLM providers fail, we'll fallback to a good hardcoded
|
||||
# response.
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.schema.output_parser import StrOutputParser
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def when_all_is_lost(inputs):
|
||||
return ("Looks like our LLM providers are down. "
|
||||
"Here's a nice 🦜️ emoji for you instead.")
|
||||
|
||||
chain_with_fallback = (
|
||||
PromptTemplate.from_template('Tell me a joke about {topic}')
|
||||
| model
|
||||
| StrOutputParser()
|
||||
).with_fallbacks([RunnableLambda(when_all_is_lost)])
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
"""The runnable to run first."""
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
"""A sequence of fallbacks to try."""
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
"""The exceptions on which fallbacks should be tried.
|
||||
|
||||
Any exception that is not a subclass of these exceptions will be raised immediately.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in [self.runnable, *self.fallbacks]
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
yield self.runnable
|
||||
yield from self.fallbacks
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = runnable.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(output)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = await runnable.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(output)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||
raise first_error
|
@ -0,0 +1,288 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.load import load
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from langchain_core.schema.chat_history import BaseChatMessageHistory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||
|
||||
|
||||
class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
"""A runnable that manages chat message history for another runnable.
|
||||
|
||||
Base runnable must have inputs and outputs that can be converted to a list of
|
||||
BaseMessages.
|
||||
|
||||
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
|
||||
``{"configurable": {"session_id": "<SESSION_ID>"}}``
|
||||
|
||||
Example (dict input):
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.chat_models import ChatAnthropic
|
||||
from langchain_core.memory.chat_message_histories import RedisChatMessageHistory
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You're an assistant who's good at {ability}"),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", "{question}"),
|
||||
])
|
||||
|
||||
chain = prompt | ChatAnthropic(model="claude-2")
|
||||
|
||||
chain_with_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
RedisChatMessageHistory,
|
||||
input_messages_key="question",
|
||||
history_messages_key="history",
|
||||
)
|
||||
|
||||
chain_with_history.invoke(
|
||||
{"ability": "math", "question": "What does cosine mean?"},
|
||||
config={"configurable": {"session_id": "foo"}}
|
||||
)
|
||||
# -> "Cosine is ..."
|
||||
chain_with_history.invoke(
|
||||
{"ability": "math", "question": "What's its inverse"},
|
||||
config={"configurable": {"session_id": "foo"}}
|
||||
)
|
||||
# -> "The inverse of cosine is called arccosine ..."
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
get_session_history: GetSessionHistoryCallable
|
||||
input_messages_key: Optional[str] = None
|
||||
output_messages_key: Optional[str] = None
|
||||
history_messages_key: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Runnable[
|
||||
MessagesOrDictWithMessages,
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
],
|
||||
get_session_history: GetSessionHistoryCallable,
|
||||
*,
|
||||
input_messages_key: Optional[str] = None,
|
||||
output_messages_key: Optional[str] = None,
|
||||
history_messages_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize RunnableWithMessageHistory.
|
||||
|
||||
Args:
|
||||
runnable: The base Runnable to be wrapped.
|
||||
|
||||
Must take as input one of:
|
||||
- A sequence of BaseMessages
|
||||
- A dict with one key for all messages
|
||||
- A dict with one key for the current input string/message(s) and
|
||||
a separate key for historical messages. If the input key points
|
||||
to a string, it will be treated as a HumanMessage in history.
|
||||
|
||||
Must return as output one of:
|
||||
- A string which can be treated as an AIMessage
|
||||
- A BaseMessage or sequence of BaseMessages
|
||||
- A dict with a key for a BaseMessage or sequence of BaseMessages
|
||||
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory
|
||||
given a session id. Should take a single
|
||||
positional argument `session_id` which is a string and a named argument
|
||||
`user_id` which can be a string or None. e.g.:
|
||||
|
||||
```python
|
||||
def get_session_history(
|
||||
session_id: str,
|
||||
*,
|
||||
user_id: Optional[str]=None
|
||||
) -> BaseChatMessageHistory:
|
||||
...
|
||||
```
|
||||
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input.
|
||||
output_messages_key: Must be specified if the base runnable returns a dict
|
||||
as output.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
**kwargs: Arbitrary additional kwargs to pass to parent class
|
||||
``RunnableBindingBase`` init.
|
||||
""" # noqa: E501
|
||||
history_chain: Runnable = RunnableLambda(
|
||||
self._enter_history, self._aenter_history
|
||||
).with_config(run_name="load_history")
|
||||
messages_key = history_messages_key or input_messages_key
|
||||
if messages_key:
|
||||
history_chain = RunnablePassthrough.assign(
|
||||
**{messages_key: history_chain}
|
||||
).with_config(run_name="insert_history")
|
||||
bound = (
|
||||
history_chain | runnable.with_listeners(on_end=self._exit_history)
|
||||
).with_config(run_name="RunnableWithMessageHistory")
|
||||
super().__init__(
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key=input_messages_key,
|
||||
output_messages_key=output_messages_key,
|
||||
bound=bound,
|
||||
history_messages_key=history_messages_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
super().config_specs
|
||||
+ [
|
||||
ConfigurableFieldSpec(
|
||||
id="session_id",
|
||||
annotation=str,
|
||||
name="Session ID",
|
||||
description="Unique identifier for a session.",
|
||||
default="",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
super_schema = super().get_input_schema(config)
|
||||
if super_schema.__custom_root_type__ is not None:
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
if self.history_messages_key:
|
||||
fields[self.history_messages_key] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
)
|
||||
else:
|
||||
return super_schema
|
||||
|
||||
def _get_input_messages(
|
||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
if isinstance(input_val, str):
|
||||
from langchain_core.schema.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=input_val)]
|
||||
elif isinstance(input_val, BaseMessage):
|
||||
return [input_val]
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
return list(input_val)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||
f"Got {input_val}."
|
||||
)
|
||||
|
||||
def _get_output_messages(
|
||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
if isinstance(output_val, dict):
|
||||
output_val = output_val[self.output_messages_key or "output"]
|
||||
|
||||
if isinstance(output_val, str):
|
||||
from langchain_core.schema.messages import AIMessage
|
||||
|
||||
return [AIMessage(content=output_val)]
|
||||
elif isinstance(output_val, BaseMessage):
|
||||
return [output_val]
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
return list(output_val)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
||||
hist = config["configurable"]["message_history"]
|
||||
# return only historic messages
|
||||
if self.history_messages_key:
|
||||
return hist.messages.copy()
|
||||
# return all messages
|
||||
else:
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
)
|
||||
return hist.messages.copy() + self._get_input_messages(input_val)
|
||||
|
||||
async def _aenter_history(
|
||||
self, input: Dict[str, Any], config: RunnableConfig
|
||||
) -> List[BaseMessage]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._enter_history, input, config
|
||||
)
|
||||
|
||||
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
|
||||
hist = config["configurable"]["message_history"]
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_val = inputs[self.input_messages_key or "input"]
|
||||
input_messages = self._get_input_messages(input_val)
|
||||
|
||||
# Get the output messages
|
||||
output_val = load(run.outputs)
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
|
||||
for m in input_messages + output_messages:
|
||||
hist.add_message(m)
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = super()._merge_configs(*configs)
|
||||
# extract session_id
|
||||
if "session_id" not in config.get("configurable", {}):
|
||||
example_input = {self.input_messages_key: "foo"}
|
||||
example_config = {"configurable": {"session_id": "123"}}
|
||||
raise ValueError(
|
||||
"session_id_id is required."
|
||||
" Pass it in as part of the config argument to .invoke() or .stream()"
|
||||
f"\neg. chain.invoke({example_input}, {example_config})"
|
||||
)
|
||||
# attach message_history
|
||||
session_id = config["configurable"]["session_id"]
|
||||
config["configurable"]["message_history"] = self.get_session_history(session_id)
|
||||
return config
|
@ -0,0 +1,453 @@
|
||||
"""Implementation of the RunnablePassthrough."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
RunnableParallel,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
|
||||
|
||||
def identity(x: Other) -> Other:
|
||||
"""An identity function"""
|
||||
return x
|
||||
|
||||
|
||||
async def aidentity(x: Other) -> Other:
|
||||
"""An async identity function"""
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""A runnable to passthrough inputs unchanged or with additional keys.
|
||||
|
||||
This runnable behaves almost like the identity function, except that it
|
||||
can be configured to add additional keys to the output, if the input is a
|
||||
dict.
|
||||
|
||||
The examples below demonstrate this runnable works using a few simple
|
||||
chains. The chains rely on simple lambdas to make the examples easy to execute
|
||||
and experiment with.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
|
||||
|
||||
runnable = RunnableParallel(
|
||||
origin=RunnablePassthrough(),
|
||||
modified=lambda x: x+1
|
||||
)
|
||||
|
||||
runnable.invoke(1) # {'origin': 1, 'modified': 2}
|
||||
|
||||
|
||||
def fake_llm(prompt: str) -> str: # Fake LLM for the example
|
||||
return "completion"
|
||||
|
||||
chain = RunnableLambda(fake_llm) | {
|
||||
'original': RunnablePassthrough(), # Original LLM output
|
||||
'parsed': lambda text: text[::-1] # Parsing logic
|
||||
}
|
||||
|
||||
chain.invoke('hello') # {'original': 'completion', 'parsed': 'noitelpmoc'}
|
||||
|
||||
In some cases, it may be useful to pass the input through while adding some
|
||||
keys to the output. In this case, you can use the `assign` method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
|
||||
|
||||
def fake_llm(prompt: str) -> str: # Fake LLM for the example
|
||||
return "completion"
|
||||
|
||||
runnable = {
|
||||
'llm1': fake_llm,
|
||||
'llm2': fake_llm,
|
||||
}
|
||||
| RunnablePassthrough.assign(
|
||||
total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2'])
|
||||
)
|
||||
|
||||
runnable.invoke('hello')
|
||||
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
|
||||
"""
|
||||
|
||||
input_type: Optional[Type[Other]] = None
|
||||
|
||||
func: Optional[
|
||||
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
|
||||
] = None
|
||||
|
||||
afunc: Optional[
|
||||
Union[
|
||||
Callable[[Other], Awaitable[None]],
|
||||
Callable[[Other, RunnableConfig], Awaitable[None]],
|
||||
]
|
||||
] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Optional[
|
||||
Union[
|
||||
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]],
|
||||
Union[
|
||||
Callable[[Other], Awaitable[None]],
|
||||
Callable[[Other, RunnableConfig], Awaitable[None]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
afunc: Optional[
|
||||
Union[
|
||||
Callable[[Other], Awaitable[None]],
|
||||
Callable[[Other, RunnableConfig], Awaitable[None]],
|
||||
]
|
||||
] = None,
|
||||
*,
|
||||
input_type: Optional[Type[Other]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
afunc = func
|
||||
func = None
|
||||
|
||||
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
@classmethod
|
||||
def assign(
|
||||
cls,
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
Mapping[
|
||||
str,
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> RunnableAssign:
|
||||
"""Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
mapping: A mapping from keys to runnables or callables.
|
||||
|
||||
Returns:
|
||||
A runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Other:
|
||||
if self.func is not None:
|
||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Other:
|
||||
if self.afunc is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self.afunc, input, config or {}, **kwargs
|
||||
)
|
||||
elif self.func is not None:
|
||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
||||
return await self._acall_with_config(aidentity, input, config)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Other]:
|
||||
if self.func is None:
|
||||
for chunk in self._transform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
else:
|
||||
final = None
|
||||
|
||||
for chunk in self._transform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
|
||||
if final is not None:
|
||||
call_func_with_variable_args(self.func, final, config or {}, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Other],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Other]:
|
||||
if self.afunc is None and self.func is None:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, identity, config
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
final = None
|
||||
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, identity, config
|
||||
):
|
||||
yield chunk
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
|
||||
if final is not None:
|
||||
config = config or {}
|
||||
if self.afunc is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self.afunc, final, config, **kwargs
|
||||
)
|
||||
elif self.func is not None:
|
||||
call_func_with_variable_args(self.func, final, config, **kwargs)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Other]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Other,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Other]:
|
||||
async def input_aiter() -> AsyncIterator[Other]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
"""
|
||||
|
||||
mapper: RunnableParallel[Dict[str, Any]]
|
||||
|
||||
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
super().__init__(mapper=mapper, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
return super().get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
):
|
||||
# ie. both are dicts
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableAssignOutput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for s in (map_input_schema, map_output_schema)
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
)
|
||||
elif not map_output_schema.__custom_root_type__:
|
||||
# ie. only map output is a dict
|
||||
# ie. input type is either unknown or inferred incorrectly
|
||||
return map_output_schema
|
||||
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
return {
|
||||
**input,
|
||||
**self.mapper.invoke(input, config, **kwargs),
|
||||
}
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
return {
|
||||
**input,
|
||||
**await self.mapper.ainvoke(input, config, **kwargs),
|
||||
}
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.transform(for_map, config, **kwargs)
|
||||
# get executor to start map output stream in background
|
||||
with get_executor_for_config(config or {}) as executor:
|
||||
# start map output stream
|
||||
first_map_chunk_future = executor.submit(
|
||||
next,
|
||||
map_output, # type: ignore
|
||||
None,
|
||||
)
|
||||
# consume passthrough stream
|
||||
for chunk in for_passthrough:
|
||||
assert isinstance(
|
||||
chunk, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
)
|
||||
if filtered:
|
||||
yield filtered
|
||||
# yield map output
|
||||
yield cast(Dict[str, Any], first_map_chunk_future.result())
|
||||
for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.atransform(for_map, config, **kwargs)
|
||||
# start map output stream
|
||||
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
||||
py_anext(map_output, None), # type: ignore[arg-type]
|
||||
)
|
||||
# consume passthrough stream
|
||||
async for chunk in for_passthrough:
|
||||
assert isinstance(
|
||||
chunk, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map output
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
)
|
||||
if filtered:
|
||||
yield filtered
|
||||
# yield map output
|
||||
yield await first_map_chunk_task
|
||||
async for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
yield chunk
|
@ -0,0 +1,337 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryCallState,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain_core.runnables.base import Input, Output, RunnableBindingBase
|
||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
"""Retry a Runnable if it fails.
|
||||
|
||||
A RunnableRetry helps can be used to add retry logic to any object
|
||||
that subclasses the base Runnable.
|
||||
|
||||
Such retries are especially useful for network calls that may fail
|
||||
due to transient errors.
|
||||
|
||||
The RunnableRetry is implemented as a RunnableBinding. The easiest
|
||||
way to use it is through the `.with_retry()` method on all Runnables.
|
||||
|
||||
Example:
|
||||
|
||||
Here's an example that uses a RunnableLambda to raise an exception
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
def foo(input) -> None:
|
||||
'''Fake function that raises an exception.'''
|
||||
raise ValueError("Invoking foo failed. At time {time.time()}")
|
||||
|
||||
runnable = RunnableLambda(foo)
|
||||
|
||||
runnable_with_retries = runnable.with_retry(
|
||||
retry_exception_types=(ValueError,), # Retry only on ValueError
|
||||
wait_exponential_jitter=True, # Add jitter to the exponential backoff
|
||||
max_attempt_number=2, # Try twice
|
||||
)
|
||||
|
||||
# The method invocation above is equivalent to the longer form below:
|
||||
|
||||
runnable_with_retries = RunnableRetry(
|
||||
bound=runnable,
|
||||
retry_exception_types=(ValueError,),
|
||||
max_attempt_number=2,
|
||||
wait_exponential_jitter=True
|
||||
)
|
||||
|
||||
This logic can be used to retry any Runnable, including a chain of Runnables,
|
||||
but in general it's best practice to keep the scope of the retry as small as
|
||||
possible. For example, if you have a chain of Runnables, you should only retry
|
||||
the Runnable that is likely to fail, not the entire chain.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.chat_models import ChatOpenAI
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
template = PromptTemplate.from_template("tell me a joke about {topic}.")
|
||||
model = ChatOpenAI(temperature=0.5)
|
||||
|
||||
# Good
|
||||
chain = template | model.with_retry()
|
||||
|
||||
# Bad
|
||||
chain = template | model
|
||||
retryable_chain = chain.with_retry()
|
||||
"""
|
||||
|
||||
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
"""The exception types to retry on. By default all exceptions are retried.
|
||||
|
||||
In general you should only retry on exceptions that are likely to be
|
||||
transient, such as network errors.
|
||||
|
||||
Good exceptions to retry are all server errors (5xx) and selected client
|
||||
errors (4xx) such as 429 Too Many Requests.
|
||||
"""
|
||||
|
||||
wait_exponential_jitter: bool = True
|
||||
"""Whether to add jitter to the exponential backoff."""
|
||||
|
||||
max_attempt_number: int = 3
|
||||
"""The maximum number of attempts to retry the runnable."""
|
||||
|
||||
@property
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = dict()
|
||||
|
||||
if self.max_attempt_number:
|
||||
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
|
||||
|
||||
if self.wait_exponential_jitter:
|
||||
kwargs["wait"] = wait_exponential_jitter()
|
||||
|
||||
if self.retry_exception_types:
|
||||
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _sync_retrying(self, **kwargs: Any) -> Retrying:
|
||||
return Retrying(**self._kwargs_retrying, **kwargs)
|
||||
|
||||
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
|
||||
return AsyncRetrying(**self._kwargs_retrying, **kwargs)
|
||||
|
||||
def _patch_config(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
run_manager: "T",
|
||||
retry_state: RetryCallState,
|
||||
) -> RunnableConfig:
|
||||
attempt = retry_state.attempt_number
|
||||
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
|
||||
return patch_config(config, callbacks=run_manager.get_child(tag))
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: List[RunnableConfig],
|
||||
run_manager: List["T"],
|
||||
retry_state: RetryCallState,
|
||||
) -> List[RunnableConfig]:
|
||||
return [
|
||||
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||
]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: "CallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
for attempt in self._sync_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input,
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
**kwargs,
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: "AsyncCallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
async for attempt in self._async_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input,
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
**kwargs,
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List["CallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = super().batch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
return self._batch_with_config(
|
||||
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List["AsyncCallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = await super().abatch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
return await self._abatch_with_config(
|
||||
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
# stream() and transform() are not retried because retrying a stream
|
||||
# is not very intuitive.
|
@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Input,
|
||||
Output,
|
||||
Runnable,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
gather_with_concurrency,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
"""A Router input.
|
||||
|
||||
Attributes:
|
||||
key: The key to route on.
|
||||
input: The input to pass to the selected runnable.
|
||||
"""
|
||||
|
||||
key: str
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Any, Output]]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.runnables.values() for spec in step.config_specs
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: RouterInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return await runnable.ainvoke(actual_input, config)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
def invoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output],
|
||||
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
async def ainvoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
return await gather_with_concurrency(
|
||||
configs[0].get("max_concurrency"),
|
||||
*(
|
||||
ainvoke(runnable, input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: RouterInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
yield from runnable.stream(actual_input, config)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: RouterInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
async for output in runnable.astream(actual_input, config):
|
||||
yield output
|
@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
Input = TypeVar("Input", contravariant=True)
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output", covariant=True)
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
"""Run a coroutine with a semaphore.
|
||||
Args:
|
||||
semaphore: The semaphore to use.
|
||||
coro: The coroutine to run.
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
"""
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
"""Gather coroutines with a limit on the number of concurrent coroutines."""
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a run_manager argument."""
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a config argument."""
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
if (
|
||||
isinstance(node.ctx, ast.Load)
|
||||
and isinstance(node.value, ast.Name)
|
||||
and node.value.id == self.name
|
||||
and isinstance(node.slice, ast.Constant)
|
||||
and isinstance(node.slice.value, str)
|
||||
):
|
||||
# we've found a subscript access on the name we're looking for
|
||||
self.keys.add(node.slice.value)
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and isinstance(node.func.value, ast.Name)
|
||||
and node.func.value.id == self.name
|
||||
and node.func.attr == "get"
|
||||
and len(node.args) in (1, 2)
|
||||
and isinstance(node.args[0], ast.Constant)
|
||||
and isinstance(node.args[0].value, str)
|
||||
):
|
||||
# we've found a .get() call on the name we're looking for
|
||||
self.keys.add(node.args[0].value)
|
||||
|
||||
|
||||
class IsFunctionArgDict(ast.NodeVisitor):
|
||||
"""Check if the first argument of a function is a dict."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.keys: Set[str] = set()
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
|
||||
class GetLambdaSource(ast.NodeVisitor):
|
||||
"""Get the source code of a lambda function."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the visitor."""
|
||||
self.source: Optional[str] = None
|
||||
self.count = 0
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function."""
|
||||
self.count += 1
|
||||
if hasattr(ast, "unparse"):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
"""Get the keys of the first argument of a function if it is a dict."""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
"""Get the source code of a lambda function.
|
||||
|
||||
Args:
|
||||
func: a callable that can be a lambda function
|
||||
|
||||
Returns:
|
||||
str: the source code of the lambda function
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = GetLambdaSource()
|
||||
visitor.visit(tree)
|
||||
return visitor.source if visitor.count == 1 else None
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
"""Indent all lines of text after the first line.
|
||||
|
||||
Args:
|
||||
text: The text to indent
|
||||
prefix: Used to determine the number of spaces to indent
|
||||
|
||||
Returns:
|
||||
str: The indented text
|
||||
"""
|
||||
n_spaces = len(prefix)
|
||||
spaces = " " * n_spaces
|
||||
lines = text.splitlines()
|
||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||
|
||||
|
||||
class AddableDict(Dict[str, Any]):
|
||||
"""
|
||||
Dictionary that can be added to another dictionary.
|
||||
"""
|
||||
|
||||
def __add__(self, other: AddableDict) -> AddableDict:
|
||||
chunk = AddableDict(self)
|
||||
for key in other:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = other[key]
|
||||
elif other[key] is not None:
|
||||
try:
|
||||
added = chunk[key] + other[key]
|
||||
except TypeError:
|
||||
added = other[key]
|
||||
chunk[key] = added
|
||||
return chunk
|
||||
|
||||
def __radd__(self, other: AddableDict) -> AddableDict:
|
||||
chunk = AddableDict(other)
|
||||
for key in self:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = self[key]
|
||||
elif self[key] is not None:
|
||||
try:
|
||||
added = chunk[key] + self[key]
|
||||
except TypeError:
|
||||
added = self[key]
|
||||
chunk[key] = added
|
||||
return chunk
|
||||
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
_T_contra = TypeVar("_T_contra", contravariant=True)
|
||||
|
||||
|
||||
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
||||
"""Protocol for objects that support addition."""
|
||||
|
||||
def __add__(self, __x: _T_contra) -> _T_co:
|
||||
...
|
||||
|
||||
|
||||
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
||||
|
||||
|
||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
"""Add a sequence of addable objects together."""
|
||||
final = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
return final
|
||||
|
||||
|
||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
"""Asynchronously add a sequence of addable objects together."""
|
||||
final = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
return final
|
||||
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
"""A field that can be configured by the user."""
|
||||
|
||||
id: str
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
annotation: Optional[Any] = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, self.annotation))
|
||||
|
||||
|
||||
class ConfigurableFieldSingleOption(NamedTuple):
|
||||
"""A field that can be configured by the user with a default value."""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: str
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, tuple(self.options.keys()), self.default))
|
||||
|
||||
|
||||
class ConfigurableFieldMultiOption(NamedTuple):
|
||||
"""A field that can be configured by the user with multiple default values."""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: Sequence[str]
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
|
||||
|
||||
|
||||
AnyConfigurableField = Union[
|
||||
ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption
|
||||
]
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
"""A field that can be configured by the user. It is a specification of a field."""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
|
||||
default: Any
|
||||
annotation: Any
|
||||
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
others = list(dupes)
|
||||
if len(others) == 0:
|
||||
unique.append(first)
|
||||
elif all(o == first for o in others):
|
||||
unique.append(first)
|
||||
else:
|
||||
raise ValueError(
|
||||
"RunnableSequence contains conflicting config specs"
|
||||
f"for {id}: {[first] + others}"
|
||||
)
|
||||
return unique
|
@ -0,0 +1,78 @@
|
||||
"""**Schemas** are the LangChain Base Classes and Interfaces."""
|
||||
from langchain_core.schema.agent import AgentAction, AgentFinish
|
||||
from langchain_core.schema.cache import BaseCache
|
||||
from langchain_core.schema.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.schema.document import BaseDocumentTransformer, Document
|
||||
from langchain_core.schema.exceptions import LangChainException
|
||||
from langchain_core.schema.memory import BaseMemory
|
||||
from langchain_core.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
_message_from_dict,
|
||||
_message_to_dict,
|
||||
get_buffer_string,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
Generation,
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.schema.output_parser import (
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
OutputParserException,
|
||||
StrOutputParser,
|
||||
)
|
||||
from langchain_core.schema.prompt import PromptValue
|
||||
from langchain_core.schema.prompt_template import BasePromptTemplate, format_document
|
||||
from langchain_core.schema.retriever import BaseRetriever
|
||||
from langchain_core.schema.storage import BaseStore
|
||||
|
||||
RUN_KEY = "__run"
|
||||
Memory = BaseMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseCache",
|
||||
"BaseMemory",
|
||||
"BaseStore",
|
||||
"AgentFinish",
|
||||
"AgentAction",
|
||||
"Document",
|
||||
"BaseChatMessageHistory",
|
||||
"BaseDocumentTransformer",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"FunctionMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"get_buffer_string",
|
||||
"RunInfo",
|
||||
"LLMResult",
|
||||
"ChatResult",
|
||||
"ChatGeneration",
|
||||
"Generation",
|
||||
"PromptValue",
|
||||
"LangChainException",
|
||||
"BaseRetriever",
|
||||
"RUN_KEY",
|
||||
"Memory",
|
||||
"OutputParserException",
|
||||
"StrOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BasePromptTemplate",
|
||||
"format_document",
|
||||
]
|
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class AgentAction(Serializable):
|
||||
"""A full description of an action for an ActionAgent to execute."""
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action.
|
||||
This log can be used in a few ways. First, it can be used to audit
|
||||
what exactly the LLM predicted to lead to this (tool, tool_input).
|
||||
Second, it can be used in future iterations to show the LLMs prior
|
||||
thoughts. This is useful when (tool, tool_input) does not contain
|
||||
full information about the LLM prediction (for example, any `thought`
|
||||
before the tool/tool_input)."""
|
||||
type: Literal["AgentAction"] = "AgentAction"
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class AgentActionMessageLog(AgentAction):
|
||||
message_log: Sequence[BaseMessage]
|
||||
"""Similar to log, this can be used to pass along extra
|
||||
information about what exact messages were predicted by the LLM
|
||||
before parsing out the (tool, tool_input). This is again useful
|
||||
if (tool, tool_input) cannot be used to fully recreate the LLM
|
||||
prediction, and you need that LLM prediction (for future agent iteration).
|
||||
Compared to `log`, this is useful when the underlying LLM is a
|
||||
ChatModel (and therefore returns messages rather than a string)."""
|
||||
# Ignoring type because we're overriding the type from AgentAction.
|
||||
# And this is the correct thing to do in this case.
|
||||
# The type literal is used for serialization purposes.
|
||||
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
|
||||
|
||||
|
||||
class AgentFinish(Serializable):
|
||||
"""The final return value of an ActionAgent."""
|
||||
|
||||
return_values: dict
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value.
|
||||
This is used to pass along the full LLM prediction, not just the parsed out
|
||||
return value. For example, if the full LLM prediction was
|
||||
`Final Answer: 2` you may want to just return `2` as a return value, but pass
|
||||
along the full string as a `log` (for debugging or observability purposes).
|
||||
"""
|
||||
type: Literal["AgentFinish"] = "AgentFinish"
|
||||
|
||||
def __init__(self, return_values: dict, log: str, **kwargs: Any):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(return_values=return_values, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from langchain_core.schema.output import Generation
|
||||
|
||||
RETURN_VAL_TYPE = Sequence[Generation]
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
"""Base interface for cache."""
|
||||
|
||||
@abstractmethod
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
|
||||
@abstractmethod
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache that can take additional keyword arguments."""
|
@ -0,0 +1,13 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
from langchain_core.schema import BaseMessage
|
||||
|
||||
|
||||
class ChatSession(TypedDict, total=False):
|
||||
"""Chat Session represents a single
|
||||
conversation, channel, or other group of messages."""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""The LangChain chat messages loaded from the source."""
|
||||
functions: Sequence[dict]
|
||||
"""The function calling specs for the messages."""
|
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
See `ChatMessageHistory` for default implementation.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
storage_path: str
|
||||
session_id: str
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
messages = self.messages.append(_message_to_dict(message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
def clear(self):
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""A list of Messages stored in-memory."""
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of a human message.
|
||||
"""
|
||||
self.add_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of an AI message.
|
||||
"""
|
||||
self.add_message(AIMessage(content=message))
|
||||
|
||||
@abstractmethod
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Literal, Sequence
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
class Document(Serializable):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
"""String text."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
type: Literal["Document"] = "Document"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
||||
A document transformation system takes a sequence of Documents and returns a
|
||||
sequence of transformed Documents.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
included_idxs = _filter_similar_embeddings(
|
||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||
)
|
||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
@ -0,0 +1,2 @@
|
||||
class LangChainException(Exception):
|
||||
"""General LangChain exception."""
|
@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.schema.output import LLMResult
|
||||
from langchain_core.schema.prompt import PromptValue
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_token_ids. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-2 tokenizer instance
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# tokenize the text using the GPT-2 tokenizer
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
|
||||
LanguageModelOutput = TypeVar("LanguageModelOutput")
|
||||
|
||||
|
||||
class BaseLanguageModel(
|
||||
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
All language model wrappers inherit from BaseLanguageModel.
|
||||
|
||||
Exposes three main methods:
|
||||
- generate_prompt: generate language model outputs for a sequence of prompt
|
||||
values. A prompt value is a model input that can be converted to any language
|
||||
model input format (string or messages).
|
||||
- predict: pass in a single string to a language model and return a string
|
||||
prediction.
|
||||
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
|
||||
model call) to a language model and return a BaseMessage prediction.
|
||||
|
||||
Each of these has an equivalent asynchronous method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompts.base import StringPromptValue
|
||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to the model and return model generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts and return model generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Pass a single string input to the model and return a string prediction.
|
||||
|
||||
Use this method when passing in raw text. If you want to pass in specific
|
||||
types of chat messages, use predict_messages.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Pass a message sequence to the model and return a message prediction.
|
||||
|
||||
Use this method when passing in chat messages. If you want to pass in raw text,
|
||||
use predict.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Asynchronously pass a string to the model and return a string prediction.
|
||||
|
||||
Use this method when calling pure text generation models and only the top
|
||||
candidate generation is needed.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Asynchronously pass messages to the model and return a message prediction.
|
||||
|
||||
Use this method when calling chat models and only the top
|
||||
candidate generation is needed.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
"""
|
||||
return _get_token_ids_default_method(text)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> Set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
"""
|
||||
return get_pydantic_field_names(cls)
|
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Abstract base class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||
|
||||
Args:
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.schema import AIMessage, HumanMessage
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
AIMessage(content="Good, how are you?"),
|
||||
]
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
"""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
message = f"{role}: {m.content}"
|
||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
class BaseMessage(Serializable):
|
||||
"""The base abstract Message class.
|
||||
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: Union[str, List[Union[str, Dict]]]
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Any additional information."""
|
||||
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
second_content: Union[str, List[Union[str, Dict]]],
|
||||
) -> Union[str, List[Union[str, Dict]]]:
|
||||
# If first chunk is a string
|
||||
if isinstance(first_content, str):
|
||||
# If the second chunk is also a string, then merge them naively
|
||||
if isinstance(second_content, str):
|
||||
return first_content + second_content
|
||||
# If the second chunk is a list, add the first chunk to the start of the list
|
||||
else:
|
||||
return_list: List[Union[str, Dict]] = [first_content]
|
||||
return return_list + second_content
|
||||
# If both are lists, merge them naively
|
||||
elif isinstance(second_content, List):
|
||||
return first_content + second_content
|
||||
# If the first content is a list, and the second content is a string
|
||||
else:
|
||||
# If the last element of the first content is a string
|
||||
# Add the second content to the last element
|
||||
if isinstance(first_content[-1], str):
|
||||
return first_content[:-1] + [first_content[-1] + second_content]
|
||||
else:
|
||||
# Otherwise, add the second content as a new element of the list
|
||||
return first_content + [second_content]
|
||||
|
||||
|
||||
class BaseMessageChunk(BaseMessage):
|
||||
"""A Message chunk, which can be concatenated with other Message chunks."""
|
||||
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif type(merged[k]) != type(v):
|
||||
raise ValueError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Additional kwargs key {k} already exists in this message."
|
||||
)
|
||||
return merged
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, BaseMessageChunk):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
# concat into a single BaseMessageChunk
|
||||
|
||||
if isinstance(self, ChatMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
return self.__class__(
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
'unsupported operand type(s) for +: "'
|
||||
f"{self.__class__.__name__}"
|
||||
f'" and "{other.__class__.__name__}"'
|
||||
)
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
type: Literal["human"] = "human"
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
"""A Human Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""A Message from an AI."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""A Message chunk from an AI."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
if self.example != other.example:
|
||||
raise ValueError(
|
||||
"Cannot concatenate AIMessageChunks with different example values."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||
of input messages.
|
||||
"""
|
||||
|
||||
type: Literal["system"] = "system"
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
"""A System Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a function back to a model."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""A Function Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, FunctionMessageChunk):
|
||||
if self.name != other.name:
|
||||
raise ValueError(
|
||||
"Cannot concatenate FunctionMessageChunks with different names."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a tool back to a model."""
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
"""A Tool Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ToolMessageChunk):
|
||||
if self.tool_call_id != other.tool_call_id:
|
||||
raise ValueError(
|
||||
"Cannot concatenate ToolMessageChunks with different names."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
tool_call_id=self.tool_call_id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||
|
||||
role: str
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
type: Literal["chat"] = "chat"
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
"""A Chat Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ChatMessageChunk):
|
||||
if self.role != other.role:
|
||||
raise ValueError(
|
||||
"Cannot concatenate ChatMessageChunks with different roles."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
AnyMessage = Union[
|
||||
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
||||
]
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
"""Convert a sequence of Messages to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as BaseMessages) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages as dicts.
|
||||
"""
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
elif _type == "function":
|
||||
return FunctionMessage(**message["data"])
|
||||
elif _type == "tool":
|
||||
return ToolMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected message type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as dicts) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_message_from_dict(m) for m in messages]
|
@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.schema.messages import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""A single text generation output."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw response from the provider. May include things like the
|
||||
reason for finishing or token log probabilities.
|
||||
"""
|
||||
type: Literal["Generation"] = "Generation"
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
# TODO: add log probs as separate attribute
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class GenerationChunk(Generation):
|
||||
"""A Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
if isinstance(other, GenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return GenerationChunk(
|
||||
text=self.text + other.text,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output."""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
try:
|
||||
values["text"] = values["message"].content
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
"""A ChatGeneration chunk, which can be concatenated with other
|
||||
ChatGeneration chunks.
|
||||
|
||||
Attributes:
|
||||
message: The message chunk output by the chat model.
|
||||
"""
|
||||
|
||||
message: BaseMessageChunk
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return ChatGenerationChunk(
|
||||
message=self.message + other.message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||
|
||||
run_id: UUID
|
||||
"""A unique identifier for the model or chain run."""
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
"""Class that contains all results for a single chat model call."""
|
||||
|
||||
generations: List[ChatGeneration]
|
||||
"""List of the chat generations. This is a List because an input can have multiple
|
||||
candidate generations.
|
||||
"""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""Class that contains all results for a batched LLM call."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of generated outputs. This is a List[List[]] because
|
||||
each input could have multiple candidate generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""Arbitrary LLM provider-specific output."""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list.
|
||||
|
||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||
contains only a single Generation. If token usage information is available,
|
||||
it is kept only for the LLMResult corresponding to the top-choice
|
||||
Generation, to avoid over-counting of token usage downstream.
|
||||
|
||||
Returns:
|
||||
List of LLMResults where each returned LLMResult contains a single
|
||||
Generation.
|
||||
"""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = deepcopy(self.llm_output)
|
||||
llm_output["token_usage"] = dict()
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
@ -0,0 +1,475 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
Generation,
|
||||
GenerationChunk,
|
||||
)
|
||||
from langchain_core.schema.prompt import PromptValue
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | BaseMessage,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
true_val: str = "YES"
|
||||
false_val: str = "NO"
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
cleaned_text = text.strip().upper()
|
||||
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
||||
raise OutputParserException(
|
||||
f"BooleanOutputParser expected output value to either be "
|
||||
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
||||
f"Received {cleaned_text}."
|
||||
)
|
||||
return cleaned_text == self.true_val.upper()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
"Override the OutputType property to specify the output type."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | BaseMessage,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, functools.partial(self.parse_result, partial=partial), result
|
||||
)
|
||||
|
||||
async def aparse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
|
||||
The prompt is largely provided in the event the OutputParser wants
|
||||
to retry or fix the output in some way, and needs information from
|
||||
the prompt to do so.
|
||||
|
||||
Args:
|
||||
completion: String output of a language model.
|
||||
prompt: Input PromptValue.
|
||||
|
||||
Returns:
|
||||
Structured output
|
||||
"""
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
try:
|
||||
output_parser_dict["_type"] = self._type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[T]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, run_type="parser"
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, run_type="parser"
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
diff: bool = False
|
||||
"""In streaming mode, whether to yield diffs between the previous and current
|
||||
parsed output, or just the current parsed output.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||
up to the output parser."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
return "default"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
# TODO: Deprecate
|
||||
NoOpOutputParser = StrOutputParser
|
||||
|
||||
|
||||
class OutputParserException(ValueError):
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
that also may arise inside the output parser. OutputParserExceptions will be
|
||||
available to catch and handle in ways to fix the parsing error, while other
|
||||
errors will be raised.
|
||||
|
||||
Args:
|
||||
error: The error that's being re-raised or an error message.
|
||||
observation: String explanation of error which can be passed to a
|
||||
model to try and remediate the issue.
|
||||
llm_output: String model output which is error-ing.
|
||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||
after an OutputParserException has been raised. This gives the underlying
|
||||
model driving the agent the context that the previous output was improperly
|
||||
structured, in the hopes that it will update the output to the correct
|
||||
format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: Any,
|
||||
observation: Optional[str] = None,
|
||||
llm_output: Optional[str] = None,
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super(OutputParserException, self).__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.schema.document import Document
|
||||
from langchain_core.schema.output_parser import BaseOutputParser
|
||||
from langchain_core.schema.prompt import PromptValue
|
||||
|
||||
|
||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""A dictionary of the types of the variables the prompt template expects.
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
from langchain_core.prompts.base import StringPromptValue
|
||||
from langchain_core.prompts.chat import ChatPromptValueConcrete
|
||||
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"PromptInput",
|
||||
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
) -> PromptValue:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.format_prompt(
|
||||
**{key: inner_input[key] for key in self.input_variables}
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="prompt",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
"""Return a partial of the prompt template."""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if isinstance(v, str) else v()
|
||||
for k, v in self.partial_variables.items()
|
||||
}
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of prompt."""
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
try:
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
Args:
|
||||
file_path: Path to directory to save prompt to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
if "_type" not in prompt_dict:
|
||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
|
||||
1. `page_content`:
|
||||
This takes the information from the `document.page_content`
|
||||
and assigns it to a variable named `page_content`.
|
||||
2. metadata:
|
||||
This takes information from `document.metadata` and assigns
|
||||
it to variables of the same name.
|
||||
|
||||
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||
|
||||
Args:
|
||||
doc: Document, the page_content and metadata will be used to create
|
||||
the final string.
|
||||
prompt: BasePromptTemplate, will be used to format the page_content
|
||||
and metadata into the final string.
|
||||
|
||||
Returns:
|
||||
string of the document formatted.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.schema import Document
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
||||
format_document(doc, prompt)
|
||||
>>> "Page 1: This is a joke"
|
||||
"""
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.schema.document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
the most 'relevant' Documents from some source.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
vectorizer: Any
|
||||
docs: List[Document]
|
||||
tfidf_array: Any
|
||||
k: int = 4
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
query_vec = self.vectorizer.transform([query])
|
||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
||||
)
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
else:
|
||||
result = self._get_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
@ -0,0 +1,53 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class BaseStore(Generic[K, V], ABC):
|
||||
"""Abstract interface for a key-value store."""
|
||||
|
||||
@abstractmethod
|
||||
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[K]): A sequence of keys.
|
||||
|
||||
Returns:
|
||||
A sequence of optional values associated with the keys.
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def mdelete(self, keys: Sequence[K]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[K]): A sequence of keys to delete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def yield_keys(
|
||||
self, *, prefix: Optional[str] = None
|
||||
) -> Union[Iterator[K], Iterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix to match.
|
||||
|
||||
Returns:
|
||||
Iterator[K | str]: An iterator over keys that match the given prefix.
|
||||
|
||||
This method is allowed to return an iterator over either K or str
|
||||
depending on what makes more sense for the given store.
|
||||
"""
|
@ -0,0 +1,702 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.schema import BaseRetriever
|
||||
from langchain_core.schema.document import Document
|
||||
from langchain_core.schema.embeddings import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VST = TypeVar("VST", bound="VectorStore")
|
||||
|
||||
|
||||
class VectorStore(ABC):
|
||||
"""Interface for vector store."""
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
kwargs: vectorstore specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
logger.debug(
|
||||
f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("delete method must be implemented by subclass.")
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("delete method must be implemented by subclass.")
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||
)
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
documents (List[Document]: Documents to add to the vectorstore.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return self.add_texts(texts, metadatas, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
documents (List[Document]: Documents to add to the vectorstore.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await self.aadd_texts(texts, metadatas, **kwargs)
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
"""Return docs most similar to query using specified search type."""
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity' or 'mmr'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query using specified search type."""
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity' or 'mmr'."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
||||
@staticmethod
|
||||
def _euclidean_relevance_score_fn(distance: float) -> float:
|
||||
"""Return a similarity score on a scale [0, 1]."""
|
||||
# The 'correct' relevance function
|
||||
# may differ depending on a few things, including:
|
||||
# - the distance / similarity metric used by the VectorStore
|
||||
# - the scale of your embeddings (OpenAI's are unit normed. Many
|
||||
# others are not!)
|
||||
# - embedding dimensionality
|
||||
# - etc.
|
||||
# This function converts the euclidean norm of normalized embeddings
|
||||
# (0 is most similar, sqrt(2) most dissimilar)
|
||||
# to a similarity function (0 to 1)
|
||||
return 1.0 - distance / math.sqrt(2)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_relevance_score_fn(distance: float) -> float:
|
||||
"""Normalize the distance to a score on a scale [0, 1]."""
|
||||
|
||||
return 1.0 - distance
|
||||
|
||||
@staticmethod
|
||||
def _max_inner_product_relevance_score_fn(distance: float) -> float:
|
||||
"""Normalize the distance to a score on a scale [0, 1]."""
|
||||
if distance > 0:
|
||||
return 1.0 - distance
|
||||
|
||||
return -1.0 * distance
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
|
||||
Vectorstores should define their own selection based method of relevance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with distance."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with distance asynchronously."""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Default similarity search with relevance scores. Modify if necessary
|
||||
in subclass.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, **kwargs)
|
||||
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
||||
async def _asimilarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Default async similarity search with relevance scores. Modify if necessary
|
||||
in subclass.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs)
|
||||
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
||||
def similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
score_threshold = kwargs.pop("score_threshold", None)
|
||||
|
||||
docs_and_similarities = self._similarity_search_with_relevance_scores(
|
||||
query, k=k, **kwargs
|
||||
)
|
||||
if any(
|
||||
similarity < 0.0 or similarity > 1.0
|
||||
for _, similarity in docs_and_similarities
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
docs_and_similarities = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs_and_similarities
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
return docs_and_similarities
|
||||
|
||||
async def asimilarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1], asynchronously.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
score_threshold = kwargs.pop("score_threshold", None)
|
||||
|
||||
docs_and_similarities = await self._asimilarity_search_with_relevance_scores(
|
||||
query, k=k, **kwargs
|
||||
)
|
||||
if any(
|
||||
similarity < 0.0 or similarity > 1.0
|
||||
for _, similarity in docs_and_similarities
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
docs_and_similarities = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs_and_similarities
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
return docs_and_similarities
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector."""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def amax_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.max_marginal_relevance_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
|
||||
)
|
||||
|
||||
def _get_retriever_tags(self) -> List[str]:
|
||||
"""Get tags for retriever."""
|
||||
tags = [self.__class__.__name__]
|
||||
if self.embeddings:
|
||||
tags.append(self.embeddings.__class__.__name__)
|
||||
return tags
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
|
||||
"""Return VectorStoreRetriever initialized from this VectorStore.
|
||||
|
||||
Args:
|
||||
search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be "similarity" (default), "mmr", or
|
||||
"similarity_score_threshold".
|
||||
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
k: Amount of documents to return (Default: 4)
|
||||
score_threshold: Minimum relevance threshold
|
||||
for similarity_score_threshold
|
||||
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
|
||||
lambda_mult: Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5)
|
||||
filter: Filter by document metadata
|
||||
|
||||
Returns:
|
||||
VectorStoreRetriever: Retriever class for VectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
# Use a filter to only retrieve documents from a specific paper
|
||||
docsearch.as_retriever(
|
||||
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
|
||||
)
|
||||
"""
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self._get_retriever_tags())
|
||||
return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class VectorStoreRetriever(BaseRetriever):
|
||||
"""Base Retriever class for VectorStore."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""VectorStore to use for retrieval."""
|
||||
search_type: str = "similarity"
|
||||
"""Type of search to perform. Defaults to "similarity"."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the search function."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
search_type = values["search_type"]
|
||||
if search_type not in cls.allowed_search_types:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Valid values are: "
|
||||
f"{cls.allowed_search_types}"
|
||||
)
|
||||
if search_type == "similarity_score_threshold":
|
||||
score_threshold = values["search_kwargs"].get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
raise ValueError(
|
||||
"`score_threshold` is not specified with a float value(0~1) "
|
||||
"in `search_kwargs`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return await self.vectorstore.aadd_documents(documents, **kwargs)
|
@ -0,0 +1,845 @@
|
||||
"""Base implementation for tools or skills."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
fields[field_name] = (field.outer_type_, field.field_info)
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
func: Callable,
|
||||
) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
"""Configuration for the pydantic model."""
|
||||
|
||||
extra: Any = Extra.forbid
|
||||
arbitrary_types_allowed: bool = True
|
||||
|
||||
|
||||
def create_schema_from_function(
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydandic schema
|
||||
func: Function to generate the schema from
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
inferred_model = validated.model # type: ignore
|
||||
if "run_manager" in inferred_model.__fields__:
|
||||
del inferred_model.__fields__["run_manager"]
|
||||
if "callbacks" in inferred_model.__fields__:
|
||||
del inferred_model.__fields__["callbacks"]
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
||||
)
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
"""An optional exception that tool throws when execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
but will handle the exception according to the handle_tool_error
|
||||
variable of the tool, and the processing result will be returned
|
||||
to the agent as observation, and printed in red on the console.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""Create the definition of the new tool class."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
args_schema_type = cls.__annotations__.get("args_schema", None)
|
||||
|
||||
if args_schema_type is not None:
|
||||
if args_schema_type is None or args_schema_type == BaseModel:
|
||||
# Throw errors for common mis-annotations.
|
||||
# TODO: Use get_args / get_origin and fully
|
||||
# specify valid annotations.
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
name = cls.__name__
|
||||
raise SchemaAnnotationError(
|
||||
f"Tool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
)
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool.
|
||||
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
return_direct: bool = False
|
||||
"""Whether to return the tool's output directly. Setting this to True means
|
||||
|
||||
that after the tool is called, the AgentExecutor will stop looping.
|
||||
"""
|
||||
verbose: bool = False
|
||||
"""Whether to log the tool's progress."""
|
||||
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to be called during tool execution."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Deprecated. Please use callbacks instead."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the tool. Defaults to None
|
||||
These tags will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the tool. Defaults to None
|
||||
This metadata will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
keys = {k for k in self.args if k != "kwargs"}
|
||||
return len(keys) == 1
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
schema = create_schema_from_function(self.name, self._run)
|
||||
return schema.schema()["properties"]
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""The tool's input schema."""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema
|
||||
else:
|
||||
return create_schema_from_function(self.name, self._run)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
key_ = next(iter(input_args.__fields__.keys()))
|
||||
input_args.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
if input_args is not None:
|
||||
result = input_args.parse_obj(tool_input)
|
||||
return {k: v for k, v in result.dict().items() if k in tool_input}
|
||||
return tool_input
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously.
|
||||
|
||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
partial(self._run, **kwargs),
|
||||
*args,
|
||||
)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
return (tool_input,), {}
|
||||
else:
|
||||
return (), tool_input
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
||||
new_arg_supported = signature(self._run).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
|
||||
if new_arg_supported
|
||||
else self._run(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = signature(self._arun).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
|
||||
if new_arg_supported
|
||||
else await self._arun(*tool_args, **tool_kwargs)
|
||||
)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input, callbacks=callbacks)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""Tool that takes in function or coroutine directly."""
|
||||
|
||||
description: str = ""
|
||||
func: Optional[Callable[..., str]]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not self.coroutine:
|
||||
# If the tool does not implement async, fall back to default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The tool's input arguments."""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
# For backwards compatibility, if the function signature is ambiguous,
|
||||
# assume it takes a single string input.
|
||||
return {"tool_input": {"type": "string"}}
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
||||
# For backwards compatibility. The tool must be run with a single input
|
||||
all_args = list(args) + list(kwargs.values())
|
||||
if len(all_args) != 1:
|
||||
raise ToolException(
|
||||
f"Too many arguments to single-input tool {self.name}."
|
||||
f" Args: {all_args}"
|
||||
)
|
||||
return tuple(all_args), {}
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
else:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._run, run_manager=run_manager, **kwargs), *args
|
||||
)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize tool."""
|
||||
super(Tool, self).__init__(
|
||||
name=name, func=func, description=description, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Optional[Callable],
|
||||
name: str, # We keep these required to support backwards compatibility
|
||||
description: str,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
coroutine: Optional[
|
||||
Callable[..., Awaitable[Any]]
|
||||
] = None, # This is last for compatibility, but should be after func
|
||||
**kwargs: Any,
|
||||
) -> Tool:
|
||||
"""Initialize tool from a function."""
|
||||
if func is None and coroutine is None:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
args_schema=args_schema,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not self.coroutine:
|
||||
# If the tool does not implement async, fall back to default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The tool's input arguments."""
|
||||
return self.args_schema.schema()["properties"]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
partial(self._run, run_manager=run_manager, **kwargs),
|
||||
*args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Optional[Callable] = None,
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> StructuredTool:
|
||||
"""Create tool from a given function.
|
||||
|
||||
A classmethod that helps to create a tool from a function.
|
||||
|
||||
Args:
|
||||
func: The function from which to create a tool
|
||||
coroutine: The async function from which to create a tool
|
||||
name: The name of the tool. Defaults to the function name
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
The tool
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def add(a: int, b: int) -> int:
|
||||
\"\"\"Add two numbers\"\"\"
|
||||
return a + b
|
||||
tool = StructuredTool.from_function(add)
|
||||
tool.run(1, 2) # 3
|
||||
"""
|
||||
|
||||
if func is not None:
|
||||
source_function = func
|
||||
elif coroutine is not None:
|
||||
source_function = coroutine
|
||||
else:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
name = name or source_function.__name__
|
||||
description = description or source_function.__doc__
|
||||
if description is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if description not provided."
|
||||
)
|
||||
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
sig = signature(source_function)
|
||||
description = f"{name}{sig} - {description.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def tool(
|
||||
*args: Union[str, Callable, Runnable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the tool.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: optional argument schema for user to specify
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
- Function must have a docstring
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||
if isinstance(dec_func, Runnable):
|
||||
runnable = dec_func
|
||||
|
||||
if runnable.input_schema.schema().get("type") != "object":
|
||||
raise ValueError("Runnable must have an object schema.")
|
||||
|
||||
async def ainvoke_wrapper(
|
||||
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
return await runnable.ainvoke(kwargs, {"callbacks": callbacks})
|
||||
|
||||
def invoke_wrapper(
|
||||
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
return runnable.invoke(kwargs, {"callbacks": callbacks})
|
||||
|
||||
coroutine = ainvoke_wrapper
|
||||
func = invoke_wrapper
|
||||
schema: Optional[Type[BaseModel]] = runnable.input_schema
|
||||
description = repr(runnable)
|
||||
elif inspect.iscoroutinefunction(dec_func):
|
||||
coroutine = dec_func
|
||||
func = None
|
||||
schema = args_schema
|
||||
description = None
|
||||
else:
|
||||
coroutine = None
|
||||
func = dec_func
|
||||
schema = args_schema
|
||||
description = None
|
||||
|
||||
if infer_schema or args_schema is not None:
|
||||
return StructuredTool.from_function(
|
||||
func,
|
||||
coroutine,
|
||||
name=tool_name,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
args_schema=schema,
|
||||
infer_schema=infer_schema,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
if func.__doc__ is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if "
|
||||
"description not provided and infer_schema is False."
|
||||
)
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
description=f"{tool_name} tool",
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
|
||||
return _make_with_name(args[0])(args[1])
|
||||
elif len(args) == 1 and isinstance(args[0], str):
|
||||
# if the argument is a string, then we use the string as the tool name
|
||||
# Example usage: @tool("search", return_direct=True)
|
||||
return _make_with_name(args[0])
|
||||
elif len(args) == 1 and callable(args[0]):
|
||||
# if the argument is a function, then we use the function name as the tool name
|
||||
# Example usage: @tool
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
elif len(args) == 0:
|
||||
# if there are no arguments, then we use the function name as the tool name
|
||||
# Example usage: @tool(return_direct=True)
|
||||
def _partial(func: Callable[[str], str]) -> BaseTool:
|
||||
return _make_with_name(func.__name__)(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
@ -0,0 +1,38 @@
|
||||
"""
|
||||
**Utility functions** for LangChain.
|
||||
|
||||
These functions do not depend on any other LangChain module.
|
||||
"""
|
||||
|
||||
from langchain_core.utils.formatting import StrictFormatter, formatter
|
||||
from langchain_core.utils.input import (
|
||||
get_bolded_text,
|
||||
get_color_mapping,
|
||||
get_colored_text,
|
||||
print_text,
|
||||
)
|
||||
from langchain_core.utils.utils import (
|
||||
check_package_version,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
xor_args,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"StrictFormatter",
|
||||
"check_package_version",
|
||||
"convert_to_secret_str",
|
||||
"formatter",
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"mock_now",
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"xor_args",
|
||||
]
|
@ -0,0 +1,38 @@
|
||||
"""Utilities for formatting strings."""
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
"""A subclass of formatter that checks for extra keys."""
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Check to see if extra parameters are passed."""
|
||||
extra = set(kwargs).difference(used_args)
|
||||
if extra:
|
||||
raise KeyError(extra)
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
"""Check that no arguments are provided."""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"No arguments should be provided, "
|
||||
"everything should be passed as keyword arguments."
|
||||
)
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
) -> None:
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
@ -0,0 +1,42 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Dict, List, Optional, TextIO
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def get_bolded_text(text: str) -> str:
|
||||
"""Get bolded text."""
|
||||
return f"\033[1m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
@ -0,0 +1,54 @@
|
||||
"""Utilities for loading configurations from langchain_core-hub."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Callable, Optional, Set, TypeVar, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
|
||||
URL_BASE = os.environ.get(
|
||||
"LANGCHAIN_HUB_URL_BASE",
|
||||
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
|
||||
)
|
||||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
path: Union[str, Path],
|
||||
loader: Callable[[str], T],
|
||||
valid_prefix: str,
|
||||
valid_suffixes: Set[str],
|
||||
**kwargs: Any,
|
||||
) -> Optional[T]:
|
||||
"""Load configuration from hub. Returns None if path is not a hub path."""
|
||||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
|
||||
return None
|
||||
ref, remote_path_str = match.groups()
|
||||
ref = ref[1:] if ref else DEFAULT_REF
|
||||
remote_path = Path(remote_path_str)
|
||||
if remote_path.parts[0] != valid_prefix:
|
||||
return None
|
||||
if remote_path.suffix[1:] not in valid_suffixes:
|
||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||
|
||||
# Using Path with URLs is not recommended, because on Windows
|
||||
# the backslash is used as the path separator, which can cause issues
|
||||
# when working with URLs that use forward slashes as the path separator.
|
||||
# Instead, use PurePosixPath to ensure that forward slashes are used as the
|
||||
# path separator, regardless of the operating system.
|
||||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
|
||||
|
||||
r = requests.get(full_url, timeout=5)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = Path(tmpdirname) / remote_path.name
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return loader(str(file), **kwargs)
|
@ -0,0 +1,14 @@
|
||||
"""Utilities for tests."""
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""Get the major version of Pydantic."""
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
return int(pydantic.__version__.split(".")[0])
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
|
||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
@ -0,0 +1,180 @@
|
||||
"""Generic utility functions."""
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Validate exactly one arg in each group is not None."""
|
||||
counts = [
|
||||
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
||||
for arg_group in arg_groups
|
||||
]
|
||||
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
||||
if invalid_groups:
|
||||
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
||||
raise ValueError(
|
||||
"Exactly one argument in each of the following"
|
||||
" groups must be defined:"
|
||||
f" {', '.join(invalid_group_names)}"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_for_status_with_text(response: Response) -> None:
|
||||
"""Raise an error with the response text."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
raise ValueError(response.text) from e
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_now(dt_value): # type: ignore
|
||||
"""Context manager for mocking out datetime.now() in unit tests.
|
||||
|
||||
Example:
|
||||
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
|
||||
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
|
||||
"""
|
||||
|
||||
class MockDateTime(datetime.datetime):
|
||||
"""Mock datetime.datetime.now() with a fixed datetime."""
|
||||
|
||||
@classmethod
|
||||
def now(cls): # type: ignore
|
||||
# Create a copy of dt_value.
|
||||
return datetime.datetime(
|
||||
dt_value.year,
|
||||
dt_value.month,
|
||||
dt_value.day,
|
||||
dt_value.hour,
|
||||
dt_value.minute,
|
||||
dt_value.second,
|
||||
dt_value.microsecond,
|
||||
dt_value.tzinfo,
|
||||
)
|
||||
|
||||
real_datetime = datetime.datetime
|
||||
datetime.datetime = MockDateTime
|
||||
try:
|
||||
yield datetime.datetime
|
||||
finally:
|
||||
datetime.datetime = real_datetime
|
||||
|
||||
|
||||
def guard_import(
|
||||
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Dynamically imports a module and raises a helpful exception if the module is not
|
||||
installed."""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name or module_name}`."
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
def check_package_version(
|
||||
package: str,
|
||||
lt_version: Optional[str] = None,
|
||||
lte_version: Optional[str] = None,
|
||||
gt_version: Optional[str] = None,
|
||||
gte_version: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Check the version of a package."""
|
||||
imported_version = parse(version(package))
|
||||
if lt_version is not None and imported_version >= parse(lt_version):
|
||||
raise ValueError(
|
||||
f"Expected {package} version to be < {lt_version}. Received "
|
||||
f"{imported_version}."
|
||||
)
|
||||
if lte_version is not None and imported_version > parse(lte_version):
|
||||
raise ValueError(
|
||||
f"Expected {package} version to be <= {lte_version}. Received "
|
||||
f"{imported_version}."
|
||||
)
|
||||
if gt_version is not None and imported_version <= parse(gt_version):
|
||||
raise ValueError(
|
||||
f"Expected {package} version to be > {gt_version}. Received "
|
||||
f"{imported_version}."
|
||||
)
|
||||
if gte_version is not None and imported_version < parse(gte_version):
|
||||
raise ValueError(
|
||||
f"Expected {package} version to be >= {gte_version}. Received "
|
||||
f"{imported_version}."
|
||||
)
|
||||
|
||||
|
||||
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
||||
"""Get field names, including aliases, for a pydantic class.
|
||||
|
||||
Args:
|
||||
pydantic_cls: Pydantic class."""
|
||||
all_required_field_names = set()
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
|
||||
|
||||
def build_extra_kwargs(
|
||||
extra_kwargs: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
all_required_field_names: Set[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from values and extra_kwargs.
|
||||
|
||||
Args:
|
||||
extra_kwargs: Extra kwargs passed in by user.
|
||||
values: Values passed in by user.
|
||||
all_required_field_names: All required field names for the pydantic class.
|
||||
"""
|
||||
for field_name in list(values):
|
||||
if field_name in extra_kwargs:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra_kwargs[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,85 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.0.1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
pydantic = ">=1,<3"
|
||||
langsmith = "~0.0.63"
|
||||
tenacity = "^8.1.0"
|
||||
jsonpatch = "^1.33"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
types-pyyaml = "^6.0.12.2"
|
||||
types-requests = "^2.28.11.5"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
setuptools = "^67.6.1"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
# The only dependencies that should be added are
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
dependencies = {}
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = "True"
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue