langchain-robocorp: Fix parsing of Union types (such as Optional). (#22277)

This commit is contained in:
Mikko Korpela 2024-05-29 19:47:02 +03:00 committed by GitHub
parent af1f723ada
commit fc5909ad6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1988 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Set, Tuple, Union
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.utils.json_schema import dereference_refs
@ -93,26 +94,45 @@ def get_schema(endpoint_spec: dict) -> dict:
)
def create_field(schema: dict, required: bool) -> Tuple[Any, Any]:
def create_field(
schema: dict, required: bool, created_model_names: Set[str]
) -> Tuple[Any, Any]:
"""
Creates a Pydantic field based on the schema definition.
"""
field_type = type_mapping.get(schema.get("type", "string"), str)
if "anyOf" in schema:
field_types = [
create_field(sub_schema, required, created_model_names)[0]
for sub_schema in schema["anyOf"]
]
if len(field_types) == 1:
field_type = field_types[0] # Simplified handling
else:
field_type = Union[tuple(field_types)]
else:
field_type = type_mapping.get(schema.get("type", "string"), str)
description = schema.get("description", "")
# Handle nested objects
if schema["type"] == "object":
if schema.get("type") == "object":
nested_fields = {
k: create_field(v, k in schema.get("required", []))
k: create_field(v, k in schema.get("required", []), created_model_names)
for k, v in schema.get("properties", {}).items()
}
model_name = schema.get("title", "NestedModel")
model_name = schema.get("title", f"NestedModel{time.time()}")
if model_name in created_model_names:
# needs to be unique
model_name = model_name + str(time.time())
nested_model = create_model(model_name, **nested_fields) # type: ignore
created_model_names.add(model_name)
return nested_model, Field(... if required else None, description=description)
# Handle arrays
elif schema["type"] == "array":
item_type, _ = create_field(schema["items"], required=True)
elif schema.get("type") == "array":
item_type, _ = create_field(
schema["items"], required=True, created_model_names=created_model_names
)
return List[item_type], Field( # type: ignore
... if required else None, description=description
)
@ -128,9 +148,10 @@ def get_param_fields(endpoint_spec: dict) -> dict:
required_fields = schema.get("required", [])
fields = {}
created_model_names: Set[str] = set()
for key, value in properties.items():
is_required = key in required_fields
field_info = create_field(value, is_required)
field_info = create_field(value, is_required, created_model_names)
fields[key] = field_info
return fields

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,10 @@ import json
from pathlib import Path
from unittest.mock import MagicMock, patch
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_robocorp.toolkits import ActionServerToolkit
@ -118,3 +121,66 @@ Strictly adhere to the schema."""
],
}
assert params["properties"]["rows_to_add"] == expected
def test_get_tools_with_complex_inputs() -> None:
toolkit_instance = ActionServerToolkit(
url="http://example.com", api_key="dummy_key"
)
fixture_path = Path(__file__).with_name("_openapi3.fixture.json")
with patch(
"langchain_robocorp.toolkits.requests.get"
) as mocked_get, fixture_path.open("r") as f:
data = json.load(f) # Using json.load directly on the file object
mocked_response = MagicMock()
mocked_response.json.return_value = data
mocked_response.status_code = 200
mocked_response.headers = {"Content-Type": "application/json"}
mocked_get.return_value = mocked_response
# Execute
tools = toolkit_instance.get_tools()
assert len(tools) == 4
tool = tools[0]
assert tool.name == "create_event"
assert tool.description == "Creates a new event in the specified calendar."
all_tools_as_openai_tools = [convert_to_openai_tool(t) for t in tools]
openai_tool_spec = all_tools_as_openai_tools[0]["function"]
assert isinstance(
openai_tool_spec, dict
), "openai_func_spec should be a dictionary."
assert set(openai_tool_spec.keys()) == {
"description",
"name",
"parameters",
}, "Top-level keys mismatch."
assert openai_tool_spec["description"] == tool.description
assert openai_tool_spec["name"] == tool.name
assert isinstance(
openai_tool_spec["parameters"], dict
), "Parameters should be a dictionary."
params = openai_tool_spec["parameters"]
assert set(params.keys()) == {
"type",
"properties",
"required",
}, "Parameters keys mismatch."
assert params["type"] == "object", "`type` in parameters should be 'object'."
assert isinstance(
params["properties"], dict
), "`properties` should be a dictionary."
assert isinstance(params["required"], list), "`required` should be a list."
assert set(params["required"]) == {
"event",
}, "Required fields mismatch."
assert set(params["properties"].keys()) == {"calendar_id", "event"}