diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2310927ac2..269e2b4846 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -734,7 +734,7 @@ class StructuredTool(BaseTool): def tool( - *args: Union[str, Callable], + *args: Union[str, Callable, Runnable], return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, @@ -769,21 +769,46 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Callable) -> BaseTool: - if inspect.iscoroutinefunction(dec_func): + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + if isinstance(dec_func, Runnable): + runnable = dec_func + + if runnable.input_schema.schema().get("type") != "object": + raise ValueError("Runnable must have an object schema.") + + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return runnable.invoke(kwargs, {"callbacks": callbacks}) + + coroutine = ainvoke_wrapper + func = invoke_wrapper + schema: Optional[Type[BaseModel]] = runnable.input_schema + description = repr(runnable) + elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func func = None + schema = args_schema + description = None else: coroutine = None func = dec_func + schema = args_schema + description = None if infer_schema or args_schema is not None: return StructuredTool.from_function( func, coroutine, name=tool_name, + description=description, return_direct=return_direct, - args_schema=args_schema, + args_schema=schema, infer_schema=infer_schema, ) # If someone doesn't want a schema applied, we must treat it as @@ -803,7 +828,9 @@ def tool( return _make_tool - if len(args) == 1 and isinstance(args[0], str): + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + return _make_with_name(args[0])(args[1]) + elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name # Example usage: @tool("search", return_direct=True) return _make_with_name(args[0]) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a9103d3725..b72144102a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -46,6 +46,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -2779,3 +2780,32 @@ def test_representation_of_runnables() -> None: " b: RunnableLambda(...)\n" " }" ), "repr where code string contains multiple lambdas gives up" + + +@pytest.mark.asyncio +async def test_tool_from_runnable() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain = prompt | llm | StrOutputParser() + + chain_tool = tool("chain_tool", chain) + + assert isinstance(chain_tool, BaseTool) + assert chain_tool.name == "chain_tool" + assert chain_tool.run({"question": "What up"}) == chain.invoke( + {"question": "What up"} + ) + assert await chain_tool.arun({"question": "What up"}) == await chain.ainvoke( + {"question": "What up"} + ) + assert chain_tool.description.endswith(repr(chain)) + assert chain_tool.args_schema.schema() == chain.input_schema.schema() + assert chain_tool.args_schema.schema() == { + "properties": {"question": {"title": "Question"}}, + "title": "PromptInput", + "type": "object", + }