|
|
@ -1,5 +1,5 @@
|
|
|
|
from operator import itemgetter
|
|
|
|
from operator import itemgetter
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
|
|
from uuid import UUID
|
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
@ -34,6 +34,7 @@ from langchain.schema.retriever import BaseRetriever
|
|
|
|
from langchain.schema.runnable import (
|
|
|
|
from langchain.schema.runnable import (
|
|
|
|
RouterRunnable,
|
|
|
|
RouterRunnable,
|
|
|
|
Runnable,
|
|
|
|
Runnable,
|
|
|
|
|
|
|
|
RunnableBranch,
|
|
|
|
RunnableConfig,
|
|
|
|
RunnableConfig,
|
|
|
|
RunnableLambda,
|
|
|
|
RunnableLambda,
|
|
|
|
RunnableMap,
|
|
|
|
RunnableMap,
|
|
|
@ -541,7 +542,7 @@ async def test_prompt_with_llm(
|
|
|
|
mocker.stop(prompt_spy)
|
|
|
|
mocker.stop(prompt_spy)
|
|
|
|
mocker.stop(llm_spy)
|
|
|
|
mocker.stop(llm_spy)
|
|
|
|
|
|
|
|
|
|
|
|
# Test stream
|
|
|
|
# Test stream#
|
|
|
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
|
|
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
|
|
|
llm_spy = mocker.spy(llm.__class__, "astream")
|
|
|
|
llm_spy = mocker.spy(llm.__class__, "astream")
|
|
|
|
tracer = FakeTracer()
|
|
|
|
tracer = FakeTracer()
|
|
|
@ -1816,3 +1817,205 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|
|
|
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
|
|
|
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
|
|
|
assert len(parent_run_qux.child_runs) == 4
|
|
|
|
assert len(parent_run_qux.child_runs) == 4
|
|
|
|
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
|
|
|
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_runnable_branch_init() -> None:
|
|
|
|
|
|
|
|
"""Verify that runnable branch gets initialized properly."""
|
|
|
|
|
|
|
|
add = RunnableLambda(lambda x: x + 1)
|
|
|
|
|
|
|
|
condition = RunnableLambda(lambda x: x > 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test failure with less than 2 branches
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
|
|
RunnableBranch((condition, add))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test failure with less than 2 branches
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
|
|
RunnableBranch(condition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
|
|
|
"branches",
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
|
|
|
|
|
|
|
RunnableLambda(lambda x: x - 1),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
|
|
|
|
|
|
|
(RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)),
|
|
|
|
|
|
|
|
RunnableLambda(lambda x: x - 1),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
(lambda x: x > 0, lambda x: x + 1),
|
|
|
|
|
|
|
|
(lambda x: x > 5, lambda x: x + 1),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
|
|
|
|
|
|
|
"""Verify that runnable branch gets initialized properly."""
|
|
|
|
|
|
|
|
runnable = RunnableBranch[int, int](*branches)
|
|
|
|
|
|
|
|
for branch in runnable.branches:
|
|
|
|
|
|
|
|
condition, body = branch
|
|
|
|
|
|
|
|
assert isinstance(condition, Runnable)
|
|
|
|
|
|
|
|
assert isinstance(body, Runnable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(runnable.default, Runnable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
|
|
|
|
|
|
|
"""Verify that runnables are invoked only when necessary."""
|
|
|
|
|
|
|
|
# Test with single branch
|
|
|
|
|
|
|
|
add = RunnableLambda(lambda x: x + 1)
|
|
|
|
|
|
|
|
sub = RunnableLambda(lambda x: x - 1)
|
|
|
|
|
|
|
|
condition = RunnableLambda(lambda x: x > 0)
|
|
|
|
|
|
|
|
spy = mocker.spy(condition, "invoke")
|
|
|
|
|
|
|
|
add_spy = mocker.spy(add, "invoke")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int]((condition, add), (condition, add), sub)
|
|
|
|
|
|
|
|
assert spy.call_count == 0
|
|
|
|
|
|
|
|
assert add_spy.call_count == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.invoke(1) == 2
|
|
|
|
|
|
|
|
assert add_spy.call_count == 1
|
|
|
|
|
|
|
|
assert spy.call_count == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.invoke(2) == 3
|
|
|
|
|
|
|
|
assert spy.call_count == 2
|
|
|
|
|
|
|
|
assert add_spy.call_count == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.invoke(-3) == -4
|
|
|
|
|
|
|
|
# Should fall through to default branch with condition being evaluated twice!
|
|
|
|
|
|
|
|
assert spy.call_count == 4
|
|
|
|
|
|
|
|
# Add should not be invoked
|
|
|
|
|
|
|
|
assert add_spy.call_count == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_runnable_branch_invoke() -> None:
|
|
|
|
|
|
|
|
# Test with single branch
|
|
|
|
|
|
|
|
def raise_value_error(x: int) -> int:
|
|
|
|
|
|
|
|
"""Raise a value error."""
|
|
|
|
|
|
|
|
raise ValueError("x is too large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 100, raise_value_error),
|
|
|
|
|
|
|
|
# mypy cannot infer types from the lambda
|
|
|
|
|
|
|
|
(lambda x: x > 0 and x < 5, lambda x: x + 1), # type: ignore[misc]
|
|
|
|
|
|
|
|
(lambda x: x > 5, lambda x: x * 10),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.invoke(1) == 2
|
|
|
|
|
|
|
|
assert branch.invoke(10) == 100
|
|
|
|
|
|
|
|
assert branch.invoke(0) == -1
|
|
|
|
|
|
|
|
# Should raise an exception
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
|
|
branch.invoke(1000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_runnable_branch_batch() -> None:
|
|
|
|
|
|
|
|
"""Test batch variant."""
|
|
|
|
|
|
|
|
# Test with single branch
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
|
|
|
|
|
|
|
(lambda x: x > 5, lambda x: x * 10),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.batch([1, 10, 0]) == [2, 100, -1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
|
|
|
async def test_runnable_branch_ainvoke() -> None:
|
|
|
|
|
|
|
|
"""Test async variant of invoke."""
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
|
|
|
|
|
|
|
(lambda x: x > 5, lambda x: x * 10),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert await branch.ainvoke(1) == 2
|
|
|
|
|
|
|
|
assert await branch.ainvoke(10) == 100
|
|
|
|
|
|
|
|
assert await branch.ainvoke(0) == -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Verify that the async variant is used if available
|
|
|
|
|
|
|
|
async def condition(x: int) -> bool:
|
|
|
|
|
|
|
|
return x > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def add(x: int) -> int:
|
|
|
|
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def sub(x: int) -> int:
|
|
|
|
|
|
|
|
return x - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int]((condition, add), sub)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert await branch.ainvoke(1) == 2
|
|
|
|
|
|
|
|
assert await branch.ainvoke(-10) == -11
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_runnable_branch_invoke_callbacks() -> None:
|
|
|
|
|
|
|
|
"""Verify that callbacks are correctly used in invoke."""
|
|
|
|
|
|
|
|
tracer = FakeTracer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def raise_value_error(x: int) -> int:
|
|
|
|
|
|
|
|
"""Raise a value error."""
|
|
|
|
|
|
|
|
raise ValueError("x is too large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 100, raise_value_error),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert branch.invoke(1, config={"callbacks": [tracer]}) == 0
|
|
|
|
|
|
|
|
assert len(tracer.runs) == 1
|
|
|
|
|
|
|
|
assert tracer.runs[0].error is None
|
|
|
|
|
|
|
|
assert tracer.runs[0].outputs == {"output": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check that the chain on end is invoked
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
|
|
branch.invoke(1000, config={"callbacks": [tracer]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(tracer.runs) == 2
|
|
|
|
|
|
|
|
assert tracer.runs[1].error == "ValueError('x is too large')"
|
|
|
|
|
|
|
|
assert tracer.runs[1].outputs is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
|
|
|
async def test_runnable_branch_ainvoke_callbacks() -> None:
|
|
|
|
|
|
|
|
"""Verify that callbacks are invoked correctly in ainvoke."""
|
|
|
|
|
|
|
|
tracer = FakeTracer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def raise_value_error(x: int) -> int:
|
|
|
|
|
|
|
|
"""Raise a value error."""
|
|
|
|
|
|
|
|
raise ValueError("x is too large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 100, raise_value_error),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0
|
|
|
|
|
|
|
|
assert len(tracer.runs) == 1
|
|
|
|
|
|
|
|
assert tracer.runs[0].error is None
|
|
|
|
|
|
|
|
assert tracer.runs[0].outputs == {"output": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check that the chain on end is invoked
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
|
|
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(tracer.runs) == 2
|
|
|
|
|
|
|
|
assert tracer.runs[1].error == "ValueError('x is too large')"
|
|
|
|
|
|
|
|
assert tracer.runs[1].outputs is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
|
|
|
async def test_runnable_branch_abatch() -> None:
|
|
|
|
|
|
|
|
"""Test async variant of invoke."""
|
|
|
|
|
|
|
|
branch = RunnableBranch[int, int](
|
|
|
|
|
|
|
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
|
|
|
|
|
|
|
(lambda x: x > 5, lambda x: x * 10),
|
|
|
|
|
|
|
|
lambda x: x - 1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
|
|
|