add tests

This commit is contained in:
Bagatur 2023-08-30 15:23:02 -07:00
parent 1f5c579ef4
commit e805f8e263
4 changed files with 273 additions and 18 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, TypeVar, Union, cast from copy import deepcopy
from typing import Any, List, Optional, Sequence
def _retrieve_ref(path: str, schema: dict) -> dict: def _retrieve_ref(path: str, schema: dict) -> dict:
@ -13,36 +14,59 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
out = schema out = schema
for component in components[1:]: for component in components[1:]:
out = out[component] out = out[component]
return out return deepcopy(out)
JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list]) def _dereference_refs_helper(
obj: Any, full_schema: dict, skip_keys: Sequence[str]
) -> Any:
def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE:
if isinstance(obj, dict): if isinstance(obj, dict):
obj_out = {} obj_out = {}
for k, v in obj.items(): for k, v in obj.items():
if k == "$ref": if k in skip_keys:
obj_out[k] = v
elif k == "$ref":
ref = _retrieve_ref(v, full_schema) ref = _retrieve_ref(v, full_schema)
obj_out[k] = _dereference_refs_helper(ref, full_schema) return _dereference_refs_helper(ref, full_schema, skip_keys)
elif isinstance(v, (list, dict)): elif isinstance(v, (list, dict)):
obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys)
else: else:
obj_out[k] = v obj_out[k] = v
return cast(JSON_LIKE, obj_out) return obj_out
elif isinstance(obj, list): elif isinstance(obj, list):
return cast( return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj]
JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj]
)
else: else:
return obj return obj
def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]:
keys = []
if isinstance(obj, dict):
for k, v in obj.items():
if k == "$ref":
ref = _retrieve_ref(v, full_schema)
keys.append(v.split("/")[1])
keys += _infer_skip_keys(ref, full_schema)
elif isinstance(v, (list, dict)):
keys += _infer_skip_keys(v, full_schema)
elif isinstance(obj, list):
for el in obj:
keys += _infer_skip_keys(el, full_schema)
return keys
def dereference_refs( def dereference_refs(
schema_obj: dict, *, full_schema: Optional[dict] = None schema_obj: dict,
) -> Union[dict, list]: *,
full_schema: Optional[dict] = None,
skip_keys: Optional[Sequence[str]] = None,
) -> dict:
"""Try to substitute $refs in JSON Schema.""" """Try to substitute $refs in JSON Schema."""
full_schema = full_schema or schema_obj full_schema = full_schema or schema_obj
return _dereference_refs_helper(schema_obj, full_schema) skip_keys = (
skip_keys
if skip_keys is not None
else _infer_skip_keys(schema_obj, full_schema)
)
return _dereference_refs_helper(schema_obj, full_schema, skip_keys)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Type, TypedDict, cast from typing import Optional, Type, TypedDict
from langchain.pydantic_v1 import BaseModel from langchain.pydantic_v1 import BaseModel
from langchain.utils.json_schema import dereference_refs from langchain.utils.json_schema import dereference_refs
@ -21,7 +21,8 @@ def convert_pydantic_to_openai_function(
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None description: Optional[str] = None
) -> FunctionDescription: ) -> FunctionDescription:
schema = cast(Dict, dereference_refs(model.schema())) schema = dereference_refs(model.schema())
schema.pop("definitions", None)
return { return {
"name": name or schema["title"], "name": name or schema["title"],
"description": description or schema["description"], "description": description or schema["description"],

View File

@ -0,0 +1,151 @@
import pytest
from langchain.utils.json_schema import dereference_refs
def test_dereference_refs_no_refs() -> None:
schema = {
"type": "object",
"properties": {
"first_name": {"type": "string"},
},
}
actual = dereference_refs(schema)
assert actual == schema
def test_dereference_refs_one_ref() -> None:
schema = {
"type": "object",
"properties": {
"first_name": {"$ref": "#/$defs/name"},
},
"$defs": {"name": {"type": "string"}},
}
expected = {
"type": "object",
"properties": {
"first_name": {"type": "string"},
},
"$defs": {"name": {"type": "string"}},
}
actual = dereference_refs(schema)
assert actual == expected
def test_dereference_refs_multiple_refs() -> None:
schema = {
"type": "object",
"properties": {
"first_name": {"$ref": "#/$defs/name"},
"other": {"$ref": "#/$defs/other"},
},
"$defs": {
"name": {"type": "string"},
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
},
}
expected = {
"type": "object",
"properties": {
"first_name": {"type": "string"},
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
},
"$defs": {
"name": {"type": "string"},
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
},
}
actual = dereference_refs(schema)
assert actual == expected
def test_dereference_refs_nested_refs_skip() -> None:
schema = {
"type": "object",
"properties": {
"info": {"$ref": "#/$defs/info"},
},
"$defs": {
"name": {"type": "string"},
"info": {
"type": "object",
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
},
},
}
expected = {
"type": "object",
"properties": {
"info": {
"type": "object",
"properties": {"age": "int", "name": {"type": "string"}},
},
},
"$defs": {
"name": {"type": "string"},
"info": {
"type": "object",
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
},
},
}
actual = dereference_refs(schema)
assert actual == expected
def test_dereference_refs_nested_refs_no_skip() -> None:
schema = {
"type": "object",
"properties": {
"info": {"$ref": "#/$defs/info"},
},
"$defs": {
"name": {"type": "string"},
"info": {
"type": "object",
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
},
},
}
expected = {
"type": "object",
"properties": {
"info": {
"type": "object",
"properties": {"age": "int", "name": {"type": "string"}},
},
},
"$defs": {
"name": {"type": "string"},
"info": {
"type": "object",
"properties": {"age": "int", "name": {"type": "string"}},
},
},
}
actual = dereference_refs(schema, skip_keys=())
assert actual == expected
def test_dereference_refs_missing_ref() -> None:
schema = {
"type": "object",
"properties": {
"first_name": {"$ref": "#/$defs/name"},
},
"$defs": {},
}
with pytest.raises(KeyError):
dereference_refs(schema)
def test_dereference_refs_remote_ref() -> None:
schema = {
"type": "object",
"properties": {
"first_name": {"$ref": "https://somewhere/else/name"},
},
}
with pytest.raises(ValueError):
dereference_refs(schema)

View File

@ -0,0 +1,79 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
def test_convert_pydantic_to_openai_function() -> None:
class Data(BaseModel):
"""The data to return."""
key: str = Field(..., description="API key")
days: int = Field(default=0, description="Number of days to forecast")
actual = convert_pydantic_to_openai_function(Data)
expected = {
"name": "Data",
"description": "The data to return.",
"parameters": {
"title": "Data",
"description": "The data to return.",
"type": "object",
"properties": {
"key": {"title": "Key", "description": "API key", "type": "string"},
"days": {
"title": "Days",
"description": "Number of days to forecast",
"default": 0,
"type": "integer",
},
},
"required": ["key"],
},
}
assert actual == expected
def test_convert_pydantic_to_openai_function_nested() -> None:
class Data(BaseModel):
"""The data to return."""
key: str = Field(..., description="API key")
days: int = Field(default=0, description="Number of days to forecast")
class Model(BaseModel):
"""The model to return."""
data: Data
actual = convert_pydantic_to_openai_function(Model)
expected = {
"name": "Model",
"description": "The model to return.",
"parameters": {
"title": "Model",
"description": "The model to return.",
"type": "object",
"properties": {
"data": {
"title": "Data",
"description": "The data to return.",
"type": "object",
"properties": {
"key": {
"title": "Key",
"description": "API key",
"type": "string",
},
"days": {
"title": "Days",
"description": "Number of days to forecast",
"default": 0,
"type": "integer",
},
},
"required": ["key"],
}
},
"required": ["data"],
},
}
assert actual == expected