You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/unit_tests/tools/test_signatures.py

47 lines
1.7 KiB
Python

"""Test base tool child implementations."""
import inspect
import re
from typing import List, Type
import pytest
from langchain.tools.base import BaseTool
from langchain.tools.playwright.base import BaseBrowserTool
def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
to_skip = {BaseBrowserTool} # Abstract but not recognized
subclasses = []
for subclass in cls.__subclasses__():
if (
not getattr(subclass, "__abstract__", None)
and not subclass.__name__.startswith("_")
and subclass not in to_skip
):
subclasses.append(subclass)
sc = get_non_abstract_subclasses(subclass)
subclasses.extend(sc)
return subclasses
@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool)) # type: ignore
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params
pattern = re.compile(r"(?!Async)CallbackManagerForToolRun")
assert bool(re.search(pattern, str(params["run_manager"].annotation)))
assert params["run_manager"].default is None
if cls._arun is not BaseTool._arun:
run_func = cls._arun
params = inspect.signature(run_func).parameters
assert "run_manager" in params
assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation)
assert params["run_manager"].default is None