From 6eb6c45c981dd8d04dfeb7ac6becdb0f6b863728 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 15:40:22 +0100 Subject: [PATCH 1/4] Enable creating Tools from any Runnable --- libs/langchain/langchain/tools/base.py | 22 ++++++++++++++---- .../schema/runnable/test_runnable.py | 23 +++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2310927ac2..160071be3e 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -734,7 +734,7 @@ class StructuredTool(BaseTool): def tool( - *args: Union[str, Callable], + *args: Union[str, Callable, Runnable], return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, @@ -769,21 +769,31 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Callable) -> BaseTool: - if inspect.iscoroutinefunction(dec_func): + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + if isinstance(dec_func, Runnable): + coroutine = dec_func.ainvoke + func = dec_func.invoke + schema = dec_func.input_schema + description = repr(dec_func) + elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func func = None + schema = args_schema + description = None else: coroutine = None func = dec_func + schema = args_schema + description = None if infer_schema or args_schema is not None: return StructuredTool.from_function( func, coroutine, name=tool_name, + description=description, return_direct=return_direct, - args_schema=args_schema, + args_schema=schema, infer_schema=infer_schema, ) # If someone doesn't want a schema applied, we must treat it as @@ -803,7 +813,9 @@ def tool( return _make_tool - if len(args) == 1 and isinstance(args[0], str): + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + return _make_with_name(args[0])(args[1]) + elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name # Example usage: @tool("search", return_direct=True) return _make_with_name(args[0]) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a9103d3725..d46bd8df8e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2,6 +2,7 @@ import sys from operator import itemgetter from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID +from langchain.tools.base import BaseTool, tool import pytest from freezegun import freeze_time @@ -2779,3 +2780,25 @@ def test_representation_of_runnables() -> None: " b: RunnableLambda(...)\n" " }" ), "repr where code string contains multiple lambdas gives up" + + +def test_tool_from_runnable() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain = prompt | llm | StrOutputParser() + + chain_tool = tool("chain_tool", chain) + + assert isinstance(chain_tool, BaseTool) + assert chain_tool.name == "chain_tool" + assert chain_tool.description.endswith(repr(chain)) + assert chain_tool.args_schema.schema() == chain.input_schema.schema() + assert chain_tool.args_schema.schema() == { + "properties": {"question": {"title": "Question"}}, + "title": "PromptInput", + "type": "object", + } From 8be598f5041c147752f0f28f21221fa8a039fc43 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 15:47:59 +0100 Subject: [PATCH 2/4] Fix invocation --- libs/langchain/langchain/tools/base.py | 17 +++++++++++++++-- .../unit_tests/schema/runnable/test_runnable.py | 9 ++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 160071be3e..b56c1b0882 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -771,8 +771,21 @@ def tool( def _make_with_name(tool_name: str) -> Callable: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): - coroutine = dec_func.ainvoke - func = dec_func.invoke + if dec_func.input_schema.schema().get("type") != "object": + raise ValueError("Runnable must have an object schema.") + + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await dec_func.ainvoke(kwargs, {"callbacks": callbacks}) + + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return dec_func.invoke(kwargs, {"callbacks": callbacks}) + + coroutine = ainvoke_wrapper + func = invoke_wrapper schema = dec_func.input_schema description = repr(dec_func) elif inspect.iscoroutinefunction(dec_func): diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index d46bd8df8e..e51f55d72f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2782,7 +2782,8 @@ def test_representation_of_runnables() -> None: ), "repr where code string contains multiple lambdas gives up" -def test_tool_from_runnable() -> None: +@pytest.mark.asyncio +async def test_tool_from_runnable() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}" @@ -2795,6 +2796,12 @@ def test_tool_from_runnable() -> None: assert isinstance(chain_tool, BaseTool) assert chain_tool.name == "chain_tool" + assert chain_tool.run({"question": "What up"}) == chain.invoke( + {"question": "What up"} + ) + assert await chain_tool.arun({"question": "What up"}) == await chain.ainvoke( + {"question": "What up"} + ) assert chain_tool.description.endswith(repr(chain)) assert chain_tool.args_schema.schema() == chain.input_schema.schema() assert chain_tool.args_schema.schema() == { From 7f589ebbc2189f5b50c0e444d79908ed3f718b66 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 15:49:04 +0100 Subject: [PATCH 3/4] Lint --- .../langchain/tests/unit_tests/schema/runnable/test_runnable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index e51f55d72f..b72144102a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2,7 +2,6 @@ import sys from operator import itemgetter from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID -from langchain.tools.base import BaseTool, tool import pytest from freezegun import freeze_time @@ -47,6 +46,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec From e35ea565d154efc43219ce6392ddf6f838255c11 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 12:00:56 +0100 Subject: [PATCH 4/4] Lint --- libs/langchain/langchain/tools/base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index b56c1b0882..269e2b4846 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -771,23 +771,25 @@ def tool( def _make_with_name(tool_name: str) -> Callable: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): - if dec_func.input_schema.schema().get("type") != "object": + runnable = dec_func + + if runnable.input_schema.schema().get("type") != "object": raise ValueError("Runnable must have an object schema.") async def ainvoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: - return await dec_func.ainvoke(kwargs, {"callbacks": callbacks}) + return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) def invoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: - return dec_func.invoke(kwargs, {"callbacks": callbacks}) + return runnable.invoke(kwargs, {"callbacks": callbacks}) coroutine = ainvoke_wrapper func = invoke_wrapper - schema = dec_func.input_schema - description = repr(dec_func) + schema: Optional[Type[BaseModel]] = runnable.input_schema + description = repr(runnable) elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func func = None