mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
add tests
This commit is contained in:
parent
1f5c579ef4
commit
e805f8e263
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, TypeVar, Union, cast
|
from copy import deepcopy
|
||||||
|
from typing import Any, List, Optional, Sequence
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_ref(path: str, schema: dict) -> dict:
|
def _retrieve_ref(path: str, schema: dict) -> dict:
|
||||||
@ -13,36 +14,59 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
|
|||||||
out = schema
|
out = schema
|
||||||
for component in components[1:]:
|
for component in components[1:]:
|
||||||
out = out[component]
|
out = out[component]
|
||||||
return out
|
return deepcopy(out)
|
||||||
|
|
||||||
|
|
||||||
JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list])
|
def _dereference_refs_helper(
|
||||||
|
obj: Any, full_schema: dict, skip_keys: Sequence[str]
|
||||||
|
) -> Any:
|
||||||
def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE:
|
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
obj_out = {}
|
obj_out = {}
|
||||||
for k, v in obj.items():
|
for k, v in obj.items():
|
||||||
if k == "$ref":
|
if k in skip_keys:
|
||||||
|
obj_out[k] = v
|
||||||
|
elif k == "$ref":
|
||||||
ref = _retrieve_ref(v, full_schema)
|
ref = _retrieve_ref(v, full_schema)
|
||||||
obj_out[k] = _dereference_refs_helper(ref, full_schema)
|
return _dereference_refs_helper(ref, full_schema, skip_keys)
|
||||||
elif isinstance(v, (list, dict)):
|
elif isinstance(v, (list, dict)):
|
||||||
obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore
|
obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys)
|
||||||
else:
|
else:
|
||||||
obj_out[k] = v
|
obj_out[k] = v
|
||||||
return cast(JSON_LIKE, obj_out)
|
return obj_out
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
return cast(
|
return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj]
|
||||||
JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]:
|
||||||
|
keys = []
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for k, v in obj.items():
|
||||||
|
if k == "$ref":
|
||||||
|
ref = _retrieve_ref(v, full_schema)
|
||||||
|
keys.append(v.split("/")[1])
|
||||||
|
keys += _infer_skip_keys(ref, full_schema)
|
||||||
|
elif isinstance(v, (list, dict)):
|
||||||
|
keys += _infer_skip_keys(v, full_schema)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for el in obj:
|
||||||
|
keys += _infer_skip_keys(el, full_schema)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
def dereference_refs(
|
def dereference_refs(
|
||||||
schema_obj: dict, *, full_schema: Optional[dict] = None
|
schema_obj: dict,
|
||||||
) -> Union[dict, list]:
|
*,
|
||||||
|
full_schema: Optional[dict] = None,
|
||||||
|
skip_keys: Optional[Sequence[str]] = None,
|
||||||
|
) -> dict:
|
||||||
"""Try to substitute $refs in JSON Schema."""
|
"""Try to substitute $refs in JSON Schema."""
|
||||||
|
|
||||||
full_schema = full_schema or schema_obj
|
full_schema = full_schema or schema_obj
|
||||||
return _dereference_refs_helper(schema_obj, full_schema)
|
skip_keys = (
|
||||||
|
skip_keys
|
||||||
|
if skip_keys is not None
|
||||||
|
else _infer_skip_keys(schema_obj, full_schema)
|
||||||
|
)
|
||||||
|
return _dereference_refs_helper(schema_obj, full_schema, skip_keys)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Optional, Type, TypedDict, cast
|
from typing import Optional, Type, TypedDict
|
||||||
|
|
||||||
from langchain.pydantic_v1 import BaseModel
|
from langchain.pydantic_v1 import BaseModel
|
||||||
from langchain.utils.json_schema import dereference_refs
|
from langchain.utils.json_schema import dereference_refs
|
||||||
@ -21,7 +21,8 @@ def convert_pydantic_to_openai_function(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
) -> FunctionDescription:
|
) -> FunctionDescription:
|
||||||
schema = cast(Dict, dereference_refs(model.schema()))
|
schema = dereference_refs(model.schema())
|
||||||
|
schema.pop("definitions", None)
|
||||||
return {
|
return {
|
||||||
"name": name or schema["title"],
|
"name": name or schema["title"],
|
||||||
"description": description or schema["description"],
|
"description": description or schema["description"],
|
||||||
|
151
libs/langchain/tests/unit_tests/utils/test_json_schema.py
Normal file
151
libs/langchain/tests/unit_tests/utils/test_json_schema.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.utils.json_schema import dereference_refs
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_no_refs() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dereference_refs(schema)
|
||||||
|
assert actual == schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_one_ref() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"$ref": "#/$defs/name"},
|
||||||
|
},
|
||||||
|
"$defs": {"name": {"type": "string"}},
|
||||||
|
}
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"type": "string"},
|
||||||
|
},
|
||||||
|
"$defs": {"name": {"type": "string"}},
|
||||||
|
}
|
||||||
|
actual = dereference_refs(schema)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_multiple_refs() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"$ref": "#/$defs/name"},
|
||||||
|
"other": {"$ref": "#/$defs/other"},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"type": "string"},
|
||||||
|
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dereference_refs(schema)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_nested_refs_skip() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"info": {"$ref": "#/$defs/info"},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dereference_refs(schema)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_nested_refs_no_skip() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"info": {"$ref": "#/$defs/info"},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"info": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"age": "int", "name": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dereference_refs(schema, skip_keys=())
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_missing_ref() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"$ref": "#/$defs/name"},
|
||||||
|
},
|
||||||
|
"$defs": {},
|
||||||
|
}
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
dereference_refs(schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dereference_refs_remote_ref() -> None:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"$ref": "https://somewhere/else/name"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dereference_refs(schema)
|
@ -0,0 +1,79 @@
|
|||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
|
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_pydantic_to_openai_function() -> None:
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""The data to return."""
|
||||||
|
|
||||||
|
key: str = Field(..., description="API key")
|
||||||
|
days: int = Field(default=0, description="Number of days to forecast")
|
||||||
|
|
||||||
|
actual = convert_pydantic_to_openai_function(Data)
|
||||||
|
expected = {
|
||||||
|
"name": "Data",
|
||||||
|
"description": "The data to return.",
|
||||||
|
"parameters": {
|
||||||
|
"title": "Data",
|
||||||
|
"description": "The data to return.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {"title": "Key", "description": "API key", "type": "string"},
|
||||||
|
"days": {
|
||||||
|
"title": "Days",
|
||||||
|
"description": "Number of days to forecast",
|
||||||
|
"default": 0,
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["key"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_pydantic_to_openai_function_nested() -> None:
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""The data to return."""
|
||||||
|
|
||||||
|
key: str = Field(..., description="API key")
|
||||||
|
days: int = Field(default=0, description="Number of days to forecast")
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
"""The model to return."""
|
||||||
|
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
actual = convert_pydantic_to_openai_function(Model)
|
||||||
|
expected = {
|
||||||
|
"name": "Model",
|
||||||
|
"description": "The model to return.",
|
||||||
|
"parameters": {
|
||||||
|
"title": "Model",
|
||||||
|
"description": "The model to return.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"title": "Data",
|
||||||
|
"description": "The data to return.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"title": "Key",
|
||||||
|
"description": "API key",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"days": {
|
||||||
|
"title": "Days",
|
||||||
|
"description": "Number of days to forecast",
|
||||||
|
"default": 0,
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["key"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert actual == expected
|
Loading…
Reference in New Issue
Block a user