diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 07ebd37abf..8ed3940cbe 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -138,16 +138,16 @@ class FakeTracer(BaseTracer): "child_execution_order": None, "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None, "dotted_order": new_dotted_order, - "inputs": { - k: self._replace_message_id(v) for k, v in run.inputs.items() - } - if isinstance(run.inputs, dict) - else run.inputs, - "outputs": { - k: self._replace_message_id(v) for k, v in run.outputs.items() - } - if isinstance(run.outputs, dict) - else run.outputs, + "inputs": ( + {k: self._replace_message_id(v) for k, v in run.inputs.items()} + if isinstance(run.inputs, dict) + else run.inputs + ), + "outputs": ( + {k: self._replace_message_id(v) for k, v in run.outputs.items()} + if isinstance(run.outputs, dict) + else run.outputs + ), } ) @@ -1652,11 +1652,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert len(spy.call_args_list) == 2 for i, call in enumerate(spy.call_args_list): - assert call.args[0] == ("hello" if i == 0 else "wooorld") - if i == 0: + call_arg = call.args[0] + + if call_arg == "hello": + assert call_arg == "hello" assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} else: + assert call_arg == "wooorld" assert call.args[1].get("tags") == [] assert call.args[1].get("metadata") == {"key": "value"} @@ -1664,8 +1667,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert len(spy.call_args_list) == 2 + assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} for i, call in enumerate(spy.call_args_list): - assert call.args[0] == ("hello" if i == 0 else "wooorld") assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} spy.reset_mock() @@ -1686,28 +1689,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: 5, 7, ] - assert spy.call_args_list == [ - mocker.call( - "hello", - dict( - metadata={"key": "value"}, - tags=[], - callbacks=None, - recursion_limit=25, - run_id=None, - ), - ), - mocker.call( - "wooorld", - dict( - metadata={"key": "value"}, - tags=[], - callbacks=None, - recursion_limit=25, - run_id=None, - ), - ), - ] + assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} + for call in spy.call_args_list: + assert call.args[1] == dict( + metadata={"key": "value"}, + tags=[], + callbacks=None, + recursion_limit=25, + run_id=None, + ) async def test_prompt() -> None: