diff --git a/libs/langchain/langchain/utils/json_schema.py b/libs/langchain/langchain/utils/json_schema.py index c5feab8478..9628f9e521 100644 --- a/libs/langchain/langchain/utils/json_schema.py +++ b/libs/langchain/langchain/utils/json_schema.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, TypeVar, Union, cast +from copy import deepcopy +from typing import Any, List, Optional, Sequence def _retrieve_ref(path: str, schema: dict) -> dict: @@ -13,36 +14,59 @@ def _retrieve_ref(path: str, schema: dict) -> dict: out = schema for component in components[1:]: out = out[component] - return out + return deepcopy(out) -JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list]) - - -def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE: +def _dereference_refs_helper( + obj: Any, full_schema: dict, skip_keys: Sequence[str] +) -> Any: if isinstance(obj, dict): obj_out = {} for k, v in obj.items(): - if k == "$ref": + if k in skip_keys: + obj_out[k] = v + elif k == "$ref": ref = _retrieve_ref(v, full_schema) - obj_out[k] = _dereference_refs_helper(ref, full_schema) + return _dereference_refs_helper(ref, full_schema, skip_keys) elif isinstance(v, (list, dict)): - obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore + obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys) else: obj_out[k] = v - return cast(JSON_LIKE, obj_out) + return obj_out elif isinstance(obj, list): - return cast( - JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj] - ) + return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj] else: return obj +def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]: + keys = [] + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "$ref": + ref = _retrieve_ref(v, full_schema) + keys.append(v.split("/")[1]) + keys += _infer_skip_keys(ref, full_schema) + elif isinstance(v, (list, dict)): + keys += _infer_skip_keys(v, full_schema) + elif isinstance(obj, list): + for el in obj: + keys += _infer_skip_keys(el, full_schema) + return keys + + def dereference_refs( - schema_obj: dict, *, full_schema: Optional[dict] = None -) -> Union[dict, list]: + schema_obj: dict, + *, + full_schema: Optional[dict] = None, + skip_keys: Optional[Sequence[str]] = None, +) -> dict: """Try to substitute $refs in JSON Schema.""" full_schema = full_schema or schema_obj - return _dereference_refs_helper(schema_obj, full_schema) + skip_keys = ( + skip_keys + if skip_keys is not None + else _infer_skip_keys(schema_obj, full_schema) + ) + return _dereference_refs_helper(schema_obj, full_schema, skip_keys) diff --git a/libs/langchain/langchain/utils/openai_functions.py b/libs/langchain/langchain/utils/openai_functions.py index 48c49541dc..cfb1e76d59 100644 --- a/libs/langchain/langchain/utils/openai_functions.py +++ b/libs/langchain/langchain/utils/openai_functions.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Type, TypedDict, cast +from typing import Optional, Type, TypedDict from langchain.pydantic_v1 import BaseModel from langchain.utils.json_schema import dereference_refs @@ -21,7 +21,8 @@ def convert_pydantic_to_openai_function( name: Optional[str] = None, description: Optional[str] = None ) -> FunctionDescription: - schema = cast(Dict, dereference_refs(model.schema())) + schema = dereference_refs(model.schema()) + schema.pop("definitions", None) return { "name": name or schema["title"], "description": description or schema["description"], diff --git a/libs/langchain/tests/unit_tests/utils/test_json_schema.py b/libs/langchain/tests/unit_tests/utils/test_json_schema.py new file mode 100644 index 0000000000..233c467272 --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_json_schema.py @@ -0,0 +1,151 @@ +import pytest + +from langchain.utils.json_schema import dereference_refs + + +def test_dereference_refs_no_refs() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + }, + } + actual = dereference_refs(schema) + assert actual == schema + + +def test_dereference_refs_one_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + }, + "$defs": {"name": {"type": "string"}}, + } + expected = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + }, + "$defs": {"name": {"type": "string"}}, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_multiple_refs() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + "other": {"$ref": "#/$defs/other"}, + }, + "$defs": { + "name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + } + expected = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + "$defs": { + "name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_nested_refs_skip() -> None: + schema = { + "type": "object", + "properties": { + "info": {"$ref": "#/$defs/info"}, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + expected = { + "type": "object", + "properties": { + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_nested_refs_no_skip() -> None: + schema = { + "type": "object", + "properties": { + "info": {"$ref": "#/$defs/info"}, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + expected = { + "type": "object", + "properties": { + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + } + actual = dereference_refs(schema, skip_keys=()) + assert actual == expected + + +def test_dereference_refs_missing_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + }, + "$defs": {}, + } + with pytest.raises(KeyError): + dereference_refs(schema) + + +def test_dereference_refs_remote_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "https://somewhere/else/name"}, + }, + } + with pytest.raises(ValueError): + dereference_refs(schema) diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py new file mode 100644 index 0000000000..b5a22d837b --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -0,0 +1,79 @@ +from langchain.pydantic_v1 import BaseModel, Field +from langchain.utils.openai_functions import convert_pydantic_to_openai_function + + +def test_convert_pydantic_to_openai_function() -> None: + class Data(BaseModel): + """The data to return.""" + + key: str = Field(..., description="API key") + days: int = Field(default=0, description="Number of days to forecast") + + actual = convert_pydantic_to_openai_function(Data) + expected = { + "name": "Data", + "description": "The data to return.", + "parameters": { + "title": "Data", + "description": "The data to return.", + "type": "object", + "properties": { + "key": {"title": "Key", "description": "API key", "type": "string"}, + "days": { + "title": "Days", + "description": "Number of days to forecast", + "default": 0, + "type": "integer", + }, + }, + "required": ["key"], + }, + } + assert actual == expected + + +def test_convert_pydantic_to_openai_function_nested() -> None: + class Data(BaseModel): + """The data to return.""" + + key: str = Field(..., description="API key") + days: int = Field(default=0, description="Number of days to forecast") + + class Model(BaseModel): + """The model to return.""" + + data: Data + + actual = convert_pydantic_to_openai_function(Model) + expected = { + "name": "Model", + "description": "The model to return.", + "parameters": { + "title": "Model", + "description": "The model to return.", + "type": "object", + "properties": { + "data": { + "title": "Data", + "description": "The data to return.", + "type": "object", + "properties": { + "key": { + "title": "Key", + "description": "API key", + "type": "string", + }, + "days": { + "title": "Days", + "description": "Number of days to forecast", + "default": 0, + "type": "integer", + }, + }, + "required": ["key"], + } + }, + "required": ["data"], + }, + } + assert actual == expected