Separate out langchain_core package (#13577)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
pull/13649/head
Harrison Chase 7 months ago committed by GitHub
parent 4eec47b191
commit d82cbf5e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,6 +7,10 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@ -40,6 +44,14 @@ jobs:
shell: bash
run: poetry install --with=test_integration
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Check integration tests compile
shell: bash
run: poetry run pytest -m compile tests/integration_tests

@ -11,6 +11,10 @@ on:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@ -76,7 +80,15 @@ jobs:
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
pip install -e "$LANGCHAIN_LOCATION"
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Get .mypy_cache to speed up mypy
uses: actions/cache@v3

@ -7,6 +7,14 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-location:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@ -40,6 +48,22 @@ jobs:
shell: bash
run: poetry install
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-location }}
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Install the opposite major version of pydantic
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
shell: bash

@ -7,6 +7,14 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-location:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@ -40,9 +48,26 @@ jobs:
shell: bash
run: poetry install
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-location }}
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Run core tests
shell: bash
run: make test
run: |
make test
- name: Ensure the tests did not create any additional files
shell: bash

@ -36,6 +36,7 @@ jobs:
./.github/workflows/_lint.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
test:
@ -43,6 +44,7 @@ jobs:
./.github/workflows/_test.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
compile-integration-tests:
@ -50,6 +52,7 @@ jobs:
./.github/workflows/_compile_integration_test.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
pydantic-compatibility:
@ -57,6 +60,7 @@ jobs:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
extended-tests:
@ -89,6 +93,11 @@ jobs:
echo "Running extended tests, installing dependencies with poetry..."
poetry install -E extended_testing
- name: Install langchain core editable
shell: bash
run: |
poetry run pip install -e ../core
- name: Run extended tests
run: make extended_tests

@ -0,0 +1,52 @@
---
name: libs/langchain core CI
on:
push:
branches: [ master ]
pull_request:
paths:
- '.github/actions/poetry_setup/action.yml'
- '.github/tools/**'
- '.github/workflows/_lint.yml'
- '.github/workflows/_test.yml'
- '.github/workflows/_pydantic_compatibility.yml'
- '.github/workflows/langchain_core_ci.yml'
- 'libs/core/**'
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
# If another push to the same PR or branch happens while this workflow is still running,
# cancel the earlier run in favor of the next run.
#
# There's no point in testing an outdated version of the code. GitHub only allows
# a limited number of job runners to be active at the same time, so it's better to cancel
# pointless jobs early so that more useful jobs can run sooner.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
POETRY_VERSION: "1.6.1"
WORKDIR: "libs/core"
jobs:
lint:
uses:
./.github/workflows/_lint.yml
with:
working-directory: libs/core
secrets: inherit
test:
uses:
./.github/workflows/_test.yml
with:
working-directory: libs/core
secrets: inherit
pydantic-compatibility:
uses:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: libs/core
secrets: inherit

@ -0,0 +1,13 @@
---
name: libs/core Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/core
secrets: inherit

@ -36,6 +36,7 @@ jobs:
with:
working-directory: libs/experimental
langchain-location: ../langchain
langchain-core-location: ../core
secrets: inherit
test:
@ -43,6 +44,8 @@ jobs:
./.github/workflows/_test.yml
with:
working-directory: libs/experimental
langchain-location: ../langchain
langchain-core-location: ../core
secrets: inherit
compile-integration-tests:
@ -88,6 +91,7 @@ jobs:
echo "Editably installing langchain outside of poetry, to avoid messing up lockfile..."
poetry run pip install -e ../langchain
poetry run pip install -e ../core
- name: Run tests
run: make test

@ -13,8 +13,10 @@ HERE = Path(__file__).parent
PKG_DIR = ROOT_DIR / "libs" / "langchain" / "langchain"
EXP_DIR = ROOT_DIR / "libs" / "experimental" / "langchain_experimental"
CORE_DIR = ROOT_DIR / "libs" / "core" / "langchain_core"
WRITE_FILE = HERE / "api_reference.rst"
EXP_WRITE_FILE = HERE / "experimental_api_reference.rst"
CORE_WRITE_FILE = HERE / "core_api_reference.rst"
ClassKind = Literal["TypedDict", "Regular", "Pydantic", "enum"]
@ -292,6 +294,17 @@ def _document_langchain_experimental() -> None:
def _document_langchain_core() -> None:
"""Document the langchain_core package."""
# Generate core_api_reference.rst
core_members = _load_package_modules(EXP_DIR)
core_doc = ".. _core_api_reference:\n\n" + _construct_doc(
"langchain_core", core_members
)
with open(CORE_WRITE_FILE, "w") as f:
f.write(core_doc)
def _document_langchain() -> None:
"""Document the main langchain package."""
# load top level module members
lc_members = _load_package_modules(PKG_DIR)
@ -306,7 +319,6 @@ def _document_langchain_core() -> None:
"agents.output_parsers": agents["output_parsers"],
"agents.format_scratchpad": agents["format_scratchpad"],
"tools.render": tools["render"],
"schema.runnable": schema["runnable"],
}
)
@ -318,8 +330,9 @@ def _document_langchain_core() -> None:
def main() -> None:
"""Generate the reference.rst file for each package."""
_document_langchain_core()
_document_langchain()
_document_langchain_experimental()
_document_langchain_core()
if __name__ == "__main__":

@ -34,6 +34,9 @@
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('api_reference') }}">API</a>
</li>
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('core_api_reference') }}">Core</a>
</li>
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('experimental_api_reference') }}">Experimental</a>
</li>

@ -0,0 +1,54 @@
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint lint_diff:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'test_watch - run unit tests in watch mode'

@ -0,0 +1 @@
# langchain-core

@ -0,0 +1,7 @@
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""

@ -0,0 +1,26 @@
"""Helper functions for managing the LangChain API.
This module is only relevant for LangChain developers, not for users.
.. warning::
This module and its submodules are for internal use only. Do not use them
in your own code. We may change the API at any time with no warning.
"""
from .deprecation import (
LangChainDeprecationWarning,
deprecated,
suppress_langchain_deprecation_warning,
surface_langchain_deprecation_warnings,
warn_deprecated,
)
__all__ = [
"deprecated",
"LangChainDeprecationWarning",
"suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings",
"warn_deprecated",
]

@ -0,0 +1,341 @@
"""Helper functions for deprecating parts of the LangChain API.
This module was adapted from matplotlibs _api/deprecation.py module:
https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py
.. warning::
This module is for internal use only. Do not use it in your own code.
We may change the API at any time with no warning.
"""
import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar
class LangChainDeprecationWarning(DeprecationWarning):
"""A class for issuing deprecation warnings for LangChain users."""
class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
"""A class for issuing deprecation warnings for LangChain users."""
# PUBLIC API
T = TypeVar("T", Type, Callable)
def deprecated(
since: str,
*,
message: str = "",
name: str = "",
alternative: str = "",
pending: bool = False,
obj_type: str = "",
addendum: str = "",
removal: str = "",
) -> Callable[[T], T]:
"""Decorator to mark a function, a class, or a property as deprecated.
When deprecating a classmethod, a staticmethod, or a property, the
``@deprecated`` decorator should go *under* ``@classmethod`` and
``@staticmethod`` (i.e., `deprecated` should directly decorate the
underlying callable), but *over* ``@property``.
When deprecating a class ``C`` intended to be used as a base class in a
multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method
(if ``C`` instead inherited its ``__init__`` from its own base class, then
``@deprecated`` would mess up ``__init__`` inheritance when installing its
own (deprecation-emitting) ``C.__init__``).
Parameters are the same as for `warn_deprecated`, except that *obj_type*
defaults to 'class' if decorating a class, 'attribute' if decorating a
property, and 'function' otherwise.
Arguments:
since : str
The release at which this API became deprecated.
message : str, optional
Override the default deprecation message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
and %(removal)s format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the deprecated object.
alternative : str, optional
An alternative API that the user may use in place of the
deprecated API. The deprecation warning will tell the user
about this alternative if provided.
pending : bool, optional
If True, uses a PendingDeprecationWarning instead of a
DeprecationWarning. Cannot be used together with removal.
obj_type : str, optional
The object type being deprecated.
addendum : str, optional
Additional text appended directly to the final message.
removal : str, optional
The expected removal version. With the default (an empty
string), a removal version is automatically computed from
since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending.
Examples
--------
.. code-block:: python
@deprecated('1.4.0')
def the_function_to_deprecate():
pass
"""
def deprecate(
obj: T,
*,
_obj_type: str = obj_type,
_name: str = name,
_message: str = message,
_alternative: str = alternative,
_pending: bool = pending,
_addendum: str = addendum,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
if isinstance(obj, type):
if not _obj_type:
_obj_type = "class"
wrapped = obj.__init__ # type: ignore
_name = _name or obj.__name__
old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the deprecation of a class."""
try:
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
wrapper
)
return obj
elif isinstance(obj, property):
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or obj.fget.__name__
old_doc = obj.__doc__
class _deprecated_property(type(obj)): # type: ignore
"""A deprecated property."""
def __get__(self, instance, owner=None): # type: ignore
if instance is not None or owner is not None:
emit_warning()
return super().__get__(instance, owner)
def __set__(self, instance, value): # type: ignore
if instance is not None:
emit_warning()
return super().__set__(instance, value)
def __delete__(self, instance): # type: ignore
if instance is not None:
emit_warning()
return super().__delete__(instance)
def __set_name__(self, owner, set_name): # type: ignore
nonlocal _name
if _name == "<lambda>":
_name = set_name
def finalize(_: Any, new_doc: str) -> Any: # type: ignore
"""Finalize the property."""
return _deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
)
else:
if not _obj_type:
_obj_type = "function"
wrapped = obj
_name = _name or obj.__name__ # type: ignore
old_doc = wrapped.__doc__
def finalize( # type: ignore
wrapper: Callable[..., Any], new_doc: str
) -> T:
"""Wrap the wrapped function using the wrapper and update the docstring.
Args:
wrapper: The wrapper function.
new_doc: The new docstring.
Returns:
The wrapped function.
"""
wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc
return wrapper
def emit_warning() -> None:
"""Emit the warning."""
warn_deprecated(
since,
message=_message,
name=_name,
alternative=_alternative,
pending=_pending,
obj_type=_obj_type,
addendum=_addendum,
removal=removal,
)
def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the original wrapped callable that emits a warning.
Args:
*args: The positional arguments to the function.
**kwargs: The keyword arguments to the function.
Returns:
The return value of the function being wrapped.
"""
emit_warning()
return wrapped(*args, **kwargs)
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
if not old_doc:
new_doc = "[*Deprecated*]"
else:
new_doc = f"[*Deprecated*] {old_doc}"
# Modify the docstring to include a deprecation notice.
notes_header = "\nNotes\n-----"
components = [
message,
f"Use {alternative} instead." if alternative else "",
addendum,
]
details = " ".join([component.strip() for component in components if component])
new_doc += (
f"[*Deprecated*] {old_doc}\n"
f"{notes_header if notes_header not in old_doc else ''}\n"
f".. deprecated:: {since}\n"
f" {details}"
)
return finalize(warning_emitting_wrapper, new_doc)
return deprecate
@contextlib.contextmanager
def suppress_langchain_deprecation_warning() -> Generator[None, None, None]:
"""Context manager to suppress LangChainDeprecationWarning."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", LangChainDeprecationWarning)
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
yield
def warn_deprecated(
since: str,
*,
message: str = "",
name: str = "",
alternative: str = "",
pending: bool = False,
obj_type: str = "",
addendum: str = "",
removal: str = "",
) -> None:
"""Display a standardized deprecation.
Arguments:
since : str
The release at which this API became deprecated.
message : str, optional
Override the default deprecation message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
and %(removal)s format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the deprecated object.
alternative : str, optional
An alternative API that the user may use in place of the
deprecated API. The deprecation warning will tell the user
about this alternative if provided.
pending : bool, optional
If True, uses a PendingDeprecationWarning instead of a
DeprecationWarning. Cannot be used together with removal.
obj_type : str, optional
The object type being deprecated.
addendum : str, optional
Additional text appended directly to the final message.
removal : str, optional
The expected removal version. With the default (an empty
string), a removal version is automatically computed from
since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending.
"""
if pending and removal:
raise ValueError("A pending deprecation cannot have a scheduled removal")
if not pending:
if not removal:
removal = f"in {removal}" if removal else "within ?? minor releases"
raise NotImplementedError(
f"Need to determine which default deprecation schedule to use. "
f"{removal}"
)
else:
removal = f"in {removal}"
if not message:
message = ""
if obj_type:
message += f"The {obj_type} `{name}`"
else:
message += f"`{name}`"
if pending:
message += " will be deprecated in a future version"
else:
message += f" was deprecated in LangChain {since}"
if removal:
message += f" and will be removed {removal}"
if alternative:
message += f". Use {alternative} instead."
if addendum:
message += f" {addendum}"
warning_cls = (
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
)
warning = warning_cls(message)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
def surface_langchain_deprecation_warnings() -> None:
"""Unmute LangChain deprecation warnings."""
warnings.filterwarnings(
"default",
category=LangChainPendingDeprecationWarning,
)
warnings.filterwarnings(
"default",
category=LangChainDeprecationWarning,
)

@ -0,0 +1,36 @@
import os
from pathlib import Path
from typing import Optional, Union
HERE = Path(__file__).parent
# Get directory of langchain package
PACKAGE_DIR = HERE.parent
SEPARATOR = os.sep
def get_relative_path(
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
) -> str:
"""Get the path of the file as a relative path to the package directory."""
if isinstance(file, str):
file = Path(file)
return str(file.relative_to(relative_to))
def as_import_path(
file: Union[Path, str],
*,
suffix: Optional[str] = None,
relative_to: Path = PACKAGE_DIR,
) -> str:
"""Path of the file as a LangChain import exclude langchain top namespace."""
if isinstance(file, str):
file = Path(file)
path = get_relative_path(file, relative_to=relative_to)
if file.is_file():
path = path[: -len(file.suffix)]
import_path = path.replace(SEPARATOR, ".")
if suffix:
import_path += "." + suffix
return import_path

@ -0,0 +1,598 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
from tenacity import RetryCallState
from langchain_core.schema.agent import AgentAction, AgentFinish
from langchain_core.schema.document import Document
from langchain_core.schema.messages import BaseMessage
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
"""Mixin for LLM callbacks."""
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM errors."""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
class ToolManagerMixin:
"""Mixin for tool callbacks."""
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
class CallbackManagerMixin:
"""Mixin for callback manager."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class RunManagerMixin:
"""Mixin for run manager."""
def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
"""Base callback handler that handles callbacks from LangChain."""
raise_error: bool = False
run_inline: bool = False
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return False
@property
def ignore_retry(self) -> bool:
"""Whether to ignore retry callbacks."""
return False
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return False
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return False
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that handles callbacks from LangChain."""
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
async def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
async def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain."""
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,97 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
from langchain_core.utils.input import print_text
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color or self.color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
"""Run when agent ends."""
print_text(text, color=color or self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color or self.color, end="\n")

@ -0,0 +1,67 @@
"""Callback Handler streams to stdout on new llm token."""
import sys
from typing import Any, Dict, List
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.schema import AgentAction, AgentFinish, LLMResult
from langchain_core.schema.messages import BaseMessage
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""

@ -0,0 +1,537 @@
"""Base interfaces for tracing runs."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
from tenacity import RetryCallState
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.load.dump import dumpd
from langchain_core.schema.document import Document
from langchain_core.schema.output import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__)
class TracerException(Exception):
"""Base class for exceptions in tracers module."""
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Run] = {}
@staticmethod
def _add_child_run(
parent_run: Run,
child_run: Run,
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
@abstractmethod
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
if run.parent_run_id:
parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run:
self._add_child_run(parent_run, run)
parent_run.child_execution_order = max(
parent_run.child_execution_order, run.child_execution_order
)
else:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
self.run_map[str(run.id)] = run
self._on_run_create(run)
def _end_trace(self, run: Run) -> None:
"""End a trace for a run."""
if not run.parent_run_id:
self._persist_run(run)
else:
parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run is None:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
elif (
run.child_execution_order is not None
and parent_run.child_execution_order is not None
and run.child_execution_order > parent_run.child_execution_order
):
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(str(run.id))
self._on_run_update(run)
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
"""Get the execution order for a run."""
if parent_run_id is None:
return 1
parent_run = self.run_map.get(parent_run_id)
if parent_run is None:
logger.debug(f"Parent run with UUID {parent_run_id} not found.")
return 1
if parent_run.child_execution_order is None:
raise TracerException(
f"Parent run with UUID {parent_run_id} has no child execution order."
)
return parent_run.child_execution_order + 1
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
llm_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"prompts": prompts},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type="llm",
tags=tags or [],
name=name,
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
return llm_run
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""Run on new LLM token. Only available when streaming is enabled."""
if not run_id:
raise TracerException("No run_id provided for on_llm_new_token callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
{
"name": "new_token",
"time": datetime.utcnow(),
"kwargs": event_kwargs,
},
)
self._on_llm_new_token(llm_run, token, chunk)
return llm_run
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
**kwargs: Any,
) -> Run:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None:
raise TracerException("No Run found to be traced for on_retry")
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
llm_run.events.append(
{
"name": "retry",
"time": datetime.utcnow(),
"kwargs": retry_d,
},
)
return llm_run
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
if "message" in output_generation:
output_generation["message"] = dumpd(
cast(ChatGeneration, generation).message
)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "end", "time": llm_run.end_time})
self._end_trace(llm_run)
self._on_llm_end(llm_run)
return llm_run
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time})
self._end_trace(llm_run)
self._on_chain_error(llm_run)
return llm_run
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a chain run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
chain_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=run_type or "chain",
name=name,
tags=tags or [],
)
self._start_trace(chain_run)
self._on_chain_start(chain_run)
return chain_run
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""End a trace for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run.outputs = (
outputs if isinstance(outputs, dict) else {"output": outputs}
)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_end(chain_run)
return chain_run
def on_chain_error(
self,
error: BaseException,
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_error(chain_run)
return chain_run
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a tool run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
tool_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"input": input_str},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type="tool",
tags=tags or [],
name=name,
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)
return tool_run
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "end", "time": tool_run.end_time})
self._end_trace(tool_run)
self._on_tool_end(tool_run)
return tool_run
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "error", "time": tool_run.end_time})
self._end_trace(tool_run)
self._on_tool_error(tool_run)
return tool_run
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
retrieval_run = Run(
id=run_id,
name=name or "Retriever",
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"query": query},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
tags=tags,
child_runs=[],
run_type="retriever",
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
return retrieval_run
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
return retrieval_run
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
return retrieval_run
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
def __copy__(self) -> BaseTracer:
"""Copy the tracer."""
return self
def _on_run_create(self, run: Run) -> None:
"""Process a run upon creation."""
def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
def _on_llm_start(self, run: Run) -> None:
"""Process the LLM Run upon start."""
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None:
"""Process new LLM token."""
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""

@ -0,0 +1,223 @@
"""A tracer that runs evaluators over completed runs."""
from __future__ import annotations
import logging
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
import langsmith
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
from langchain_core.callbacks import manager
from langchain_core.callbacks.tracers import langchain as langchain_tracer
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.langchain import _get_executor
from langchain_core.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__)
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
def wait_for_all_evaluators() -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
Parameters
----------
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
client : LangSmith Client, optional
The LangSmith client instance to use for evaluating the runs.
If not specified, a new instance will be created.
example_id : Union[UUID, str], optional
The example ID to be associated with the runs.
project_name : str, optional
The LangSmith project name to be organize eval chain runs under.
Attributes
----------
example_id : Union[UUID, None]
The example ID associated with the runs.
client : Client
The LangSmith client instance used for evaluating the runs.
evaluators : Sequence[RunEvaluator]
The sequence of run evaluators to be executed.
executor : ThreadPoolExecutor
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
skip_unfinished : bool
Whether to skip runs that are not finished or raised
an error.
project_name : Optional[str]
The LangSmith project name to be organize eval chain runs under.
"""
name = "evaluator_callback_handler"
def __init__(
self,
evaluators: Sequence[langsmith.RunEvaluator],
client: Optional[langsmith.Client] = None,
example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True,
project_name: Optional[str] = "evaluators",
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators
if max_concurrency is None:
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
elif max_concurrency > 0:
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
weakref.finalize(
self,
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
)
else:
self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
self.lock = threading.Lock()
global _TRACERS
_TRACERS.add(self)
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
"""Evaluate the run in the project.
Parameters
----------
run : Run
The run to be evaluated.
evaluator : RunEvaluator
The evaluator to use for evaluating the run.
"""
try:
if self.project_name is None:
eval_result = self.client.evaluate_run(run, evaluator)
eval_results = [eval_result]
with manager.tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
) as cb:
reference_example = (
self.client.read_example(run.reference_example_id)
if run.reference_example_id
else None
)
evaluation_result = evaluator.evaluate_run(
# This is subclass, but getting errors for some reason
run, # type: ignore
example=reference_example,
)
eval_results = self._log_evaluation_feedback(
evaluation_result,
run,
source_run_id=cb.latest_run.id if cb.latest_run else None,
)
except Exception as e:
logger.error(
f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {repr(e)}",
exc_info=True,
)
raise e
example_id = str(run.reference_example_id)
with self.lock:
for res in eval_results:
run_id = (
str(getattr(res, "target_run_id"))
if hasattr(res, "target_run_id")
else str(run.id)
)
self.logged_eval_results.setdefault((run_id, example_id), []).append(
res
)
def _select_eval_results(
self,
results: Union[EvaluationResult, EvaluationResults],
) -> List[EvaluationResult]:
if isinstance(results, EvaluationResult):
results_ = [results]
elif isinstance(results, dict) and "results" in results:
results_ = cast(List[EvaluationResult], results["results"])
else:
raise TypeError(
f"Invalid evaluation result type {type(results)}."
" Expected EvaluationResult or EvaluationResults."
)
return results_
def _log_evaluation_feedback(
self,
evaluator_response: Union[EvaluationResult, EvaluationResults],
run: Run,
source_run_id: Optional[UUID] = None,
) -> List[EvaluationResult]:
results = self._select_eval_results(evaluator_response)
for res in results:
source_info_: Dict[str, Any] = {}
if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = (
getattr(res, "target_run_id")
if hasattr(res, "target_run_id") and res.target_run_id is not None
else run.id
)
self.client.create_feedback(
run_id_,
res.key,
score=res.score,
value=res.value,
comment=res.comment,
correction=res.correction,
source_info=source_info_,
source_run_id=res.source_run_id or source_run_id,
feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL,
)
return results
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
Parameters
----------
run : Run
The run to be evaluated.
"""
if self.skip_unfinished and not run.outputs:
logger.debug(f"Skipping unfinished run {run.id}")
return
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
if self.executor is None:
self._evaluate_in_project(run_, evaluator)
else:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""
wait(self.futures)

@ -0,0 +1,262 @@
"""A Tracer implementation that records to LangChain endpoint."""
from __future__ import annotations
import logging
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
from uuid import UUID
from langsmith import Client
from langsmith import utils as ls_utils
from tenacity import (
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.env import get_runtime_environment
from langchain_core.load.dump import dumpd
from langchain_core.schema.messages import BaseMessage
logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
_CLIENT: Optional[Client] = None
_EXECUTOR: Optional[ThreadPoolExecutor] = None
def log_error_once(method: str, exception: Exception) -> None:
"""Log an error once."""
global _LOGGED
if (method, type(exception)) in _LOGGED:
return
_LOGGED.add((method, type(exception)))
logger.error(exception)
def wait_for_all_tracers() -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
def get_client() -> Client:
"""Get the client."""
global _CLIENT
if _CLIENT is None:
_CLIENT = Client()
return _CLIENT
def _get_executor() -> ThreadPoolExecutor:
"""Get the executor."""
global _EXECUTOR
if _EXECUTOR is None:
_EXECUTOR = ThreadPoolExecutor()
return _EXECUTOR
def _copy(run: Run) -> Run:
"""Copy a run."""
try:
return run.copy(deep=True)
except TypeError:
# Fallback in case the object contains a lock or other
# non-pickleable object
return run.copy()
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(
self,
example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None,
client: Optional[Client] = None,
tags: Optional[List[str]] = None,
use_threading: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.project_name = project_name or ls_utils.get_tracer_project()
self.client = client or get_client()
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.tags = tags or []
self.executor = _get_executor() if use_threading else None
self.latest_run: Optional[Run] = None
global _TRACERS
_TRACERS.add(self)
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type="llm",
tags=tags,
name=name,
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
def _persist_run(self, run: Run) -> None:
run_ = run.copy()
run_.reference_example_id = self.example_id
self.latest_run = run_
def get_run_url(self) -> str:
"""Get the LangSmith root run URL"""
if not self.latest_run:
raise ValueError("No traced run found.")
# If this is the first run in a project, the project may not yet be created.
# This method is only really useful for debugging flows, so we will assume
# there is some tolerace for latency.
for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(),
retry=retry_if_exception_type(ls_utils.LangSmithError),
):
with attempt:
return self.client.get_run_url(
run=self.latest_run, project_name=self.project_name
)
raise ValueError("Failed to get run URL.")
def _get_tags(self, run: Run) -> List[str]:
"""Get combined tags for a run."""
tags = set(run.tags or [])
tags.update(self.tags or [])
return list(tags)
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
run_dict = run.dict(exclude={"child_runs"})
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
self.client.create_run(**run_dict, project_name=self.project_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
raise
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
try:
run_dict = run.dict()
run_dict["tags"] = self._get_tags(run)
self.client.update_run(run.id, **run_dict)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("patch", e)
raise
def _submit(self, function: Callable[[Run], None], run: Run) -> None:
"""Submit a function to the executor."""
if self.executor is None:
function(run)
else:
self._futures.add(self.executor.submit(function, run))
def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, _copy(run))
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, _copy(run))
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self._submit(self._update_run_single, _copy(run))
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self._submit(self._update_run_single, _copy(run))
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, _copy(run))
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self._submit(self._update_run_single, _copy(run))
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self._submit(self._update_run_single, _copy(run))
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, _copy(run))
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self._submit(self._update_run_single, _copy(run))
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self._submit(self._update_run_single, _copy(run))
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, _copy(run))
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self._submit(self._update_run_single, _copy(run))
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self._submit(self._update_run_single, _copy(run))
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
wait(self._futures)

@ -0,0 +1,185 @@
from __future__ import annotations
import logging
import os
from typing import Any, Dict, Optional, Union
import requests
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionV1,
TracerSessionV1Base,
)
from langchain_core.schema.messages import get_buffer_string
from langchain_core.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)
def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracerV1(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSessionV1] = None
self._endpoint = _get_endpoint()
self._headers = get_headers()
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
session = self.session or self.load_default_session()
if not isinstance(session, TracerSessionV1):
raise ValueError(
"LangChainTracerV1 is not compatible with"
f" session of type {type(session)}"
)
if run.run_type == "llm":
if "prompts" in run.inputs:
prompts = run.inputs["prompts"]
elif "messages" in run.inputs:
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
error=run.error,
extra=run.extra,
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
)
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),
output=None if run.outputs is None else run.outputs.get("output"),
error=run.error,
extra=run.extra,
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
)
raise ValueError(f"Unknown run type: {run.run_type}")
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
v1_run = self._convert_to_v1_run(run)
else:
v1_run = run
if isinstance(v1_run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(v1_run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
response = requests.post(
endpoint,
data=v1_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logger.warning(f"Failed to persist run: {e}")
def _persist_session(
self, session_create: TracerSessionV1Base
) -> Union[TracerSessionV1, TracerSession]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
except Exception as e:
logger.warning(f"Failed to create session, using default session: {e}")
session = TracerSessionV1(id=1, **session_create.dict())
return session
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSessionV1(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logger.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV1(id=1)
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")

@ -0,0 +1,313 @@
from __future__ import annotations
import math
import threading
from collections import defaultdict
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
Sequence,
TypedDict,
Union,
)
from uuid import UUID
import jsonpatch
from anyio import create_memory_object_stream
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.load.load import load
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
class LogEntry(TypedDict):
"""A single entry in the run log."""
id: str
"""ID of the sub-run."""
name: str
"""Name of the object being run."""
type: str
"""Type of the object being run, eg. prompt, chain, llm, etc."""
tags: List[str]
"""List of tags for the run."""
metadata: Dict[str, Any]
"""Key-value pairs of metadata for the run."""
start_time: str
"""ISO-8601 timestamp of when the run started."""
streamed_output_str: List[str]
"""List of LLM tokens streamed by this run, if applicable."""
final_output: Optional[Any]
"""Final output of this run.
Only available after the run has finished successfully."""
end_time: Optional[str]
"""ISO-8601 timestamp of when the run ended.
Only available after the run has finished."""
class RunState(TypedDict):
"""State of the run."""
id: str
"""ID of the run."""
streamed_output: List[Any]
"""List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
Only available after the run has finished successfully."""
logs: Dict[str, LogEntry]
"""Map of run names to sub-runs. If filters were supplied, this list will
contain only the runs that matched the filters."""
class RunLogPatch:
"""A patch to the run log."""
ops: List[Dict[str, Any]]
"""List of jsonpatch operations, which describe how to create the run state
from an empty dict. This is the minimal representation of the log, designed to
be serialized as JSON and sent over the wire to reconstruct the log on the other
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
see https://jsonpatch.com for more information."""
def __init__(self, *ops: Dict[str, Any]) -> None:
self.ops = list(ops)
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(None, ops)
return RunLog(*ops, state=state)
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
def __repr__(self) -> str:
from pprint import pformat
# 1:-1 to get rid of the [] around the list
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
def __eq__(self, other: object) -> bool:
return isinstance(other, RunLogPatch) and self.ops == other.ops
class RunLog(RunLogPatch):
"""A run log."""
state: RunState
"""Current state of the log, obtained from applying all ops in sequence."""
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
super().__init__(*ops)
self.state = state
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(self.state, other.ops)
return RunLog(*ops, state=state)
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
def __repr__(self) -> str:
from pprint import pformat
return f"RunLog({pformat(self.state)})"
class LogStreamCallbackHandler(BaseTracer):
"""A tracer that streams run logs to a stream."""
def __init__(
self,
*,
auto_close: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
) -> None:
super().__init__()
self.auto_close = auto_close
self.include_names = include_names
self.include_types = include_types
self.include_tags = include_tags
self.exclude_names = exclude_names
self.exclude_types = exclude_types
self.exclude_tags = exclude_tags
send_stream: Any
receive_stream: Any
send_stream, receive_stream = create_memory_object_stream(
math.inf, item_type=RunLogPatch
)
self.lock = threading.Lock()
self.send_stream = send_stream
self.receive_stream = receive_stream
self._key_map_by_run_id: Dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
self.root_id: Optional[UUID] = None
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
def include_run(self, run: Run) -> bool:
if run.id == self.root_id:
return False
run_tags = run.tags or []
if (
self.include_names is None
and self.include_types is None
and self.include_tags is None
):
include = True
else:
include = False
if self.include_names is not None:
include = include or run.name in self.include_names
if self.include_types is not None:
include = include or run.run_type in self.include_types
if self.include_tags is not None:
include = include or any(tag in self.include_tags for tag in run_tags)
if self.exclude_names is not None:
include = include and run.name not in self.exclude_names
if self.exclude_types is not None:
include = include and run.run_type not in self.exclude_types
if self.exclude_tags is not None:
include = include and all(tag not in self.exclude_tags for tag in run_tags)
return include
def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
def _on_run_create(self, run: Run) -> None:
"""Start a run."""
if self.root_id is None:
self.root_id = run.id
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "",
"value": RunState(
id=str(run.id),
streamed_output=[],
final_output=None,
logs={},
),
}
)
)
if not self.include_run(run):
return
# Determine previous index, increment by 1
with self.lock:
self._counter_map_by_name[run.name] += 1
count = self._counter_map_by_name[run.name]
self._key_map_by_run_id[run.id] = (
run.name if count == 1 else f"{run.name}:{count}"
)
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": LogEntry(
id=str(run.id),
name=run.name,
type=run.run_type,
tags=run.tags or [],
metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output_str=[],
final_output=None,
end_time=None,
),
}
)
)
def _on_run_update(self, run: Run) -> None:
"""Finish a run."""
try:
index = self._key_map_by_run_id.get(run.id)
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/final_output",
# to undo the dumpd done by some runnables / tracer / etc
"value": load(run.outputs),
},
{
"op": "add",
"path": f"/logs/{index}/end_time",
"value": run.end_time.isoformat(timespec="milliseconds")
if run.end_time is not None
else None,
},
)
)
finally:
if run.id == self.root_id:
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "/final_output",
"value": load(run.outputs),
}
)
)
if self.auto_close:
self.send_stream.close()
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None:
"""Process new LLM token."""
index = self._key_map_by_run_id.get(run.id)
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
}
)
)

@ -0,0 +1,54 @@
from typing import Callable, Optional, Union
from uuid import UUID
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.runnables.config import (
RunnableConfig,
call_func_with_variable_args,
)
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
class RootListenersTracer(BaseTracer):
def __init__(
self,
*,
config: RunnableConfig,
on_start: Optional[Listener],
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
self.config = config
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
self.root_id: Optional[UUID] = None
def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
def _on_run_create(self, run: Run) -> None:
if self.root_id is not None:
return
self.root_id = run.id
if self._arg_on_start is not None:
call_func_with_variable_args(self._arg_on_start, run, self.config)
def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
return
if run.error is None:
if self._arg_on_end is not None:
call_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
call_func_with_variable_args(self._arg_on_error, run, self.config)

@ -0,0 +1,52 @@
"""A tracer that collects all nested runs in a list."""
from typing import Any, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
class RunCollectorCallbackHandler(BaseTracer):
"""
A tracer that collects all nested runs in a list.
This tracer is useful for inspection and evaluation purposes.
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
name: str = "run-collector_callback_handler"
def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
) -> None:
"""
Initialize the RunCollectorCallbackHandler.
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.traced_runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
"""
Persist a run by adding it to the traced_runs list.
Parameters
----------
run : Run
The run to be persisted.
"""
run_ = run.copy()
run_.reference_example_id = self.example_id
self.traced_runs.append(run_)

@ -0,0 +1,140 @@
"""Schemas for tracers."""
from __future__ import annotations
import datetime
import warnings
from typing import Any, Dict, List, Optional, Type
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.schema import LLMResult
def RunTypeEnum() -> Type[RunTypeEnumDep]:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
class TracerSessionV1Create(TracerSessionV1Base):
"""Create class for TracerSessionV1."""
class TracerSessionV1(TracerSessionV1Base):
"""TracerSessionV1 schema."""
id: int
class TracerSessionBase(TracerSessionV1Base):
"""Base class for TracerSession."""
tenant_id: UUID
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
id: UUID
class BaseRun(BaseModel):
"""Base class for Run."""
uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
execution_order: int
child_execution_order: int
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
class ChainRun(BaseRun):
"""Class for ChainRun."""
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = None
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
class ToolRun(BaseRun):
"""Class for ToolRun."""
tool_input: str
output: Optional[str] = None
action: str
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
# Begin V2 API Schemas
class Run(BaseRunV2):
"""Run schema for the V2 API in the Tracer."""
execution_order: int
child_execution_order: int
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
events: List[Dict[str, Any]] = Field(default_factory=list)
@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:
"""Assign name to the run."""
if values.get("name") is None:
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
if values.get("events") is None:
values["events"] = []
return values
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
__all__ = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]

@ -0,0 +1,178 @@
import json
from typing import Any, Callable, List
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.utils.input import get_bolded_text, get_colored_text
def try_json_stringify(obj: Any, fallback: str) -> str:
"""
Try to stringify an object to JSON.
Args:
obj: Object to stringify.
fallback: Fallback string to return if the object cannot be stringified.
Returns:
A JSON string if the object can be stringified, otherwise the fallback string.
"""
try:
return json.dumps(obj, indent=2, ensure_ascii=False)
except Exception:
return fallback
def elapsed(run: Any) -> str:
"""Get the elapsed time of a run.
Args:
run: any object with a start_time and end_time attribute.
Returns:
A string with the elapsed time in seconds or
milliseconds if time is less than a second.
"""
elapsed_time = run.end_time - run.start_time
milliseconds = elapsed_time.total_seconds() * 1000
if milliseconds < 1000:
return f"{milliseconds:.0f}ms"
return f"{(milliseconds / 1000):.2f}s"
class FunctionCallbackHandler(BaseTracer):
"""Tracer that calls a function with a single str parameter."""
name: str = "function_callback_handler"
def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None:
super().__init__(**kwargs)
self.function_callback = function
def _persist_run(self, run: Run) -> None:
pass
def get_parents(self, run: Run) -> List[Run]:
parents = []
current_run = run
while current_run.parent_run_id:
parent = self.run_map.get(str(current_run.parent_run_id))
if parent:
parents.append(parent)
current_run = parent
else:
break
return parents
def get_breadcrumbs(self, run: Run) -> str:
parents = self.get_parents(run)[::-1]
string = " > ".join(
f"{parent.execution_order}:{parent.run_type}:{parent.name}"
if i != len(parents) - 1
else f"{parent.execution_order}:{parent.run_type}:{parent.name}"
for i, parent in enumerate(parents + [run])
)
return string
# logging methods
def _on_chain_start(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
run_type = run.run_type.capitalize()
self.function_callback(
f"{get_colored_text('[chain/start]', color='green')} "
+ get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n")
+ f"{try_json_stringify(run.inputs, '[inputs]')}"
)
def _on_chain_end(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
run_type = run.run_type.capitalize()
self.function_callback(
f"{get_colored_text('[chain/end]', color='blue')} "
+ get_bolded_text(
f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n"
)
+ f"{try_json_stringify(run.outputs, '[outputs]')}"
)
def _on_chain_error(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
run_type = run.run_type.capitalize()
self.function_callback(
f"{get_colored_text('[chain/error]', color='red')} "
+ get_bolded_text(
f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n"
)
+ f"{try_json_stringify(run.error, '[error]')}"
)
def _on_llm_start(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
inputs = (
{"prompts": [p.strip() for p in run.inputs["prompts"]]}
if "prompts" in run.inputs
else run.inputs
)
self.function_callback(
f"{get_colored_text('[llm/start]', color='green')} "
+ get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n")
+ f"{try_json_stringify(inputs, '[inputs]')}"
)
def _on_llm_end(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
self.function_callback(
f"{get_colored_text('[llm/end]', color='blue')} "
+ get_bolded_text(
f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n"
)
+ f"{try_json_stringify(run.outputs, '[response]')}"
)
def _on_llm_error(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
self.function_callback(
f"{get_colored_text('[llm/error]', color='red')} "
+ get_bolded_text(
f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n"
)
+ f"{try_json_stringify(run.error, '[error]')}"
)
def _on_tool_start(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
self.function_callback(
f'{get_colored_text("[tool/start]", color="green")} '
+ get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n")
+ f'"{run.inputs["input"].strip()}"'
)
def _on_tool_end(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
if run.outputs:
self.function_callback(
f'{get_colored_text("[tool/end]", color="blue")} '
+ get_bolded_text(
f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n"
)
+ f'"{run.outputs["output"].strip()}"'
)
def _on_tool_error(self, run: Run) -> None:
crumbs = self.get_breadcrumbs(run)
self.function_callback(
f"{get_colored_text('[tool/error]', color='red')} "
+ get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ")
+ f"Tool run errored with error:\n"
f"{run.error}"
)
class ConsoleCallbackHandler(FunctionCallbackHandler):
"""Tracer that prints to the console."""
name: str = "console_callback_handler"
def __init__(self, **kwargs: Any) -> None:
super().__init__(function=print, **kwargs)

@ -0,0 +1,735 @@
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
cast,
)
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.load.dump import dumpd, dumps
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig
from langchain_core.schema import (
ChatGeneration,
ChatResult,
LLMResult,
PromptValue,
RunInfo,
)
from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain_core.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.schema.output import ChatGenerationChunk
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
async def _agenerate_from_stream(
stream: AsyncIterator[ChatGenerationChunk],
) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
# --- Runnable methods ---
@property
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
return cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
params["stop"] = stop
return {**params, **kwargs}
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = []
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)
except BaseException as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = await callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
for i, m in enumerate(messages)
],
return_exceptions=True,
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
generation = self.generate(
[messages], stop=stop, callbacks=callbacks, **kwargs
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
async def _call_async(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop, **kwargs)
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel):
"""Simple Chat Model."""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Simpler interface."""
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -0,0 +1,17 @@
import platform
from functools import lru_cache
@lru_cache(maxsize=1)
def get_runtime_environment() -> dict:
"""Get information about the LangChain runtime environment."""
# Lazy import to avoid circular imports
from langchain_core import __version__
return {
"library_version": __version__,
"library": "langchain",
"platform": platform.platform(),
"runtime": "python",
"runtime_version": platform.python_version(),
}

@ -0,0 +1,197 @@
# flake8: noqa
"""Global values and configuration that apply to all of LangChain."""
import warnings
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from langchain_core.schema import BaseCache
# DO NOT USE THESE VALUES DIRECTLY!
# Use them only via `get_<X>()` and `set_<X>()` below,
# or else your code may behave unexpectedly with other uses of these global settings:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
_verbose: bool = False
_debug: bool = False
_llm_cache: Optional["BaseCache"] = None
def set_verbose(value: bool) -> None:
"""Set a new value for the `verbose` global setting."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing verbose from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.verbose` is no longer supported, and once all users
# have migrated to using `set_verbose()` here.
langchain.verbose = value
except ImportError:
pass
global _verbose
_verbose = value
def get_verbose() -> bool:
"""Get the value of the `verbose` global setting."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing verbose from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.verbose` is no longer supported, and once all users
# have migrated to using `set_verbose()` here.
#
# In the meantime, the `verbose` setting is considered True if either the old
# or the new value are True. This accommodates users who haven't migrated
# to using `set_verbose()` yet. Those users are getting deprecation warnings
# directing them to use `set_verbose()` when they import `langhchain.verbose`.
old_verbose = langchain.verbose
except ImportError:
old_verbose = False
global _verbose
return _verbose or old_verbose
def set_debug(value: bool) -> None:
"""Set a new value for the `debug` global setting."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Importing debug from langchain_core root module is no longer supported",
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.debug` is no longer supported, and once all users
# have migrated to using `set_debug()` here.
langchain.debug = value
except ImportError:
pass
global _debug
_debug = value
def get_debug() -> bool:
"""Get the value of the `debug` global setting."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Importing debug from langchain_core root module is no longer supported",
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.debug` is no longer supported, and once all users
# have migrated to using `set_debug()` here.
#
# In the meantime, the `debug` setting is considered True if either the old
# or the new value are True. This accommodates users who haven't migrated
# to using `set_debug()` yet. Those users are getting deprecation warnings
# directing them to use `set_debug()` when they import `langhchain.debug`.
old_debug = langchain.debug
except ImportError:
old_debug = False
global _debug
return _debug or old_debug
def set_llm_cache(value: Optional["BaseCache"]) -> None:
"""Set a new LLM cache, overwriting the previous value, if any."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing llm_cache from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.llm_cache` is no longer supported, and
# once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value
except ImportError:
pass
global _llm_cache
_llm_cache = value
def get_llm_cache() -> "BaseCache":
"""Get the value of the `llm_cache` global setting."""
try:
import langchain
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing llm_cache from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.llm_cache` is no longer supported, and
# once all users have migrated to using `set_llm_cache()` here.
#
# In the meantime, the `llm_cache` setting returns whichever of
# its two backing sources is truthy (not `None` and non-empty),
# or the old value if both are falsy. This accommodates users
# who haven't migrated to using `set_llm_cache()` yet.
# Those users are getting deprecation warnings directing them
# to use `set_llm_cache()` when they import `langhchain.llm_cache`.
old_llm_cache = langchain.llm_cache
except ImportError:
old_llm_cache = None
global _llm_cache
return _llm_cache or old_llm_cache

File diff suppressed because it is too large Load Diff

@ -0,0 +1,6 @@
"""Serialization and deserialization."""
from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.load import load, loads
from langchain_core.load.serializable import Serializable
__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"]

@ -0,0 +1,26 @@
import json
from typing import Any, Dict
from langchain_core.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
"""Return a default value for a Serializable object or
a SerializedNotImplemented object."""
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
def dumps(obj: Any, *, pretty: bool = False) -> str:
"""Return a json string representation of an object."""
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
def dumpd(obj: Any) -> Dict[str, Any]:
"""Return a json dict representation of an object."""
return json.loads(dumps(obj))

@ -0,0 +1,130 @@
import importlib
import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.load.serializable import Serializable
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
class Reviver:
"""Reviver for JSON objects."""
def __init__(
self,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> None:
self.secrets_map = secrets_map or dict()
# By default only support langchain, but user can pass in additional namespaces
self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces
else DEFAULT_NAMESPACES
)
def __call__(self, value: Dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
):
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
if key in os.environ and os.environ[key]:
return os.environ[key]
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
[*namespace, name] = value["id"]
if namespace[0] not in self.valid_namespaces:
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
raise ValueError(f"Invalid namespace: {value}")
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
return cls(**kwargs)
return value
def loads(
text: str,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`.
Args:
text: The string to load.
secrets_map: A map of secrets to load.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
def load(
obj: Any,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
Args:
obj: The object to load.
secrets_map: A map of secrets to load.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
reviver = Reviver(secrets_map, valid_namespaces)
def _load(obj: Any) -> Any:
if isinstance(obj, dict):
# Need to revive leaf nodes before reviving this node
loaded_obj = {k: _load(v) for k, v in obj.items()}
return reviver(loaded_obj)
if isinstance(obj, list):
return [_load(o) for o in obj]
return obj
return _load(obj)

@ -0,0 +1,207 @@
from abc import ABC
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
class BaseSerialized(TypedDict):
"""Base class for serialized objects."""
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
"""Serialized constructor."""
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
"""Serialized secret."""
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
"""Serialized not implemented."""
type: Literal["not_implemented"]
repr: Optional[str]
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
try:
return model.__fields__[key].get_default() != value
except Exception:
return True
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Is this class serializable?"""
return False
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the
namespace is ["langchain", "llms", "openai"]
"""
return cls.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""A map of constructor argument names to secret ids.
For example,
{"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor.
"""
return {}
@classmethod
def lc_id(cls) -> List[str]:
"""A unique identifier for this class for serialization purposes.
The unique identifier is a list of strings that describes the path
to the object.
"""
return [*cls.get_lc_namespace(), cls.__name__]
class Config:
extra = "ignore"
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or try_neq_default(v, k, self))
]
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.is_lc_serializable():
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self._lc_kwargs.items()
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
if cls:
deprecated_attributes = [
"lc_namespace",
"lc_serializable",
]
for attr in deprecated_attributes:
if hasattr(cls, attr):
raise ValueError(
f"Class {self.__class__} has a deprecated "
f"attribute {attr}. Please use the corresponding "
f"classmethod instead."
)
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
return {
"lc": 1,
"type": "constructor",
"id": self.lc_id(),
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"""Serialize a "not implemented" object.
Args:
obj: object to serialize
Returns:
SerializedNotImplemented
"""
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
result: SerializedNotImplemented = {
"lc": 1,
"type": "not_implemented",
"id": _id,
"repr": None,
}
try:
result["repr"] = repr(obj)
except Exception:
pass
return result

@ -0,0 +1,79 @@
from __future__ import annotations
import re
from abc import abstractmethod
from typing import List
from langchain_core.schema import BaseOutputParser
class ListOutputParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a list."""
@property
def _type(self) -> str:
return "list"
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
@property
def _type(self) -> str:
return "comma-separated-list"
class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list."""
def get_format_instructions(self) -> str:
return (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"\d+\.\s([^\n]+)"
# Extract the text of each item
matches = re.findall(pattern, text)
return matches
@property
def _type(self) -> str:
return "numbered-list"
class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list."""
def get_format_instructions(self) -> str:
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"-\s([^\n]+)"
return re.findall(pattern, text)
@property
def _type(self) -> str:
return "markdown-list"

@ -0,0 +1,75 @@
"""**Prompt** is the input to the model.
Prompt is often constructed
from multiple components. Prompt classes and functions make constructing
and working with prompts easy.
**Class hierarchy:**
.. code-block::
BasePromptTemplate --> PipelinePromptTemplate
StringPromptTemplate --> PromptTemplate
FewShotPromptTemplate
FewShotPromptWithTemplates
BaseChatPromptTemplate --> AutoGPTPrompt
ChatPromptTemplate --> AgentScratchPadChatPromptTemplate
BaseMessagePromptTemplate --> MessagesPlaceholder
BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate
HumanMessagePromptTemplate
AIMessagePromptTemplate
SystemMessagePromptTemplate
PromptValue --> StringPromptValue
ChatPromptValue
""" # noqa: E501
from langchain_core.prompts.base import StringPromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.prompts.example_selector import (
LengthBasedExampleSelector,
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import Prompt, PromptTemplate
from langchain_core.schema.prompt_template import BasePromptTemplate
__all__ = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]

@ -0,0 +1,173 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
from langchain_core.schema.messages import BaseMessage, HumanMessage
from langchain_core.schema.prompt import PromptValue
from langchain_core.schema.prompt_template import BasePromptTemplate
from langchain_core.utils.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
SandboxedEnvironment by default. However, this sand-boxing should
be treated as a best-effort approach rather than a guarantee of security.
Do not accept jinja2 templates from untrusted sources as they may lead
to arbitrary Python code execution.
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
)
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
# Please treat this sand-boxing as a best-effort approach rather than
# a guarantee of security.
# We recommend to never use jinja2 templates with untrusted inputs.
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
# approach not a guarantee of security.
return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
Args:
template: The template string.
input_variables: The input variables.
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
warning_message = ""
if missing_variables:
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
input_variables: The input variables.
Raises:
ValueError: If the template format is not supported.
"""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
def get_template_variables(template: str, template_format: str) -> List[str]:
"""Get the variables from the template.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
Returns:
The variables from the template.
Raises:
ValueError: If the template format is not supported.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))

@ -0,0 +1,748 @@
"""Chat prompt template."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from langchain_core._api import deprecated
from langchain_core.load.serializable import Serializable
from langchain_core.prompts.base import StringPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.schema import (
BasePromptTemplate,
PromptValue,
)
from langchain_core.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
variable_name: str
"""Name of variable to use as messages."""
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessage.
"""
value = kwargs[self.variable_name]
if not isinstance(value, list):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
@property
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variable names.
"""
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
"""Type variable for message prompt templates."""
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Base class for message prompt templates that use a string prompt template."""
prompt: StringPromptTemplate
"""String prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_template(template, template_format=template_format)
return cls(prompt=prompt, **kwargs)
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs)
@abstractmethod
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
return self.prompt.input_variables
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
role: str
"""Role of the message."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class ChatPromptValue(PromptValue):
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: Sequence[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return list(self.messages)
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {"input_variables": self.input_variables}
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
str,
]
class ChatPromptTemplate(BaseChatPromptTemplate):
"""A prompt template for chat models.
Use to create flexible templated prompts for chat models.
Examples:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI bot. Your name is {name}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{user_input}"),
])
messages = template.format_messages(
name="Bob",
user_input="What is your name?"
)
"""
input_variables: List[str]
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""Validate input variables.
If input_variables is not set, it will be set to the union of
all input variables in the messages.
Args:
values: values to validate.
Returns:
Validated values.
"""
messages = values["messages"]
input_vars = set()
input_types: Dict[str, Any] = values.get("input_types", {})
for message in messages:
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables)
if isinstance(message, MessagesPlaceholder):
if message.variable_name not in input_types:
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
f"Expected: {input_vars}. "
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = sorted(input_vars)
values["input_types"] = input_types
return values
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a chat prompt template from a template string.
Creates a chat template consisting of a single message assumed to be from
the human.
Args:
template: template string
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
a chat prompt template
"""
return cls(
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages
]
)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
Args:
string_messages: list of (role class, template) tuples.
Returns:
a chat prompt template
"""
return cls.from_messages(string_messages)
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
("human", "That's good to hear."),
])
Instantiation from mixed message formats:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
SystemMessage(content="hello"),
("human", "Hello, how are you?"),
])
Args:
messages: sequence of message representations.
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}"
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
for _message in _messages:
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
list of formatted messages
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
"""Get a new ChatPromptTemplate with some input variables already filled in.
Args:
**kwargs: keyword arguments to use for filling in template variables. Ought
to be a subset of the input variables.
Returns:
A new ChatPromptTemplate.
Example:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"),
("ai", "Hi there, {user}, I'm {name}."),
("human", "{input}"),
]
)
template2 = template.partial(user="Lucy", name="R2D2")
template2.format_messages(input="hello")
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def append(self, message: MessageLikeRepresentation) -> None:
"""Append message to the end of the chat template.
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> ChatPromptTemplate:
...
def __getitem__(
self, index: Union[int, slice]
) -> Union[MessageLike, ChatPromptTemplate]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return ChatPromptTemplate.from_messages(messages)
else:
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
"""Save prompt to file.
Args:
file_path: path to file.
"""
raise NotImplementedError()
def _create_template_from_message_type(
message_type: str, template: str
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
_message: Union[
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
] = message
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message

@ -0,0 +1,14 @@
"""Logic for selecting examples to include in prompts."""
from langchain_core.prompts.example_selector.length_based import (
LengthBasedExampleSelector,
)
from langchain_core.prompts.example_selector.semantic_similarity import (
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
)
__all__ = [
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"SemanticSimilarityExampleSelector",
]

@ -0,0 +1,15 @@
"""Interface for selecting examples to include in prompts."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
class BaseExampleSelector(ABC):
"""Interface for selecting examples to include in prompts."""
@abstractmethod
def add_example(self, example: Dict[str, str]) -> Any:
"""Add new example to store for a key."""
@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the inputs."""

@ -0,0 +1,63 @@
"""Select examples based on length."""
import re
from typing import Callable, Dict, List
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator
def _get_length_based(text: str) -> int:
return len(re.split("\n| ", text))
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
"""Select examples based on length."""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
get_text_length: Callable[[str], int] = _get_length_based
"""Function to measure prompt length. Defaults to word count."""
max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut."""
example_text_lengths: List[int] = [] #: :meta private:
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
string_example = self.example_prompt.format(**example)
self.example_text_lengths.append(self.get_text_length(string_example))
@validator("example_text_lengths", always=True)
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
"""Calculate text lengths if they don't exist."""
# Check if text lengths were passed in
if v:
return v
# If they were not, calculate them
example_prompt = values["example_prompt"]
get_text_length = values["get_text_length"]
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
return [get_text_length(eg) for eg in string_examples]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the input lengths."""
inputs = " ".join(input_variables.values())
remaining_length = self.max_length - self.get_text_length(inputs)
i = 0
examples = []
while remaining_length > 0 and i < len(self.examples):
new_length = remaining_length - self.example_text_lengths[i]
if new_length < 0:
break
else:
examples.append(self.examples[i])
remaining_length = new_length
i += 1
return examples

@ -0,0 +1,165 @@
"""Example selector that selects examples based on SemanticSimilarity."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.schema.embeddings import Embeddings
from langchain_core.schema.vectorstore import VectorStore
def sorted_values(values: Dict[str, str]) -> List[Any]:
"""Return a list of values in dict sorted by key."""
return [values[val] for val in sorted(values)]
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
"""Example selector that selects examples based on SemanticSimilarity."""
vectorstore: VectorStore
"""VectorStore than contains information about examples."""
k: int = 4
"""Number of examples to select."""
example_keys: Optional[List[str]] = None
"""Optional keys to filter examples to."""
input_keys: Optional[List[str]] = None
"""Optional keys to filter input to. If provided, the search is based on
the input variables instead of all variables."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def add_example(self, example: Dict[str, str]) -> str:
"""Add new example to vectorstore."""
if self.input_keys:
string_example = " ".join(
sorted_values({key: example[key] for key in self.input_keys})
)
else:
string_example = " ".join(sorted_values(example))
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
return ids[0]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.similarity_search(query, k=self.k)
# Get the examples from the metadata.
# This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in example_docs]
# If example keys are provided, filter examples to those keys.
if self.example_keys:
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples
@classmethod
def from_examples(
cls,
examples: List[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings.
Reshuffles examples dynamically based on query similarity.
Args:
examples: List of examples to use in the prompt.
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
"""ExampleSelector that selects examples based on Max Marginal Relevance.
This was shown to improve performance in this paper:
https://arxiv.org/pdf/2211.13892.pdf
"""
fetch_k: int = 20
"""Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.max_marginal_relevance_search(
query, k=self.k, fetch_k=self.fetch_k
)
# Get the examples from the metadata.
# This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in example_docs]
# If example keys are provided, filter examples to those keys.
if self.example_keys:
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples
@classmethod
def from_examples(
cls,
examples: List[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
fetch_k: int = 20,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:
"""Create k-shot example selector using example list and embeddings.
Reshuffles examples dynamically based on query similarity.
Args:
examples: List of examples to use in the prompt.
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)

@ -0,0 +1,343 @@
"""Prompt template that contains few shot examples."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
)
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.schema.messages import BaseMessage, get_buffer_string
class _FewShotPromptTemplateMixin(BaseModel):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)
if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
return values
def _get_examples(self, **kwargs: Any) -> List[dict]:
"""Get the examples to use for formatting the prompt.
Args:
**kwargs: Keyword arguments to be passed to the example selector.
Returns:
List of examples.
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
validate_template: bool = False
"""Whether or not to try validating the template."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: str
"""A prompt template string to put after the examples."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
**kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall template.
pieces = [self.prefix, *example_strings, self.suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot"
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().save(file_path)
class FewShotChatMessagePromptTemplate(
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
):
"""Chat prompt template that supports few-shot examples.
The high level structure of produced by this prompt template is a list of messages
consisting of prefix message(s), example message(s), and suffix message(s).
This structure enables creating a conversation with intermediate examples like:
System: You are a helpful AI Assistant
Human: What is 2+2?
AI: 4
Human: What is 2+3?
AI: 5
Human: What is 4+4?
This prompt template can be used to generate a fixed list of examples or else
to dynamically select examples based on the input.
Examples:
Prompt template with a fixed list of examples (matching the sample
conversation above):
.. code-block:: python
from langchain_core.prompts import (
FewShotChatMessagePromptTemplate,
ChatPromptTemplate
)
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_prompt = ChatPromptTemplate.from_messages(
[('human', '{input}'), ('ai', '{output}')]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
examples=examples,
# This is a prompt template used to format each individual example.
example_prompt=example_prompt,
)
final_prompt = ChatPromptTemplate.from_messages(
[
('system', 'You are a helpful AI Assistant'),
few_shot_prompt,
('human', '{input}'),
]
)
final_prompt.format(input="What is 4+4?")
Prompt template with dynamically selected examples:
.. code-block:: python
from langchain_core.prompts import SemanticSimilarityExampleSelector
from langchain_core.embeddings import OpenAIEmbeddings
from langchain_core.vectorstores import Chroma
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
{"input": "2+4", "output": "6"},
# ...
]
to_vectorize = [
" ".join(example.values())
for example in examples
]
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
to_vectorize, embeddings, metadatas=examples
)
example_selector = SemanticSimilarityExampleSelector(
vectorstore=vectorstore
)
from langchain_core.schema import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
few_shot_prompt = FewShotChatMessagePromptTemplate(
# Which variable(s) will be passed to the example selector.
input_variables=["input"],
example_selector=example_selector,
# Define how each example will be formatted.
# In this case, each example will become 2 messages:
# 1 human, and 1 AI
example_prompt=(
HumanMessagePromptTemplate.from_template("{input}")
+ AIMessagePromptTemplate.from_template("{output}")
),
)
# Define the overall prompt.
final_prompt = (
SystemMessagePromptTemplate.from_template(
"You are a helpful AI Assistant"
)
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
# Show the prompt
print(final_prompt.format_messages(input="What's 3+3?"))
# Use within an LLM
from langchain_core.chat_models import ChatAnthropic
chain = final_prompt | ChatAnthropic()
chain.invoke({"input": "What's 3+3?"})
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
input_variables: List[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided."""
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for filling in templates in messages.
Returns:
A list of formatted messages with all template variables filled in.
"""
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
return messages
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
Use this method to generate a string representation of a prompt consisting
of chat messages.
Useful for feeding into a string based completion language model or debugging.
Args:
**kwargs: keyword arguments to use for formatting.
Returns:
A string representation of the prompt
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)

@ -0,0 +1,153 @@
"""Prompt template that contains few shot examples."""
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain_core.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, root_validator
class FewShotPromptWithTemplates(StringPromptTemplate):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: StringPromptTemplate
"""A PromptTemplate to put after the examples."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: Optional[StringPromptTemplate] = None
"""A PromptTemplate to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)
if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
return values
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)
expected_input_variables |= set(values["partial_variables"])
if values["prefix"] is not None:
expected_input_variables |= set(values["prefix"].input_variables)
missing_vars = expected_input_variables.difference(input_variables)
if missing_vars:
raise ValueError(
f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = self._get_examples(**kwargs)
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall prefix.
if self.prefix is None:
prefix = ""
else:
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
kwargs.pop(k)
prefix = self.prefix.format(**prefix_kwargs)
# Create the overall suffix
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
kwargs.pop(k)
suffix = self.suffix.format(
**suffix_kwargs,
)
pieces = [prefix, *example_strings, suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot_with_templates"
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().save(file_path)

@ -0,0 +1,162 @@
"""Load prompts."""
import json
import logging
from pathlib import Path
from typing import Callable, Dict, Union
import yaml
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.schema import (
BasePromptTemplate,
StrOutputParser,
)
from langchain_core.utils.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__name__)
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
"""Load prompt from Config Dict."""
if "_type" not in config:
logger.warning("No `_type` key found, defaulting to `prompt`.")
config_type = config.pop("_type", "prompt")
if config_type not in type_to_loader_dict:
raise ValueError(f"Loading {config_type} prompt not supported")
prompt_loader = type_to_loader_dict[config_type]
return prompt_loader(config)
def _load_template(var_name: str, config: dict) -> dict:
"""Load template from the path if applicable."""
# Check if template_path exists in config.
if f"{var_name}_path" in config:
# If it does, make sure template variable doesn't also exist.
if var_name in config:
raise ValueError(
f"Both `{var_name}_path` and `{var_name}` cannot be provided."
)
# Pop the template path from the config.
template_path = Path(config.pop(f"{var_name}_path"))
# Load the template.
if template_path.suffix == ".txt":
with open(template_path) as f:
template = f.read()
else:
raise ValueError
# Set the template variable to the extracted variable.
config[var_name] = template
return config
def _load_examples(config: dict) -> dict:
"""Load examples if necessary."""
if isinstance(config["examples"], list):
pass
elif isinstance(config["examples"], str):
with open(config["examples"]) as f:
if config["examples"].endswith(".json"):
examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")):
examples = yaml.safe_load(f)
else:
raise ValueError(
"Invalid file format. Only json or yaml formats are supported."
)
config["examples"] = examples
else:
raise ValueError("Invalid examples format. Only list or string are supported.")
return config
def _load_output_parser(config: dict) -> dict:
"""Load output parser."""
if "output_parser" in config and config["output_parser"]:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "default":
output_parser = StrOutputParser(**_config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser
return config
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
"""Load the "few shot" prompt from the config."""
# Load the suffix and prefix templates.
config = _load_template("suffix", config)
config = _load_template("prefix", config)
# Load the example prompt.
if "example_prompt_path" in config:
if "example_prompt" in config:
raise ValueError(
"Only one of example_prompt and example_prompt_path should "
"be specified."
)
config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
else:
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
# Load the examples.
config = _load_examples(config)
config = _load_output_parser(config)
return FewShotPromptTemplate(**config)
def _load_prompt(config: dict) -> PromptTemplate:
"""Load the prompt template from config."""
# Load the template from disk if necessary.
config = _load_template("template", config)
config = _load_output_parser(config)
template_format = config.get("template_format", "f-string")
if template_format == "jinja2":
# Disabled due to:
# https://github.com/langchain-ai/langchain/issues/4394
raise ValueError(
f"Loading templates with '{template_format}' format is no longer supported "
f"since it can lead to arbitrary code execution. Please migrate to using "
f"the 'f-string' template format, which does not suffer from this issue."
)
return PromptTemplate(**config)
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from LangChainHub or local fs."""
if hub_result := try_load_from_hub(
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
):
return hub_result
else:
return _load_prompt_from_file(path)
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
"""Load prompt from file."""
# Convert file to a Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
config = yaml.safe_load(f)
else:
raise ValueError(f"Got unsupported file type {file_path.suffix}")
# Load the prompt from the config now.
return load_prompt_from_config(config)
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
"prompt": _load_prompt,
"few_shot": _load_few_shot_prompt,
}

@ -0,0 +1,56 @@
from typing import Any, Dict, List, Tuple
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.pydantic_v1 import root_validator
from langchain_core.schema import BasePromptTemplate, PromptValue
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
return {k: inputs[k] for k in input_variables}
class PipelinePromptTemplate(BasePromptTemplate):
"""A prompt template for composing multiple prompt templates together.
This can be useful when you want to reuse parts of prompts.
A PipelinePrompt consists of two main parts:
- final_prompt: This is the final prompt that is returned
- pipeline_prompts: This is a list of tuples, consisting
of a string (`name`) and a Prompt Template.
Each PromptTemplate will be formatted and then passed
to future prompt templates as a variable with
the same name as `name`
"""
final_prompt: BasePromptTemplate
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict:
"""Get input variables."""
created_variables = set()
all_variables = set()
for k, prompt in values["pipeline_prompts"]:
created_variables.add(k)
all_variables.update(prompt.input_variables)
values["input_variables"] = list(all_variables.difference(created_variables))
return values
def format_prompt(self, **kwargs: Any) -> PromptValue:
for k, prompt in self.pipeline_prompts:
_inputs = _get_inputs(kwargs, prompt.input_variables)
if isinstance(prompt, BaseChatPromptTemplate):
kwargs[k] = prompt.format_messages(**_inputs)
else:
kwargs[k] = prompt.format(**_inputs)
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return self.final_prompt.format_prompt(**_inputs)
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
@property
def _prompt_type(self) -> str:
raise ValueError

@ -0,0 +1,250 @@
"""Prompt schema definition."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import root_validator
class PromptTemplate(StringPromptTemplate):
"""A prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
*Security warning*: Prefer using `template_format="f-string"` instead of
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
from untrusted sources as they may lead to arbitrary Python code execution.
As of LangChain 0.0.329, Jinja2 templates will be rendered using
Jinja2's SandboxedEnvironment by default. This sand-boxing should
be treated as a best-effort approach rather than a guarantee of security,
as it is an opt-out rather than opt-in approach.
Despite the sand-boxing, we recommend to never use jinja2 templates
from untrusted sources.
Example:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
# Instantiation using from_template (recommended)
prompt = PromptTemplate.from_template("Say {foo}")
prompt.format(foo="bar")
# Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"template_format": self.template_format,
}
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: str
"""The prompt template."""
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
if isinstance(other, PromptTemplate):
if self.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
if other.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
input_variables = list(
set(self.input_variables) | set(other.input_variables)
)
template = self.template + other.template
# If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()}
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")
else:
partial_variables[k] = v
return PromptTemplate(
template=template,
input_variables=input_variables,
partial_variables=partial_variables,
template_format="f-string",
validate_template=validate_template,
)
elif isinstance(other, str):
prompt = PromptTemplate.from_template(other)
return self + prompt
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod
def from_examples(
cls,
examples: List[str],
suffix: str,
input_variables: List[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt.
Intended to be used as a way to dynamically create a prompt from examples.
Args:
examples: List of examples to use in the prompt.
suffix: String to go after the list of examples. Should generally
set up the user's input.
input_variables: A list of variable names the final prompt template
will expect.
example_separator: The separator to use in between examples. Defaults
to two new line characters.
prefix: String that should go before any examples. Generally includes
examples. Default to an empty string.
Returns:
The final prompt generated.
"""
template = example_separator.join([prefix, *examples, suffix])
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_file(
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
) -> PromptTemplate:
"""Load a prompt from a file.
Args:
template_file: The path to the file containing the prompt template.
input_variables: A list of variable names the final prompt template
will expect.
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_template(
cls,
template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt template from a template.
*Security warning*: Prefer using `template_format="f-string"` instead of
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
from untrusted sources as they may lead to arbitrary Python code execution.
As of LangChain 0.0.329, Jinja2 templates will be rendered using
Jinja2's SandboxedEnvironment by default. This sand-boxing should
be treated as a best-effort approach rather than a guarantee of security,
as it is an opt-out rather than opt-in approach.
Despite the sand-boxing, we recommend to never use jinja2 templates
from untrusted sources.
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
Returns:
The prompt template loaded from the template.
"""
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}
if _partial_variables:
input_variables = [
var for var in input_variables if var not in _partial_variables
]
return cls(
input_variables=input_variables,
template=template,
template_format=template_format,
partial_variables=_partial_variables,
**kwargs,
)
# For backwards compatibility.
Prompt = PromptTemplate

@ -0,0 +1,23 @@
from importlib import metadata
## Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
try:
from pydantic.v1 import * # noqa: F403 # type: ignore
except ImportError:
from pydantic import * # noqa: F403 # type: ignore
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0

@ -0,0 +1,4 @@
try:
from pydantic.v1.dataclasses import * # noqa: F403
except ImportError:
from pydantic.dataclasses import * # noqa: F403

@ -0,0 +1,4 @@
try:
from pydantic.v1.main import * # noqa: F403
except ImportError:
from pydantic.main import * # noqa: F403

@ -0,0 +1,57 @@
"""LangChain **Runnable** and the **LangChain Expression Language (LCEL)**.
The LangChain Expression Language (LCEL) offers a declarative method to build
production-grade programs that harness the power of LLMs.
Programs created using LCEL and LangChain Runnables inherently support
synchronous, asynchronous, batch, and streaming operations.
Support for **async** allows servers hosting LCEL based programs to scale better
for higher concurrent loads.
**Streaming** of intermediate outputs as they're being generated allows for
creating more responsive UX.
This module contains schema and implementation of LangChain Runnables primitives.
"""
from langchain_core.runnables.base import (
Runnable,
RunnableBinding,
RunnableGenerator,
RunnableLambda,
RunnableMap,
RunnableParallel,
RunnableSequence,
RunnableSerializable,
)
from langchain_core.runnables.branch import RunnableBranch
from langchain_core.runnables.config import RunnableConfig, patch_config
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
)
__all__ = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"patch_config",
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",
"RunnableParallel",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,254 @@
from typing import (
Any,
Awaitable,
Callable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Runnable,
RunnableLike,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_callback_manager_for_config,
patch_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
class RunnableBranch(RunnableSerializable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain_core.runnables import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default,
(Runnable, Callable, Mapping), # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.get_input_schema(config).schema().get("type") is not None:
return runnable.get_input_schema(config)
return super().get_input_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for spec in step.config_specs
)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output

@ -0,0 +1,401 @@
from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Union,
cast,
)
from typing_extensions import TypedDict
from langchain_core.runnables.utils import (
Input,
Output,
accepts_config,
accepts_run_manager,
)
if TYPE_CHECKING:
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
)
else:
# Pydantic validates through typed dicts, but
# the callbacks need forward refs updated
Callbacks = Optional[Union[List, Any]]
class EmptyDict(TypedDict, total=False):
"""Empty dict type."""
pass
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
run_name: str
"""
Name for the tracer run for this call. Defaults to the name of the class.
"""
max_concurrency: Optional[int]
"""
Maximum number of parallel calls to make. If not provided, defaults to
ThreadPoolExecutor's default.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 25.
"""
configurable: Dict[str, Any]
"""
Runtime values for attributes previously made configurable on this Runnable,
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
Check .output_schema() for a description of the attributes that have been made
configurable.
"""
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
Args:
config (Optional[RunnableConfig], optional): The config to ensure.
Defaults to None.
Returns:
RunnableConfig: The ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata={},
callbacks=None,
recursion_limit=25,
)
if config is not None:
empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
)
return empty
def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""Get a list of configs from a single config or a list of configs.
It is useful for subclasses overriding batch() or abatch().
Args:
config (Optional[Union[RunnableConfig, List[RunnableConfig]]]):
The config or list of configs.
length (int): The length of the list.
Returns:
List[RunnableConfig]: The list of configs.
Raises:
ValueError: If the length of the list is not equal to the length of the inputs.
"""
if length < 0:
raise ValueError(f"length must be >= 0, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [ensure_config(config) for _ in range(length)]
)
def patch_config(
config: Optional[RunnableConfig],
*,
callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
configurable: Optional[Dict[str, Any]] = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config (Optional[RunnableConfig]): The config to patch.
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
Defaults to None.
recursion_limit (Optional[int], optional): The recursion limit to set.
Defaults to None.
max_concurrency (Optional[int], optional): The max concurrency to set.
Defaults to None.
run_name (Optional[str], optional): The run name to set. Defaults to None.
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
Defaults to None.
Returns:
RunnableConfig: The patched config.
"""
config = ensure_config(config)
if callbacks is not None:
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if configurable is not None:
config["configurable"] = {**config.get("configurable", {}), **configurable}
return config
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs (Optional[RunnableConfig]): The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "tags":
base[key] = list( # type: ignore
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
)
elif key == "configurable":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "callbacks":
base_callbacks = base.get("callbacks")
these_callbacks = config["callbacks"]
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(these_callbacks, list):
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + these_callbacks
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in these_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif these_callbacks is not None:
# these_callbacks is a manager
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
mngr = these_callbacks.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers,
inheritable_handlers=base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers,
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
base_callbacks.inheritable_tags
+ these_callbacks.inheritable_tags
)
),
metadata={
**base_callbacks.metadata,
**these_callbacks.metadata,
},
)
else:
base[key] = config[key] or base.get(key) # type: ignore
return base
def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
The function to call.
input (Input): The input to the function.
run_manager (CallbackManagerForChainRun): The run manager to
pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
else:
kwargs["config"] = config
if run_manager is not None and accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return func(input, **kwargs) # type: ignore[call-arg]
async def acall_func_with_variable_args(
func: Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
The function to call.
input (Input): The input to the function.
run_manager (AsyncCallbackManagerForChainRun): The run manager
to pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
else:
kwargs["config"] = config
if run_manager is not None and accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return await func(input, **kwargs) # type: ignore[call-arg]
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
"""Get an executor for a config.
Args:
config (RunnableConfig): The config.
Yields:
Generator[Executor, None, None]: The executor.
"""
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

@ -0,0 +1,388 @@
from __future__ import annotations
import enum
import threading
from abc import abstractmethod
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from weakref import WeakValueDictionary
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain_core.runnables.utils import (
AnyConfigurableField,
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
ConfigurableFieldSpec,
Input,
Output,
gather_with_concurrency,
get_unique_config_specs,
)
class DynamicRunnable(RunnableSerializable[Input, Output]):
"""A Serializable Runnable that can be dynamically configured."""
default: RunnableSerializable[Input, Output]
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def InputType(self) -> Type[Input]:
return self.default.InputType
@property
def OutputType(self) -> Type[Output]:
return self.default.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_output_schema(config)
@abstractmethod
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
...
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return self._prepare(config).invoke(input, config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._prepare(config).ainvoke(input, config, **kwargs)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]
if all(p is self.default for p in prepared):
return self.default.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)
if not inputs:
return []
configs = get_config_list(config, len(inputs))
def invoke(
bound: Runnable[Input, Output],
input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]:
if return_exceptions:
try:
return bound.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return bound.invoke(input, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output], list(executor.map(invoke, prepared, inputs, configs))
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]
if all(p is self.default for p in prepared):
return await self.default.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)
if not inputs:
return []
configs = get_config_list(config, len(inputs))
async def ainvoke(
bound: Runnable[Input, Output],
input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await bound.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await bound.ainvoke(input, config, **kwargs)
coros = map(ainvoke, prepared, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
return self._prepare(config).stream(input, config, **kwargs)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).astream(input, config, **kwargs):
yield chunk
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
return self._prepare(config).transform(input, config, **kwargs)
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
yield chunk
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
"""A Runnable that can be dynamically configured."""
fields: Dict[str, AnyConfigurableField]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
[
ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description
or self.default.__fields__[field_name].field_info.description,
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name),
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
for field_name, spec in self.fields.items()
]
+ list(self.default.config_specs)
)
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.default.configurable_fields(**{**self.fields, **kwargs})
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = {
specs_by_id[k][0]: v
for k, v in config.get("configurable", {}).items()
if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField)
}
configurable_single_options = {
k: v.options[(config.get("configurable", {}).get(v.id) or v.default)]
for k, v in self.fields.items()
if isinstance(v, ConfigurableFieldSingleOption)
}
configurable_multi_options = {
k: [
v.options[o]
for o in config.get("configurable", {}).get(v.id, v.default)
]
for k, v in self.fields.items()
if isinstance(v, ConfigurableFieldMultiOption)
}
configurable = {
**configurable_fields,
**configurable_single_options,
**configurable_multi_options,
}
if configurable:
return self.default.__class__(**{**self.default.__dict__, **configurable})
else:
return self.default
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
pass
_enums_for_spec: WeakValueDictionary[
Union[
ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField
],
Type[StrEnum],
] = WeakValueDictionary()
_enums_for_spec_lock = threading.Lock()
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
"""A Runnable that can be dynamically configured."""
which: ConfigurableField
alternatives: Dict[
str,
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
]
default_key: str = "default"
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
with _enums_for_spec_lock:
if which_enum := _enums_for_spec.get(self.which):
pass
else:
which_enum = StrEnum( # type: ignore[call-overload]
self.which.name or self.which.id,
(
(v, v)
for v in list(self.alternatives.keys()) + [self.default_key]
),
)
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=which_enum,
default=self.default_key,
),
*self.default.config_specs,
] + [
s
for alt in self.alternatives.values()
if isinstance(alt, RunnableSerializable)
for s in alt.config_specs
]
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.__class__(
which=self.which,
default=self.default.configurable_fields(**kwargs),
alternatives=self.alternatives,
)
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
which = config.get("configurable", {}).get(self.which.id, self.default_key)
if which == self.default_key:
return self.default
elif which in self.alternatives:
alt = self.alternatives[which]
if isinstance(alt, Runnable):
return alt
else:
return alt()
else:
raise ValueError(f"Unknown alternative: {which}")
def make_options_spec(
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
description: Optional[str],
) -> ConfigurableFieldSpec:
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
ConfigurableFieldMultiOption."""
with _enums_for_spec_lock:
if enum := _enums_for_spec.get(spec):
pass
else:
enum = StrEnum( # type: ignore[call-overload]
spec.name or spec.id,
((v, v) for v in list(spec.options.keys())),
)
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
if isinstance(spec, ConfigurableFieldSingleOption):
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=enum,
default=spec.default,
)
else:
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
)

@ -0,0 +1,344 @@
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
if TYPE_CHECKING:
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""A Runnable that can fallback to other Runnables if it fails.
External APIs (e.g., APIs for a language model) may at times experience
degraded performance or even downtime.
In these cases, it can be useful to have a fallback runnable that can be
used in place of the original runnable (e.g., fallback to another LLM provider).
Fallbacks can be defined at the level of a single runnable, or at the level
of a chain of runnables. Fallbacks are tried in order until one succeeds or
all fail.
While you can instantiate a ``RunnableWithFallbacks`` directly, it is usually
more convenient to use the ``with_fallbacks`` method on a runnable.
Example:
.. code-block:: python
from langchain_core.chat_models.openai import ChatOpenAI
from langchain_core.chat_models.anthropic import ChatAnthropic
model = ChatAnthropic().with_fallbacks([ChatOpenAI()])
# Will usually use ChatAnthropic, but fallback to ChatOpenAI
# if ChatAnthropic fails.
model.invoke('hello')
# And you can also use fallbacks at the level of a chain.
# Here if both LLM providers fail, we'll fallback to a good hardcoded
# response.
from langchain_core.prompts import PromptTemplate
from langchain_core.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda
def when_all_is_lost(inputs):
return ("Looks like our LLM providers are down. "
"Here's a nice 🦜️ emoji for you instead.")
chain_with_fallback = (
PromptTemplate.from_template('Tell me a joke about {topic}')
| model
| StrOutputParser()
).with_fallbacks([RunnableLambda(when_all_is_lost)])
"""
runnable: Runnable[Input, Output]
"""The runnable to run first."""
fallbacks: Sequence[Runnable[Input, Output]]
"""A sequence of fallbacks to try."""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
"""The exceptions on which fallbacks should be tried.
Any exception that is not a subclass of these exceptions will be raised immediately.
"""
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.runnable.get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]
for spec in step.config_specs
)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error

@ -0,0 +1,288 @@
from __future__ import annotations
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.load import load
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
from langchain_core.schema.chat_history import BaseChatMessageHistory
if TYPE_CHECKING:
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.runnables.config import RunnableConfig
from langchain_core.schema.messages import BaseMessage
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase):
"""A runnable that manages chat message history for another runnable.
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}``
Example (dict input):
.. code-block:: python
from typing import Optional
from langchain_core.chat_models import ChatAnthropic
from langchain_core.memory.chat_message_histories import RedisChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
prompt = ChatPromptTemplate.from_messages([
("system", "You're an assistant who's good at {ability}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
])
chain = prompt | ChatAnthropic(model="claude-2")
chain_with_history = RunnableWithMessageHistory(
chain,
RedisChatMessageHistory,
input_messages_key="question",
history_messages_key="history",
)
chain_with_history.invoke(
{"ability": "math", "question": "What does cosine mean?"},
config={"configurable": {"session_id": "foo"}}
)
# -> "Cosine is ..."
chain_with_history.invoke(
{"ability": "math", "question": "What's its inverse"},
config={"configurable": {"session_id": "foo"}}
)
# -> "The inverse of cosine is called arccosine ..."
""" # noqa: E501
get_session_history: GetSessionHistoryCallable
input_messages_key: Optional[str] = None
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
def __init__(
self,
runnable: Runnable[
MessagesOrDictWithMessages,
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
get_session_history: GetSessionHistoryCallable,
*,
input_messages_key: Optional[str] = None,
output_messages_key: Optional[str] = None,
history_messages_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize RunnableWithMessageHistory.
Args:
runnable: The base Runnable to be wrapped.
Must take as input one of:
- A sequence of BaseMessages
- A dict with one key for all messages
- A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points
to a string, it will be treated as a HumanMessage in history.
Must return as output one of:
- A string which can be treated as an AIMessage
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory
given a session id. Should take a single
positional argument `session_id` which is a string and a named argument
`user_id` which can be a string or None. e.g.:
```python
def get_session_history(
session_id: str,
*,
user_id: Optional[str]=None
) -> BaseChatMessageHistory:
...
```
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
output_messages_key: Must be specified if the base runnable returns a dict
as output.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
**kwargs: Arbitrary additional kwargs to pass to parent class
``RunnableBindingBase`` init.
""" # noqa: E501
history_chain: Runnable = RunnableLambda(
self._enter_history, self._aenter_history
).with_config(run_name="load_history")
messages_key = history_messages_key or input_messages_key
if messages_key:
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="RunnableWithMessageHistory")
super().__init__(
get_session_history=get_session_history,
input_messages_key=input_messages_key,
output_messages_key=output_messages_key,
bound=bound,
history_messages_key=history_messages_key,
**kwargs,
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
super().config_specs
+ [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
),
]
)
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ is not None:
from langchain_core.schema.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
if self.history_messages_key:
fields[self.history_messages_key] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
else:
return super_schema
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
) -> List[BaseMessage]:
from langchain_core.schema.messages import BaseMessage
if isinstance(input_val, str):
from langchain_core.schema.messages import HumanMessage
return [HumanMessage(content=input_val)]
elif isinstance(input_val, BaseMessage):
return [input_val]
elif isinstance(input_val, (list, tuple)):
return list(input_val)
else:
raise ValueError(
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {input_val}."
)
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
from langchain_core.schema.messages import BaseMessage
if isinstance(output_val, dict):
output_val = output_val[self.output_messages_key or "output"]
if isinstance(output_val, str):
from langchain_core.schema.messages import AIMessage
return [AIMessage(content=output_val)]
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
return list(output_val)
else:
raise ValueError()
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist = config["configurable"]["message_history"]
# return only historic messages
if self.history_messages_key:
return hist.messages.copy()
# return all messages
else:
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
return hist.messages.copy() + self._get_input_messages(input_val)
async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]:
return await asyncio.get_running_loop().run_in_executor(
None, self._enter_history, input, config
)
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val)
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
for m in input_messages + output_messages:
hist.add_message(m)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
# extract session_id
if "session_id" not in config.get("configurable", {}):
example_input = {self.input_messages_key: "foo"}
example_config = {"configurable": {"session_id": "123"}}
raise ValueError(
"session_id_id is required."
" Pass it in as part of the config argument to .invoke() or .stream()"
f"\neg. chain.invoke({example_input}, {example_config})"
)
# attach message_history
session_id = config["configurable"]["session_id"]
config["configurable"]["message_history"] = self.get_session_history(session_id)
return config

@ -0,0 +1,453 @@
"""Implementation of the RunnablePassthrough."""
from __future__ import annotations
import asyncio
import inspect
import threading
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import (
Other,
Runnable,
RunnableParallel,
RunnableSerializable,
)
from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
get_executor_for_config,
)
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
def identity(x: Other) -> Other:
"""An identity function"""
return x
async def aidentity(x: Other) -> Other:
"""An async identity function"""
return x
class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""A runnable to passthrough inputs unchanged or with additional keys.
This runnable behaves almost like the identity function, except that it
can be configured to add additional keys to the output, if the input is a
dict.
The examples below demonstrate this runnable works using a few simple
chains. The chains rely on simple lambdas to make the examples easy to execute
and experiment with.
Examples:
.. code-block:: python
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
runnable = RunnableParallel(
origin=RunnablePassthrough(),
modified=lambda x: x+1
)
runnable.invoke(1) # {'origin': 1, 'modified': 2}
def fake_llm(prompt: str) -> str: # Fake LLM for the example
return "completion"
chain = RunnableLambda(fake_llm) | {
'original': RunnablePassthrough(), # Original LLM output
'parsed': lambda text: text[::-1] # Parsing logic
}
chain.invoke('hello') # {'original': 'completion', 'parsed': 'noitelpmoc'}
In some cases, it may be useful to pass the input through while adding some
keys to the output. In this case, you can use the `assign` method:
.. code-block:: python
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
def fake_llm(prompt: str) -> str: # Fake LLM for the example
return "completion"
runnable = {
'llm1': fake_llm,
'llm2': fake_llm,
}
| RunnablePassthrough.assign(
total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2'])
)
runnable.invoke('hello')
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
"""
input_type: Optional[Type[Other]] = None
func: Optional[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
] = None
afunc: Optional[
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
]
] = None
def __init__(
self,
func: Optional[
Union[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]],
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
],
]
] = None,
afunc: Optional[
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
]
] = None,
*,
input_type: Optional[Type[Other]] = None,
**kwargs: Any,
) -> None:
if inspect.iscoroutinefunction(func):
afunc = func
func = None
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def InputType(self) -> Any:
return self.input_type or Any
@property
def OutputType(self) -> Any:
return self.input_type or Any
@classmethod
def assign(
cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> RunnableAssign:
"""Merge the Dict input with the output produced by the mapping argument.
Args:
mapping: A mapping from keys to runnables or callables.
Returns:
A runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other:
if self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
return self._call_with_config(identity, input, config)
async def ainvoke(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Other:
if self.afunc is not None:
await acall_func_with_variable_args(
self.afunc, input, config or {}, **kwargs
)
elif self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
return await self._acall_with_config(aidentity, input, config)
def transform(
self,
input: Iterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Other]:
if self.func is None:
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
else:
final = None
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
if final is None:
final = chunk
else:
final = final + chunk
if final is not None:
call_func_with_variable_args(self.func, final, config or {}, **kwargs)
async def atransform(
self,
input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Other]:
if self.afunc is None and self.func is None:
async for chunk in self._atransform_stream_with_config(
input, identity, config
):
yield chunk
else:
final = None
async for chunk in self._atransform_stream_with_config(
input, identity, config
):
yield chunk
if final is None:
final = chunk
else:
final = final + chunk
if final is not None:
config = config or {}
if self.afunc is not None:
await acall_func_with_variable_args(
self.afunc, final, config, **kwargs
)
elif self.func is not None:
call_func_with_variable_args(self.func, final, config, **kwargs)
def stream(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Other]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Other]:
async def input_aiter() -> AsyncIterator[Other]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""
mapper: RunnableParallel[Dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
return map_input_schema
return super().get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
):
# ie. both are dicts
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
)
elif not map_output_schema.__custom_root_type__:
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema
return super().get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**self.mapper.invoke(input, config, **kwargs),
}
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
}
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
first_map_chunk_future = executor.submit(
next,
map_output, # type: ignore
None,
)
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output, None), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield await first_map_chunk_task
async for chunk in map_output:
yield chunk
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk

@ -0,0 +1,337 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from tenacity import (
AsyncRetrying,
RetryCallState,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain_core.runnables.base import Input, Output, RunnableBindingBase
from langchain_core.runnables.config import RunnableConfig, patch_config
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")
class RunnableRetry(RunnableBindingBase[Input, Output]):
"""Retry a Runnable if it fails.
A RunnableRetry helps can be used to add retry logic to any object
that subclasses the base Runnable.
Such retries are especially useful for network calls that may fail
due to transient errors.
The RunnableRetry is implemented as a RunnableBinding. The easiest
way to use it is through the `.with_retry()` method on all Runnables.
Example:
Here's an example that uses a RunnableLambda to raise an exception
.. code-block:: python
import time
def foo(input) -> None:
'''Fake function that raises an exception.'''
raise ValueError("Invoking foo failed. At time {time.time()}")
runnable = RunnableLambda(foo)
runnable_with_retries = runnable.with_retry(
retry_exception_types=(ValueError,), # Retry only on ValueError
wait_exponential_jitter=True, # Add jitter to the exponential backoff
max_attempt_number=2, # Try twice
)
# The method invocation above is equivalent to the longer form below:
runnable_with_retries = RunnableRetry(
bound=runnable,
retry_exception_types=(ValueError,),
max_attempt_number=2,
wait_exponential_jitter=True
)
This logic can be used to retry any Runnable, including a chain of Runnables,
but in general it's best practice to keep the scope of the retry as small as
possible. For example, if you have a chain of Runnables, you should only retry
the Runnable that is likely to fail, not the entire chain.
Example:
.. code-block:: python
from langchain_core.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate
template = PromptTemplate.from_template("tell me a joke about {topic}.")
model = ChatOpenAI(temperature=0.5)
# Good
chain = template | model.with_retry()
# Bad
chain = template | model
retryable_chain = chain.with_retry()
"""
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
"""The exception types to retry on. By default all exceptions are retried.
In general you should only retry on exceptions that are likely to be
transient, such as network errors.
Good exceptions to retry are all server errors (5xx) and selected client
errors (4xx) such as 429 Too Many Requests.
"""
wait_exponential_jitter: bool = True
"""Whether to add jitter to the exponential backoff."""
max_attempt_number: int = 3
"""The maximum number of attempts to retry the runnable."""
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
if self.wait_exponential_jitter:
kwargs["wait"] = wait_exponential_jitter()
if self.retry_exception_types:
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)
return kwargs
def _sync_retrying(self, **kwargs: Any) -> Retrying:
return Retrying(**self._kwargs_retrying, **kwargs)
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
return AsyncRetrying(**self._kwargs_retrying, **kwargs)
def _patch_config(
self,
config: RunnableConfig,
run_manager: "T",
retry_state: RetryCallState,
) -> RunnableConfig:
attempt = retry_state.attempt_number
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
return patch_config(config, callbacks=run_manager.get_child(tag))
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List["T"],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
def _invoke(
self,
input: Input,
run_manager: "CallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
) -> Output:
for attempt in self._sync_retrying(reraise=True):
with attempt:
result = super().invoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Input,
run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
) -> Output:
async for attempt in self._async_retrying(reraise=True):
with attempt:
result = await super().ainvoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _batch(
self,
inputs: List[Input],
run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = super().batch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
async def _abatch(
self,
inputs: List[Input],
run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = await super().abatch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
# stream() and transform() are not retried because retrying a stream
# is not very intuitive.

@ -0,0 +1,206 @@
from __future__ import annotations
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Optional,
Union,
cast,
)
from typing_extensions import TypedDict
from langchain_core.runnables.base import (
Input,
Output,
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
gather_with_concurrency,
get_unique_config_specs,
)
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)
def __init__(
self,
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
) -> None:
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
def invoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return runnable.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return runnable.invoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await runnable.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await runnable.ainvoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
return await gather_with_concurrency(
configs[0].get("max_concurrency"),
*(
ainvoke(runnable, input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output

@ -0,0 +1,327 @@
from __future__ import annotations
import ast
import asyncio
import inspect
import textwrap
from inspect import signature
from itertools import groupby
from typing import (
Any,
AsyncIterable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Set,
TypeVar,
Union,
)
Input = TypeVar("Input", contravariant=True)
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output", covariant=True)
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
"""Run a coroutine with a semaphore.
Args:
semaphore: The semaphore to use.
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
async with semaphore:
return await coro
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
"""Gather coroutines with a limit on the number of concurrent coroutines."""
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a run_manager argument."""
try:
return signature(callable).parameters.get("run_manager") is not None
except ValueError:
return False
def accepts_config(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a config argument."""
try:
return signature(callable).parameters.get("config") is not None
except ValueError:
return False
class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict."""
def __init__(self, name: str, keys: Set[str]) -> None:
self.name = name
self.keys = keys
def visit_Subscript(self, node: ast.Subscript) -> Any:
if (
isinstance(node.ctx, ast.Load)
and isinstance(node.value, ast.Name)
and node.value.id == self.name
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
):
# we've found a subscript access on the name we're looking for
self.keys.add(node.slice.value)
def visit_Call(self, node: ast.Call) -> Any:
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.name
and node.func.attr == "get"
and len(node.args) in (1, 2)
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
):
# we've found a .get() call on the name we're looking for
self.keys.add(node.args[0].value)
class IsFunctionArgDict(ast.NodeVisitor):
"""Check if the first argument of a function is a dict."""
def __init__(self) -> None:
self.keys: Set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
class GetLambdaSource(ast.NodeVisitor):
"""Get the source code of a lambda function."""
def __init__(self) -> None:
"""Initialize the visitor."""
self.source: Optional[str] = None
self.count = 0
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function."""
self.count += 1
if hasattr(ast, "unparse"):
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
"""Get the keys of the first argument of a function if it is a dict."""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
except (SyntaxError, TypeError, OSError):
return None
def get_lambda_source(func: Callable) -> Optional[str]:
"""Get the source code of a lambda function.
Args:
func: a callable that can be a lambda function
Returns:
str: the source code of the lambda function
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = GetLambdaSource()
visitor.visit(tree)
return visitor.source if visitor.count == 1 else None
except (SyntaxError, TypeError, OSError):
return None
def indent_lines_after_first(text: str, prefix: str) -> str:
"""Indent all lines of text after the first line.
Args:
text: The text to indent
prefix: Used to determine the number of spaces to indent
Returns:
str: The indented text
"""
n_spaces = len(prefix)
spaces = " " * n_spaces
lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""
def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
try:
added = chunk[key] + other[key]
except TypeError:
added = other[key]
chunk[key] = added
return chunk
def __radd__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
try:
added = chunk[key] + self[key]
except TypeError:
added = self[key]
chunk[key] = added
return chunk
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition."""
def __add__(self, __x: _T_contra) -> _T_co:
...
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
def add(addables: Iterable[Addable]) -> Optional[Addable]:
"""Add a sequence of addable objects together."""
final = None
for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
"""Asynchronously add a sequence of addable objects together."""
final = None
async for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final
class ConfigurableField(NamedTuple):
"""A field that can be configured by the user."""
id: str
name: Optional[str] = None
description: Optional[str] = None
annotation: Optional[Any] = None
def __hash__(self) -> int:
return hash((self.id, self.annotation))
class ConfigurableFieldSingleOption(NamedTuple):
"""A field that can be configured by the user with a default value."""
id: str
options: Mapping[str, Any]
default: str
name: Optional[str] = None
description: Optional[str] = None
def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), self.default))
class ConfigurableFieldMultiOption(NamedTuple):
"""A field that can be configured by the user with multiple default values."""
id: str
options: Mapping[str, Any]
default: Sequence[str]
name: Optional[str] = None
description: Optional[str] = None
def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
AnyConfigurableField = Union[
ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption
]
class ConfigurableFieldSpec(NamedTuple):
"""A field that can be configured by the user. It is a specification of a field."""
id: str
name: Optional[str]
description: Optional[str]
default: Any
annotation: Any
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
unique: List[ConfigurableFieldSpec] = []
for id, dupes in grouped:
first = next(dupes)
others = list(dupes)
if len(others) == 0:
unique.append(first)
elif all(o == first for o in others):
unique.append(first)
else:
raise ValueError(
"RunnableSequence contains conflicting config specs"
f"for {id}: {[first] + others}"
)
return unique

@ -0,0 +1,78 @@
"""**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain_core.schema.agent import AgentAction, AgentFinish
from langchain_core.schema.cache import BaseCache
from langchain_core.schema.chat_history import BaseChatMessageHistory
from langchain_core.schema.document import BaseDocumentTransformer, Document
from langchain_core.schema.exceptions import LangChainException
from langchain_core.schema.memory import BaseMemory
from langchain_core.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
_message_from_dict,
_message_to_dict,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
from langchain_core.schema.output import (
ChatGeneration,
ChatResult,
Generation,
LLMResult,
RunInfo,
)
from langchain_core.schema.output_parser import (
BaseLLMOutputParser,
BaseOutputParser,
OutputParserException,
StrOutputParser,
)
from langchain_core.schema.prompt import PromptValue
from langchain_core.schema.prompt_template import BasePromptTemplate, format_document
from langchain_core.schema.retriever import BaseRetriever
from langchain_core.schema.storage import BaseStore
RUN_KEY = "__run"
Memory = BaseMemory
__all__ = [
"BaseCache",
"BaseMemory",
"BaseStore",
"AgentFinish",
"AgentAction",
"Document",
"BaseChatMessageHistory",
"BaseDocumentTransformer",
"BaseMessage",
"ChatMessage",
"FunctionMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"messages_from_dict",
"messages_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",
"LLMResult",
"ChatResult",
"ChatGeneration",
"Generation",
"PromptValue",
"LangChainException",
"BaseRetriever",
"RUN_KEY",
"Memory",
"OutputParserException",
"StrOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"format_document",
]

@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Any, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.schema.messages import BaseMessage
class AgentAction(Serializable):
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action.
This log can be used in a few ways. First, it can be used to audit
what exactly the LLM predicted to lead to this (tool, tool_input).
Second, it can be used in future iterations to show the LLMs prior
thoughts. This is useful when (tool, tool_input) does not contain
full information about the LLM prediction (for example, any `thought`
before the tool/tool_input)."""
type: Literal["AgentAction"] = "AgentAction"
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
class AgentActionMessageLog(AgentAction):
message_log: Sequence[BaseMessage]
"""Similar to log, this can be used to pass along extra
information about what exact messages were predicted by the LLM
before parsing out the (tool, tool_input). This is again useful
if (tool, tool_input) cannot be used to fully recreate the LLM
prediction, and you need that LLM prediction (for future agent iteration).
Compared to `log`, this is useful when the underlying LLM is a
ChatModel (and therefore returns messages rather than a string)."""
# Ignoring type because we're overriding the type from AgentAction.
# And this is the correct thing to do in this case.
# The type literal is used for serialization purposes.
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
class AgentFinish(Serializable):
"""The final return value of an ActionAgent."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value.
This is used to pass along the full LLM prediction, not just the parsed out
return value. For example, if the full LLM prediction was
`Final Answer: 2` you may want to just return `2` as a return value, but pass
along the full string as a `log` (for debugging or observability purposes).
"""
type: Literal["AgentFinish"] = "AgentFinish"
def __init__(self, return_values: dict, log: str, **kwargs: Any):
"""Override init to support instantiation by position for backward compat."""
super().__init__(return_values=return_values, log=log, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain_core.schema.output import Generation
RETURN_VAL_TYPE = Sequence[Generation]
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""

@ -0,0 +1,13 @@
from typing import Sequence, TypedDict
from langchain_core.schema import BaseMessage
class ChatSession(TypedDict, total=False):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""
functions: Sequence[dict]
"""The function calling specs for the messages."""

@ -0,0 +1,67 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

@ -0,0 +1,91 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Literal, Sequence
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
type: Literal["Document"] = "Document"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)

@ -0,0 +1,27 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

@ -0,0 +1,2 @@
class LangChainException(Exception):
"""General LangChain exception."""

@ -0,0 +1,291 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
from langchain_core.runnables import RunnableSerializable
from langchain_core.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.schema.output import LLMResult
from langchain_core.schema.prompt import PromptValue
from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain_core.callbacks.manager import Callbacks
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ImportError(
"Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`."
)
# create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2")
def _get_token_ids_default_method(text: str) -> List[int]:
"""Encode the text into token IDs."""
# get the cached tokenizer
tokenizer = get_tokenizer()
# tokenize the text using the GPT-2 tokenizer
return tokenizer.encode(text)
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel.
Exposes three main methods:
- generate_prompt: generate language model outputs for a sequence of prompt
values. A prompt value is a model input that can be converted to any language
model input format (string or messages).
- predict: pass in a single string to a language model and return a string
prediction.
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
model call) to a language model and return a BaseMessage prediction.
Each of these has an equivalent asynchronous method.
"""
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValueConcrete
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
@abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Pass a single string input to the model and return a string prediction.
Use this method when passing in raw text. If you want to pass in specific
types of chat messages, use predict_messages.
Args:
text: String input to pass to the model.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a string.
"""
@abstractmethod
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Pass a message sequence to the model and return a message prediction.
Use this method when passing in chat messages. If you want to pass in raw text,
use predict.
Args:
messages: A sequence of chat messages corresponding to a single model input.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a message.
"""
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Asynchronously pass a string to the model and return a string prediction.
Use this method when calling pure text generation models and only the top
candidate generation is needed.
Args:
text: String input to pass to the model.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a string.
"""
@abstractmethod
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Asynchronously pass messages to the model and return a message prediction.
Use this method when calling chat models and only the top
candidate generation is needed.
Args:
messages: A sequence of chat messages corresponding to a single model input.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a message.
"""
def get_token_ids(self, text: str) -> List[int]:
"""Return the ordered ids of the tokens in a text.
Args:
text: The string input to tokenize.
Returns:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def _all_required_field_names(cls) -> Set:
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.
"""
return get_pydantic_field_names(cls)

@ -0,0 +1,59 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain_core.load.serializable import Serializable
class BaseMemory(Serializable, ABC):
"""Abstract base class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""

@ -0,0 +1,415 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from typing_extensions import Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: Union[str, List[Union[str, Dict]]]
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
type: str
class Config:
extra = Extra.allow
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
# If both are lists, merge them naively
elif isinstance(second_content, List):
return first_content + second_content
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
if isinstance(self, ChatMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return self.__class__(
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["human"] = "human"
HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["ai"] = "ai"
AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
type: Literal["system"] = "system"
SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
type: Literal["function"] = "function"
FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ToolMessage(BaseMessage):
"""A Message for passing the result of executing a tool back to a model."""
tool_call_id: str
"""Tool call that this message is responding to."""
type: Literal["tool"] = "tool"
ToolMessage.update_forward_refs()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""A Tool Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:
raise ValueError(
"Cannot concatenate ToolMessageChunks with different names."
)
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]

@ -0,0 +1,175 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Literal, Optional
from uuid import UUID
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.schema.messages import BaseMessage, BaseMessageChunk
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)

@ -0,0 +1,475 @@
from __future__ import annotations
import asyncio
import functools
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
)
from typing_extensions import get_args
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain_core.schema.output import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain_core.schema.prompt import PromptValue
T = TypeVar("T")
class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
class BaseGenerationOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
async def ainvoke(
self,
input: str | BaseMessage,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.
Example:
.. code-block:: python
class BooleanOutputParser(BaseOutputParser[bool]):
true_val: str = "YES"
false_val: str = "NO"
def parse(self, text: str) -> bool:
cleaned_text = text.strip().upper()
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
raise OutputParserException(
f"BooleanOutputParser expected output value to either be "
f"{self.true_val} or {self.false_val} (case-insensitive). "
f"Received {cleaned_text}."
)
return cleaned_text == self.true_val.upper()
@property
def _type(self) -> str:
return "boolean_output_parser"
""" # noqa: E501
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
async def ainvoke(
self,
input: str | BaseMessage,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, functools.partial(self.parse_result, partial=partial), result
)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: String output of a language model.
prompt: Input PromptValue.
Returns:
Structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
# TODO: Deprecate
NoOpOutputParser = StrOutputParser
class OutputParserException(ValueError):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

@ -0,0 +1,28 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain_core.load.serializable import Serializable
from langchain_core.schema.messages import BaseMessage
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""

@ -0,0 +1,228 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
import yaml
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.document import Document
from langchain_core.schema.output_parser import BaseOutputParser
from langchain_core.schema.prompt import PromptValue
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
input_types: Dict[str, Any] = Field(default_factory=dict)
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import ChatPromptValueConcrete
return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(
**{key: inner_input[key] for key in self.input_variables}
),
input,
config,
run_type="prompt",
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain_core.schema import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

@ -0,0 +1,275 @@
from __future__ import annotations
import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.load.dump import dumpd
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.schema.document import Document
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Example:
.. code-block:: python
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
config = config or {}
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
async def ainvoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
config = config or {}
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
)
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = await self._aget_relevant_documents(query, **_kwargs)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
K = TypeVar("K")
V = TypeVar("V")
class BaseStore(Generic[K, V], ABC):
"""Abstract interface for a key-value store."""
@abstractmethod
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[K]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
@abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
@abstractmethod
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[K]): A sequence of keys to delete.
"""
@abstractmethod
def yield_keys(
self, *, prefix: Optional[str] = None
) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str): The prefix to match.
Returns:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""

@ -0,0 +1,702 @@
from __future__ import annotations
import asyncio
import logging
import math
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.schema import BaseRetriever
from langchain_core.schema.document import Document
from langchain_core.schema.embeddings import Embeddings
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
logger = logging.getLogger(__name__)
VST = TypeVar("VST", bound="VectorStore")
class VectorStore(ABC):
"""Interface for vector store."""
@abstractmethod
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids from adding the texts into the vectorstore.
"""
@property
def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
logger.debug(
f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}"
)
return None
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
raise NotImplementedError("delete method must be implemented by subclass.")
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
raise NotImplementedError("delete method must be implemented by subclass.")
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
documents (List[Document]: Documents to add to the vectorstore.
Returns:
List[str]: List of IDs of the added texts.
"""
# TODO: Handle the case where the user doesn't provide ids on the Collection
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
documents (List[Document]: Documents to add to the vectorstore.
Returns:
List[str]: List of IDs of the added texts.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
@abstractmethod
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
@staticmethod
def _euclidean_relevance_score_fn(distance: float) -> float:
"""Return a similarity score on a scale [0, 1]."""
# The 'correct' relevance function
# may differ depending on a few things, including:
# - the distance / similarity metric used by the VectorStore
# - the scale of your embeddings (OpenAI's are unit normed. Many
# others are not!)
# - embedding dimensionality
# - etc.
# This function converts the euclidean norm of normalized embeddings
# (0 is most similar, sqrt(2) most dissimilar)
# to a similarity function (0 to 1)
return 1.0 - distance / math.sqrt(2)
@staticmethod
def _cosine_relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1]."""
return 1.0 - distance
@staticmethod
def _max_inner_product_relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1]."""
if distance > 0:
return 1.0 - distance
return -1.0 * distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- embedding dimensionality
- etc.
Vectorstores should define their own selection based method of relevance.
"""
raise NotImplementedError
def similarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Run similarity search with distance."""
raise NotImplementedError
async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Run similarity search with distance asynchronously."""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_with_score, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Default similarity search with relevance scores. Modify if necessary
in subclass.
Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
relevance_score_fn = self._select_relevance_score_fn()
docs_and_scores = self.similarity_search_with_score(query, k, **kwargs)
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
async def _asimilarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Default async similarity search with relevance scores. Modify if necessary
in subclass.
Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
relevance_score_fn = self._select_relevance_score_fn()
docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs)
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
def similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
score_threshold = kwargs.pop("score_threshold", None)
docs_and_similarities = self._similarity_search_with_relevance_scores(
query, k=k, **kwargs
)
if any(
similarity < 0.0 or similarity > 1.0
for _, similarity in docs_and_similarities
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
)
if score_threshold is not None:
docs_and_similarities = [
(doc, similarity)
for doc, similarity in docs_and_similarities
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return docs_and_similarities
async def asimilarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1], asynchronously.
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
score_threshold = kwargs.pop("score_threshold", None)
docs_and_similarities = await self._asimilarity_search_with_relevance_scores(
query, k=k, **kwargs
)
if any(
similarity < 0.0 or similarity > 1.0
for _, similarity in docs_and_similarities
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
)
if score_threshold is not None:
docs_and_similarities = [
(doc, similarity)
for doc, similarity in docs_and_similarities
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return docs_and_similarities
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search, query, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query vector.
"""
raise NotImplementedError
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to embedding vector."""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
raise NotImplementedError
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
self.max_marginal_relevance_search,
query,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
raise NotImplementedError
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
raise NotImplementedError
@classmethod
def from_documents(
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod
async def afrom_documents(
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod
@abstractmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
@classmethod
async def afrom_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
)
def _get_retriever_tags(self) -> List[str]:
"""Get tags for retriever."""
tags = [self.__class__.__name__]
if self.embeddings:
tags.append(self.embeddings.__class__.__name__)
return tags
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
"""Return VectorStoreRetriever initialized from this VectorStore.
Args:
search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be "similarity" (default), "mmr", or
"similarity_score_threshold".
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
k: Amount of documents to return (Default: 4)
score_threshold: Minimum relevance threshold
for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
lambda_mult: Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5)
filter: Filter by document metadata
Returns:
VectorStoreRetriever: Retriever class for VectorStore.
Examples:
.. code-block:: python
# Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
)
# Only retrieve documents that have a relevance score
# Above a certain threshold
docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
# Only get the single most similar document from the dataset
docsearch.as_retriever(search_kwargs={'k': 1})
# Use a filter to only retrieve documents from a specific paper
docsearch.as_retriever(
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
)
"""
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class VectorStoreRetriever(BaseRetriever):
"""Base Retriever class for VectorStore."""
vectorstore: VectorStore
"""VectorStore to use for retrieval."""
search_type: str = "similarity"
"""Type of search to perform. Defaults to "similarity"."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
search_type = values["search_type"]
if search_type not in cls.allowed_search_types:
raise ValueError(
f"search_type of {search_type} not allowed. Valid values are: "
f"{cls.allowed_search_types}"
)
if search_type == "similarity_score_threshold":
score_threshold = values["search_kwargs"].get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
raise ValueError(
"`score_threshold` is not specified with a float value(0~1) "
"in `search_kwargs`."
)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
return await self.vectorstore.aadd_documents(documents, **kwargs)

@ -0,0 +1,845 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForToolRun,
CallbackManager,
CallbackManagerForToolRun,
Callbacks,
)
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
create_model,
root_validator,
validate_arguments,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.outer_type_, field.field_info)
return create_model(name, **fields) # type: ignore
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
class _SchemaConfig:
"""Configuration for the pydantic model."""
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
)
class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,
but will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
"""
pass
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Create the definition of the new tool class."""
super().__init_subclass__(**kwargs)
args_schema_type = cls.__annotations__.get("args_schema", None)
if args_schema_type is not None:
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
name = cls.__name__
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to be called during tool execution."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None
These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None
This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
schema = create_schema_from_function(self.name, self._run)
return schema.schema()["properties"]
# --- Runnable ---
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
def invoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
# --- Tool ---
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously.
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
return await asyncio.get_running_loop().run_in_executor(
None,
partial(self._run, **kwargs),
*args,
)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
)
except ToolException as e:
if not self.handle_tool_error:
run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e)
raise e
else:
run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
)
except ToolException as e:
if not self.handle_tool_error:
await run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e)
raise e
else:
await run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
"""Make tool callable."""
return self.run(tool_input, callbacks=callbacks)
class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""
description: str = ""
func: Optional[Callable[..., str]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
# --- Runnable ---
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---
@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ToolException(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), *args
)
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
name=name, func=func, description=description, **kwargs
)
@classmethod
def from_function(
cls,
func: Optional[Callable],
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls(
name=name,
func=func,
coroutine=coroutine,
description=description,
return_direct=return_direct,
args_schema=args_schema,
**kwargs,
)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
# --- Runnable ---
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await asyncio.get_running_loop().run_in_executor(
None,
partial(self._run, run_manager=run_manager, **kwargs),
*args,
)
@classmethod
def from_function(
cls,
func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
**kwargs: Additional arguments to pass to the tool
Returns:
The tool
Examples:
.. code-block:: python
def add(a: int, b: int) -> int:
\"\"\"Add two numbers\"\"\"
return a + b
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
if func is not None:
source_function = func
elif coroutine is not None:
source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
# Description example:
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)
def tool(
*args: Union[str, Callable, Runnable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable):
runnable = dec_func
if runnable.input_schema.schema().get("type") != "object":
raise ValueError("Runnable must have an object schema.")
async def ainvoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return await runnable.ainvoke(kwargs, {"callbacks": callbacks})
def invoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return runnable.invoke(kwargs, {"callbacks": callbacks})
coroutine = ainvoke_wrapper
func = invoke_wrapper
schema: Optional[Type[BaseModel]] = runnable.input_schema
description = repr(runnable)
elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
schema = args_schema
description = None
else:
coroutine = None
func = dec_func
schema = args_schema
description = None
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
description=description,
return_direct=return_direct,
args_schema=schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
if func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
)
return Tool(
name=tool_name,
func=func,
description=f"{tool_name} tool",
return_direct=return_direct,
coroutine=coroutine,
)
return _make_tool
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseTool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

@ -0,0 +1,38 @@
"""
**Utility functions** for LangChain.
These functions do not depend on any other LangChain module.
"""
from langchain_core.utils.formatting import StrictFormatter, formatter
from langchain_core.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
from langchain_core.utils.utils import (
check_package_version,
convert_to_secret_str,
get_pydantic_field_names,
guard_import,
mock_now,
raise_for_status_with_text,
xor_args,
)
__all__ = [
"StrictFormatter",
"check_package_version",
"convert_to_secret_str",
"formatter",
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"get_pydantic_field_names",
"guard_import",
"mock_now",
"print_text",
"raise_for_status_with_text",
"xor_args",
]

@ -0,0 +1,209 @@
"""
Adapted from
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
"""
from collections import deque
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Deque,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
)
T = TypeVar("T")
_no_default = object()
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
# before 3.10, the builtin anext() was not available
def py_anext(
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
) -> Awaitable[Union[T, None, Any]]:
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
if default is _no_default:
return __anext__(iterator)
async def anext_impl() -> Union[T, Any]:
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
async def __aenter__(self) -> None:
pass
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
return False
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
async with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
async with lock:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "aclose"):
await iterator.aclose()
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
.. code-block:: python3
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: AsyncIterator[T],
n: int = 2,
*,
lock: Optional[AsyncContextManager[Any]] = None,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]:
...
@overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
await self.aclose()
return False
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
atee = Tee

@ -0,0 +1,38 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

@ -0,0 +1,42 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file

@ -0,0 +1,175 @@
from collections import deque
from itertools import islice
from typing import (
Any,
ContextManager,
Deque,
Generator,
Generic,
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import Literal
T = TypeVar("T")
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
return False
def tee_peer(
iterator: Iterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: ContextManager[Any],
) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = next(iterator)
except StopIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
with lock:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "close"):
iterator.close()
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
.. code-block:: python3
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: Iterator[T],
n: int = 2,
*,
lock: Optional[ContextManager[Any]] = None,
):
self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> Iterator[T]:
...
@overload
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[Iterator[T]]:
yield from self._children
def __enter__(self) -> "Tee[T]":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
self.close()
return False
def close(self) -> None:
for child in self._children:
child.close()
# Why this is needed https://stackoverflow.com/a/44638570
safetee = Tee
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function."""
it = iter(iterable)
while True:
chunk = list(islice(it, size))
if not chunk:
return
yield chunk

@ -0,0 +1,54 @@
"""Utilities for loading configurations from langchain_core-hub."""
import os
import re
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin
import requests
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"LANGCHAIN_HUB_URL_BASE",
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
T = TypeVar("T")
def try_load_from_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)

@ -0,0 +1,14 @@
"""Utilities for tests."""
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()

@ -0,0 +1,180 @@
"""Generic utility functions."""
import contextlib
import datetime
import functools
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from packaging.version import parse
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Validate exactly one arg in each group is not None."""
counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg_group in arg_groups
]
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups:
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
raise ValueError(
"Exactly one argument in each of the following"
" groups must be defined:"
f" {', '.join(invalid_group_names)}"
)
return func(*args, **kwargs)
return wrapper
return decorator
def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text."""
try:
response.raise_for_status()
except HTTPError as e:
raise ValueError(response.text) from e
@contextlib.contextmanager
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.
Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
"""
class MockDateTime(datetime.datetime):
"""Mock datetime.datetime.now() with a fixed datetime."""
@classmethod
def now(cls): # type: ignore
# Create a copy of dt_value.
return datetime.datetime(
dt_value.year,
dt_value.month,
dt_value.day,
dt_value.hour,
dt_value.minute,
dt_value.second,
dt_value.microsecond,
dt_value.tzinfo,
)
real_datetime = datetime.datetime
datetime.datetime = MockDateTime
try:
yield datetime.datetime
finally:
datetime.datetime = real_datetime
def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not
installed."""
try:
module = importlib.import_module(module_name, package)
except ImportError:
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`."
)
return module
def check_package_version(
package: str,
lt_version: Optional[str] = None,
lte_version: Optional[str] = None,
gt_version: Optional[str] = None,
gte_version: Optional[str] = None,
) -> None:
"""Check the version of a package."""
imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError(
f"Expected {package} version to be < {lt_version}. Received "
f"{imported_version}."
)
if lte_version is not None and imported_version > parse(lte_version):
raise ValueError(
f"Expected {package} version to be <= {lte_version}. Received "
f"{imported_version}."
)
if gt_version is not None and imported_version <= parse(gt_version):
raise ValueError(
f"Expected {package} version to be > {gt_version}. Received "
f"{imported_version}."
)
if gte_version is not None and imported_version < parse(gte_version):
raise ValueError(
f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}."
)
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class.
Args:
pydantic_cls: Pydantic class."""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
def build_extra_kwargs(
extra_kwargs: Dict[str, Any],
values: Dict[str, Any],
all_required_field_names: Set[str],
) -> Dict[str, Any]:
"""Build extra kwargs from values and extra_kwargs.
Args:
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
"""
for field_name in list(values):
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra_kwargs[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
return extra_kwargs
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)

2689
libs/core/poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -0,0 +1,85 @@
[tool.poetry]
name = "langchain-core"
version = "0.0.1"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
pydantic = ">=1,<3"
langsmith = "~0.0.63"
tenacity = "^8.1.0"
jsonpatch = "^1.33"
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
setuptools = "^67.6.1"
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
[tool.poetry.group.test_integration]
optional = true
dependencies = {}
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -3,8 +3,8 @@ from typing import Any, Dict
import pytest
from langchain._api.deprecation import deprecated, warn_deprecated
from langchain.pydantic_v1 import BaseModel
from langchain_core._api.deprecation import deprecated, warn_deprecated
from langchain_core.pydantic_v1 import BaseModel
@pytest.mark.parametrize(

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save