mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
langchain-robocorp: Fix parsing of Union types (such as Optional). (#22277)
This commit is contained in:
parent
af1f723ada
commit
fc5909ad6f
@ -1,5 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
@ -93,26 +94,45 @@ def get_schema(endpoint_spec: dict) -> dict:
|
||||
)
|
||||
|
||||
|
||||
def create_field(schema: dict, required: bool) -> Tuple[Any, Any]:
|
||||
def create_field(
|
||||
schema: dict, required: bool, created_model_names: Set[str]
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Creates a Pydantic field based on the schema definition.
|
||||
"""
|
||||
if "anyOf" in schema:
|
||||
field_types = [
|
||||
create_field(sub_schema, required, created_model_names)[0]
|
||||
for sub_schema in schema["anyOf"]
|
||||
]
|
||||
if len(field_types) == 1:
|
||||
field_type = field_types[0] # Simplified handling
|
||||
else:
|
||||
field_type = Union[tuple(field_types)]
|
||||
else:
|
||||
field_type = type_mapping.get(schema.get("type", "string"), str)
|
||||
|
||||
description = schema.get("description", "")
|
||||
|
||||
# Handle nested objects
|
||||
if schema["type"] == "object":
|
||||
if schema.get("type") == "object":
|
||||
nested_fields = {
|
||||
k: create_field(v, k in schema.get("required", []))
|
||||
k: create_field(v, k in schema.get("required", []), created_model_names)
|
||||
for k, v in schema.get("properties", {}).items()
|
||||
}
|
||||
model_name = schema.get("title", "NestedModel")
|
||||
model_name = schema.get("title", f"NestedModel{time.time()}")
|
||||
if model_name in created_model_names:
|
||||
# needs to be unique
|
||||
model_name = model_name + str(time.time())
|
||||
nested_model = create_model(model_name, **nested_fields) # type: ignore
|
||||
created_model_names.add(model_name)
|
||||
return nested_model, Field(... if required else None, description=description)
|
||||
|
||||
# Handle arrays
|
||||
elif schema["type"] == "array":
|
||||
item_type, _ = create_field(schema["items"], required=True)
|
||||
elif schema.get("type") == "array":
|
||||
item_type, _ = create_field(
|
||||
schema["items"], required=True, created_model_names=created_model_names
|
||||
)
|
||||
return List[item_type], Field( # type: ignore
|
||||
... if required else None, description=description
|
||||
)
|
||||
@ -128,9 +148,10 @@ def get_param_fields(endpoint_spec: dict) -> dict:
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
fields = {}
|
||||
created_model_names: Set[str] = set()
|
||||
for key, value in properties.items():
|
||||
is_required = key in required_fields
|
||||
field_info = create_field(value, is_required)
|
||||
field_info = create_field(value, is_required, created_model_names)
|
||||
fields[key] = field_info
|
||||
|
||||
return fields
|
||||
|
1891
libs/partners/robocorp/tests/unit_tests/_openapi3.fixture.json
Normal file
1891
libs/partners/robocorp/tests/unit_tests/_openapi3.fixture.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,10 @@ import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
|
||||
from langchain_robocorp.toolkits import ActionServerToolkit
|
||||
|
||||
@ -118,3 +121,66 @@ Strictly adhere to the schema."""
|
||||
],
|
||||
}
|
||||
assert params["properties"]["rows_to_add"] == expected
|
||||
|
||||
|
||||
def test_get_tools_with_complex_inputs() -> None:
|
||||
toolkit_instance = ActionServerToolkit(
|
||||
url="http://example.com", api_key="dummy_key"
|
||||
)
|
||||
|
||||
fixture_path = Path(__file__).with_name("_openapi3.fixture.json")
|
||||
|
||||
with patch(
|
||||
"langchain_robocorp.toolkits.requests.get"
|
||||
) as mocked_get, fixture_path.open("r") as f:
|
||||
data = json.load(f) # Using json.load directly on the file object
|
||||
mocked_response = MagicMock()
|
||||
mocked_response.json.return_value = data
|
||||
mocked_response.status_code = 200
|
||||
mocked_response.headers = {"Content-Type": "application/json"}
|
||||
mocked_get.return_value = mocked_response
|
||||
|
||||
# Execute
|
||||
tools = toolkit_instance.get_tools()
|
||||
assert len(tools) == 4
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "create_event"
|
||||
assert tool.description == "Creates a new event in the specified calendar."
|
||||
|
||||
all_tools_as_openai_tools = [convert_to_openai_tool(t) for t in tools]
|
||||
openai_tool_spec = all_tools_as_openai_tools[0]["function"]
|
||||
|
||||
assert isinstance(
|
||||
openai_tool_spec, dict
|
||||
), "openai_func_spec should be a dictionary."
|
||||
assert set(openai_tool_spec.keys()) == {
|
||||
"description",
|
||||
"name",
|
||||
"parameters",
|
||||
}, "Top-level keys mismatch."
|
||||
|
||||
assert openai_tool_spec["description"] == tool.description
|
||||
assert openai_tool_spec["name"] == tool.name
|
||||
|
||||
assert isinstance(
|
||||
openai_tool_spec["parameters"], dict
|
||||
), "Parameters should be a dictionary."
|
||||
|
||||
params = openai_tool_spec["parameters"]
|
||||
assert set(params.keys()) == {
|
||||
"type",
|
||||
"properties",
|
||||
"required",
|
||||
}, "Parameters keys mismatch."
|
||||
assert params["type"] == "object", "`type` in parameters should be 'object'."
|
||||
assert isinstance(
|
||||
params["properties"], dict
|
||||
), "`properties` should be a dictionary."
|
||||
assert isinstance(params["required"], list), "`required` should be a list."
|
||||
|
||||
assert set(params["required"]) == {
|
||||
"event",
|
||||
}, "Required fields mismatch."
|
||||
|
||||
assert set(params["properties"].keys()) == {"calendar_id", "event"}
|
||||
|
Loading…
Reference in New Issue
Block a user