From 01ab2918a29d22d2676ed5c286f1b44ccaaba631 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sun, 28 Jul 2024 15:45:19 -0700 Subject: [PATCH] core[patch]: Respect injected in bound fns (#24733) Since right now you cant use the nice injected arg syntas directly with model.bind_tools() --- libs/core/langchain_core/tools.py | 13 ++++++-- .../langchain_core/utils/function_calling.py | 1 + libs/core/tests/unit_tests/test_tools.py | 30 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 633faa7b34..b84e6d5510 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -120,6 +120,7 @@ def _get_filtered_args( func: Callable, *, filter_args: Sequence[str], + include_injected: bool = True, ) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] @@ -127,7 +128,9 @@ def _get_filtered_args( return { k: schema[k] for i, (k, param) in enumerate(valid_keys.items()) - if k not in filter_args and (i > 0 or param.name not in ("self", "cls")) + if k not in filter_args + and (i > 0 or param.name not in ("self", "cls")) + and (include_injected or not _is_injected_arg_type(param.annotation)) } @@ -247,6 +250,7 @@ def create_schema_from_function( filter_args: Optional[Sequence[str]] = None, parse_docstring: bool = False, error_on_invalid_docstring: bool = False, + include_injected: bool = True, ) -> Type[BaseModel]: """Create a pydantic schema from a function's signature. @@ -260,6 +264,9 @@ def create_schema_from_function( error_on_invalid_docstring: if ``parse_docstring`` is provided, configure whether to raise ValueError on invalid Google Style docstrings. Defaults to False. + include_injected: Whether to include injected arguments in the schema. + Defaults to True, since we want to include them in the schema + when *validating* tool inputs. Returns: A pydantic model with the same arguments as the function. @@ -277,7 +284,9 @@ def create_schema_from_function( error_on_invalid_docstring=error_on_invalid_docstring, ) # Pydantic adds placeholder virtual fields we need to strip - valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args) + valid_properties = _get_filtered_args( + inferred_model, func, filter_args=filter_args, include_injected=include_injected + ) return _create_subset_model( f"{model_name}Schema", inferred_model, diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index b2f93bc2c3..10a65b2099 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -179,6 +179,7 @@ def convert_python_function_to_openai_function( filter_args=(), parse_docstring=True, error_on_invalid_docstring=False, + include_injected=False, ) return convert_pydantic_to_openai_function( model, diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 489feeb7dc..ca2a04f9dc 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1429,6 +1429,36 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: } +def _get_parametrized_tools() -> list: + def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: + """my_tool.""" + return some_tool + + async def my_async_tool( + x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg] + ) -> str: + """my_tool.""" + return some_tool + + return [my_tool, my_async_tool] + + +@pytest.mark.parametrize("tool_", _get_parametrized_tools()) +def test_fn_injected_arg_with_schema(tool_: Callable) -> None: + assert convert_to_openai_function(tool_) == { + "name": tool_.__name__, + "description": "my_tool.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "string"}, + }, + "required": ["x", "y"], + }, + } + + def generate_models() -> List[Any]: """Generate a list of base models depending on the pydantic version.""" from pydantic import BaseModel as BaseModelProper # pydantic: ignore