mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
core: fix batch race condition in FakeListChatModel (#26924)
fixed #26273
This commit is contained in:
parent
87fc5ce688
commit
ab4dab9a0c
@ -13,6 +13,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
@ -128,6 +129,33 @@ class FakeListChatModel(SimpleChatModel):
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
# manually override batch to preserve batch ordering with no concurrency
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[Any],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseMessage]:
|
||||
if isinstance(config, list):
|
||||
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||
return [self.invoke(m, config, **kwargs) for m in inputs]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[Any],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseMessage]:
|
||||
if isinstance(config, list):
|
||||
# do Not use an async iterator here because need explicit ordering
|
||||
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||
# do Not use an async iterator here because need explicit ordering
|
||||
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
|
||||
|
||||
|
||||
class FakeChatModel(SimpleChatModel):
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
|
@ -5,7 +5,11 @@ from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
|
||||
from langchain_core.language_models import (
|
||||
FakeListChatModel,
|
||||
GenericFakeChatModel,
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None:
|
||||
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
|
||||
content="blah"
|
||||
)
|
||||
|
||||
|
||||
def test_fake_list_chat_model_batch() -> None:
|
||||
expected = [
|
||||
_any_id_ai_message(content="a"),
|
||||
_any_id_ai_message(content="b"),
|
||||
_any_id_ai_message(content="c"),
|
||||
]
|
||||
for _ in range(20):
|
||||
# run this 20 times to test race condition in batch
|
||||
fake = FakeListChatModel(responses=["a", "b", "c"])
|
||||
resp = fake.batch(["1", "2", "3"])
|
||||
assert resp == expected
|
||||
|
@ -199,19 +199,13 @@ async def test_global_cache_abatch() -> None:
|
||||
assert results[0].content == "hello"
|
||||
assert results[1].content == "hello"
|
||||
|
||||
## RACE CONDITION -- note behavior is different from sync
|
||||
# Now, reset cache and test the race condition
|
||||
# For now we just hard-code the result, if this changes
|
||||
# we can investigate further
|
||||
global_cache = InMemoryCache()
|
||||
set_llm_cache(global_cache)
|
||||
assert global_cache._cache == {}
|
||||
results = await chat_model.abatch(["prompt", "prompt"])
|
||||
# suspecting that tasks will be scheduled and executed in order
|
||||
# if this ever fails, we can relax to a set comparison
|
||||
# Cache misses likely guaranteed?
|
||||
|
||||
assert results[0].content == "meow"
|
||||
assert results[1].content == "woof"
|
||||
assert results[1].content == "meow"
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user