core: fix batch ordering test (#20952)

pull/19607/head^2
Erick Friis 1 month ago committed by GitHub
parent 8ed150b2fe
commit d4befd0cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -138,16 +138,16 @@ class FakeTracer(BaseTracer):
"child_execution_order": None, "child_execution_order": None,
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None, "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order, "dotted_order": new_dotted_order,
"inputs": { "inputs": (
k: self._replace_message_id(v) for k, v in run.inputs.items() {k: self._replace_message_id(v) for k, v in run.inputs.items()}
} if isinstance(run.inputs, dict)
if isinstance(run.inputs, dict) else run.inputs
else run.inputs, ),
"outputs": { "outputs": (
k: self._replace_message_id(v) for k, v in run.outputs.items() {k: self._replace_message_id(v) for k, v in run.outputs.items()}
} if isinstance(run.outputs, dict)
if isinstance(run.outputs, dict) else run.outputs
else run.outputs, ),
} }
) )
@ -1652,11 +1652,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list): for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld") call_arg = call.args[0]
if i == 0:
if call_arg == "hello":
assert call_arg == "hello"
assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {} assert call.args[1].get("metadata") == {}
else: else:
assert call_arg == "wooorld"
assert call.args[1].get("tags") == [] assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"} assert call.args[1].get("metadata") == {"key": "value"}
@ -1664,8 +1667,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
for i, call in enumerate(spy.call_args_list): for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {} assert call.args[1].get("metadata") == {}
spy.reset_mock() spy.reset_mock()
@ -1686,28 +1689,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
5, 5,
7, 7,
] ]
assert spy.call_args_list == [ assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
mocker.call( for call in spy.call_args_list:
"hello", assert call.args[1] == dict(
dict( metadata={"key": "value"},
metadata={"key": "value"}, tags=[],
tags=[], callbacks=None,
callbacks=None, recursion_limit=25,
recursion_limit=25, run_id=None,
run_id=None, )
),
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
]
async def test_prompt() -> None: async def test_prompt() -> None:

Loading…
Cancel
Save