pull/20912/head
William Fu-Hinthorn 1 month ago
parent 4c437ebb9c
commit e4d8ef659c

@ -21,13 +21,25 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import typing
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextvars import copy_context from contextvars import copy_context
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -76,7 +88,10 @@ class SchemaAnnotationError(TypeError):
def _create_subset_model( def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list name: str,
model: Type[BaseModel],
field_names: list,
descriptions: Optional[Mapping[str, str]] = None,
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
@ -88,6 +103,10 @@ def _create_subset_model(
if field.required and not field.allow_none if field.required and not field.allow_none
else Optional[field.outer_type_] else Optional[field.outer_type_]
) )
# Inject the description into the field_info
description = descriptions.get(field_name) if descriptions else None
if description:
field.field_info.description = description
fields[field_name] = (t, field.field_info) fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore rtn = create_model(name, **fields) # type: ignore
return rtn return rtn
@ -103,6 +122,24 @@ def _get_filtered_args(
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
def _get_description_from_annotation(ann: Any) -> Optional[str]:
possible_descriptions = [
arg for arg in typing.get_args(ann) if isinstance(arg, str)
]
return "\n".join(possible_descriptions) if possible_descriptions else None
def _get_descriptions(func: Callable) -> Dict[str, str]:
"""Get the descriptions from a function's signature."""
descriptions = {}
for param in inspect.signature(func).parameters.values():
if param.annotation is not inspect.Parameter.empty:
description = _get_description_from_annotation(param.annotation)
if description:
descriptions[param.name] = description
return descriptions
class _SchemaConfig: class _SchemaConfig:
"""Configuration for the pydantic model.""" """Configuration for the pydantic model."""
@ -128,10 +165,16 @@ def create_schema_from_function(
del inferred_model.__fields__["run_manager"] del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__: if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"] del inferred_model.__fields__["callbacks"]
breakpoint()
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func) valid_properties = _get_filtered_args(inferred_model, func)
# TODO: we could pass through additional metadata here
descriptions = _get_descriptions(func)
return _create_subset_model( return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties) f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=descriptions,
) )

@ -9,6 +9,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
@ -23,6 +24,7 @@ from langchain_core.tools import (
Tool, Tool,
ToolException, ToolException,
_create_subset_model, _create_subset_model,
create_schema_from_function,
tool, tool,
) )
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@ -53,7 +55,12 @@ class _MockStructuredTool(BaseTool):
args_schema: Type[BaseModel] = _MockSchema args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool" description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: def _run(
self,
arg1: int,
arg2: bool,
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}" return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
@ -70,6 +77,33 @@ def test_structured_args() -> None:
assert structured_api.run(args) == expected_result assert structured_api.run(args) == expected_result
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_structured_args_description() -> None:
class _AnnotatedTool(BaseTool):
name: str = "structured_api"
description: str = "A Structured Tool"
def _run(
self,
arg1: int,
arg2: Annotated[bool, "V important"],
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
expected = {
"arg1": {"title": "Arg1", "type": "integer"},
"arg2": {"title": "Arg2", "type": "boolean", "description": "V important"},
"arg3": {"title": "Arg3", "type": "object"},
}
assert _AnnotatedTool().args == expected
def test_misannotated_base_tool_raises_error() -> None: def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrect typehint raises an exception.""" "" """Test that a BaseTool with the incorrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError): with pytest.raises(SchemaAnnotationError):
@ -876,6 +910,73 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
foo.invoke(inputs) # type: ignore foo.invoke(inputs) # type: ignore
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_create_schema_from_function_with_descriptions() -> None:
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()
schema = create_schema_from_function("foo", foo)
assert schema.schema() == {
"title": "fooSchema",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"required": ["bar", "baz"],
}
def foo_annotated(
bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"],
) -> str:
"""Docstring
Args:
bar: int
"""
raise bar
schema = create_schema_from_function("foo_annotated", foo_annotated)
assert schema.schema() == {
"title": "foo_annotatedSchema",
"type": "object",
"properties": {
"bar": {
"title": "Bar",
"type": "integer",
"description": "This is bar\nit's useful",
},
},
"required": ["bar"],
}
def test_annotated_tool_typing() -> None:
@tool
def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
"""The foo."""
return str(bar)
assert foo.invoke({"bar": 5}) == "5" # type: ignore
with pytest.raises(ValidationError):
foo.invoke({"bar": 4}) # type: ignore
async def test_annotated_async_tool_typing() -> None:
@tool
async def foo(bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"]) -> str:
"""The foo."""
return str(bar)
assert await foo.ainvoke({"bar": 5}) == "5" # type: ignore
with pytest.raises(ValidationError):
await foo.ainvoke({"bar": 4}) # type: ignore
def test_tool_pass_context() -> None: def test_tool_pass_context() -> None:
@tool @tool
def foo(bar: str) -> str: def foo(bar: str) -> str:

Loading…
Cancel
Save