core: fix batch ordering test (#20952)

pull/19607/head^2
Erick Friis 3 weeks ago committed by GitHub
parent 8ed150b2fe
commit d4befd0cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -138,16 +138,16 @@ class FakeTracer(BaseTracer):
"child_execution_order": None,
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order,
"inputs": {
k: self._replace_message_id(v) for k, v in run.inputs.items()
}
if isinstance(run.inputs, dict)
else run.inputs,
"outputs": {
k: self._replace_message_id(v) for k, v in run.outputs.items()
}
if isinstance(run.outputs, dict)
else run.outputs,
"inputs": (
{k: self._replace_message_id(v) for k, v in run.inputs.items()}
if isinstance(run.inputs, dict)
else run.inputs
),
"outputs": (
{k: self._replace_message_id(v) for k, v in run.outputs.items()}
if isinstance(run.outputs, dict)
else run.outputs
),
}
)
@ -1652,11 +1652,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
call_arg = call.args[0]
if call_arg == "hello":
assert call_arg == "hello"
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call_arg == "wooorld"
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
@ -1664,8 +1667,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
spy.reset_mock()
@ -1686,28 +1689,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
5,
7,
]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
]
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
for call in spy.call_args_list:
assert call.args[1] == dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
)
async def test_prompt() -> None:

Loading…
Cancel
Save