Pass Callbacks through load_tools (#4298)

- Update the load_tools method to properly accept `callbacks` arguments.
- Add a deprecation warning when `callback_manager` is passed
- Add two unit tests to check the deprecation warning is raised and to
confirm the callback is passed through.

Closes issue #4096
This commit is contained in:
Zander Chase 2023-05-08 08:44:26 -07:00 committed by GitHub
parent 0870a45a69
commit 35c9e6ab40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 10 deletions

View File

@ -7,6 +7,7 @@ from mypy_extensions import Arg, KwArg
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain
@ -279,10 +280,26 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
}
def _handle_callbacks(
callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks
) -> Callbacks:
if callback_manager is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
if callbacks is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks arguments."
)
return callback_manager
return callbacks
def load_tools(
tool_names: List[str],
llm: Optional[BaseLanguageModel] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> List[BaseTool]:
"""Load tools based on their name.
@ -290,13 +307,16 @@ def load_tools(
Args:
tool_names: name of tools to load.
llm: Optional language model, may be needed to initialize certain tools.
callback_manager: Optional callback manager. If not provided, default global callback manager will be used.
callbacks: Optional callback manager or list of callback handlers.
If not provided, default global callback manager will be used.
Returns:
List of tools.
"""
tools = []
callbacks = _handle_callbacks(
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
)
for name in tool_names:
if name == "requests":
warnings.warn(
@ -316,8 +336,6 @@ def load_tools(
if llm is None:
raise ValueError(f"Tool {name} requires an LLM to be provided")
tool = _LLM_TOOLS[name](llm)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool)
elif name in _EXTRA_LLM_TOOLS:
if llm is None:
@ -331,18 +349,17 @@ def load_tools(
)
sub_kwargs = {k: kwargs[k] for k in extra_keys}
tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool)
elif name in _EXTRA_OPTIONAL_TOOLS:
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
tool = _get_tool_func(**sub_kwargs)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool)
else:
raise ValueError(f"Got unknown tool {name}")
if callbacks is not None:
for tool in tools:
tool.callbacks = callbacks
return tools

View File

@ -1,9 +1,11 @@
"""Test tool utils."""
import unittest
from typing import Any, Type
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
import pytest
from langchain.agents import load_tools
from langchain.agents.agent import Agent
from langchain.agents.chat.base import ChatAgent
from langchain.agents.conversational.base import ConversationalAgent
@ -12,6 +14,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.parametrize(
@ -62,3 +65,28 @@ def test_tool_no_args_specified_assumes_str() -> None:
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
"""Test load_tools raises a deprecation for old callback manager kwarg."""
callback_manager = MagicMock()
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
tools = load_tools(["requests_get"], callback_manager=callback_manager)
assert len(tools) == 1
assert tools[0].callbacks == callback_manager
def test_load_tools_with_callbacks_is_called() -> None:
"""Test callbacks are called when provided to load_tools fn."""
callbacks = [FakeCallbackHandler()]
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
assert len(tools) == 1
# Patch the requests.get() method to return a mock response
with unittest.mock.patch(
"langchain.requests.TextRequestsWrapper.get",
return_value=Mock(text="Hello world!"),
):
result = tools[0].run("https://www.google.com")
assert result.text == "Hello world!"
assert callbacks[0].tool_starts == 1
assert callbacks[0].tool_ends == 1