mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
498 lines
16 KiB
Python
498 lines
16 KiB
Python
"""Test the server and client together."""
|
|
import asyncio
|
|
from asyncio import AbstractEventLoop
|
|
from contextlib import asynccontextmanager
|
|
from typing import List, Optional, Union
|
|
|
|
import httpx
|
|
import pytest
|
|
import pytest_asyncio
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from httpx import AsyncClient
|
|
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
|
from langchain.schema.messages import HumanMessage, SystemMessage
|
|
from langchain.schema.runnable import RunnablePassthrough
|
|
from langchain.schema.runnable.base import RunnableLambda
|
|
from pytest_mock import MockerFixture
|
|
|
|
from langserve.client import RemoteRunnable
|
|
from langserve.server import add_routes
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for each test case."""
|
|
loop = asyncio.get_event_loop()
|
|
try:
|
|
yield loop
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture()
|
|
def app(event_loop: AbstractEventLoop) -> FastAPI:
|
|
"""A simple server that wraps a Runnable and exposes it as an API."""
|
|
|
|
async def add_one_or_passthrough(
|
|
x: Union[int, HumanMessage]
|
|
) -> Union[int, HumanMessage]:
|
|
"""Add one to int or passthrough."""
|
|
if isinstance(x, int):
|
|
return x + 1
|
|
else:
|
|
return x
|
|
|
|
runnable_lambda = RunnableLambda(func=add_one_or_passthrough)
|
|
app = FastAPI()
|
|
try:
|
|
add_routes(app, runnable_lambda)
|
|
yield app
|
|
finally:
|
|
del app
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(app: FastAPI) -> RemoteRunnable:
|
|
"""Create a FastAPI app that exposes the Runnable as an API."""
|
|
remote_runnable_client = RemoteRunnable(url="http://localhost:9999")
|
|
sync_client = TestClient(app=app)
|
|
remote_runnable_client.sync_client = sync_client
|
|
yield remote_runnable_client
|
|
sync_client.close()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_async_client(
|
|
server: FastAPI, path: Optional[str] = None
|
|
) -> RemoteRunnable:
|
|
"""Get an async client."""
|
|
url = "http://localhost:9999"
|
|
if path:
|
|
url += path
|
|
remote_runnable_client = RemoteRunnable(url=url)
|
|
async_client = AsyncClient(app=server, base_url=url)
|
|
remote_runnable_client.async_client = async_client
|
|
try:
|
|
yield remote_runnable_client
|
|
finally:
|
|
await async_client.aclose()
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
async def async_client(app: FastAPI) -> RemoteRunnable:
|
|
"""Create a FastAPI app that exposes the Runnable as an API."""
|
|
async with get_async_client(app) as client:
|
|
yield client
|
|
|
|
|
|
def test_server(app: FastAPI) -> None:
|
|
"""Test the server directly via HTTP requests."""
|
|
sync_client = TestClient(app=app)
|
|
|
|
# Test invoke
|
|
response = sync_client.post("/invoke", json={"input": 1})
|
|
assert response.json() == {"output": 2}
|
|
|
|
# Test batch
|
|
response = sync_client.post("/batch", json={"inputs": [1]})
|
|
assert response.json() == {
|
|
"output": [2],
|
|
}
|
|
|
|
# TODO(Team): Fix test. Issue with eventloops right now when using sync client
|
|
## Test stream
|
|
# response = sync_client.post("/stream", json={"input": 1})
|
|
# assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_async(app: FastAPI) -> None:
|
|
"""Test the server directly via HTTP requests."""
|
|
async_client = AsyncClient(app=app, base_url="http://localhost:9999")
|
|
|
|
# Test invoke
|
|
response = await async_client.post("/invoke", json={"input": 1})
|
|
assert response.json() == {"output": 2}
|
|
|
|
# Test batch
|
|
response = await async_client.post("/batch", json={"inputs": [1]})
|
|
assert response.json() == {
|
|
"output": [2],
|
|
}
|
|
|
|
# Test stream
|
|
response = await async_client.post("/stream", json={"input": 1})
|
|
assert response.text == "event: data\r\ndata: 2\r\n\r\nevent: end\r\n\r\n"
|
|
|
|
|
|
def test_invoke(client: RemoteRunnable) -> None:
|
|
"""Test sync invoke."""
|
|
assert client.invoke(1) == 2
|
|
assert client.invoke(HumanMessage(content="hello")) == HumanMessage(content="hello")
|
|
# Test invocation with config
|
|
assert client.invoke(1, config={"tags": ["test"]}) == 2
|
|
|
|
|
|
def test_batch(client: RemoteRunnable) -> None:
|
|
"""Test sync batch."""
|
|
assert client.batch([]) == []
|
|
assert client.batch([1, 2, 3]) == [2, 3, 4]
|
|
assert client.batch([HumanMessage(content="hello")]) == [
|
|
HumanMessage(content="hello")
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ainvoke(async_client: RemoteRunnable) -> None:
|
|
"""Test async invoke."""
|
|
assert await async_client.ainvoke(1) == 2
|
|
assert await async_client.ainvoke(HumanMessage(content="hello")) == HumanMessage(
|
|
content="hello"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_abatch(async_client: RemoteRunnable) -> None:
|
|
"""Test async batch."""
|
|
assert await async_client.abatch([]) == []
|
|
assert await async_client.abatch([1, 2, 3]) == [2, 3, 4]
|
|
assert await async_client.abatch([HumanMessage(content="hello")]) == [
|
|
HumanMessage(content="hello")
|
|
]
|
|
|
|
|
|
# TODO(Team): Determine how to test
|
|
# Some issue with event loops
|
|
# def test_stream(client: RemoteRunnable) -> None:
|
|
# """Test stream."""
|
|
# assert list(client.stream(1)) == [2]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_astream(async_client: RemoteRunnable) -> None:
|
|
"""Test async stream."""
|
|
outputs = []
|
|
|
|
async for chunk in async_client.astream(1):
|
|
outputs.append(chunk)
|
|
|
|
assert outputs == [2]
|
|
|
|
outputs = []
|
|
data = HumanMessage(content="hello")
|
|
|
|
async for chunk in async_client.astream(data):
|
|
outputs.append(chunk)
|
|
|
|
assert outputs == [data]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_astream_log(async_client: RemoteRunnable) -> None:
|
|
"""Test async stream."""
|
|
outputs = []
|
|
|
|
async for chunk in async_client.astream_log(1):
|
|
outputs.append(chunk)
|
|
|
|
assert len(outputs) == 3
|
|
|
|
op = outputs[0].ops[0]
|
|
uuid = op["value"]["id"]
|
|
assert op == {
|
|
"op": "replace",
|
|
"path": "",
|
|
"value": {
|
|
"final_output": {"output": 2},
|
|
"id": uuid,
|
|
"logs": [],
|
|
"streamed_output": [],
|
|
},
|
|
}
|
|
|
|
|
|
def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None:
|
|
"""Test as part of sequence."""
|
|
runnable = client | RunnableLambda(func=lambda x: x + 1)
|
|
# without config
|
|
assert runnable.invoke(1) == 3
|
|
# with config
|
|
assert runnable.invoke(1, config={"tags": ["test"]}) == 3
|
|
# without config
|
|
assert runnable.batch([1, 2]) == [3, 4]
|
|
# with config
|
|
assert runnable.batch([1, 2], config={"tags": ["test"]}) == [3, 4]
|
|
# TODO(Team): Determine how to test some issues with event loops for testing
|
|
# set up
|
|
# without config
|
|
# assert list(runnable.stream([1, 2])) == [3, 4]
|
|
# # with config
|
|
# assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None:
|
|
"""Test as part of a sequence.
|
|
|
|
This helps to verify that config is handled properly (e.g., callbacks are not
|
|
passed to the server, but other config is)
|
|
"""
|
|
runnable = async_client | RunnableLambda(
|
|
func=lambda x: x + 1 if isinstance(x, int) else x
|
|
).with_config({"run_name": "hello"})
|
|
# without config
|
|
assert await runnable.ainvoke(1) == 3
|
|
# with config
|
|
assert await runnable.ainvoke(1, config={"tags": ["test"]}) == 3
|
|
# without config
|
|
assert await runnable.abatch([1, 2]) == [3, 4]
|
|
# with config
|
|
assert await runnable.abatch([1, 2], config={"tags": ["test"]}) == [3, 4]
|
|
|
|
# Verify can pass many configs to batch
|
|
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
|
|
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
|
|
|
# Verify can ValueError on mismatched configs number
|
|
with pytest.raises(ValueError):
|
|
assert await runnable.abatch([1, 2], config=[configs[0]]) == [3, 4]
|
|
|
|
configs = [{"tags": ["test"]}, {"tags": ["test2"]}]
|
|
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
|
|
|
configs = [
|
|
{"tags": ["test"]},
|
|
{"tags": ["test2"], "other": "test"},
|
|
]
|
|
assert await runnable.abatch([1, 2], config=configs) == [3, 4]
|
|
|
|
# Without config
|
|
assert [x async for x in runnable.astream(1)] == [3]
|
|
|
|
# With Config
|
|
assert [x async for x in runnable.astream(1, config={"tags": ["test"]})] == [3]
|
|
|
|
# With config and LC input data
|
|
assert [
|
|
x
|
|
async for x in runnable.astream(
|
|
HumanMessage(content="hello"), config={"tags": ["test"]}
|
|
)
|
|
] == [HumanMessage(content="hello")]
|
|
|
|
log_patches = [x async for x in runnable.astream_log(1)]
|
|
for log_patch in log_patches:
|
|
assert isinstance(log_patch, RunLogPatch)
|
|
# Only check the first entry (not validating implementation here)
|
|
first_op = log_patches[0].ops[0]
|
|
assert first_op["op"] == "replace"
|
|
assert first_op["path"] == ""
|
|
|
|
# Validate with HumanMessage
|
|
log_patches = [x async for x in runnable.astream_log(HumanMessage(content="hello"))]
|
|
for log_patch in log_patches:
|
|
assert isinstance(log_patch, RunLogPatch)
|
|
# Only check the first entry (not validating implementation here)
|
|
first_op = log_patches[0].ops[0]
|
|
assert first_op == {
|
|
"op": "replace",
|
|
"path": "",
|
|
"value": {
|
|
"final_output": None,
|
|
"id": first_op["value"]["id"],
|
|
"logs": [],
|
|
"streamed_output": [],
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
|
|
"""Test serving multiple runnables."""
|
|
|
|
async def add_one(x: int) -> int:
|
|
"""Add one to simulate a valid function"""
|
|
return x + 1
|
|
|
|
async def mul_2(x: int) -> int:
|
|
"""Add one to simulate a valid function"""
|
|
return x * 2
|
|
|
|
app = FastAPI()
|
|
add_routes(app, RunnableLambda(add_one), path="/add_one")
|
|
add_routes(
|
|
app,
|
|
RunnableLambda(mul_2),
|
|
input_type=int,
|
|
path="/mul_2",
|
|
)
|
|
|
|
async with get_async_client(app, path="/add_one") as runnable:
|
|
async with get_async_client(app, path="/mul_2") as runnable2:
|
|
assert await runnable.ainvoke(1) == 2
|
|
assert await runnable2.ainvoke(4) == 8
|
|
|
|
composite_runnable = runnable | runnable2
|
|
assert await composite_runnable.ainvoke(3) == 8
|
|
|
|
# Invoke runnable (remote add_one), local add_one, remote mul_2
|
|
composite_runnable_2 = runnable | add_one | runnable2
|
|
assert await composite_runnable_2.ainvoke(3) == 10
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_input_validation(
|
|
event_loop: AbstractEventLoop, mocker: MockerFixture
|
|
) -> None:
|
|
"""Test client side and server side exceptions."""
|
|
|
|
async def add_one(x: int) -> int:
|
|
"""Add one to simulate a valid function"""
|
|
return x + 1
|
|
|
|
server_runnable = RunnableLambda(func=add_one, afunc=add_one)
|
|
server_runnable2 = RunnableLambda(func=add_one, afunc=add_one)
|
|
|
|
app = FastAPI()
|
|
add_routes(
|
|
app,
|
|
server_runnable,
|
|
input_type=int,
|
|
path="/add_one",
|
|
)
|
|
|
|
add_routes(
|
|
app,
|
|
server_runnable2,
|
|
input_type=int,
|
|
path="/add_one_config",
|
|
config_keys=["tags", "run_name"],
|
|
)
|
|
|
|
async with get_async_client(app, path="/add_one") as runnable:
|
|
# Verify that can be invoked with valid input
|
|
assert await runnable.ainvoke(1) == 2
|
|
# Verify that the following substring is present in the error message
|
|
with pytest.raises(httpx.HTTPError):
|
|
await runnable.ainvoke("hello")
|
|
|
|
with pytest.raises(httpx.HTTPError):
|
|
await runnable.abatch(["hello"])
|
|
|
|
config = {"tags": ["test"]}
|
|
|
|
invoke_spy_1 = mocker.spy(server_runnable, "ainvoke")
|
|
# Verify config is handled correctly
|
|
async with get_async_client(app, path="/add_one") as runnable1:
|
|
# Verify that can be invoked with valid input
|
|
# Config ignored for runnable1
|
|
assert await runnable1.ainvoke(1, config=config) == 2
|
|
assert invoke_spy_1.call_args[1]["config"] == {}
|
|
|
|
invoke_spy_2 = mocker.spy(server_runnable2, "ainvoke")
|
|
async with get_async_client(app, path="/add_one_config") as runnable2:
|
|
# Config accepted for runnable2
|
|
assert await runnable2.ainvoke(1, config=config) == 2
|
|
# Config ignored
|
|
assert invoke_spy_2.call_args[1]["config"] == config
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
|
|
"""Test client side and server side exceptions."""
|
|
|
|
app = FastAPI()
|
|
# Test with langchain objects
|
|
add_routes(
|
|
app, RunnablePassthrough(), input_type=List[HumanMessage], config_keys=["tags"]
|
|
)
|
|
# Invoke request
|
|
async with get_async_client(app) as passthrough_runnable:
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.ainvoke("Hello")
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.ainvoke(["hello"])
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.ainvoke(HumanMessage(content="h"))
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.ainvoke([SystemMessage(content="hello")])
|
|
|
|
# Valid
|
|
result = await passthrough_runnable.ainvoke([HumanMessage(content="hello")])
|
|
assert isinstance(result, list)
|
|
assert isinstance(result[0], HumanMessage)
|
|
|
|
# Batch request
|
|
async with get_async_client(app) as passthrough_runnable:
|
|
# invalid
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.abatch("Hello")
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.abatch(["hello"])
|
|
with pytest.raises(httpx.HTTPError):
|
|
await passthrough_runnable.abatch([[SystemMessage(content="hello")]])
|
|
|
|
# valid
|
|
result = await passthrough_runnable.abatch([[HumanMessage(content="hello")]])
|
|
assert isinstance(result, list)
|
|
assert isinstance(result[0], list)
|
|
assert isinstance(result[0][0], HumanMessage)
|
|
|
|
|
|
def test_client_close() -> None:
|
|
"""Test that the client can be automatically."""
|
|
runnable = RemoteRunnable(url="/dev/null", timeout=1)
|
|
sync_client = runnable.sync_client
|
|
async_client = runnable.async_client
|
|
assert async_client.is_closed is False
|
|
assert sync_client.is_closed is False
|
|
del runnable
|
|
assert sync_client.is_closed is True
|
|
assert async_client.is_closed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_client_close() -> None:
|
|
"""Test that the client can be automatically."""
|
|
runnable = RemoteRunnable(url="/dev/null", timeout=1)
|
|
sync_client = runnable.sync_client
|
|
async_client = runnable.async_client
|
|
assert async_client.is_closed is False
|
|
assert sync_client.is_closed is False
|
|
del runnable
|
|
assert sync_client.is_closed is True
|
|
assert async_client.is_closed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openapi_docs_with_identical_runnables(
|
|
event_loop: AbstractEventLoop, mocker: MockerFixture
|
|
) -> None:
|
|
"""Test client side and server side exceptions."""
|
|
|
|
async def add_one(x: int) -> int:
|
|
"""Add one to simulate a valid function"""
|
|
return x + 1
|
|
|
|
server_runnable = RunnableLambda(func=add_one)
|
|
server_runnable2 = RunnableLambda(func=add_one)
|
|
|
|
app = FastAPI()
|
|
add_routes(
|
|
app,
|
|
server_runnable,
|
|
path="/1",
|
|
)
|
|
# Add another route that uses the same schema (inferred from runnable input schema)
|
|
add_routes(
|
|
app,
|
|
server_runnable2,
|
|
path="/2",
|
|
config_keys=["tags", "run_name"],
|
|
)
|
|
|
|
async with AsyncClient(app=app, base_url="http://localhost:9999") as async_client:
|
|
response = await async_client.get("/openapi.json")
|
|
assert response.status_code == 200
|