|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import time
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
from typing import Any, Dict, List, Set, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
|
|
|
|
from langchain_core.utils.json_schema import dereference_refs
|
|
|
|
@ -93,26 +94,45 @@ def get_schema(endpoint_spec: dict) -> dict:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_field(schema: dict, required: bool) -> Tuple[Any, Any]:
|
|
|
|
|
def create_field(
|
|
|
|
|
schema: dict, required: bool, created_model_names: Set[str]
|
|
|
|
|
) -> Tuple[Any, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Creates a Pydantic field based on the schema definition.
|
|
|
|
|
"""
|
|
|
|
|
field_type = type_mapping.get(schema.get("type", "string"), str)
|
|
|
|
|
if "anyOf" in schema:
|
|
|
|
|
field_types = [
|
|
|
|
|
create_field(sub_schema, required, created_model_names)[0]
|
|
|
|
|
for sub_schema in schema["anyOf"]
|
|
|
|
|
]
|
|
|
|
|
if len(field_types) == 1:
|
|
|
|
|
field_type = field_types[0] # Simplified handling
|
|
|
|
|
else:
|
|
|
|
|
field_type = Union[tuple(field_types)]
|
|
|
|
|
else:
|
|
|
|
|
field_type = type_mapping.get(schema.get("type", "string"), str)
|
|
|
|
|
|
|
|
|
|
description = schema.get("description", "")
|
|
|
|
|
|
|
|
|
|
# Handle nested objects
|
|
|
|
|
if schema["type"] == "object":
|
|
|
|
|
if schema.get("type") == "object":
|
|
|
|
|
nested_fields = {
|
|
|
|
|
k: create_field(v, k in schema.get("required", []))
|
|
|
|
|
k: create_field(v, k in schema.get("required", []), created_model_names)
|
|
|
|
|
for k, v in schema.get("properties", {}).items()
|
|
|
|
|
}
|
|
|
|
|
model_name = schema.get("title", "NestedModel")
|
|
|
|
|
model_name = schema.get("title", f"NestedModel{time.time()}")
|
|
|
|
|
if model_name in created_model_names:
|
|
|
|
|
# needs to be unique
|
|
|
|
|
model_name = model_name + str(time.time())
|
|
|
|
|
nested_model = create_model(model_name, **nested_fields) # type: ignore
|
|
|
|
|
created_model_names.add(model_name)
|
|
|
|
|
return nested_model, Field(... if required else None, description=description)
|
|
|
|
|
|
|
|
|
|
# Handle arrays
|
|
|
|
|
elif schema["type"] == "array":
|
|
|
|
|
item_type, _ = create_field(schema["items"], required=True)
|
|
|
|
|
elif schema.get("type") == "array":
|
|
|
|
|
item_type, _ = create_field(
|
|
|
|
|
schema["items"], required=True, created_model_names=created_model_names
|
|
|
|
|
)
|
|
|
|
|
return List[item_type], Field( # type: ignore
|
|
|
|
|
... if required else None, description=description
|
|
|
|
|
)
|
|
|
|
@ -128,9 +148,10 @@ def get_param_fields(endpoint_spec: dict) -> dict:
|
|
|
|
|
required_fields = schema.get("required", [])
|
|
|
|
|
|
|
|
|
|
fields = {}
|
|
|
|
|
created_model_names: Set[str] = set()
|
|
|
|
|
for key, value in properties.items():
|
|
|
|
|
is_required = key in required_fields
|
|
|
|
|
field_info = create_field(value, is_required)
|
|
|
|
|
field_info = create_field(value, is_required, created_model_names)
|
|
|
|
|
fields[key] = field_info
|
|
|
|
|
|
|
|
|
|
return fields
|
|
|
|
|