diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 38119ceb70..d286a3c794 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -21,13 +21,25 @@ from __future__ import annotations import asyncio import inspect +import typing import uuid import warnings from abc import ABC, abstractmethod from contextvars import copy_context from functools import partial from inspect import signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Type, + Union, +) from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -76,7 +88,10 @@ class SchemaAnnotationError(TypeError): def _create_subset_model( - name: str, model: Type[BaseModel], field_names: list + name: str, + model: Type[BaseModel], + field_names: list, + descriptions: Optional[Mapping[str, str]] = None, ) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} @@ -88,6 +103,10 @@ def _create_subset_model( if field.required and not field.allow_none else Optional[field.outer_type_] ) + # Inject the description into the field_info + description = descriptions.get(field_name) if descriptions else None + if description: + field.field_info.description = description fields[field_name] = (t, field.field_info) rtn = create_model(name, **fields) # type: ignore return rtn @@ -103,6 +122,24 @@ def _get_filtered_args( return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} +def _get_description_from_annotation(ann: Any) -> Optional[str]: + possible_descriptions = [ + arg for arg in typing.get_args(ann) if isinstance(arg, str) + ] + return "\n".join(possible_descriptions) if possible_descriptions else None + + +def _get_descriptions(func: Callable) -> Dict[str, str]: + """Get the descriptions from a function's signature.""" + descriptions = {} + for param in inspect.signature(func).parameters.values(): + if param.annotation is not inspect.Parameter.empty: + description = _get_description_from_annotation(param.annotation) + if description: + descriptions[param.name] = description + return descriptions + + class _SchemaConfig: """Configuration for the pydantic model.""" @@ -128,10 +165,16 @@ def create_schema_from_function( del inferred_model.__fields__["run_manager"] if "callbacks" in inferred_model.__fields__: del inferred_model.__fields__["callbacks"] + breakpoint() # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func) + # TODO: we could pass through additional metadata here + descriptions = _get_descriptions(func) return _create_subset_model( - f"{model_name}Schema", inferred_model, list(valid_properties) + f"{model_name}Schema", + inferred_model, + list(valid_properties), + descriptions=descriptions, ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index a424d6ecf6..3e8fae0c3e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -9,6 +9,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Type, Union import pytest +from typing_extensions import Annotated from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -23,6 +24,7 @@ from langchain_core.tools import ( Tool, ToolException, _create_subset_model, + create_schema_from_function, tool, ) from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -53,7 +55,12 @@ class _MockStructuredTool(BaseTool): args_schema: Type[BaseModel] = _MockSchema description: str = "A Structured Tool" - def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def _run( + self, + arg1: int, + arg2: bool, + arg3: Optional[dict] = None, + ) -> str: return f"{arg1} {arg2} {arg3}" async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: @@ -70,6 +77,33 @@ def test_structured_args() -> None: assert structured_api.run(args) == expected_result +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above") +def test_structured_args_description() -> None: + class _AnnotatedTool(BaseTool): + name: str = "structured_api" + description: str = "A Structured Tool" + + def _run( + self, + arg1: int, + arg2: Annotated[bool, "V important"], + arg3: Optional[dict] = None, + ) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + raise NotImplementedError + + expected = { + "arg1": {"title": "Arg1", "type": "integer"}, + "arg2": {"title": "Arg2", "type": "boolean", "description": "V important"}, + "arg3": {"title": "Arg3", "type": "object"}, + } + assert _AnnotatedTool().args == expected + + def test_misannotated_base_tool_raises_error() -> None: """Test that a BaseTool with the incorrect typehint raises an exception.""" "" with pytest.raises(SchemaAnnotationError): @@ -876,6 +910,73 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No foo.invoke(inputs) # type: ignore +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above") +def test_create_schema_from_function_with_descriptions() -> None: + def foo(bar: int, baz: str) -> str: + """Docstring + Args: + bar: int + baz: str + """ + raise NotImplementedError() + + schema = create_schema_from_function("foo", foo) + assert schema.schema() == { + "title": "fooSchema", + "type": "object", + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + }, + "required": ["bar", "baz"], + } + + def foo_annotated( + bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"], + ) -> str: + """Docstring + Args: + bar: int + """ + raise bar + + schema = create_schema_from_function("foo_annotated", foo_annotated) + assert schema.schema() == { + "title": "foo_annotatedSchema", + "type": "object", + "properties": { + "bar": { + "title": "Bar", + "type": "integer", + "description": "This is bar\nit's useful", + }, + }, + "required": ["bar"], + } + + +def test_annotated_tool_typing() -> None: + @tool + def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str: + """The foo.""" + return str(bar) + + assert foo.invoke({"bar": 5}) == "5" # type: ignore + with pytest.raises(ValidationError): + foo.invoke({"bar": 4}) # type: ignore + + +async def test_annotated_async_tool_typing() -> None: + @tool + async def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str: + """The foo.""" + return str(bar) + + assert await foo.ainvoke({"bar": 5}) == "5" # type: ignore + with pytest.raises(ValidationError): + await foo.ainvoke({"bar": 4}) # type: ignore + + def test_tool_pass_context() -> None: @tool def foo(bar: str) -> str: