langchain/libs/langserve/tests/unit_tests/test_validation.py
Eugene Yurtsev b05bb9e136
LangServe (#11046)
Adds LangServe package

* Integrate Runnables with Fast API creating Server and a RemoteRunnable
client
* Support multiple runnables for a given server
* Support sync/async/batch/abatch/stream/astream/astream_log on the
client side (using async implementations on server)
* Adds validation using annotations (relying on pydantic under the hood)
-- this still has some rough edges -- e.g., open api docs do NOT
generate correctly at the moment
* Uses pydantic v1 namespace

Known issues: type translation code doesn't handle a lot of types (e.g.,
TypedDicts)

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2023-09-28 10:52:44 +01:00

237 lines
6.3 KiB
Python

import typing
from typing import Optional
import pytest
from langchain.load.dump import dumpd
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from typing_extensions import TypedDict
try:
from pydantic.v1 import BaseModel, ValidationError
except ImportError:
from pydantic import BaseModel, ValidationError
from langserve.validation import (
create_batch_request_model,
create_invoke_request_model,
create_runnable_config_model,
replace_lc_object_types,
)
@pytest.mark.parametrize(
"test_case",
[
{
"input": {"a": "qqq"},
"kwargs": {},
"valid": False,
},
{
"input": {"a": 2},
"kwargs": "hello",
"valid": False,
},
{
"input": {"a": 2},
"config": "hello",
"valid": False,
},
{
"input": {"b": "hello"},
"valid": False,
},
{
"input": {"a": 2, "b": "hello"},
"config": "hello",
"valid": False,
},
{
"input": {"a": 2, "b": "hello"},
"valid": True,
},
{
"input": {"a": 2, "b": "hello"},
"valid": True,
},
{
"input": {"a": 2},
"valid": True,
},
],
)
def test_create_invoke_and_batch_models(test_case: dict) -> None:
"""Test that the invoke request model is created correctly."""
class Input(BaseModel):
"""Test input."""
a: int
b: Optional[str] = None
valid = test_case.pop("valid")
config = create_runnable_config_model("test", ["tags"])
model = create_invoke_request_model("namespace", Input, config)
if valid:
model(**test_case)
else:
with pytest.raises(ValidationError):
model(**test_case)
# Validate batch request
# same structure as input request, but
# 'input' is a list of inputs and is called 'inputs'
batch_model = create_batch_request_model("namespace", Input, config)
test_case["inputs"] = [test_case.pop("input")]
if valid:
batch_model(**test_case)
else:
with pytest.raises(ValidationError):
batch_model(**test_case)
@pytest.mark.parametrize(
"test_case",
[
{
"type": int,
"input": 1,
"valid": True,
},
{
"type": float,
"input": "name",
"valid": False,
},
{
"type": float,
"input": [3.2],
"valid": False,
},
{
"type": float,
"input": 1.1,
"valid": True,
},
{
"type": Optional[float],
"valid": True,
"input": None,
},
],
)
def test_validation(test_case) -> None:
"""Test that the invoke request model is created correctly."""
config = create_runnable_config_model("test", [])
model = create_invoke_request_model("namespace", test_case.pop("type"), config)
if test_case["valid"]:
model(**test_case)
else:
with pytest.raises(ValidationError):
model(**test_case)
def test_replace_lc_object_types() -> None:
"""Replace lc object types in a model."""
updated_type = replace_lc_object_types(typing.List[HumanMessage])
config = create_runnable_config_model("test", [])
invoke_request = create_invoke_request_model("namespace", updated_type, config)
invoke_request(
input=dumpd(
[
HumanMessage(content="Hello, world!"),
HumanMessage(content="Hello, world 2!"),
]
)
)
with pytest.raises(ValidationError):
invoke_request(input=[dumpd(AIMessage(content="Hello, world!"))])
with pytest.raises(ValidationError):
invoke_request(
input=dumpd(
[
AIMessage(content="Hello, world!"),
HumanMessage(content="Hello, world!"),
]
),
)
def test_batch_request_with_lc_serialization() -> None:
"""Test batch request with LC serialization."""
input_type = replace_lc_object_types(typing.List[HumanMessage])
config = create_runnable_config_model("test", [])
batch_request = create_batch_request_model("namespace", input_type, config)
with pytest.raises(ValidationError):
batch_request(inputs=dumpd([[SystemMessage(content="Hello, world!")]]))
with pytest.raises(ValidationError):
batch_request(inputs=dumpd(HumanMessage(content="Hello, world!")))
with pytest.raises(ValidationError):
batch_request(inputs=dumpd([HumanMessage(content="Hello, world!")]))
batch_request(inputs=dumpd([[HumanMessage(content="Hello, world!")]]))
class PlaceHolderTypedDict(TypedDict):
x: int
z: HumanMessage
@pytest.mark.parametrize(
"type_,input,is_valid",
[
(None, None, True),
(str, "hello", True),
(str, 123.0, True),
(float, "qwe", False),
(int, 1, True),
(int, "qwe", False),
(typing.Union[str, int], "hello", True),
(typing.Union[str, int], 3, True),
(typing.List[str], ["a", "b"], True),
(typing.List[str], ["a", None], False),
(typing.List[HumanMessage], [HumanMessage(content="hello, world!")], True),
(typing.List[HumanMessage], [SystemMessage(content="hello, world!")], False),
(
typing.List[typing.Union[HumanMessage, SystemMessage]],
[HumanMessage(content="he"), SystemMessage(content="hello, world!")],
True,
),
(
typing.List[typing.Union[HumanMessage, SystemMessage]],
HumanMessage(content="hello"),
False,
),
(
typing.Union[
typing.List[typing.Union[SystemMessage, HumanMessage, str]], str
],
["hello", "world"],
True,
),
],
)
def test_replace_lc_object_type(
type_: typing.Any, input: typing.Any, is_valid: bool
) -> None:
"""Verify that code runs on different python versions."""
new_type = replace_lc_object_types(type_)
class Model(BaseModel):
input_: new_type
if is_valid:
Model(input_=dumpd(input))
else:
with pytest.raises(ValidationError):
Model(input_=dumpd(input))