mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[patch]: preserve inspect.iscoroutinefunction with @deprecated decorator (#16295)
Adjusted `deprecate` decorator to make sure decorated async functions are still recognized as "coroutinefunction" by `inspect`. Before change, functions such as `LLMChain.acall` which are decorated as deprecated are not recognized as coroutine functions. After the change, they are recognized: ```python import inspect from langchain import LLMChain # Is false before change but true after. inspect.iscoroutinefunction(LLMChain.acall) ```
This commit is contained in:
parent
01c2f27ffa
commit
1b9001db47
@ -144,6 +144,15 @@ def deprecated(
|
||||
emit_warning()
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Same as warning_emitting_wrapper, but for async functions."""
|
||||
|
||||
nonlocal warned
|
||||
if not warned and not is_caller_internal():
|
||||
warned = True
|
||||
emit_warning()
|
||||
return await wrapped(*args, **kwargs)
|
||||
|
||||
if isinstance(obj, type):
|
||||
if not _obj_type:
|
||||
_obj_type = "class"
|
||||
@ -256,7 +265,10 @@ def deprecated(
|
||||
f" {details}"
|
||||
)
|
||||
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
if inspect.iscoroutinefunction(obj):
|
||||
return finalize(awarning_emitting_wrapper, new_doc)
|
||||
else:
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
|
||||
return deprecate
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -74,6 +75,12 @@ def deprecated_function() -> str:
|
||||
return "This is a deprecated function."
|
||||
|
||||
|
||||
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
|
||||
async def deprecated_async_function() -> str:
|
||||
"""original doc"""
|
||||
return "This is a deprecated async function."
|
||||
|
||||
|
||||
class ClassWithDeprecatedMethods:
|
||||
def __init__(self) -> None:
|
||||
"""original doc"""
|
||||
@ -84,6 +91,11 @@ class ClassWithDeprecatedMethods:
|
||||
"""original doc"""
|
||||
return "This is a deprecated method."
|
||||
|
||||
@deprecated(since="2.0.0", removal="3.0.0")
|
||||
async def deprecated_async_method(self) -> str:
|
||||
"""original doc"""
|
||||
return "This is a deprecated async method."
|
||||
|
||||
@classmethod
|
||||
@deprecated(since="2.0.0", removal="3.0.0")
|
||||
def deprecated_classmethod(cls) -> str:
|
||||
@ -119,6 +131,30 @@ def test_deprecated_function() -> None:
|
||||
assert isinstance(doc, str)
|
||||
assert doc.startswith("[*Deprecated*] original doc")
|
||||
|
||||
assert not inspect.iscoroutinefunction(deprecated_function)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecated_async_function() -> None:
|
||||
"""Test deprecated async function."""
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
assert (
|
||||
await deprecated_async_function() == "This is a deprecated async function."
|
||||
)
|
||||
assert len(warning_list) == 1
|
||||
warning = warning_list[0].message
|
||||
assert str(warning) == (
|
||||
"The function `deprecated_async_function` was deprecated "
|
||||
"in LangChain 2.0.0 and will be removed in 3.0.0"
|
||||
)
|
||||
|
||||
doc = deprecated_function.__doc__
|
||||
assert isinstance(doc, str)
|
||||
assert doc.startswith("[*Deprecated*] original doc")
|
||||
|
||||
assert inspect.iscoroutinefunction(deprecated_async_function)
|
||||
|
||||
|
||||
def test_deprecated_method() -> None:
|
||||
"""Test deprecated method."""
|
||||
@ -137,6 +173,31 @@ def test_deprecated_method() -> None:
|
||||
assert isinstance(doc, str)
|
||||
assert doc.startswith("[*Deprecated*] original doc")
|
||||
|
||||
assert not inspect.iscoroutinefunction(obj.deprecated_method)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecated_async_method() -> None:
|
||||
"""Test deprecated async method."""
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
obj = ClassWithDeprecatedMethods()
|
||||
assert (
|
||||
await obj.deprecated_async_method() == "This is a deprecated async method."
|
||||
)
|
||||
assert len(warning_list) == 1
|
||||
warning = warning_list[0].message
|
||||
assert str(warning) == (
|
||||
"The function `deprecated_async_method` was deprecated in "
|
||||
"LangChain 2.0.0 and will be removed in 3.0.0"
|
||||
)
|
||||
|
||||
doc = obj.deprecated_method.__doc__
|
||||
assert isinstance(doc, str)
|
||||
assert doc.startswith("[*Deprecated*] original doc")
|
||||
|
||||
assert inspect.iscoroutinefunction(obj.deprecated_async_method)
|
||||
|
||||
|
||||
def test_deprecated_classmethod() -> None:
|
||||
"""Test deprecated classmethod."""
|
||||
|
Loading…
Reference in New Issue
Block a user