core[patch]: add beta decorator (#15589)

This commit is contained in:
Bagatur 2024-01-05 13:16:27 -05:00 committed by GitHub
parent b484d941ae
commit e1fc4d5b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 496 additions and 0 deletions

View File

@ -0,0 +1,258 @@
"""Helper functions for marking parts of the LangChain API as beta.
This module was loosely adapted from matplotlibs _api/deprecation.py module:
https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py
.. warning::
This module is for internal use only. Do not use it in your own code.
We may change the API at any time with no warning.
"""
import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar
class LangChainBetaWarning(DeprecationWarning):
"""A class for issuing beta warnings for LangChain users."""
# PUBLIC API
T = TypeVar("T", Type, Callable)
def beta(
*,
message: str = "",
name: str = "",
obj_type: str = "",
addendum: str = "",
) -> Callable[[T], T]:
"""Decorator to mark a function, a class, or a property as beta.
When marking a classmethod, a staticmethod, or a property, the
``@beta`` decorator should go *under* ``@classmethod`` and
``@staticmethod`` (i.e., `beta` should directly decorate the
underlying callable), but *over* ``@property``.
When marking a class ``C`` intended to be used as a base class in a
multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method
(if ``C`` instead inherited its ``__init__`` from its own base class, then
``@beta`` would mess up ``__init__`` inheritance when installing its
own (annotation-emitting) ``C.__init__``).
Arguments:
message : str, optional
Override the default beta message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
and %(removal)s format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the beta object.
obj_type : str, optional
The object type being beta.
addendum : str, optional
Additional text appended directly to the final message.
Examples
--------
.. code-block:: python
@beta
def the_function_to_annotate():
pass
"""
def beta(
obj: T,
*,
_obj_type: str = obj_type,
_name: str = name,
_message: str = message,
_addendum: str = addendum,
) -> T:
"""Implementation of the decorator returned by `beta`."""
if isinstance(obj, type):
if not _obj_type:
_obj_type = "class"
wrapped = obj.__init__ # type: ignore
_name = _name or obj.__name__
old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the annotation of a class."""
try:
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
wrapper
)
return obj
elif isinstance(obj, property):
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or obj.fget.__name__
old_doc = obj.__doc__
class _beta_property(type(obj)): # type: ignore
"""A beta property."""
def __get__(self, instance, owner=None): # type: ignore
if instance is not None or owner is not None:
emit_warning()
return super().__get__(instance, owner)
def __set__(self, instance, value): # type: ignore
if instance is not None:
emit_warning()
return super().__set__(instance, value)
def __delete__(self, instance): # type: ignore
if instance is not None:
emit_warning()
return super().__delete__(instance)
def __set_name__(self, owner, set_name): # type: ignore
nonlocal _name
if _name == "<lambda>":
_name = set_name
def finalize(_: Any, new_doc: str) -> Any: # type: ignore
"""Finalize the property."""
return _beta_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
)
else:
if not _obj_type:
_obj_type = "function"
wrapped = obj
_name = _name or obj.__name__ # type: ignore
old_doc = wrapped.__doc__
def finalize( # type: ignore
wrapper: Callable[..., Any], new_doc: str
) -> T:
"""Wrap the wrapped function using the wrapper and update the docstring.
Args:
wrapper: The wrapper function.
new_doc: The new docstring.
Returns:
The wrapped function.
"""
wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc
return wrapper
def emit_warning() -> None:
"""Emit the warning."""
warn_beta(
message=_message,
name=_name,
obj_type=_obj_type,
addendum=_addendum,
)
def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the original wrapped callable that emits a warning.
Args:
*args: The positional arguments to the function.
**kwargs: The keyword arguments to the function.
Returns:
The return value of the function being wrapped.
"""
emit_warning()
return wrapped(*args, **kwargs)
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
if not old_doc:
new_doc = "[*Beta*]"
else:
new_doc = f"[*Beta*] {old_doc}"
# Modify the docstring to include a beta notice.
notes_header = "\nNotes\n-----"
components = [
message,
addendum,
]
details = " ".join([component.strip() for component in components if component])
new_doc += (
f"[*Beta*] {old_doc}\n"
f"{notes_header if notes_header not in old_doc else ''}\n"
f".. beta::\n"
f" {details}"
)
return finalize(warning_emitting_wrapper, new_doc)
return beta
@contextlib.contextmanager
def suppress_langchain_beta_warning() -> Generator[None, None, None]:
"""Context manager to suppress LangChainDeprecationWarning."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", LangChainBetaWarning)
yield
def warn_beta(
*,
message: str = "",
name: str = "",
obj_type: str = "",
addendum: str = "",
) -> None:
"""Display a standardized beta annotation.
Arguments:
message : str, optional
Override the default beta message. The
%(name)s, %(obj_type)s, %(addendum)s
format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the annotated object.
obj_type : str, optional
The object type being annotated.
addendum : str, optional
Additional text appended directly to the final message.
"""
if not message:
message = ""
if obj_type:
message += f"The {obj_type} `{name}`"
else:
message += f"`{name}`"
message += " is in beta. It is actively being worked on, so the API may change."
if addendum:
message += f" {addendum}"
warning = LangChainBetaWarning(message)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
def surface_langchain_beta_warnings() -> None:
"""Unmute LangChain beta warnings."""
warnings.filterwarnings(
"default",
category=LangChainBetaWarning,
)

View File

@ -0,0 +1,238 @@
import warnings
from typing import Any, Dict
import pytest
from langchain_core._api.beta import beta, warn_beta
from langchain_core.pydantic_v1 import BaseModel
@pytest.mark.parametrize(
"kwargs, expected_message",
[
(
{
"name": "OldClass",
"obj_type": "class",
},
"The class `OldClass` is in beta. It is actively being worked on, so the "
"API may change.",
),
(
{
"message": "This is a custom message",
"name": "FunctionA",
"obj_type": "",
"addendum": "",
},
"This is a custom message",
),
(
{
"message": "",
"name": "SomeFunction",
"obj_type": "",
"addendum": "Please migrate your code.",
},
"`SomeFunction` is in beta. It is actively being worked on, so the API may "
"change. Please migrate your code.",
),
],
)
def test_warn_beta(kwargs: Dict[str, Any], expected_message: str) -> None:
"""Test warn beta."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
warn_beta(**kwargs)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == expected_message
@beta()
def beta_function() -> str:
"""original doc"""
return "This is a beta function."
class ClassWithBetaMethods:
def __init__(self) -> None:
"""original doc"""
pass
@beta()
def beta_method(self) -> str:
"""original doc"""
return "This is a beta method."
@classmethod
@beta()
def beta_classmethod(cls) -> str:
"""original doc"""
return "This is a beta classmethod."
@staticmethod
@beta()
def beta_staticmethod() -> str:
"""original doc"""
return "This is a beta staticmethod."
@property
@beta()
def beta_property(self) -> str:
"""original doc"""
return "This is a beta property."
def test_beta_function() -> None:
"""Test beta function."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert beta_function() == "This is a beta function."
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_function` is in beta. It is actively being worked on, "
"so the API may change."
)
doc = beta_function.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")
def test_beta_method() -> None:
"""Test beta method."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = ClassWithBetaMethods()
assert obj.beta_method() == "This is a beta method."
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_method` is in beta. It is actively being worked on, so "
"the API may change."
)
doc = obj.beta_method.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")
def test_beta_classmethod() -> None:
"""Test beta classmethod."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
ClassWithBetaMethods.beta_classmethod()
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_classmethod` is in beta. It is actively being worked "
"on, so the API may change."
)
doc = ClassWithBetaMethods.beta_classmethod.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")
def test_beta_staticmethod() -> None:
"""Test beta staticmethod."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert (
ClassWithBetaMethods.beta_staticmethod() == "This is a beta staticmethod."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_staticmethod` is in beta. It is actively being worked "
"on, so the API may change."
)
doc = ClassWithBetaMethods.beta_staticmethod.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")
def test_beta_property() -> None:
"""Test beta staticmethod."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = ClassWithBetaMethods()
assert obj.beta_property == "This is a beta property."
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_property` is in beta. It is actively being worked on, "
"so the API may change."
)
doc = ClassWithBetaMethods.beta_property.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")
def test_whole_class_deprecation() -> None:
"""Test whole class deprecation."""
# Test whole class deprecation
@beta()
class BetaClass:
def __init__(self) -> None:
"""original doc"""
pass
@beta()
def beta_method(self) -> str:
"""original doc"""
return "This is a beta method."
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = BetaClass()
assert obj.beta_method() == "This is a beta method."
assert len(warning_list) == 2
warning = warning_list[0].message
assert str(warning) == (
"The class `BetaClass` is in beta. It is actively being worked on, so the "
"API may change."
)
warning = warning_list[1].message
assert str(warning) == (
"The function `beta_method` is in beta. It is actively being worked on, so "
"the API may change."
)
# Tests with pydantic models
class MyModel(BaseModel):
@beta()
def beta_method(self) -> str:
"""original doc"""
return "This is a beta method."
def test_beta_method_pydantic() -> None:
"""Test beta method."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = MyModel()
assert obj.beta_method() == "This is a beta method."
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `beta_method` is in beta. It is actively being worked on, so "
"the API may change."
)
doc = obj.beta_method.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Beta*] original doc")