Enable creating Tools from any Runnable

pull/11177/head
Nuno Campos 1 year ago
parent 61b5942adf
commit 6eb6c45c98

@ -734,7 +734,7 @@ class StructuredTool(BaseTool):
def tool(
*args: Union[str, Callable],
*args: Union[str, Callable, Runnable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
@ -769,21 +769,31 @@ def tool(
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Callable) -> BaseTool:
if inspect.iscoroutinefunction(dec_func):
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable):
coroutine = dec_func.ainvoke
func = dec_func.invoke
schema = dec_func.input_schema
description = repr(dec_func)
elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
schema = args_schema
description = None
else:
coroutine = None
func = dec_func
schema = args_schema
description = None
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
description=description,
return_direct=return_direct,
args_schema=args_schema,
args_schema=schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
@ -803,7 +813,9 @@ def tool(
return _make_tool
if len(args) == 1 and isinstance(args[0], str):
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])

@ -2,6 +2,7 @@ import sys
from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
from langchain.tools.base import BaseTool, tool
import pytest
from freezegun import freeze_time
@ -2779,3 +2780,25 @@ def test_representation_of_runnables() -> None:
" b: RunnableLambda(...)\n"
" }"
), "repr where code string contains multiple lambdas gives up"
def test_tool_from_runnable() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = prompt | llm | StrOutputParser()
chain_tool = tool("chain_tool", chain)
assert isinstance(chain_tool, BaseTool)
assert chain_tool.name == "chain_tool"
assert chain_tool.description.endswith(repr(chain))
assert chain_tool.args_schema.schema() == chain.input_schema.schema()
assert chain_tool.args_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "PromptInput",
"type": "object",
}

Loading…
Cancel
Save