mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
core[patch]: In unit tests, use _schema()
instead of BaseModel.schema() (#24930)
This PR introduces a module with some helper utilities for the pydantic 1 -> 2 migration. They're meant to be used in the following way: 1) Use the utility code to get unit tests pass without requiring modification to the unit tests 2) (If desired) upgrade the unit tests to match pydantic 2 output 3) (If desired) stop using the utility code Currently, this module contains a way to map `schema()` generated by pydantic 2 to (mostly) match the output from pydantic v1.
This commit is contained in:
parent
1827bb4042
commit
75776e4a54
@ -9,6 +9,7 @@ from langchain_core.output_parsers.json import (
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.json import parse_json_markdown, parse_partial_json
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
GOOD_JSON = """```json
|
||||
{
|
||||
@ -596,10 +597,10 @@ def test_base_model_schema_consistency() -> None:
|
||||
setup: str
|
||||
punchline: str
|
||||
|
||||
initial_joke_schema = {k: v for k, v in Joke.schema().items()}
|
||||
initial_joke_schema = {k: v for k, v in _schema(Joke).items()}
|
||||
SimpleJsonOutputParser(pydantic_object=Joke)
|
||||
openai_func = convert_to_openai_function(Joke)
|
||||
retrieved_joke_schema = {k: v for k, v in Joke.schema().items()}
|
||||
retrieved_joke_schema = {k: v for k, v in _schema(Joke).items()}
|
||||
|
||||
assert initial_joke_schema == retrieved_joke_schema
|
||||
assert openai_func.get("name", None) is not None
|
||||
|
@ -29,6 +29,7 @@ from langchain_core.prompts.chat import (
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -795,14 +796,14 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
|
||||
assert prompt_all_required.optional_variables == []
|
||||
with pytest.raises(ValidationError):
|
||||
prompt_all_required.input_schema(input="")
|
||||
assert prompt_all_required.input_schema.schema() == snapshot(name="required")
|
||||
assert _schema(prompt_all_required.input_schema) == snapshot(name="required")
|
||||
prompt_optional = ChatPromptTemplate(
|
||||
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
|
||||
)
|
||||
# input variables only lists required variables
|
||||
assert set(prompt_optional.input_variables) == {"input"}
|
||||
prompt_optional.input_schema(input="") # won't raise error
|
||||
assert prompt_optional.input_schema.schema() == snapshot(name="partial")
|
||||
assert _schema(prompt_optional.input_schema) == snapshot(name="partial")
|
||||
|
||||
|
||||
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
@ -69,7 +70,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(foo="bar") == "This is a bar test."
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"title": "Foo", "type": "string"}},
|
||||
@ -80,7 +81,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test."
|
||||
assert prompt.input_variables == ["bar", "foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -94,7 +95,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
|
||||
assert prompt.input_variables == ["bar", "foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -110,7 +111,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
"This foo is a bar test baz."
|
||||
)
|
||||
assert prompt.input_variables == ["foo", "obj"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -134,7 +135,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.")
|
||||
assert prompt.input_variables == []
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -151,7 +152,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
@ -183,7 +184,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
@ -238,7 +239,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
@ -286,7 +287,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
@ -309,7 +310,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"title": "Foo", "type": "object"}},
|
||||
|
94
libs/core/tests/unit_tests/pydantic_utils.py
Normal file
94
libs/core/tests/unit_tests/pydantic_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Helper utilities for pydantic.
|
||||
|
||||
This module includes helper utilities to ease the migration from pydantic v1 to v2.
|
||||
|
||||
They're meant to be used in the following way:
|
||||
|
||||
1) Use utility code to help (selected) unit tests pass without modifications
|
||||
2) Upgrade the unit tests to match pydantic 2
|
||||
3) Stop using the utility code
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Function to replace allOf with $ref
|
||||
def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
"""Replace allOf with $ref in the schema."""
|
||||
if isinstance(schema, dict):
|
||||
# If the schema has an allOf key with a single item that contains a $ref
|
||||
if (
|
||||
"allOf" in schema
|
||||
and len(schema["allOf"]) == 1
|
||||
and "$ref" in schema["allOf"][0]
|
||||
):
|
||||
schema["$ref"] = schema["allOf"][0]["$ref"]
|
||||
del schema["allOf"]
|
||||
if "default" in schema and schema["default"] is None:
|
||||
del schema["default"]
|
||||
else:
|
||||
# Recursively process nested schemas
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
_replace_all_of_with_ref(value)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_replace_all_of_with_ref(item)
|
||||
|
||||
|
||||
def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
"""Removing all none defaults.
|
||||
|
||||
Pydantic v1 did not generate these, but Pydantic v2 does.
|
||||
|
||||
The None defaults usually represent **NotRequired** fields, and the None value
|
||||
is actually **incorrect** as a value since the fields do not allow a None value.
|
||||
|
||||
See difference between Optional and NotRequired types in python.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, dict):
|
||||
if "default" in value and value["default"] is None:
|
||||
any_of = value.get("anyOf", [])
|
||||
for type_ in any_of:
|
||||
if "type" in type_ and type_["type"] == "null":
|
||||
break # Null type explicitly defined
|
||||
else:
|
||||
del value["default"]
|
||||
_remove_bad_none_defaults(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_remove_bad_none_defaults(item)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_remove_bad_none_defaults(item)
|
||||
|
||||
|
||||
def _schema(obj: Any) -> dict:
|
||||
"""Get the schema of a pydantic model in the pydantic v1 style.
|
||||
|
||||
This will attempt to map the schema as close as possible to the pydantic v1 schema.
|
||||
"""
|
||||
# Remap to old style schema
|
||||
if not hasattr(obj, "model_json_schema"): # V1 model
|
||||
return obj.schema()
|
||||
|
||||
# Then we're using V2 models internally.
|
||||
raise AssertionError(
|
||||
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
|
||||
"1) remove this exception\n"
|
||||
"2) confirm that the old unit tests pass, and if not look for difference\n"
|
||||
"3) update the unit tests to match the new schema\n"
|
||||
"4) remove this utility function\n"
|
||||
)
|
||||
|
||||
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
||||
if "$defs" in schema_:
|
||||
schema_["definitions"] = schema_["$defs"]
|
||||
del schema_["$defs"]
|
||||
|
||||
_replace_all_of_with_ref(schema_)
|
||||
_remove_bad_none_defaults(schema_)
|
||||
|
||||
return schema_
|
@ -11,6 +11,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
@ -18,10 +19,10 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
graph = StrOutputParser().get_graph()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert first_node.data.schema() == runnable.input_schema.schema() # type: ignore[union-attr]
|
||||
assert _schema(first_node.data) == _schema(runnable.input_schema) # type: ignore[union-attr]
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert last_node.data.schema() == runnable.output_schema.schema() # type: ignore[union-attr]
|
||||
assert _schema(last_node.data) == _schema(runnable.output_schema) # type: ignore[union-attr]
|
||||
assert len(graph.nodes) == 3
|
||||
assert len(graph.edges) == 2
|
||||
assert graph.edges[0].source == first_node.id
|
||||
|
@ -12,6 +12,7 @@ from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_interfaces() -> None:
|
||||
@ -434,9 +435,8 @@ def test_get_input_schema_input_dict() -> None:
|
||||
history_messages_key="history",
|
||||
output_messages_key="output",
|
||||
)
|
||||
assert (
|
||||
with_history.get_input_schema().schema()
|
||||
== RunnableWithChatHistoryInput.schema()
|
||||
assert _schema(with_history.get_input_schema()) == _schema(
|
||||
RunnableWithChatHistoryInput
|
||||
)
|
||||
|
||||
|
||||
@ -464,9 +464,8 @@ def test_get_input_schema_input_messages() -> None:
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable, get_session_history, output_messages_key="output"
|
||||
)
|
||||
assert (
|
||||
with_history.get_input_schema().schema()
|
||||
== RunnableWithChatHistoryInput.schema()
|
||||
assert _schema(with_history.get_input_schema()) == _schema(
|
||||
RunnableWithChatHistoryInput
|
||||
)
|
||||
|
||||
|
||||
|
@ -89,6 +89,7 @@ from langchain_core.tracers import (
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
@ -224,15 +225,15 @@ class FakeRetriever(BaseRetriever):
|
||||
def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
fake = FakeRunnable() # str -> int
|
||||
|
||||
assert fake.input_schema.schema() == {
|
||||
assert _schema(fake.input_schema) == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake.output_schema.schema() == {
|
||||
assert _schema(fake.output_schema) == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == {
|
||||
assert _schema(fake.config_schema(include=["tags", "metadata", "run_name"])) == {
|
||||
"title": "FakeRunnableConfig",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -244,22 +245,22 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
||||
|
||||
assert fake_bound.input_schema.schema() == {
|
||||
assert _schema(fake_bound.input_schema) == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_bound.output_schema.schema() == {
|
||||
assert _schema(fake_bound.output_schema) == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
|
||||
|
||||
assert fake_w_fallbacks.input_schema.schema() == {
|
||||
assert _schema(fake_w_fallbacks.input_schema) == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_w_fallbacks.output_schema.schema() == {
|
||||
assert _schema(fake_w_fallbacks.output_schema) == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -269,11 +270,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
|
||||
|
||||
assert typed_lambda.input_schema.schema() == {
|
||||
assert _schema(typed_lambda.input_schema) == {
|
||||
"title": "typed_lambda_impl_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_lambda.output_schema.schema() == {
|
||||
assert _schema(typed_lambda.output_schema) == {
|
||||
"title": "typed_lambda_impl_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -283,22 +284,22 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
|
||||
|
||||
assert typed_async_lambda.input_schema.schema() == {
|
||||
assert _schema(typed_async_lambda.input_schema) == {
|
||||
"title": "typed_async_lambda_impl_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_async_lambda.output_schema.schema() == {
|
||||
assert _schema(typed_async_lambda.output_schema) == {
|
||||
"title": "typed_async_lambda_impl_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
fake_ret = FakeRetriever() # str -> List[Document]
|
||||
|
||||
assert fake_ret.input_schema.schema() == {
|
||||
assert _schema(fake_ret.input_schema) == {
|
||||
"title": "FakeRetrieverInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_ret.output_schema.schema() == {
|
||||
assert _schema(fake_ret.output_schema) == {
|
||||
"title": "FakeRetrieverOutput",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Document"},
|
||||
@ -328,16 +329,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert fake_llm.input_schema.schema() == snapshot
|
||||
assert fake_llm.output_schema.schema() == {
|
||||
assert _schema(fake_llm.input_schema) == snapshot
|
||||
assert _schema(fake_llm.output_schema) == {
|
||||
"title": "FakeListLLMOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert fake_chat.input_schema.schema() == snapshot
|
||||
assert fake_chat.output_schema.schema() == snapshot
|
||||
assert _schema(fake_chat.input_schema) == snapshot
|
||||
assert _schema(fake_chat.output_schema) == snapshot
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@ -346,26 +347,26 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
]
|
||||
)
|
||||
|
||||
assert chat_prompt.input_schema.schema() == snapshot(
|
||||
assert _schema(chat_prompt.input_schema) == snapshot(
|
||||
name="chat_prompt_input_schema"
|
||||
)
|
||||
assert chat_prompt.output_schema.schema() == snapshot(
|
||||
assert _schema(chat_prompt.output_schema) == snapshot(
|
||||
name="chat_prompt_output_schema"
|
||||
)
|
||||
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
|
||||
assert prompt.input_schema.schema() == {
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
assert prompt.output_schema.schema() == snapshot
|
||||
assert _schema(prompt.output_schema) == snapshot
|
||||
|
||||
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
|
||||
|
||||
assert prompt_mapper.input_schema.schema() == {
|
||||
assert _schema(prompt_mapper.input_schema) == {
|
||||
"definitions": {
|
||||
"PromptInput": {
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
@ -378,12 +379,12 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"type": "array",
|
||||
"title": "RunnableEach<PromptTemplate>Input",
|
||||
}
|
||||
assert prompt_mapper.output_schema.schema() == snapshot
|
||||
assert _schema(prompt_mapper.output_schema) == snapshot
|
||||
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
assert list_parser.input_schema.schema() == snapshot
|
||||
assert list_parser.output_schema.schema() == {
|
||||
assert _schema(list_parser.input_schema) == snapshot
|
||||
assert _schema(list_parser.output_schema) == {
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
@ -391,13 +392,13 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
seq = prompt | fake_llm | list_parser
|
||||
|
||||
assert seq.input_schema.schema() == {
|
||||
assert _schema(seq.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
assert seq.output_schema.schema() == {
|
||||
assert _schema(seq.output_schema) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
@ -405,7 +406,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
router: Runnable = RouterRunnable({})
|
||||
|
||||
assert router.input_schema.schema() == {
|
||||
assert _schema(router.input_schema) == {
|
||||
"title": "RouterRunnableInput",
|
||||
"$ref": "#/definitions/RouterInput",
|
||||
"definitions": {
|
||||
@ -420,7 +421,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
},
|
||||
}
|
||||
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
|
||||
assert _schema(router.output_schema) == {"title": "RouterRunnableOutput"}
|
||||
|
||||
seq_w_map: Runnable = (
|
||||
prompt
|
||||
@ -432,13 +433,13 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
)
|
||||
|
||||
assert seq_w_map.input_schema.schema() == {
|
||||
assert _schema(seq_w_map.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
assert seq_w_map.output_schema.schema() == {
|
||||
assert _schema(seq_w_map.output_schema) == {
|
||||
"title": "RunnableParallel<original,as_list,length>Output",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -464,12 +465,12 @@ def test_passthrough_assign_schema() -> None:
|
||||
| fake_llm
|
||||
)
|
||||
|
||||
assert seq_w_assign.input_schema.schema() == {
|
||||
assert _schema(seq_w_assign.input_schema) == {
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"title": "RunnableSequenceInput",
|
||||
"type": "object",
|
||||
}
|
||||
assert seq_w_assign.output_schema.schema() == {
|
||||
assert _schema(seq_w_assign.output_schema) == {
|
||||
"title": "FakeListLLMOutput",
|
||||
"type": "string",
|
||||
}
|
||||
@ -481,7 +482,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
|
||||
# fallback to RunnableAssign.input_schema if next runnable doesn't have
|
||||
# expected dict input_schema
|
||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||
assert _schema(invalid_seq_w_assign.input_schema) == {
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableParallel<context>Input",
|
||||
"type": "object",
|
||||
@ -493,14 +494,14 @@ def test_passthrough_assign_schema() -> None:
|
||||
)
|
||||
def test_lambda_schemas() -> None:
|
||||
first_lambda = lambda x: x["hello"] # noqa: E731
|
||||
assert RunnableLambda(first_lambda).input_schema.schema() == {
|
||||
assert _schema(RunnableLambda(first_lambda).input_schema) == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
}
|
||||
|
||||
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||
assert RunnableLambda(second_lambda).input_schema.schema() == { # type: ignore[arg-type]
|
||||
assert _schema(RunnableLambda(second_lambda).input_schema) == { # type: ignore[arg-type]
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
@ -509,7 +510,7 @@ def test_lambda_schemas() -> None:
|
||||
def get_value(input): # type: ignore[no-untyped-def]
|
||||
return input["variable_name"]
|
||||
|
||||
assert RunnableLambda(get_value).input_schema.schema() == {
|
||||
assert _schema(RunnableLambda(get_value).input_schema) == {
|
||||
"title": "get_value_input",
|
||||
"type": "object",
|
||||
"properties": {"variable_name": {"title": "Variable Name"}},
|
||||
@ -518,7 +519,7 @@ def test_lambda_schemas() -> None:
|
||||
async def aget_value(input): # type: ignore[no-untyped-def]
|
||||
return (input["variable_name"], input.get("another"))
|
||||
|
||||
assert RunnableLambda(aget_value).input_schema.schema() == {
|
||||
assert _schema(RunnableLambda(aget_value).input_schema) == {
|
||||
"title": "aget_value_input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -534,7 +535,7 @@ def test_lambda_schemas() -> None:
|
||||
"byebye": input["yo"],
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values).input_schema.schema() == {
|
||||
assert _schema(RunnableLambda(aget_values).input_schema) == {
|
||||
"title": "aget_values_input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -560,9 +561,11 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
|
||||
assert (
|
||||
RunnableLambda(
|
||||
aget_values_typed # type: ignore[arg-type]
|
||||
).input_schema.schema()
|
||||
_schema(
|
||||
RunnableLambda(
|
||||
aget_values_typed # type: ignore[arg-type]
|
||||
).input_schema
|
||||
)
|
||||
== {
|
||||
"title": "aget_values_typed_input",
|
||||
"$ref": "#/definitions/InputType",
|
||||
@ -583,7 +586,7 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
)
|
||||
|
||||
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
|
||||
assert _schema(RunnableLambda(aget_values_typed).output_schema) == { # type: ignore[arg-type]
|
||||
"title": "aget_values_typed_output",
|
||||
"$ref": "#/definitions/OutputType",
|
||||
"definitions": {
|
||||
@ -640,7 +643,7 @@ def test_schema_complex_seq() -> None:
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
assert chain2.input_schema.schema() == {
|
||||
assert _schema(chain2.input_schema) == {
|
||||
"title": "RunnableParallel<city,language>Input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -649,17 +652,17 @@ def test_schema_complex_seq() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
assert chain2.output_schema.schema() == {
|
||||
assert _schema(chain2.output_schema) == {
|
||||
"title": "StrOutputParserOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
assert chain2.with_types(input_type=str).input_schema.schema() == {
|
||||
assert _schema(chain2.with_types(input_type=str).input_schema) == {
|
||||
"title": "RunnableSequenceInput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
assert chain2.with_types(input_type=int).output_schema.schema() == {
|
||||
assert _schema(chain2.with_types(input_type=int).output_schema) == {
|
||||
"title": "StrOutputParserOutput",
|
||||
"type": "string",
|
||||
}
|
||||
@ -667,7 +670,7 @@ def test_schema_complex_seq() -> None:
|
||||
class InputType(BaseModel):
|
||||
person: str
|
||||
|
||||
assert chain2.with_types(input_type=InputType).input_schema.schema() == {
|
||||
assert _schema(chain2.with_types(input_type=InputType).input_schema) == {
|
||||
"title": "InputType",
|
||||
"type": "object",
|
||||
"properties": {"person": {"title": "Person", "type": "string"}},
|
||||
@ -690,7 +693,7 @@ def test_configurable_fields() -> None:
|
||||
|
||||
assert fake_llm_configurable.invoke("...") == "a"
|
||||
|
||||
assert fake_llm_configurable.config_schema().schema() == {
|
||||
assert _schema(fake_llm_configurable.config_schema()) == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -733,7 +736,7 @@ def test_configurable_fields() -> None:
|
||||
text="Hello, John!"
|
||||
)
|
||||
|
||||
assert prompt_configurable.config_schema().schema() == {
|
||||
assert _schema(prompt_configurable.config_schema()) == {
|
||||
"title": "RunnableConfigurableFieldsConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -761,9 +764,11 @@ def test_configurable_fields() -> None:
|
||||
text="Hello, John! John!"
|
||||
)
|
||||
|
||||
assert prompt_configurable.with_config(
|
||||
configurable={"prompt_template": "Hello {name} in {lang}"}
|
||||
).input_schema.schema() == {
|
||||
assert _schema(
|
||||
prompt_configurable.with_config(
|
||||
configurable={"prompt_template": "Hello {name} in {lang}"}
|
||||
).input_schema
|
||||
) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -777,7 +782,7 @@ def test_configurable_fields() -> None:
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
assert _schema(chain_configurable.config_schema()) == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -814,12 +819,14 @@ def test_configurable_fields() -> None:
|
||||
== "c"
|
||||
)
|
||||
|
||||
assert chain_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||
"llm_responses": ["c"],
|
||||
}
|
||||
).input_schema.schema() == {
|
||||
assert _schema(
|
||||
chain_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||
"llm_responses": ["c"],
|
||||
}
|
||||
).input_schema
|
||||
) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -844,7 +851,7 @@ def test_configurable_fields() -> None:
|
||||
"llm3": "a",
|
||||
}
|
||||
|
||||
assert chain_with_map_configurable.config_schema().schema() == {
|
||||
assert _schema(chain_with_map_configurable.config_schema()) == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -945,7 +952,7 @@ def test_configurable_fields_prefix_keys() -> None:
|
||||
|
||||
chain = prompt | fake_llm
|
||||
|
||||
assert chain.config_schema().schema() == {
|
||||
assert _schema(chain.config_schema()) == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -1055,7 +1062,7 @@ def test_configurable_fields_example() -> None:
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
|
||||
assert chain_configurable.config_schema().schema() == {
|
||||
assert _schema(chain_configurable.config_schema()) == {
|
||||
"title": "RunnableSequenceConfig",
|
||||
"type": "object",
|
||||
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
|
||||
@ -3136,7 +3143,7 @@ def test_map_stream() -> None:
|
||||
|
||||
chain_pick_one = chain.pick("llm")
|
||||
|
||||
assert chain_pick_one.output_schema.schema() == {
|
||||
assert _schema(chain_pick_one.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "string",
|
||||
}
|
||||
@ -3159,7 +3166,7 @@ def test_map_stream() -> None:
|
||||
["llm", "hello"]
|
||||
)
|
||||
|
||||
assert chain_pick_two.output_schema.schema() == {
|
||||
assert _schema(chain_pick_two.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -3524,13 +3531,13 @@ def test_deep_stream_assign() -> None:
|
||||
|
||||
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
|
||||
|
||||
assert chain_with_assign.input_schema.schema() == {
|
||||
assert _schema(chain_with_assign.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"required": ["question"],
|
||||
}
|
||||
assert chain_with_assign.output_schema.schema() == {
|
||||
assert _schema(chain_with_assign.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -3575,13 +3582,13 @@ def test_deep_stream_assign() -> None:
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign_shadow.input_schema.schema() == {
|
||||
assert _schema(chain_with_assign_shadow.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"required": ["question"],
|
||||
}
|
||||
assert chain_with_assign_shadow.output_schema.schema() == {
|
||||
assert _schema(chain_with_assign_shadow.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -3650,13 +3657,13 @@ async def test_deep_astream_assign() -> None:
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign.input_schema.schema() == {
|
||||
assert _schema(chain_with_assign.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"required": ["question"],
|
||||
}
|
||||
assert chain_with_assign.output_schema.schema() == {
|
||||
assert _schema(chain_with_assign.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -3701,13 +3708,13 @@ async def test_deep_astream_assign() -> None:
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign_shadow.input_schema.schema() == {
|
||||
assert _schema(chain_with_assign_shadow.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"required": ["question"],
|
||||
}
|
||||
assert chain_with_assign_shadow.output_schema.schema() == {
|
||||
assert _schema(chain_with_assign_shadow.output_schema) == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -4355,7 +4362,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
||||
assert isinstance(body, Runnable)
|
||||
|
||||
assert isinstance(runnable.default, Runnable)
|
||||
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
|
||||
assert _schema(runnable.input_schema) == {"title": "RunnableBranchInput"}
|
||||
|
||||
|
||||
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||
@ -4702,8 +4709,8 @@ async def test_tool_from_runnable() -> None:
|
||||
{"question": "What up"}
|
||||
)
|
||||
assert chain_tool.description.endswith(repr(chain))
|
||||
assert chain_tool.args_schema.schema() == chain.input_schema.schema()
|
||||
assert chain_tool.args_schema.schema() == {
|
||||
assert _schema(chain_tool.args_schema) == _schema(chain.input_schema)
|
||||
assert _schema(chain_tool.args_schema) == {
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
@ -4721,8 +4728,8 @@ async def test_runnable_gen() -> None:
|
||||
|
||||
runnable = RunnableGenerator(gen)
|
||||
|
||||
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||
assert runnable.output_schema.schema() == {
|
||||
assert _schema(runnable.input_schema) == {"title": "gen_input"}
|
||||
assert _schema(runnable.output_schema) == {
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -4773,8 +4780,8 @@ async def test_runnable_gen_context_config() -> None:
|
||||
|
||||
runnable = RunnableGenerator(gen)
|
||||
|
||||
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||
assert runnable.output_schema.schema() == {
|
||||
assert _schema(runnable.input_schema) == {"title": "gen_input"}
|
||||
assert _schema(runnable.output_schema) == {
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -4907,11 +4914,11 @@ async def test_runnable_iter_context_config() -> None:
|
||||
yield fake.invoke(input * 2)
|
||||
yield fake.invoke(input * 3)
|
||||
|
||||
assert gen.input_schema.schema() == {
|
||||
assert _schema(gen.input_schema) == {
|
||||
"title": "gen_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert gen.output_schema.schema() == {
|
||||
assert _schema(gen.output_schema) == {
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -4958,11 +4965,11 @@ async def test_runnable_iter_context_config() -> None:
|
||||
yield await fake.ainvoke(input * 2)
|
||||
yield await fake.ainvoke(input * 3)
|
||||
|
||||
assert agen.input_schema.schema() == {
|
||||
assert _schema(agen.input_schema) == {
|
||||
"title": "agen_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert agen.output_schema.schema() == {
|
||||
assert _schema(agen.output_schema) == {
|
||||
"title": "agen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -5025,8 +5032,8 @@ async def test_runnable_lambda_context_config() -> None:
|
||||
output += fake.invoke(input * 3)
|
||||
return output
|
||||
|
||||
assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"}
|
||||
assert fun.output_schema.schema() == {
|
||||
assert _schema(fun.input_schema) == {"title": "fun_input", "type": "string"}
|
||||
assert _schema(fun.output_schema) == {
|
||||
"title": "fun_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -5074,8 +5081,8 @@ async def test_runnable_lambda_context_config() -> None:
|
||||
output += await fake.ainvoke(input * 3)
|
||||
return output
|
||||
|
||||
assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"}
|
||||
assert afun.output_schema.schema() == {
|
||||
assert _schema(afun.input_schema) == {"title": "afun_input", "type": "string"}
|
||||
assert _schema(afun.output_schema) == {
|
||||
"title": "afun_output",
|
||||
"type": "integer",
|
||||
}
|
||||
@ -5136,19 +5143,19 @@ async def test_runnable_gen_transform() -> None:
|
||||
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
|
||||
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
|
||||
|
||||
assert chain.input_schema.schema() == {
|
||||
assert _schema(chain.input_schema) == {
|
||||
"title": "gen_indexes_input",
|
||||
"type": "integer",
|
||||
}
|
||||
assert chain.output_schema.schema() == {
|
||||
assert _schema(chain.output_schema) == {
|
||||
"title": "plus_one_output",
|
||||
"type": "integer",
|
||||
}
|
||||
assert achain.input_schema.schema() == {
|
||||
assert _schema(achain.input_schema) == {
|
||||
"title": "gen_indexes_input",
|
||||
"type": "integer",
|
||||
}
|
||||
assert achain.output_schema.schema() == {
|
||||
assert _schema(achain.output_schema) == {
|
||||
"title": "aplus_one_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ from langchain_core.tools import (
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.pydantic import _create_subset_model
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@ -166,8 +167,8 @@ def test_decorated_function_schema_equivalent() -> None:
|
||||
assert isinstance(structured_tool_input, BaseTool)
|
||||
assert structured_tool_input.args_schema is not None
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
_schema(structured_tool_input.args_schema)["properties"]
|
||||
== _schema(_MockSchema)["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
@ -336,7 +337,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
assert _schema(structured_tool.args_schema) == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
@ -374,7 +375,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
assert _schema(structured_tool.args_schema) == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {
|
||||
@ -479,7 +480,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
assert _schema(structured_tool.args_schema) == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
@ -716,7 +717,7 @@ def test_structured_tool_from_function() -> None:
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
assert _schema(structured_tool.args_schema) == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"description": inspect.getdoc(foo),
|
||||
@ -863,9 +864,9 @@ def test_optional_subset_model_rewrite() -> None:
|
||||
|
||||
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
|
||||
|
||||
assert "a" not in model2.schema()["required"] # should be optional
|
||||
assert "b" in model2.schema()["required"] # should be required
|
||||
assert "c" not in model2.schema()["required"] # should be optional
|
||||
assert "a" not in _schema(model2)["required"] # should be optional
|
||||
assert "b" in _schema(model2)["required"] # should be required
|
||||
assert "c" not in _schema(model2)["required"] # should be optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1043,7 +1044,7 @@ def test_tool_arg_descriptions() -> None:
|
||||
return bar
|
||||
|
||||
foo1 = tool(foo)
|
||||
args_schema = foo1.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(foo1.args_schema) # type: ignore
|
||||
assert args_schema == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
@ -1057,7 +1058,7 @@ def test_tool_arg_descriptions() -> None:
|
||||
|
||||
# Test parses docstring
|
||||
foo2 = tool(foo, parse_docstring=True)
|
||||
args_schema = foo2.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(foo2.args_schema) # type: ignore
|
||||
expected = {
|
||||
"title": "fooSchema",
|
||||
"description": "The foo.",
|
||||
@ -1083,7 +1084,7 @@ def test_tool_arg_descriptions() -> None:
|
||||
return bar
|
||||
|
||||
as_tool = tool(foo3, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(as_tool.args_schema) # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
assert args_schema["properties"] == expected["properties"]
|
||||
|
||||
@ -1094,7 +1095,7 @@ def test_tool_arg_descriptions() -> None:
|
||||
return "bar"
|
||||
|
||||
as_tool = tool(foo4, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(as_tool.args_schema) # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
|
||||
def foo5(run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
|
||||
@ -1102,7 +1103,7 @@ def test_tool_arg_descriptions() -> None:
|
||||
return "bar"
|
||||
|
||||
as_tool = tool(foo5, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(as_tool.args_schema) # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
|
||||
|
||||
@ -1146,7 +1147,7 @@ def test_tool_annotated_descriptions() -> None:
|
||||
return bar
|
||||
|
||||
foo1 = tool(foo)
|
||||
args_schema = foo1.args_schema.schema() # type: ignore
|
||||
args_schema = _schema(foo1.args_schema) # type: ignore
|
||||
assert args_schema == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
@ -1239,7 +1240,7 @@ def test_convert_from_runnable_dict() -> None:
|
||||
as_tool = runnable.as_tool()
|
||||
args_schema = as_tool.args_schema
|
||||
assert args_schema is not None
|
||||
assert args_schema.schema() == {
|
||||
assert _schema(args_schema) == {
|
||||
"title": "f",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -1370,7 +1371,7 @@ def injected_tool_with_schema(x: int, y: str) -> str:
|
||||
|
||||
@pytest.mark.parametrize("tool_", [InjectedTool()])
|
||||
def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
||||
assert tool_.get_input_schema().schema() == {
|
||||
assert _schema(tool_.get_input_schema()) == {
|
||||
"title": "fooSchema",
|
||||
"description": "foo.\n\nArgs:\n x: abc\n y: 123",
|
||||
"type": "object",
|
||||
@ -1380,7 +1381,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
}
|
||||
assert tool_.tool_call_schema.schema() == {
|
||||
assert _schema(tool_.tool_call_schema) == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
@ -1413,7 +1414,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
||||
[injected_tool, injected_tool_with_schema, InjectedToolWithSchema()],
|
||||
)
|
||||
def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
||||
assert tool_.get_input_schema().schema() == {
|
||||
assert _schema(tool_.get_input_schema()) == {
|
||||
"title": "fooSchema",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
@ -1423,7 +1424,7 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
}
|
||||
assert tool_.tool_call_schema.schema() == {
|
||||
assert _schema(tool_.tool_call_schema) == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
|
Loading…
Reference in New Issue
Block a user