mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Pass Callbacks through load_tools (#4298)
- Update the load_tools method to properly accept `callbacks` arguments. - Add a deprecation warning when `callback_manager` is passed - Add two unit tests to check the deprecation warning is raised and to confirm the callback is passed through. Closes issue #4096
This commit is contained in:
parent
0870a45a69
commit
35c9e6ab40
@ -7,6 +7,7 @@ from mypy_extensions import Arg, KwArg
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
@ -279,10 +280,26 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
|
||||
}
|
||||
|
||||
|
||||
def _handle_callbacks(
|
||||
callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks
|
||||
) -> Callbacks:
|
||||
if callback_manager is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if callbacks is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both callback_manager and callbacks arguments."
|
||||
)
|
||||
return callback_manager
|
||||
return callbacks
|
||||
|
||||
|
||||
def load_tools(
|
||||
tool_names: List[str],
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> List[BaseTool]:
|
||||
"""Load tools based on their name.
|
||||
@ -290,13 +307,16 @@ def load_tools(
|
||||
Args:
|
||||
tool_names: name of tools to load.
|
||||
llm: Optional language model, may be needed to initialize certain tools.
|
||||
callback_manager: Optional callback manager. If not provided, default global callback manager will be used.
|
||||
callbacks: Optional callback manager or list of callback handlers.
|
||||
If not provided, default global callback manager will be used.
|
||||
|
||||
Returns:
|
||||
List of tools.
|
||||
"""
|
||||
tools = []
|
||||
|
||||
callbacks = _handle_callbacks(
|
||||
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
|
||||
)
|
||||
for name in tool_names:
|
||||
if name == "requests":
|
||||
warnings.warn(
|
||||
@ -316,8 +336,6 @@ def load_tools(
|
||||
if llm is None:
|
||||
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
||||
tool = _LLM_TOOLS[name](llm)
|
||||
if callback_manager is not None:
|
||||
tool.callback_manager = callback_manager
|
||||
tools.append(tool)
|
||||
elif name in _EXTRA_LLM_TOOLS:
|
||||
if llm is None:
|
||||
@ -331,18 +349,17 @@ def load_tools(
|
||||
)
|
||||
sub_kwargs = {k: kwargs[k] for k in extra_keys}
|
||||
tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
|
||||
if callback_manager is not None:
|
||||
tool.callback_manager = callback_manager
|
||||
tools.append(tool)
|
||||
elif name in _EXTRA_OPTIONAL_TOOLS:
|
||||
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
|
||||
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
|
||||
tool = _get_tool_func(**sub_kwargs)
|
||||
if callback_manager is not None:
|
||||
tool.callback_manager = callback_manager
|
||||
tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"Got unknown tool {name}")
|
||||
if callbacks is not None:
|
||||
for tool in tools:
|
||||
tool.callbacks = callbacks
|
||||
return tools
|
||||
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Test tool utils."""
|
||||
import unittest
|
||||
from typing import Any, Type
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
@ -12,6 +14,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -62,3 +65,28 @@ def test_tool_no_args_specified_assumes_str() -> None:
|
||||
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
||||
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
|
||||
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
||||
|
||||
|
||||
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
||||
"""Test load_tools raises a deprecation for old callback manager kwarg."""
|
||||
callback_manager = MagicMock()
|
||||
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
|
||||
tools = load_tools(["requests_get"], callback_manager=callback_manager)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].callbacks == callback_manager
|
||||
|
||||
|
||||
def test_load_tools_with_callbacks_is_called() -> None:
|
||||
"""Test callbacks are called when provided to load_tools fn."""
|
||||
callbacks = [FakeCallbackHandler()]
|
||||
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
|
||||
assert len(tools) == 1
|
||||
# Patch the requests.get() method to return a mock response
|
||||
with unittest.mock.patch(
|
||||
"langchain.requests.TextRequestsWrapper.get",
|
||||
return_value=Mock(text="Hello world!"),
|
||||
):
|
||||
result = tools[0].run("https://www.google.com")
|
||||
assert result.text == "Hello world!"
|
||||
assert callbacks[0].tool_starts == 1
|
||||
assert callbacks[0].tool_ends == 1
|
||||
|
Loading…
Reference in New Issue
Block a user