core[patch]: preserve inspect.iscoroutinefunction with @deprecated decorator (#16295)

Adjusted `deprecate` decorator to make sure decorated async functions
are still recognized as "coroutinefunction" by `inspect`.

Before change, functions such as `LLMChain.acall` which are decorated as
deprecated are not recognized as coroutine functions. After the change,
they are recognized:

```python
import inspect
from langchain import LLMChain

# Is false before change but true after.
inspect.iscoroutinefunction(LLMChain.acall)
```
This commit is contained in:
Piotr Mardziel 2024-01-22 11:34:13 -08:00 committed by GitHub
parent 01c2f27ffa
commit 1b9001db47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 1 deletions

View File

@ -144,6 +144,15 @@ def deprecated(
emit_warning()
return wrapped(*args, **kwargs)
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Same as warning_emitting_wrapper, but for async functions."""
nonlocal warned
if not warned and not is_caller_internal():
warned = True
emit_warning()
return await wrapped(*args, **kwargs)
if isinstance(obj, type):
if not _obj_type:
_obj_type = "class"
@ -256,7 +265,10 @@ def deprecated(
f" {details}"
)
return finalize(warning_emitting_wrapper, new_doc)
if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc)
else:
return finalize(warning_emitting_wrapper, new_doc)
return deprecate

View File

@ -1,3 +1,4 @@
import inspect
import warnings
from typing import Any, Dict
@ -74,6 +75,12 @@ def deprecated_function() -> str:
return "This is a deprecated function."
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
async def deprecated_async_function() -> str:
"""original doc"""
return "This is a deprecated async function."
class ClassWithDeprecatedMethods:
def __init__(self) -> None:
"""original doc"""
@ -84,6 +91,11 @@ class ClassWithDeprecatedMethods:
"""original doc"""
return "This is a deprecated method."
@deprecated(since="2.0.0", removal="3.0.0")
async def deprecated_async_method(self) -> str:
"""original doc"""
return "This is a deprecated async method."
@classmethod
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_classmethod(cls) -> str:
@ -119,6 +131,30 @@ def test_deprecated_function() -> None:
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(deprecated_function)
@pytest.mark.asyncio
async def test_deprecated_async_function() -> None:
"""Test deprecated async function."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert (
await deprecated_async_function() == "This is a deprecated async function."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_function` was deprecated "
"in LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = deprecated_function.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(deprecated_async_function)
def test_deprecated_method() -> None:
"""Test deprecated method."""
@ -137,6 +173,31 @@ def test_deprecated_method() -> None:
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(obj.deprecated_method)
@pytest.mark.asyncio
async def test_deprecated_async_method() -> None:
"""Test deprecated async method."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = ClassWithDeprecatedMethods()
assert (
await obj.deprecated_async_method() == "This is a deprecated async method."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_method` was deprecated in "
"LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = obj.deprecated_method.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(obj.deprecated_async_method)
def test_deprecated_classmethod() -> None:
"""Test deprecated classmethod."""