mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: Tool accept RunnableConfig (#24143)
Relies on #24038 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5fd1e67808
commit
8d100c58de
@ -20,6 +20,7 @@ tool for the job.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
@ -548,6 +549,9 @@ class ChildTool(BaseTool):
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs[config_param] = config
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_raw_output":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
@ -627,10 +631,14 @@ class ChildTool(BaseTool):
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
if self.__class__._arun is BaseTool._arun or signature(
|
||||
self._arun
|
||||
).parameters.get("run_manager"):
|
||||
func_to_check = (
|
||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||
)
|
||||
if signature(func_to_check).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
if config_param := _get_runnable_config_param(func_to_check):
|
||||
tool_kwargs[config_param] = config
|
||||
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if accepts_context(asyncio.create_task):
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
@ -724,6 +732,7 @@ class Tool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -731,12 +740,15 @@ class Tool(BaseTool):
|
||||
if self.func:
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
if config_param := _get_runnable_config_param(self.func):
|
||||
kwargs[config_param] = config
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("Tool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -744,11 +756,15 @@ class Tool(BaseTool):
|
||||
if self.coroutine:
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
if config_param := _get_runnable_config_param(self.coroutine):
|
||||
kwargs[config_param] = config
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
return await super()._arun(
|
||||
*args, config=config, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
@ -822,6 +838,7 @@ class StructuredTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -829,12 +846,15 @@ class StructuredTool(BaseTool):
|
||||
if self.func:
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
if config_param := _get_runnable_config_param(self.func):
|
||||
kwargs[config_param] = config
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -842,11 +862,15 @@ class StructuredTool(BaseTool):
|
||||
if self.coroutine:
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
if config_param := _get_runnable_config_param(self.coroutine):
|
||||
kwargs[config_param] = config
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
return await super()._arun(
|
||||
*args, config=config, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@ -923,12 +947,21 @@ class StructuredTool(BaseTool):
|
||||
description_ = f"{description_.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
if config_param := _get_runnable_config_param(source_function):
|
||||
filter_args: Tuple[str, ...] = (
|
||||
config_param,
|
||||
"run_manager",
|
||||
"callbacks",
|
||||
)
|
||||
else:
|
||||
filter_args = ("run_manager", "callbacks")
|
||||
# schema name is appended within function
|
||||
_args_schema = create_schema_from_function(
|
||||
name,
|
||||
source_function,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
filter_args=filter_args,
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
@ -1112,7 +1145,7 @@ def tool(
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
if func.__doc__ is None:
|
||||
if dec_func.__doc__ is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if "
|
||||
"description not provided and infer_schema is False."
|
||||
@ -1447,3 +1480,17 @@ def convert_runnable_to_tool(
|
||||
description=description,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
|
||||
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
try:
|
||||
type_hints = get_type_hints(func)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
for name, type_ in type_hints.items():
|
||||
if type_ is RunnableConfig:
|
||||
return name
|
||||
return None
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""Test the base tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
@ -19,7 +18,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
|
||||
@tool
|
||||
async def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
await asyncio.sleep(0.0001)
|
||||
config = ensure_config()
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
|
||||
)
|
||||
|
||||
|
||||
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
assert bar_config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
return bar
|
||||
|
||||
|
||||
@tool
|
||||
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool
|
||||
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool(infer_schema=False)
|
||||
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
class FooBase(BaseTool):
|
||||
name: str = "Foo"
|
||||
description: str = "Foo"
|
||||
|
||||
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
class AFooBase(FooBase):
|
||||
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
|
||||
def test_tool_pass_config(tool: BaseTool) -> None:
|
||||
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
|
||||
)
|
||||
async def test_async_tool_pass_config(tool: BaseTool) -> None:
|
||||
assert (
|
||||
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
|
||||
== "baz"
|
||||
)
|
||||
|
||||
|
||||
def test_tool_description() -> None:
|
||||
def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
|
Loading…
Reference in New Issue
Block a user