pull/20912/head
William Fu-Hinthorn 3 weeks ago
parent 4c437ebb9c
commit e4d8ef659c

@ -21,13 +21,25 @@ from __future__ import annotations
import asyncio
import inspect
import typing
import uuid
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -76,7 +88,10 @@ class SchemaAnnotationError(TypeError):
def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list
name: str,
model: Type[BaseModel],
field_names: list,
descriptions: Optional[Mapping[str, str]] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
@ -88,6 +103,10 @@ def _create_subset_model(
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
# Inject the description into the field_info
description = descriptions.get(field_name) if descriptions else None
if description:
field.field_info.description = description
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
return rtn
@ -103,6 +122,24 @@ def _get_filtered_args(
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
def _get_description_from_annotation(ann: Any) -> Optional[str]:
possible_descriptions = [
arg for arg in typing.get_args(ann) if isinstance(arg, str)
]
return "\n".join(possible_descriptions) if possible_descriptions else None
def _get_descriptions(func: Callable) -> Dict[str, str]:
"""Get the descriptions from a function's signature."""
descriptions = {}
for param in inspect.signature(func).parameters.values():
if param.annotation is not inspect.Parameter.empty:
description = _get_description_from_annotation(param.annotation)
if description:
descriptions[param.name] = description
return descriptions
class _SchemaConfig:
"""Configuration for the pydantic model."""
@ -128,10 +165,16 @@ def create_schema_from_function(
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
breakpoint()
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
# TODO: we could pass through additional metadata here
descriptions = _get_descriptions(func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=descriptions,
)

@ -9,6 +9,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -23,6 +24,7 @@ from langchain_core.tools import (
Tool,
ToolException,
_create_subset_model,
create_schema_from_function,
tool,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@ -53,7 +55,12 @@ class _MockStructuredTool(BaseTool):
args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
def _run(
self,
arg1: int,
arg2: bool,
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
@ -70,6 +77,33 @@ def test_structured_args() -> None:
assert structured_api.run(args) == expected_result
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_structured_args_description() -> None:
class _AnnotatedTool(BaseTool):
name: str = "structured_api"
description: str = "A Structured Tool"
def _run(
self,
arg1: int,
arg2: Annotated[bool, "V important"],
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
expected = {
"arg1": {"title": "Arg1", "type": "integer"},
"arg2": {"title": "Arg2", "type": "boolean", "description": "V important"},
"arg3": {"title": "Arg3", "type": "object"},
}
assert _AnnotatedTool().args == expected
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
@ -876,6 +910,73 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
foo.invoke(inputs) # type: ignore
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_create_schema_from_function_with_descriptions() -> None:
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()
schema = create_schema_from_function("foo", foo)
assert schema.schema() == {
"title": "fooSchema",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"required": ["bar", "baz"],
}
def foo_annotated(
bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"],
) -> str:
"""Docstring
Args:
bar: int
"""
raise bar
schema = create_schema_from_function("foo_annotated", foo_annotated)
assert schema.schema() == {
"title": "foo_annotatedSchema",
"type": "object",
"properties": {
"bar": {
"title": "Bar",
"type": "integer",
"description": "This is bar\nit's useful",
},
},
"required": ["bar"],
}
def test_annotated_tool_typing() -> None:
@tool
def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
"""The foo."""
return str(bar)
assert foo.invoke({"bar": 5}) == "5" # type: ignore
with pytest.raises(ValidationError):
foo.invoke({"bar": 4}) # type: ignore
async def test_annotated_async_tool_typing() -> None:
@tool
async def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
"""The foo."""
return str(bar)
assert await foo.ainvoke({"bar": 5}) == "5" # type: ignore
with pytest.raises(ValidationError):
await foo.ainvoke({"bar": 4}) # type: ignore
def test_tool_pass_context() -> None:
@tool
def foo(bar: str) -> str:

Loading…
Cancel
Save