core[patch]: fixed circular dependency with json schema (#18657)

**Description:** Circular dependencies when parsing references leading
to `RecursionError: maximum recursion depth exceeded` issue. This PR
address the issue by handling previously seen refs as in any typical DFS
to avoid infinite depths.

**Issue:** https://github.com/langchain-ai/langchain/issues/12163

 **Twitter handle:** https://twitter.com/theBhulawat 


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Naman Jain 2024-03-12 11:12:45 +05:30 committed by GitHub
parent 0bec1f6877
commit 75122646b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 9 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Any, List, Optional, Sequence from typing import Any, Dict, List, Optional, Sequence, Set
def _retrieve_ref(path: str, schema: dict) -> dict: def _retrieve_ref(path: str, schema: dict) -> dict:
@ -21,40 +21,66 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
def _dereference_refs_helper( def _dereference_refs_helper(
obj: Any, full_schema: dict, skip_keys: Sequence[str] obj: Any,
full_schema: Dict[str, Any],
skip_keys: Sequence[str],
processed_refs: Optional[Set[str]] = None,
) -> Any: ) -> Any:
if processed_refs is None:
processed_refs = set()
if isinstance(obj, dict): if isinstance(obj, dict):
obj_out = {} obj_out = {}
for k, v in obj.items(): for k, v in obj.items():
if k in skip_keys: if k in skip_keys:
obj_out[k] = v obj_out[k] = v
elif k == "$ref": elif k == "$ref":
if v in processed_refs:
continue
processed_refs.add(v)
ref = _retrieve_ref(v, full_schema) ref = _retrieve_ref(v, full_schema)
return _dereference_refs_helper(ref, full_schema, skip_keys) full_ref = _dereference_refs_helper(
ref, full_schema, skip_keys, processed_refs
)
processed_refs.remove(v)
return full_ref
elif isinstance(v, (list, dict)): elif isinstance(v, (list, dict)):
obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys) obj_out[k] = _dereference_refs_helper(
v, full_schema, skip_keys, processed_refs
)
else: else:
obj_out[k] = v obj_out[k] = v
return obj_out return obj_out
elif isinstance(obj, list): elif isinstance(obj, list):
return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj] return [
_dereference_refs_helper(el, full_schema, skip_keys, processed_refs)
for el in obj
]
else: else:
return obj return obj
def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]: def _infer_skip_keys(
obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None
) -> List[str]:
if processed_refs is None:
processed_refs = set()
keys = [] keys = []
if isinstance(obj, dict): if isinstance(obj, dict):
for k, v in obj.items(): for k, v in obj.items():
if k == "$ref": if k == "$ref":
if v in processed_refs:
continue
processed_refs.add(v)
ref = _retrieve_ref(v, full_schema) ref = _retrieve_ref(v, full_schema)
keys.append(v.split("/")[1]) keys.append(v.split("/")[1])
keys += _infer_skip_keys(ref, full_schema) keys += _infer_skip_keys(ref, full_schema, processed_refs)
elif isinstance(v, (list, dict)): elif isinstance(v, (list, dict)):
keys += _infer_skip_keys(v, full_schema) keys += _infer_skip_keys(v, full_schema, processed_refs)
elif isinstance(obj, list): elif isinstance(obj, list):
for el in obj: for el in obj:
keys += _infer_skip_keys(el, full_schema) keys += _infer_skip_keys(el, full_schema, processed_refs)
return keys return keys

View File

@ -181,3 +181,54 @@ def test_dereference_refs_integer_ref() -> None:
} }
actual = dereference_refs(schema) actual = dereference_refs(schema)
assert actual == expected assert actual == expected
def test_dereference_refs_cyclical_refs() -> None:
schema = {
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/user"},
"customer": {"$ref": "#/$defs/user"},
},
"$defs": {
"user": {
"type": "object",
"properties": {
"friends": {"type": "array", "items": {"$ref": "#/$defs/user"}}
},
}
},
}
expected = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"friends": {
"type": "array",
"items": {}, # Recursion is broken here
}
},
},
"customer": {
"type": "object",
"properties": {
"friends": {
"type": "array",
"items": {}, # Recursion is broken here
}
},
},
},
"$defs": {
"user": {
"type": "object",
"properties": {
"friends": {"type": "array", "items": {"$ref": "#/$defs/user"}}
},
}
},
}
actual = dereference_refs(schema)
assert actual == expected