You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/langserve/tests/unit_tests/test_server_client.py

498 lines
16 KiB
Python

"""Test the server and client together."""
import asyncio
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager
from typing import List, Optional, Union
import httpx
import pytest
import pytest_asyncio
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import AsyncClient
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.schema.messages import HumanMessage, SystemMessage
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.base import RunnableLambda
from pytest_mock import MockerFixture
from langserve.client import RemoteRunnable
from langserve.server import add_routes
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop()
try:
yield loop
finally:
loop.close()
@pytest.fixture()
def app(event_loop: AbstractEventLoop) -> FastAPI:
"""A simple server that wraps a Runnable and exposes it as an API."""
async def add_one_or_passthrough(
x: Union[int, HumanMessage]
) -> Union[int, HumanMessage]:
"""Add one to int or passthrough."""
if isinstance(x, int):
return x + 1
else:
return x
runnable_lambda = RunnableLambda(func=add_one_or_passthrough)
app = FastAPI()
try:
add_routes(app, runnable_lambda)
yield app
finally:
del app
@pytest.fixture()
def client(app: FastAPI) -> RemoteRunnable:
"""Create a FastAPI app that exposes the Runnable as an API."""
remote_runnable_client = RemoteRunnable(url="http://localhost:9999")
sync_client = TestClient(app=app)
remote_runnable_client.sync_client = sync_client
yield remote_runnable_client
sync_client.close()
@asynccontextmanager
async def get_async_client(
server: FastAPI, path: Optional[str] = None
) -> RemoteRunnable:
"""Get an async client."""
url = "http://localhost:9999"
if path:
url += path
remote_runnable_client = RemoteRunnable(url=url)
async_client = AsyncClient(app=server, base_url=url)
remote_runnable_client.async_client = async_client
try:
yield remote_runnable_client
finally:
await async_client.aclose()
@pytest_asyncio.fixture()
async def async_client(app: FastAPI) -> RemoteRunnable:
"""Create a FastAPI app that exposes the Runnable as an API."""
async with get_async_client(app) as client:
yield client
def test_server(app: FastAPI) -> None:
"""Test the server directly via HTTP requests."""
sync_client = TestClient(app=app)
# Test invoke
response = sync_client.post("/invoke", json={"input": 1})
assert response.json() == {"output": 2}
# Test batch
response = sync_client.post("/batch", json={"inputs": [1]})
assert response.json() == {
"output": [2],
}
# TODO(Team): Fix test. Issue with eventloops right now when using sync client
## Test stream
# response = sync_client.post("/stream", json={"input": 1})
# assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
@pytest.mark.asyncio
async def test_server_async(app: FastAPI) -> None:
"""Test the server directly via HTTP requests."""
async_client = AsyncClient(app=app, base_url="http://localhost:9999")
# Test invoke
response = await async_client.post("/invoke", json={"input": 1})
assert response.json() == {"output": 2}
# Test batch
response = await async_client.post("/batch", json={"inputs": [1]})
assert response.json() == {
"output": [2],
}
# Test stream
response = await async_client.post("/stream", json={"input": 1})
assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
def test_invoke(client: RemoteRunnable) -> None:
"""Test sync invoke."""
assert client.invoke(1) == 2
assert client.invoke(HumanMessage(content="hello")) == HumanMessage(content="hello")
# Test invocation with config
assert client.invoke(1, config={"tags": ["test"]}) == 2
def test_batch(client: RemoteRunnable) -> None:
"""Test sync batch."""
assert client.batch([]) == []
assert client.batch([1, 2, 3]) == [2, 3, 4]
assert client.batch([HumanMessage(content="hello")]) == [
HumanMessage(content="hello")
]
@pytest.mark.asyncio
async def test_ainvoke(async_client: RemoteRunnable) -> None:
"""Test async invoke."""
assert await async_client.ainvoke(1) == 2
assert await async_client.ainvoke(HumanMessage(content="hello")) == HumanMessage(
content="hello"
)
@pytest.mark.asyncio
async def test_abatch(async_client: RemoteRunnable) -> None:
"""Test async batch."""
assert await async_client.abatch([]) == []
assert await async_client.abatch([1, 2, 3]) == [2, 3, 4]
assert await async_client.abatch([HumanMessage(content="hello")]) == [
HumanMessage(content="hello")
]
# TODO(Team): Determine how to test
# Some issue with event loops
# def test_stream(client: RemoteRunnable) -> None:
# """Test stream."""
# assert list(client.stream(1)) == [2]
@pytest.mark.asyncio
async def test_astream(async_client: RemoteRunnable) -> None:
"""Test async stream."""
outputs = []
async for chunk in async_client.astream(1):
outputs.append(chunk)
assert outputs == [2]
outputs = []
data = HumanMessage(content="hello")
async for chunk in async_client.astream(data):
outputs.append(chunk)
assert outputs == [data]
@pytest.mark.asyncio
async def test_astream_log(async_client: RemoteRunnable) -> None:
"""Test async stream."""
outputs = []
async for chunk in async_client.astream_log(1):
outputs.append(chunk)
assert len(outputs) == 3
op = outputs[0].ops[0]
uuid = op["value"]["id"]
assert op == {
"op": "replace",
"path": "",
"value": {
"final_output": {"output": 2},
"id": uuid,
"logs": [],
"streamed_output": [],
},
}
def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None:
"""Test as part of sequence."""
runnable = client | RunnableLambda(func=lambda x: x + 1)
# without config
assert runnable.invoke(1) == 3
# with config
assert runnable.invoke(1, config={"tags": ["test"]}) == 3
# without config
assert runnable.batch([1, 2]) == [3, 4]
# with config
assert runnable.batch([1, 2], config={"tags": ["test"]}) == [3, 4]
# TODO(Team): Determine how to test some issues with event loops for testing
# set up
# without config
# assert list(runnable.stream([1, 2])) == [3, 4]
# # with config
# assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4]
@pytest.mark.asyncio
async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None:
"""Test as part of a sequence.
This helps to verify that config is handled properly (e.g., callbacks are not
passed to the server, but other config is)
"""
runnable = async_client | RunnableLambda(
func=lambda x: x + 1 if isinstance(x, int) else x
).with_config({"run_name": "hello"})
# without config
assert await runnable.ainvoke(1) == 3
# with config
assert await runnable.ainvoke(1, config={"tags": ["test"]}) == 3
# without config
assert await runnable.abatch([1, 2]) == [3, 4]
# with config
assert await runnable.abatch([1, 2], config={"tags": ["test"]}) == [3, 4]
# Verify can pass many configs to batch
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
# Verify can ValueError on mismatched configs number
with pytest.raises(ValueError):
assert await runnable.abatch([1, 2], config=[configs[0]]) == [3, 4]
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
configs = [
{"tags": ["test"]},
{"tags": ["test2"], "other": "test"},
]
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
# Without config
assert [x async for x in runnable.astream(1)] == [3]
# With Config
assert [x async for x in runnable.astream(1, config={"tags": ["test"]})] == [3]
# With config and LC input data
assert [
x
async for x in runnable.astream(
HumanMessage(content="hello"), config={"tags": ["test"]}
)
] == [HumanMessage(content="hello")]
log_patches = [x async for x in runnable.astream_log(1)]
for log_patch in log_patches:
assert isinstance(log_patch, RunLogPatch)
# Only check the first entry (not validating implementation here)
first_op = log_patches[0].ops[0]
assert first_op["op"] == "replace"
assert first_op["path"] == ""
# Validate with HumanMessage
log_patches = [x async for x in runnable.astream_log(HumanMessage(content="hello"))]
for log_patch in log_patches:
assert isinstance(log_patch, RunLogPatch)
# Only check the first entry (not validating implementation here)
first_op = log_patches[0].ops[0]
assert first_op == {
"op": "replace",
"path": "",
"value": {
"final_output": None,
"id": first_op["value"]["id"],
"logs": [],
"streamed_output": [],
},
}
@pytest.mark.asyncio
async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
"""Test serving multiple runnables."""
async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1
async def mul_2(x: int) -> int:
"""Add one to simulate a valid function"""
return x * 2
app = FastAPI()
add_routes(app, RunnableLambda(add_one), path="/add_one")
add_routes(
app,
RunnableLambda(mul_2),
input_type=int,
path="/mul_2",
)
async with get_async_client(app, path="/add_one") as runnable:
async with get_async_client(app, path="/mul_2") as runnable2:
assert await runnable.ainvoke(1) == 2
assert await runnable2.ainvoke(4) == 8
composite_runnable = runnable | runnable2
assert await composite_runnable.ainvoke(3) == 8
# Invoke runnable (remote add_one), local add_one, remote mul_2
composite_runnable_2 = runnable | add_one | runnable2
assert await composite_runnable_2.ainvoke(3) == 10
@pytest.mark.asyncio
async def test_input_validation(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
"""Test client side and server side exceptions."""
async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1
server_runnable = RunnableLambda(func=add_one, afunc=add_one)
server_runnable2 = RunnableLambda(func=add_one, afunc=add_one)
app = FastAPI()
add_routes(
app,
server_runnable,
input_type=int,
path="/add_one",
)
add_routes(
app,
server_runnable2,
input_type=int,
path="/add_one_config",
config_keys=["tags", "run_name"],
)
async with get_async_client(app, path="/add_one") as runnable:
# Verify that can be invoked with valid input
assert await runnable.ainvoke(1) == 2
# Verify that the following substring is present in the error message
with pytest.raises(httpx.HTTPError):
await runnable.ainvoke("hello")
with pytest.raises(httpx.HTTPError):
await runnable.abatch(["hello"])
config = {"tags": ["test"]}
invoke_spy_1 = mocker.spy(server_runnable, "ainvoke")
# Verify config is handled correctly
async with get_async_client(app, path="/add_one") as runnable1:
# Verify that can be invoked with valid input
# Config ignored for runnable1
assert await runnable1.ainvoke(1, config=config) == 2
assert invoke_spy_1.call_args[1]["config"] == {}
invoke_spy_2 = mocker.spy(server_runnable2, "ainvoke")
async with get_async_client(app, path="/add_one_config") as runnable2:
# Config accepted for runnable2
assert await runnable2.ainvoke(1, config=config) == 2
# Config ignored
assert invoke_spy_2.call_args[1]["config"] == config
@pytest.mark.asyncio
async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
"""Test client side and server side exceptions."""
app = FastAPI()
# Test with langchain objects
add_routes(
app, RunnablePassthrough(), input_type=List[HumanMessage], config_keys=["tags"]
)
# Invoke request
async with get_async_client(app) as passthrough_runnable:
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.ainvoke("Hello")
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.ainvoke(["hello"])
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.ainvoke(HumanMessage(content="h"))
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.ainvoke([SystemMessage(content="hello")])
# Valid
result = await passthrough_runnable.ainvoke([HumanMessage(content="hello")])
assert isinstance(result, list)
assert isinstance(result[0], HumanMessage)
# Batch request
async with get_async_client(app) as passthrough_runnable:
# invalid
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.abatch("Hello")
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.abatch(["hello"])
with pytest.raises(httpx.HTTPError):
await passthrough_runnable.abatch([[SystemMessage(content="hello")]])
# valid
result = await passthrough_runnable.abatch([[HumanMessage(content="hello")]])
assert isinstance(result, list)
assert isinstance(result[0], list)
assert isinstance(result[0][0], HumanMessage)
def test_client_close() -> None:
"""Test that the client can be automatically."""
runnable = RemoteRunnable(url="/dev/null", timeout=1)
sync_client = runnable.sync_client
async_client = runnable.async_client
assert async_client.is_closed is False
assert sync_client.is_closed is False
del runnable
assert sync_client.is_closed is True
assert async_client.is_closed is True
@pytest.mark.asyncio
async def test_async_client_close() -> None:
"""Test that the client can be automatically."""
runnable = RemoteRunnable(url="/dev/null", timeout=1)
sync_client = runnable.sync_client
async_client = runnable.async_client
assert async_client.is_closed is False
assert sync_client.is_closed is False
del runnable
assert sync_client.is_closed is True
assert async_client.is_closed is True
@pytest.mark.asyncio
async def test_openapi_docs_with_identical_runnables(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
"""Test client side and server side exceptions."""
async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1
server_runnable = RunnableLambda(func=add_one)
server_runnable2 = RunnableLambda(func=add_one)
app = FastAPI()
add_routes(
app,
server_runnable,
path="/1",
)
# Add another route that uses the same schema (inferred from runnable input schema)
add_routes(
app,
server_runnable2,
path="/2",
config_keys=["tags", "run_name"],
)
async with AsyncClient(app=app, base_url="http://localhost:9999") as async_client:
response = await async_client.get("/openapi.json")
assert response.status_code == 200