Make tests stricter, remove old code, fix up pydantic import when using v2 (#11231)

Make tests stricter, remove old code, fix up pydantic import when using v2 (#11231)
pull/11239/head
Eugene Yurtsev 1 year ago committed by GitHub
parent 572968fee3
commit b4354b7694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@ from langchain.schema.runnable import Runnable
from typing_extensions import Annotated
try:
from pydantic.v1 import BaseModel
from pydantic.v1 import BaseModel, create_model
except ImportError:
from pydantic import BaseModel, create_model

@ -7,6 +7,11 @@ from langchain.schema.messages import (
SystemMessage,
)
try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
from langserve.serialization import simple_dumps, simple_loads
@ -120,3 +125,31 @@ def test_serialization(data: Any, expected_json: Any) -> None:
assert json.loads(simple_dumps(data)) == expected_json
# Test decoding
assert simple_loads(json.dumps(expected_json)) == data
# Test full representation are equivalent including the pydantic model classes
assert _get_full_representation(data) == _get_full_representation(
simple_loads(json.dumps(expected_json))
)
def _get_full_representation(data: Any) -> Any:
"""Get the full representation of the data, replacing pydantic models with schema.
Pydantic tests two different models for equality based on equality
of their schema; instead we will rely on the equality of their full
schema representation. This will make sure that both models have the
same name (e.g., HumanMessage vs. HumanMessageChunk).
Args:
data: python primitives + pydantic models
Returns:
data represented entirely with python primitives
"""
if isinstance(data, dict):
return {key: _get_full_representation(value) for key, value in data.items()}
elif isinstance(data, list):
return [_get_full_representation(value) for value in data]
elif isinstance(data, BaseModel):
return data.schema()
else:
return data

@ -231,19 +231,6 @@ def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None:
# assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4]
def test_pydantic_root():
from pydantic import BaseModel
class Model(BaseModel):
__root__: str
class Q(BaseModel):
input: Model
# s = Model(__root__=[23])
Q(input="hello")
@pytest.mark.asyncio
async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None:
"""Test as part of a sequence.

Loading…
Cancel
Save