mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
langchain[patch]: Update unit tests to workaround a pydantic 2 issue (#24886)
This will allow our unit tests to pass when using AnyID() with our pydantic models.
This commit is contained in:
parent
2019e31bc5
commit
7720483432
@ -9,7 +9,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@ -17,11 +17,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@ -29,11 +29,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@ -46,16 +46,16 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
@ -136,9 +136,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
|
||||
@ -186,8 +186,8 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
|
@ -1,6 +1,44 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
||||
# The code below creates version of pydantic models
|
||||
# that will work in unit tests with AnyStr as id field
|
||||
# Please note that the `id` field is assigned AFTER the model is created
|
||||
# to workaround an issue with pydantic ignoring the __eq__ method on
|
||||
# subclassed strings.
|
||||
|
||||
|
||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||
"""Create a document with an id field."""
|
||||
message = Document(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
Loading…
Reference in New Issue
Block a user