langchain[patch]: Update unit tests to workaround a pydantic 2 issue (#24886)

This will allow our unit tests to pass when using AnyID() with our pydantic models.
This commit is contained in:
Eugene Yurtsev 2024-07-31 14:09:40 -04:00 committed by GitHub
parent 2019e31bc5
commit 7720483432
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 19 deletions

View File

@ -9,7 +9,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
def test_generic_fake_chat_model_invoke() -> None:
@ -17,11 +17,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
@ -29,11 +29,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
@ -46,16 +46,16 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
# Test streaming of additional kwargs.
@ -136,9 +136,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
@ -186,8 +186,8 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]

View File

@ -1,6 +1,44 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field
# Please note that the `id` field is assigned AFTER the model is created
# to workaround an issue with pydantic ignoring the __eq__ method on
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message