Pass Callbacks through load_tools (#4298)

- Update the load_tools method to properly accept `callbacks` arguments.
- Add a deprecation warning when `callback_manager` is passed
- Add two unit tests to check the deprecation warning is raised and to
confirm the callback is passed through.

Closes issue #4096
parallel_dir_loader
Zander Chase 1 year ago committed by GitHub
parent 0870a45a69
commit 35c9e6ab40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,6 +7,7 @@ from mypy_extensions import Arg, KwArg
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
@ -279,10 +280,26 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
} }
def _handle_callbacks(
callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks
) -> Callbacks:
if callback_manager is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
if callbacks is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks arguments."
)
return callback_manager
return callbacks
def load_tools( def load_tools(
tool_names: List[str], tool_names: List[str],
llm: Optional[BaseLanguageModel] = None, llm: Optional[BaseLanguageModel] = None,
callback_manager: Optional[BaseCallbackManager] = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> List[BaseTool]: ) -> List[BaseTool]:
"""Load tools based on their name. """Load tools based on their name.
@ -290,13 +307,16 @@ def load_tools(
Args: Args:
tool_names: name of tools to load. tool_names: name of tools to load.
llm: Optional language model, may be needed to initialize certain tools. llm: Optional language model, may be needed to initialize certain tools.
callback_manager: Optional callback manager. If not provided, default global callback manager will be used. callbacks: Optional callback manager or list of callback handlers.
If not provided, default global callback manager will be used.
Returns: Returns:
List of tools. List of tools.
""" """
tools = [] tools = []
callbacks = _handle_callbacks(
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
)
for name in tool_names: for name in tool_names:
if name == "requests": if name == "requests":
warnings.warn( warnings.warn(
@ -316,8 +336,6 @@ def load_tools(
if llm is None: if llm is None:
raise ValueError(f"Tool {name} requires an LLM to be provided") raise ValueError(f"Tool {name} requires an LLM to be provided")
tool = _LLM_TOOLS[name](llm) tool = _LLM_TOOLS[name](llm)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool) tools.append(tool)
elif name in _EXTRA_LLM_TOOLS: elif name in _EXTRA_LLM_TOOLS:
if llm is None: if llm is None:
@ -331,18 +349,17 @@ def load_tools(
) )
sub_kwargs = {k: kwargs[k] for k in extra_keys} sub_kwargs = {k: kwargs[k] for k in extra_keys}
tool = _get_llm_tool_func(llm=llm, **sub_kwargs) tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool) tools.append(tool)
elif name in _EXTRA_OPTIONAL_TOOLS: elif name in _EXTRA_OPTIONAL_TOOLS:
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name] _get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs} sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
tool = _get_tool_func(**sub_kwargs) tool = _get_tool_func(**sub_kwargs)
if callback_manager is not None:
tool.callback_manager = callback_manager
tools.append(tool) tools.append(tool)
else: else:
raise ValueError(f"Got unknown tool {name}") raise ValueError(f"Got unknown tool {name}")
if callbacks is not None:
for tool in tools:
tool.callbacks = callbacks
return tools return tools

@ -1,9 +1,11 @@
"""Test tool utils.""" """Test tool utils."""
import unittest
from typing import Any, Type from typing import Any, Type
from unittest.mock import MagicMock from unittest.mock import MagicMock, Mock
import pytest import pytest
from langchain.agents import load_tools
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
from langchain.agents.chat.base import ChatAgent from langchain.agents.chat.base import ChatAgent
from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.conversational.base import ConversationalAgent
@ -12,6 +14,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool from langchain.agents.tools import Tool, tool
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -62,3 +65,28 @@ def test_tool_no_args_specified_assumes_str() -> None:
assert some_tool.run({"tool_input": "foobar"}) == "foobar" assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"): with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"}) some_tool.run({"tool_input": "foobar", "other_input": "bar"})
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
"""Test load_tools raises a deprecation for old callback manager kwarg."""
callback_manager = MagicMock()
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
tools = load_tools(["requests_get"], callback_manager=callback_manager)
assert len(tools) == 1
assert tools[0].callbacks == callback_manager
def test_load_tools_with_callbacks_is_called() -> None:
"""Test callbacks are called when provided to load_tools fn."""
callbacks = [FakeCallbackHandler()]
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
assert len(tools) == 1
# Patch the requests.get() method to return a mock response
with unittest.mock.patch(
"langchain.requests.TextRequestsWrapper.get",
return_value=Mock(text="Hello world!"),
):
result = tools[0].run("https://www.google.com")
assert result.text == "Hello world!"
assert callbacks[0].tool_starts == 1
assert callbacks[0].tool_ends == 1

Loading…
Cancel
Save