core[patch]: Tool accept RunnableConfig (#24143)

Relies on #24038

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2024-07-11 15:13:17 -07:00 committed by GitHub
parent 5fd1e67808
commit 8d100c58de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 117 additions and 9 deletions

View File

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations
import asyncio
import functools
import inspect
import json
import textwrap
@ -548,6 +549,9 @@ class ChildTool(BaseTool):
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_raw_output":
if not isinstance(response, tuple) or len(response) != 2:
@ -627,10 +631,14 @@ class ChildTool(BaseTool):
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
if self.__class__._arun is BaseTool._arun or signature(
self._arun
).parameters.get("run_manager"):
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if accepts_context(asyncio.create_task):
response = await asyncio.create_task(coro, context=context) # type: ignore
@ -724,6 +732,7 @@ class Tool(BaseTool):
def _run(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
@ -731,12 +740,15 @@ class Tool(BaseTool):
if self.func:
if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
return self.func(*args, **kwargs)
raise NotImplementedError("Tool does not support sync invocation.")
async def _arun(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
@ -744,11 +756,15 @@ class Tool(BaseTool):
if self.coroutine:
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not
# None.
return await super()._arun(*args, run_manager=run_manager, **kwargs)
return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
# TODO: this is for backwards compatibility, remove in future
def __init__(
@ -822,6 +838,7 @@ class StructuredTool(BaseTool):
def _run(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
@ -829,12 +846,15 @@ class StructuredTool(BaseTool):
if self.func:
if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
return self.func(*args, **kwargs)
raise NotImplementedError("StructuredTool does not support sync invocation.")
async def _arun(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
@ -842,11 +862,15 @@ class StructuredTool(BaseTool):
if self.coroutine:
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not
# None.
return await super()._arun(*args, run_manager=run_manager, **kwargs)
return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
@classmethod
def from_function(
@ -923,12 +947,21 @@ class StructuredTool(BaseTool):
description_ = f"{description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
if config_param := _get_runnable_config_param(source_function):
filter_args: Tuple[str, ...] = (
config_param,
"run_manager",
"callbacks",
)
else:
filter_args = ("run_manager", "callbacks")
# schema name is appended within function
_args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=filter_args,
)
return cls(
name=name,
@ -1112,7 +1145,7 @@ def tool(
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
if func.__doc__ is None:
if dec_func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
@ -1447,3 +1480,17 @@ def convert_runnable_to_tool(
description=description,
args_schema=args_schema,
)
def _get_runnable_config_param(func: Callable) -> Optional[str]:
if isinstance(func, functools.partial):
func = func.func
try:
type_hints = get_type_hints(func)
except Exception:
return None
else:
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
return None

View File

@ -1,6 +1,5 @@
"""Test the base tool implementation."""
import asyncio
import inspect
import json
import sys
@ -19,7 +18,12 @@ from langchain_core.callbacks import (
)
from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
)
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
@tool
async def foo(bar: str) -> str:
"""The foo."""
await asyncio.sleep(0.0001)
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
)
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
assert bar_config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
@tool
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
class FooBase(BaseTool):
name: str = "Foo"
description: str = "Foo"
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
class AFooBase(FooBase):
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
def test_tool_pass_config(tool: BaseTool) -> None:
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
@pytest.mark.parametrize(
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
)
async def test_async_tool_pass_config(tool: BaseTool) -> None:
assert (
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
== "baz"
)
def test_tool_description() -> None:
def foo(bar: str) -> str:
"""The foo."""