[Core] Add support for inferring Annotated types (#23284)

in bind_tools() / convert_to_openai_function
wfh/add_list_support
William FH 3 months ago committed by GitHub
parent 9ac302cb97
commit efb4c12abe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import logging
import uuid import uuid
from types import FunctionType, MethodType from types import FunctionType, MethodType
from typing import ( from typing import (
@ -19,7 +20,7 @@ from typing import (
cast, cast,
) )
from typing_extensions import TypedDict from typing_extensions import Annotated, TypedDict, get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import ( from langchain_core.messages import (
@ -33,7 +34,7 @@ from langchain_core.utils.json_schema import dereference_refs
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
PYTHON_TO_JSON_TYPES = { PYTHON_TO_JSON_TYPES = {
"str": "string", "str": "string",
"int": "integer", "int": "integer",
@ -160,6 +161,10 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
return description, arg_descriptions return description, arg_descriptions
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments. """Get JsonSchema describing a Python functions arguments.
@ -171,10 +176,27 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
for arg, arg_type in annotations.items(): for arg, arg_type in annotations.items():
if arg == "return": if arg == "return":
continue continue
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
# Mypy error: if _is_annotated_type(arg_type):
# "type" has no attribute "schema" annotated_args = get_args(arg_type)
properties[arg] = arg_type.schema() # type: ignore[attr-defined] arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
arg_descriptions[arg] = annotation
break
if (
isinstance(arg_type, type)
and hasattr(arg_type, "model_json_schema")
and callable(arg_type.model_json_schema)
):
properties[arg] = arg_type.model_json_schema()
elif (
isinstance(arg_type, type)
and hasattr(arg_type, "schema")
and callable(arg_type.schema)
):
properties[arg] = arg_type.schema()
elif ( elif (
hasattr(arg_type, "__name__") hasattr(arg_type, "__name__")
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
@ -185,13 +207,20 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
): ):
properties[arg] = { properties[arg] = {
"enum": list(arg_type.__args__), # type: ignore "enum": list(arg_type.__args__),
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], # type: ignore "type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
} }
else:
logger.warning(
f"Argument {arg} of type {arg_type} from function {function.__name__} "
"could not be not be converted to a JSON schema."
)
if arg in arg_descriptions: if arg in arg_descriptions:
if arg not in properties: if arg not in properties:
properties[arg] = {} properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg] properties[arg]["description"] = arg_descriptions[arg]
return properties return properties

@ -1,6 +1,10 @@
# mypy: disable-error-code="annotation-unchecked"
from typing import Any, Callable, Dict, List, Literal, Optional, Type from typing import Any, Callable, Dict, List, Literal, Optional, Type
import pytest import pytest
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
from pydantic import Field as FieldV2Maybe # pydantic: ignore
from typing_extensions import Annotated
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -22,6 +26,18 @@ def pydantic() -> Type[BaseModel]:
return dummy_function return dummy_function
@pytest.fixture()
def annotated_function() -> Callable:
def dummy_function(
arg1: Annotated[int, "foo"],
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
) -> None:
"""dummy function"""
pass
return dummy_function
@pytest.fixture() @pytest.fixture()
def function() -> Callable: def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
@ -53,6 +69,30 @@ def dummy_tool() -> BaseTool:
return DummyFunction() return DummyFunction()
@pytest.fixture()
def dummy_pydantic() -> Type[BaseModel]:
class dummy_function(BaseModel):
"""dummy function"""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return dummy_function
@pytest.fixture()
def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe):
"""dummy function"""
arg1: int = FieldV2Maybe(..., description="foo")
arg2: Literal["bar", "baz"] = FieldV2Maybe(
..., description="one of 'bar', 'baz'"
)
return dummy_function
@pytest.fixture() @pytest.fixture()
def json_schema() -> Dict: def json_schema() -> Dict:
return { return {
@ -99,6 +139,8 @@ def test_convert_to_openai_function(
function: Callable, function: Callable,
dummy_tool: BaseTool, dummy_tool: BaseTool,
json_schema: Dict, json_schema: Dict,
annotated_function: Callable,
dummy_pydantic: Type[BaseModel],
) -> None: ) -> None:
expected = { expected = {
"name": "dummy_function", "name": "dummy_function",
@ -125,11 +167,69 @@ def test_convert_to_openai_function(
expected, expected,
Dummy.dummy_function, Dummy.dummy_function,
DummyWithClassMethod.dummy_function, DummyWithClassMethod.dummy_function,
annotated_function,
dummy_pydantic,
): ):
actual = convert_to_openai_function(fn) # type: ignore actual = convert_to_openai_function(fn) # type: ignore
assert actual == expected assert actual == expected
def test_convert_to_openai_function_nested() -> None:
class Nested(BaseModel):
nested_arg1: int = Field(..., description="foo")
nested_arg2: Literal["bar", "baz"] = Field(
..., description="one of 'bar', 'baz'"
)
class NestedV2(BaseModelV2Maybe):
nested_v2_arg1: int = FieldV2Maybe(..., description="foo")
nested_v2_arg2: Literal["bar", "baz"] = FieldV2Maybe(
..., description="one of 'bar', 'baz'"
)
def my_function(arg1: Nested, arg2: NestedV2) -> None:
"""dummy function"""
pass
expected = {
"name": "my_function",
"description": "dummy function",
"parameters": {
"type": "object",
"properties": {
"arg1": {
"type": "object",
"properties": {
"nested_arg1": {"type": "integer", "description": "foo"},
"nested_arg2": {
"type": "string",
"enum": ["bar", "baz"],
"description": "one of 'bar', 'baz'",
},
},
"required": ["nested_arg1", "nested_arg2"],
},
"arg2": {
"type": "object",
"properties": {
"nested_v2_arg1": {"type": "integer", "description": "foo"},
"nested_v2_arg2": {
"type": "string",
"enum": ["bar", "baz"],
"description": "one of 'bar', 'baz'",
},
},
"required": ["nested_v2_arg1", "nested_v2_arg2"],
},
},
"required": ["arg1", "arg2"],
},
}
actual = convert_to_openai_function(my_function)
assert actual == expected
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()") @pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
def test_function_optional_param() -> None: def test_function_optional_param() -> None:
@tool @tool

Loading…
Cancel
Save