core[patch]: In unit tests, use _schema() instead of BaseModel.schema() (#24930)

This PR introduces a module with some helper utilities for the pydantic
1 -> 2 migration.

They're meant to be used in the following way:

1) Use the utility code to get unit tests pass without requiring
modification to the unit tests
2) (If desired) upgrade the unit tests to match pydantic 2 output
3) (If desired) stop using the utility code

Currently, this module contains a way to map `schema()` generated by
pydantic 2 to (mostly) match the output from pydantic v1.
This commit is contained in:
Eugene Yurtsev 2024-08-01 11:59:04 -04:00 committed by GitHub
parent 1827bb4042
commit 75776e4a54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 239 additions and 134 deletions

View File

@ -9,6 +9,7 @@ from langchain_core.output_parsers.json import (
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.json import parse_json_markdown, parse_partial_json
from tests.unit_tests.pydantic_utils import _schema
GOOD_JSON = """```json
{
@ -596,10 +597,10 @@ def test_base_model_schema_consistency() -> None:
setup: str
punchline: str
initial_joke_schema = {k: v for k, v in Joke.schema().items()}
initial_joke_schema = {k: v for k, v in _schema(Joke).items()}
SimpleJsonOutputParser(pydantic_object=Joke)
openai_func = convert_to_openai_function(Joke)
retrieved_joke_schema = {k: v for k, v in Joke.schema().items()}
retrieved_joke_schema = {k: v for k, v in _schema(Joke).items()}
assert initial_joke_schema == retrieved_joke_schema
assert openai_func.get("name", None) is not None

View File

@ -29,6 +29,7 @@ from langchain_core.prompts.chat import (
_convert_to_message,
)
from langchain_core.pydantic_v1 import ValidationError
from tests.unit_tests.pydantic_utils import _schema
@pytest.fixture
@ -795,14 +796,14 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert prompt_all_required.optional_variables == []
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
assert prompt_all_required.input_schema.schema() == snapshot(name="required")
assert _schema(prompt_all_required.input_schema) == snapshot(name="required")
prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
)
# input variables only lists required variables
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert prompt_optional.input_schema.schema() == snapshot(name="partial")
assert _schema(prompt_optional.input_schema) == snapshot(name="partial")
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:

View File

@ -7,6 +7,7 @@ import pytest
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def test_prompt_valid() -> None:
@ -69,7 +70,7 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="bar") == "This is a bar test."
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "string"}},
@ -80,7 +81,7 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test."
assert prompt.input_variables == ["bar", "foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {
@ -94,7 +95,7 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
assert prompt.input_variables == ["bar", "foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {
@ -110,7 +111,7 @@ def test_mustache_prompt_from_template() -> None:
"This foo is a bar test baz."
)
assert prompt.input_variables == ["foo", "obj"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {
@ -134,7 +135,7 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.")
assert prompt.input_variables == []
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {},
@ -151,7 +152,7 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
@ -183,7 +184,7 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
@ -238,7 +239,7 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
@ -286,7 +287,7 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
@ -309,7 +310,7 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "object"}},

View File

@ -0,0 +1,94 @@
"""Helper utilities for pydantic.
This module includes helper utilities to ease the migration from pydantic v1 to v2.
They're meant to be used in the following way:
1) Use utility code to help (selected) unit tests pass without modifications
2) Upgrade the unit tests to match pydantic 2
3) Stop using the utility code
"""
from typing import Any
# Function to replace allOf with $ref
def _replace_all_of_with_ref(schema: Any) -> None:
"""Replace allOf with $ref in the schema."""
if isinstance(schema, dict):
# If the schema has an allOf key with a single item that contains a $ref
if (
"allOf" in schema
and len(schema["allOf"]) == 1
and "$ref" in schema["allOf"][0]
):
schema["$ref"] = schema["allOf"][0]["$ref"]
del schema["allOf"]
if "default" in schema and schema["default"] is None:
del schema["default"]
else:
# Recursively process nested schemas
for key, value in schema.items():
if isinstance(value, (dict, list)):
_replace_all_of_with_ref(value)
elif isinstance(schema, list):
for item in schema:
_replace_all_of_with_ref(item)
def _remove_bad_none_defaults(schema: Any) -> None:
"""Removing all none defaults.
Pydantic v1 did not generate these, but Pydantic v2 does.
The None defaults usually represent **NotRequired** fields, and the None value
is actually **incorrect** as a value since the fields do not allow a None value.
See difference between Optional and NotRequired types in python.
"""
if isinstance(schema, dict):
for key, value in schema.items():
if isinstance(value, dict):
if "default" in value and value["default"] is None:
any_of = value.get("anyOf", [])
for type_ in any_of:
if "type" in type_ and type_["type"] == "null":
break # Null type explicitly defined
else:
del value["default"]
_remove_bad_none_defaults(value)
elif isinstance(value, list):
for item in value:
_remove_bad_none_defaults(item)
elif isinstance(schema, list):
for item in schema:
_remove_bad_none_defaults(item)
def _schema(obj: Any) -> dict:
"""Get the schema of a pydantic model in the pydantic v1 style.
This will attempt to map the schema as close as possible to the pydantic v1 schema.
"""
# Remap to old style schema
if not hasattr(obj, "model_json_schema"): # V1 model
return obj.schema()
# Then we're using V2 models internally.
raise AssertionError(
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
"1) remove this exception\n"
"2) confirm that the old unit tests pass, and if not look for difference\n"
"3) update the unit tests to match the new schema\n"
"4) remove this utility function\n"
)
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
if "$defs" in schema_:
schema_["definitions"] = schema_["$defs"]
del schema_["$defs"]
_replace_all_of_with_ref(schema_)
_remove_bad_none_defaults(schema_)
return schema_

View File

@ -11,6 +11,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.graph_mermaid import _escape_node_label
from tests.unit_tests.pydantic_utils import _schema
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
@ -18,10 +19,10 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
graph = StrOutputParser().get_graph()
first_node = graph.first_node()
assert first_node is not None
assert first_node.data.schema() == runnable.input_schema.schema() # type: ignore[union-attr]
assert _schema(first_node.data) == _schema(runnable.input_schema) # type: ignore[union-attr]
last_node = graph.last_node()
assert last_node is not None
assert last_node.data.schema() == runnable.output_schema.schema() # type: ignore[union-attr]
assert _schema(last_node.data) == _schema(runnable.output_schema) # type: ignore[union-attr]
assert len(graph.nodes) == 3
assert len(graph.edges) == 2
assert graph.edges[0].source == first_node.id

View File

@ -12,6 +12,7 @@ from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from tests.unit_tests.pydantic_utils import _schema
def test_interfaces() -> None:
@ -434,9 +435,8 @@ def test_get_input_schema_input_dict() -> None:
history_messages_key="history",
output_messages_key="output",
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
assert _schema(with_history.get_input_schema()) == _schema(
RunnableWithChatHistoryInput
)
@ -464,9 +464,8 @@ def test_get_input_schema_input_messages() -> None:
with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output"
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
assert _schema(with_history.get_input_schema()) == _schema(
RunnableWithChatHistoryInput
)

View File

@ -89,6 +89,7 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.pydantic_utils import _schema
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
@ -224,15 +225,15 @@ class FakeRetriever(BaseRetriever):
def test_schemas(snapshot: SnapshotAssertion) -> None:
fake = FakeRunnable() # str -> int
assert fake.input_schema.schema() == {
assert _schema(fake.input_schema) == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake.output_schema.schema() == {
assert _schema(fake.output_schema) == {
"title": "FakeRunnableOutput",
"type": "integer",
}
assert fake.config_schema(include=["tags", "metadata", "run_name"]).schema() == {
assert _schema(fake.config_schema(include=["tags", "metadata", "run_name"])) == {
"title": "FakeRunnableConfig",
"type": "object",
"properties": {
@ -244,22 +245,22 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
fake_bound = FakeRunnable().bind(a="b") # str -> int
assert fake_bound.input_schema.schema() == {
assert _schema(fake_bound.input_schema) == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_bound.output_schema.schema() == {
assert _schema(fake_bound.output_schema) == {
"title": "FakeRunnableOutput",
"type": "integer",
}
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
assert fake_w_fallbacks.input_schema.schema() == {
assert _schema(fake_w_fallbacks.input_schema) == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_w_fallbacks.output_schema.schema() == {
assert _schema(fake_w_fallbacks.output_schema) == {
"title": "FakeRunnableOutput",
"type": "integer",
}
@ -269,11 +270,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
assert typed_lambda.input_schema.schema() == {
assert _schema(typed_lambda.input_schema) == {
"title": "typed_lambda_impl_input",
"type": "string",
}
assert typed_lambda.output_schema.schema() == {
assert _schema(typed_lambda.output_schema) == {
"title": "typed_lambda_impl_output",
"type": "integer",
}
@ -283,22 +284,22 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
assert typed_async_lambda.input_schema.schema() == {
assert _schema(typed_async_lambda.input_schema) == {
"title": "typed_async_lambda_impl_input",
"type": "string",
}
assert typed_async_lambda.output_schema.schema() == {
assert _schema(typed_async_lambda.output_schema) == {
"title": "typed_async_lambda_impl_output",
"type": "integer",
}
fake_ret = FakeRetriever() # str -> List[Document]
assert fake_ret.input_schema.schema() == {
assert _schema(fake_ret.input_schema) == {
"title": "FakeRetrieverInput",
"type": "string",
}
assert fake_ret.output_schema.schema() == {
assert _schema(fake_ret.output_schema) == {
"title": "FakeRetrieverOutput",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
@ -328,16 +329,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert fake_llm.input_schema.schema() == snapshot
assert fake_llm.output_schema.schema() == {
assert _schema(fake_llm.input_schema) == snapshot
assert _schema(fake_llm.output_schema) == {
"title": "FakeListLLMOutput",
"type": "string",
}
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
assert fake_chat.input_schema.schema() == snapshot
assert fake_chat.output_schema.schema() == snapshot
assert _schema(fake_chat.input_schema) == snapshot
assert _schema(fake_chat.output_schema) == snapshot
chat_prompt = ChatPromptTemplate.from_messages(
[
@ -346,26 +347,26 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
]
)
assert chat_prompt.input_schema.schema() == snapshot(
assert _schema(chat_prompt.input_schema) == snapshot(
name="chat_prompt_input_schema"
)
assert chat_prompt.output_schema.schema() == snapshot(
assert _schema(chat_prompt.output_schema) == snapshot(
name="chat_prompt_output_schema"
)
prompt = PromptTemplate.from_template("Hello, {name}!")
assert prompt.input_schema.schema() == {
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
"required": ["name"],
}
assert prompt.output_schema.schema() == snapshot
assert _schema(prompt.output_schema) == snapshot
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
assert prompt_mapper.input_schema.schema() == {
assert _schema(prompt_mapper.input_schema) == {
"definitions": {
"PromptInput": {
"properties": {"name": {"title": "Name", "type": "string"}},
@ -378,12 +379,12 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"type": "array",
"title": "RunnableEach<PromptTemplate>Input",
}
assert prompt_mapper.output_schema.schema() == snapshot
assert _schema(prompt_mapper.output_schema) == snapshot
list_parser = CommaSeparatedListOutputParser()
assert list_parser.input_schema.schema() == snapshot
assert list_parser.output_schema.schema() == {
assert _schema(list_parser.input_schema) == snapshot
assert _schema(list_parser.output_schema) == {
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
"items": {"type": "string"},
@ -391,13 +392,13 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
seq = prompt | fake_llm | list_parser
assert seq.input_schema.schema() == {
assert _schema(seq.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
"required": ["name"],
}
assert seq.output_schema.schema() == {
assert _schema(seq.output_schema) == {
"type": "array",
"items": {"type": "string"},
"title": "CommaSeparatedListOutputParserOutput",
@ -405,7 +406,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
router: Runnable = RouterRunnable({})
assert router.input_schema.schema() == {
assert _schema(router.input_schema) == {
"title": "RouterRunnableInput",
"$ref": "#/definitions/RouterInput",
"definitions": {
@ -420,7 +421,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
}
},
}
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
assert _schema(router.output_schema) == {"title": "RouterRunnableOutput"}
seq_w_map: Runnable = (
prompt
@ -432,13 +433,13 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
}
)
assert seq_w_map.input_schema.schema() == {
assert _schema(seq_w_map.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
"required": ["name"],
}
assert seq_w_map.output_schema.schema() == {
assert _schema(seq_w_map.output_schema) == {
"title": "RunnableParallel<original,as_list,length>Output",
"type": "object",
"properties": {
@ -464,12 +465,12 @@ def test_passthrough_assign_schema() -> None:
| fake_llm
)
assert seq_w_assign.input_schema.schema() == {
assert _schema(seq_w_assign.input_schema) == {
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "RunnableSequenceInput",
"type": "object",
}
assert seq_w_assign.output_schema.schema() == {
assert _schema(seq_w_assign.output_schema) == {
"title": "FakeListLLMOutput",
"type": "string",
}
@ -481,7 +482,7 @@ def test_passthrough_assign_schema() -> None:
# fallback to RunnableAssign.input_schema if next runnable doesn't have
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
assert _schema(invalid_seq_w_assign.input_schema) == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallel<context>Input",
"type": "object",
@ -493,14 +494,14 @@ def test_passthrough_assign_schema() -> None:
)
def test_lambda_schemas() -> None:
first_lambda = lambda x: x["hello"] # noqa: E731
assert RunnableLambda(first_lambda).input_schema.schema() == {
assert _schema(RunnableLambda(first_lambda).input_schema) == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}},
}
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
assert RunnableLambda(second_lambda).input_schema.schema() == { # type: ignore[arg-type]
assert _schema(RunnableLambda(second_lambda).input_schema) == { # type: ignore[arg-type]
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
@ -509,7 +510,7 @@ def test_lambda_schemas() -> None:
def get_value(input): # type: ignore[no-untyped-def]
return input["variable_name"]
assert RunnableLambda(get_value).input_schema.schema() == {
assert _schema(RunnableLambda(get_value).input_schema) == {
"title": "get_value_input",
"type": "object",
"properties": {"variable_name": {"title": "Variable Name"}},
@ -518,7 +519,7 @@ def test_lambda_schemas() -> None:
async def aget_value(input): # type: ignore[no-untyped-def]
return (input["variable_name"], input.get("another"))
assert RunnableLambda(aget_value).input_schema.schema() == {
assert _schema(RunnableLambda(aget_value).input_schema) == {
"title": "aget_value_input",
"type": "object",
"properties": {
@ -534,7 +535,7 @@ def test_lambda_schemas() -> None:
"byebye": input["yo"],
}
assert RunnableLambda(aget_values).input_schema.schema() == {
assert _schema(RunnableLambda(aget_values).input_schema) == {
"title": "aget_values_input",
"type": "object",
"properties": {
@ -560,9 +561,11 @@ def test_lambda_schemas() -> None:
}
assert (
RunnableLambda(
aget_values_typed # type: ignore[arg-type]
).input_schema.schema()
_schema(
RunnableLambda(
aget_values_typed # type: ignore[arg-type]
).input_schema
)
== {
"title": "aget_values_typed_input",
"$ref": "#/definitions/InputType",
@ -583,7 +586,7 @@ def test_lambda_schemas() -> None:
}
)
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
assert _schema(RunnableLambda(aget_values_typed).output_schema) == { # type: ignore[arg-type]
"title": "aget_values_typed_output",
"$ref": "#/definitions/OutputType",
"definitions": {
@ -640,7 +643,7 @@ def test_schema_complex_seq() -> None:
| StrOutputParser()
)
assert chain2.input_schema.schema() == {
assert _schema(chain2.input_schema) == {
"title": "RunnableParallel<city,language>Input",
"type": "object",
"properties": {
@ -649,17 +652,17 @@ def test_schema_complex_seq() -> None:
},
}
assert chain2.output_schema.schema() == {
assert _schema(chain2.output_schema) == {
"title": "StrOutputParserOutput",
"type": "string",
}
assert chain2.with_types(input_type=str).input_schema.schema() == {
assert _schema(chain2.with_types(input_type=str).input_schema) == {
"title": "RunnableSequenceInput",
"type": "string",
}
assert chain2.with_types(input_type=int).output_schema.schema() == {
assert _schema(chain2.with_types(input_type=int).output_schema) == {
"title": "StrOutputParserOutput",
"type": "string",
}
@ -667,7 +670,7 @@ def test_schema_complex_seq() -> None:
class InputType(BaseModel):
person: str
assert chain2.with_types(input_type=InputType).input_schema.schema() == {
assert _schema(chain2.with_types(input_type=InputType).input_schema) == {
"title": "InputType",
"type": "object",
"properties": {"person": {"title": "Person", "type": "string"}},
@ -690,7 +693,7 @@ def test_configurable_fields() -> None:
assert fake_llm_configurable.invoke("...") == "a"
assert fake_llm_configurable.config_schema().schema() == {
assert _schema(fake_llm_configurable.config_schema()) == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -733,7 +736,7 @@ def test_configurable_fields() -> None:
text="Hello, John!"
)
assert prompt_configurable.config_schema().schema() == {
assert _schema(prompt_configurable.config_schema()) == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -761,9 +764,11 @@ def test_configurable_fields() -> None:
text="Hello, John! John!"
)
assert prompt_configurable.with_config(
configurable={"prompt_template": "Hello {name} in {lang}"}
).input_schema.schema() == {
assert _schema(
prompt_configurable.with_config(
configurable={"prompt_template": "Hello {name} in {lang}"}
).input_schema
) == {
"title": "PromptInput",
"type": "object",
"properties": {
@ -777,7 +782,7 @@ def test_configurable_fields() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == {
assert _schema(chain_configurable.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -814,12 +819,14 @@ def test_configurable_fields() -> None:
== "c"
)
assert chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).input_schema.schema() == {
assert _schema(
chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).input_schema
) == {
"title": "PromptInput",
"type": "object",
"properties": {
@ -844,7 +851,7 @@ def test_configurable_fields() -> None:
"llm3": "a",
}
assert chain_with_map_configurable.config_schema().schema() == {
assert _schema(chain_with_map_configurable.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -945,7 +952,7 @@ def test_configurable_fields_prefix_keys() -> None:
chain = prompt | fake_llm
assert chain.config_schema().schema() == {
assert _schema(chain.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -1055,7 +1062,7 @@ def test_configurable_fields_example() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.config_schema().schema() == {
assert _schema(chain_configurable.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
@ -3136,7 +3143,7 @@ def test_map_stream() -> None:
chain_pick_one = chain.pick("llm")
assert chain_pick_one.output_schema.schema() == {
assert _schema(chain_pick_one.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "string",
}
@ -3159,7 +3166,7 @@ def test_map_stream() -> None:
["llm", "hello"]
)
assert chain_pick_two.output_schema.schema() == {
assert _schema(chain_pick_two.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "object",
"properties": {
@ -3524,13 +3531,13 @@ def test_deep_stream_assign() -> None:
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
assert chain_with_assign.input_schema.schema() == {
assert _schema(chain_with_assign.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question", "type": "string"}},
"required": ["question"],
}
assert chain_with_assign.output_schema.schema() == {
assert _schema(chain_with_assign.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "object",
"properties": {
@ -3575,13 +3582,13 @@ def test_deep_stream_assign() -> None:
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
assert _schema(chain_with_assign_shadow.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question", "type": "string"}},
"required": ["question"],
}
assert chain_with_assign_shadow.output_schema.schema() == {
assert _schema(chain_with_assign_shadow.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "object",
"properties": {
@ -3650,13 +3657,13 @@ async def test_deep_astream_assign() -> None:
hello=itemgetter("str") | llm,
)
assert chain_with_assign.input_schema.schema() == {
assert _schema(chain_with_assign.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question", "type": "string"}},
"required": ["question"],
}
assert chain_with_assign.output_schema.schema() == {
assert _schema(chain_with_assign.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "object",
"properties": {
@ -3701,13 +3708,13 @@ async def test_deep_astream_assign() -> None:
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
assert _schema(chain_with_assign_shadow.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question", "type": "string"}},
"required": ["question"],
}
assert chain_with_assign_shadow.output_schema.schema() == {
assert _schema(chain_with_assign_shadow.output_schema) == {
"title": "RunnableSequenceOutput",
"type": "object",
"properties": {
@ -4355,7 +4362,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
assert isinstance(body, Runnable)
assert isinstance(runnable.default, Runnable)
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
assert _schema(runnable.input_schema) == {"title": "RunnableBranchInput"}
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
@ -4702,8 +4709,8 @@ async def test_tool_from_runnable() -> None:
{"question": "What up"}
)
assert chain_tool.description.endswith(repr(chain))
assert chain_tool.args_schema.schema() == chain.input_schema.schema()
assert chain_tool.args_schema.schema() == {
assert _schema(chain_tool.args_schema) == _schema(chain.input_schema)
assert _schema(chain_tool.args_schema) == {
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "PromptInput",
"type": "object",
@ -4721,8 +4728,8 @@ async def test_runnable_gen() -> None:
runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "gen_input"}
assert runnable.output_schema.schema() == {
assert _schema(runnable.input_schema) == {"title": "gen_input"}
assert _schema(runnable.output_schema) == {
"title": "gen_output",
"type": "integer",
}
@ -4773,8 +4780,8 @@ async def test_runnable_gen_context_config() -> None:
runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "gen_input"}
assert runnable.output_schema.schema() == {
assert _schema(runnable.input_schema) == {"title": "gen_input"}
assert _schema(runnable.output_schema) == {
"title": "gen_output",
"type": "integer",
}
@ -4907,11 +4914,11 @@ async def test_runnable_iter_context_config() -> None:
yield fake.invoke(input * 2)
yield fake.invoke(input * 3)
assert gen.input_schema.schema() == {
assert _schema(gen.input_schema) == {
"title": "gen_input",
"type": "string",
}
assert gen.output_schema.schema() == {
assert _schema(gen.output_schema) == {
"title": "gen_output",
"type": "integer",
}
@ -4958,11 +4965,11 @@ async def test_runnable_iter_context_config() -> None:
yield await fake.ainvoke(input * 2)
yield await fake.ainvoke(input * 3)
assert agen.input_schema.schema() == {
assert _schema(agen.input_schema) == {
"title": "agen_input",
"type": "string",
}
assert agen.output_schema.schema() == {
assert _schema(agen.output_schema) == {
"title": "agen_output",
"type": "integer",
}
@ -5025,8 +5032,8 @@ async def test_runnable_lambda_context_config() -> None:
output += fake.invoke(input * 3)
return output
assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"}
assert fun.output_schema.schema() == {
assert _schema(fun.input_schema) == {"title": "fun_input", "type": "string"}
assert _schema(fun.output_schema) == {
"title": "fun_output",
"type": "integer",
}
@ -5074,8 +5081,8 @@ async def test_runnable_lambda_context_config() -> None:
output += await fake.ainvoke(input * 3)
return output
assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"}
assert afun.output_schema.schema() == {
assert _schema(afun.input_schema) == {"title": "afun_input", "type": "string"}
assert _schema(afun.output_schema) == {
"title": "afun_output",
"type": "integer",
}
@ -5136,19 +5143,19 @@ async def test_runnable_gen_transform() -> None:
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.input_schema.schema() == {
assert _schema(chain.input_schema) == {
"title": "gen_indexes_input",
"type": "integer",
}
assert chain.output_schema.schema() == {
assert _schema(chain.output_schema) == {
"title": "plus_one_output",
"type": "integer",
}
assert achain.input_schema.schema() == {
assert _schema(achain.input_schema) == {
"title": "gen_indexes_input",
"type": "integer",
}
assert achain.output_schema.schema() == {
assert _schema(achain.output_schema) == {
"title": "aplus_one_output",
"type": "integer",
}

View File

@ -39,6 +39,7 @@ from langchain_core.tools import (
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import _create_subset_model
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def test_unnamed_decorator() -> None:
@ -166,8 +167,8 @@ def test_decorated_function_schema_equivalent() -> None:
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
_schema(structured_tool_input.args_schema)["properties"]
== _schema(_MockSchema)["properties"]
== structured_tool_input.args
)
@ -336,7 +337,7 @@ def test_structured_tool_from_function_docstring() -> None:
"baz": {"title": "Baz", "type": "string"},
}
assert structured_tool.args_schema.schema() == {
assert _schema(structured_tool.args_schema) == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
@ -374,7 +375,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
},
}
assert structured_tool.args_schema.schema() == {
assert _schema(structured_tool.args_schema) == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {
@ -479,7 +480,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
"baz": {"title": "Baz", "type": "string"},
}
assert structured_tool.args_schema.schema() == {
assert _schema(structured_tool.args_schema) == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
@ -716,7 +717,7 @@ def test_structured_tool_from_function() -> None:
"baz": {"title": "Baz", "type": "string"},
}
assert structured_tool.args_schema.schema() == {
assert _schema(structured_tool.args_schema) == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
@ -863,9 +864,9 @@ def test_optional_subset_model_rewrite() -> None:
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
assert "a" not in model2.schema()["required"] # should be optional
assert "b" in model2.schema()["required"] # should be required
assert "c" not in model2.schema()["required"] # should be optional
assert "a" not in _schema(model2)["required"] # should be optional
assert "b" in _schema(model2)["required"] # should be required
assert "c" not in _schema(model2)["required"] # should be optional
@pytest.mark.parametrize(
@ -1043,7 +1044,7 @@ def test_tool_arg_descriptions() -> None:
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
args_schema = _schema(foo1.args_schema) # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
@ -1057,7 +1058,7 @@ def test_tool_arg_descriptions() -> None:
# Test parses docstring
foo2 = tool(foo, parse_docstring=True)
args_schema = foo2.args_schema.schema() # type: ignore
args_schema = _schema(foo2.args_schema) # type: ignore
expected = {
"title": "fooSchema",
"description": "The foo.",
@ -1083,7 +1084,7 @@ def test_tool_arg_descriptions() -> None:
return bar
as_tool = tool(foo3, parse_docstring=True)
args_schema = as_tool.args_schema.schema() # type: ignore
args_schema = _schema(as_tool.args_schema) # type: ignore
assert args_schema["description"] == expected["description"]
assert args_schema["properties"] == expected["properties"]
@ -1094,7 +1095,7 @@ def test_tool_arg_descriptions() -> None:
return "bar"
as_tool = tool(foo4, parse_docstring=True)
args_schema = as_tool.args_schema.schema() # type: ignore
args_schema = _schema(as_tool.args_schema) # type: ignore
assert args_schema["description"] == expected["description"]
def foo5(run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
@ -1102,7 +1103,7 @@ def test_tool_arg_descriptions() -> None:
return "bar"
as_tool = tool(foo5, parse_docstring=True)
args_schema = as_tool.args_schema.schema() # type: ignore
args_schema = _schema(as_tool.args_schema) # type: ignore
assert args_schema["description"] == expected["description"]
@ -1146,7 +1147,7 @@ def test_tool_annotated_descriptions() -> None:
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
args_schema = _schema(foo1.args_schema) # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
@ -1239,7 +1240,7 @@ def test_convert_from_runnable_dict() -> None:
as_tool = runnable.as_tool()
args_schema = as_tool.args_schema
assert args_schema is not None
assert args_schema.schema() == {
assert _schema(args_schema) == {
"title": "f",
"type": "object",
"properties": {
@ -1370,7 +1371,7 @@ def injected_tool_with_schema(x: int, y: str) -> str:
@pytest.mark.parametrize("tool_", [InjectedTool()])
def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.get_input_schema().schema() == {
assert _schema(tool_.get_input_schema()) == {
"title": "fooSchema",
"description": "foo.\n\nArgs:\n x: abc\n y: 123",
"type": "object",
@ -1380,7 +1381,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
},
"required": ["x", "y"],
}
assert tool_.tool_call_schema.schema() == {
assert _schema(tool_.tool_call_schema) == {
"title": "foo",
"description": "foo.",
"type": "object",
@ -1413,7 +1414,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
[injected_tool, injected_tool_with_schema, InjectedToolWithSchema()],
)
def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
assert tool_.get_input_schema().schema() == {
assert _schema(tool_.get_input_schema()) == {
"title": "fooSchema",
"description": "foo.",
"type": "object",
@ -1423,7 +1424,7 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
},
"required": ["x", "y"],
}
assert tool_.tool_call_schema.schema() == {
assert _schema(tool_.tool_call_schema) == {
"title": "foo",
"description": "foo.",
"type": "object",