From efb4c12abedd839c1d45f2dc0230aa89d0cd49e4 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:16:30 -0700 Subject: [PATCH] [Core] Add support for inferring Annotated types (#23284) in bind_tools() / convert_to_openai_function --- .../langchain_core/utils/function_calling.py | 45 ++++++-- .../unit_tests/utils/test_function_calling.py | 100 ++++++++++++++++++ 2 files changed, 137 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 92e78fab7d..e8b1e32d1c 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +import logging import uuid from types import FunctionType, MethodType from typing import ( @@ -19,7 +20,7 @@ from typing import ( cast, ) -from typing_extensions import TypedDict +from typing_extensions import Annotated, TypedDict, get_args, get_origin from langchain_core._api import deprecated from langchain_core.messages import ( @@ -33,7 +34,7 @@ from langchain_core.utils.json_schema import dereference_refs if TYPE_CHECKING: from langchain_core.tools import BaseTool - +logger = logging.getLogger(__name__) PYTHON_TO_JSON_TYPES = { "str": "string", "int": "integer", @@ -160,6 +161,10 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: return description, arg_descriptions +def _is_annotated_type(typ: Type[Any]) -> bool: + return get_origin(typ) is Annotated + + def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: """Get JsonSchema describing a Python functions arguments. @@ -171,10 +176,27 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) - for arg, arg_type in annotations.items(): if arg == "return": continue - if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): - # Mypy error: - # "type" has no attribute "schema" - properties[arg] = arg_type.schema() # type: ignore[attr-defined] + + if _is_annotated_type(arg_type): + annotated_args = get_args(arg_type) + arg_type = annotated_args[0] + if len(annotated_args) > 1: + for annotation in annotated_args[1:]: + if isinstance(annotation, str): + arg_descriptions[arg] = annotation + break + if ( + isinstance(arg_type, type) + and hasattr(arg_type, "model_json_schema") + and callable(arg_type.model_json_schema) + ): + properties[arg] = arg_type.model_json_schema() + elif ( + isinstance(arg_type, type) + and hasattr(arg_type, "schema") + and callable(arg_type.schema) + ): + properties[arg] = arg_type.schema() elif ( hasattr(arg_type, "__name__") and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES @@ -185,13 +207,20 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) - and getattr(arg_type, "__dict__").get("__origin__", None) == Literal ): properties[arg] = { - "enum": list(arg_type.__args__), # type: ignore - "type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], # type: ignore + "enum": list(arg_type.__args__), + "type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], } + else: + logger.warning( + f"Argument {arg} of type {arg_type} from function {function.__name__} " + "could not be not be converted to a JSON schema." + ) + if arg in arg_descriptions: if arg not in properties: properties[arg] = {} properties[arg]["description"] = arg_descriptions[arg] + return properties diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 9973d97220..ddf19e87e9 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -1,6 +1,10 @@ +# mypy: disable-error-code="annotation-unchecked" from typing import Any, Callable, Dict, List, Literal, Optional, Type import pytest +from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore +from pydantic import Field as FieldV2Maybe # pydantic: ignore +from typing_extensions import Annotated from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field @@ -22,6 +26,18 @@ def pydantic() -> Type[BaseModel]: return dummy_function +@pytest.fixture() +def annotated_function() -> Callable: + def dummy_function( + arg1: Annotated[int, "foo"], + arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], + ) -> None: + """dummy function""" + pass + + return dummy_function + + @pytest.fixture() def function() -> Callable: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: @@ -53,6 +69,30 @@ def dummy_tool() -> BaseTool: return DummyFunction() +@pytest.fixture() +def dummy_pydantic() -> Type[BaseModel]: + class dummy_function(BaseModel): + """dummy function""" + + arg1: int = Field(..., description="foo") + arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") + + return dummy_function + + +@pytest.fixture() +def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]: + class dummy_function(BaseModelV2Maybe): + """dummy function""" + + arg1: int = FieldV2Maybe(..., description="foo") + arg2: Literal["bar", "baz"] = FieldV2Maybe( + ..., description="one of 'bar', 'baz'" + ) + + return dummy_function + + @pytest.fixture() def json_schema() -> Dict: return { @@ -99,6 +139,8 @@ def test_convert_to_openai_function( function: Callable, dummy_tool: BaseTool, json_schema: Dict, + annotated_function: Callable, + dummy_pydantic: Type[BaseModel], ) -> None: expected = { "name": "dummy_function", @@ -125,11 +167,69 @@ def test_convert_to_openai_function( expected, Dummy.dummy_function, DummyWithClassMethod.dummy_function, + annotated_function, + dummy_pydantic, ): actual = convert_to_openai_function(fn) # type: ignore assert actual == expected +def test_convert_to_openai_function_nested() -> None: + class Nested(BaseModel): + nested_arg1: int = Field(..., description="foo") + nested_arg2: Literal["bar", "baz"] = Field( + ..., description="one of 'bar', 'baz'" + ) + + class NestedV2(BaseModelV2Maybe): + nested_v2_arg1: int = FieldV2Maybe(..., description="foo") + nested_v2_arg2: Literal["bar", "baz"] = FieldV2Maybe( + ..., description="one of 'bar', 'baz'" + ) + + def my_function(arg1: Nested, arg2: NestedV2) -> None: + """dummy function""" + pass + + expected = { + "name": "my_function", + "description": "dummy function", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "object", + "properties": { + "nested_arg1": {"type": "integer", "description": "foo"}, + "nested_arg2": { + "type": "string", + "enum": ["bar", "baz"], + "description": "one of 'bar', 'baz'", + }, + }, + "required": ["nested_arg1", "nested_arg2"], + }, + "arg2": { + "type": "object", + "properties": { + "nested_v2_arg1": {"type": "integer", "description": "foo"}, + "nested_v2_arg2": { + "type": "string", + "enum": ["bar", "baz"], + "description": "one of 'bar', 'baz'", + }, + }, + "required": ["nested_v2_arg1", "nested_v2_arg2"], + }, + }, + "required": ["arg1", "arg2"], + }, + } + + actual = convert_to_openai_function(my_function) + assert actual == expected + + @pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()") def test_function_optional_param() -> None: @tool