mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
core[patch]: Pass sync run manager for sync stream fallback in astream (#19280)
This PR patches the fallback in chat models and language models to pass in the appropriate version of the run manager (sync vs. async)
This commit is contained in:
parent
d314acb2d5
commit
4b3dd34544
@ -273,6 +273,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if type(self)._astream is not BaseChatModel._astream:
|
||||
# Then astream is implemented
|
||||
_stream_implementation = self._astream
|
||||
using_sync_stream = False
|
||||
elif type(self)._stream is not BaseChatModel._stream:
|
||||
# Then stream is implemented, so we can create an async iterator from it
|
||||
# The typing is hard to type correctly with mypy here, so we cast
|
||||
@ -289,6 +290,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
],
|
||||
_as_async_iterator(self._stream),
|
||||
)
|
||||
using_sync_stream = True
|
||||
else: # No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
BaseMessageChunk,
|
||||
@ -318,10 +320,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in _stream_implementation(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager_, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
|
@ -11,7 +11,6 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
@ -279,25 +278,6 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
result = await run_in_executor(
|
||||
None,
|
||||
self._stream,
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in result:
|
||||
yield chunk
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
@ -463,6 +463,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
if type(self)._astream is not BaseLLM._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
_stream_implementation = self._astream
|
||||
using_sync_stream = False
|
||||
elif type(self)._stream is not BaseLLM._stream:
|
||||
# Then stream is implemented, so we can create an async iterator from it
|
||||
# The typing is hard to type correctly with mypy here, so we cast
|
||||
@ -479,6 +480,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
],
|
||||
_as_async_iterator(self._stream),
|
||||
)
|
||||
using_sync_stream = True
|
||||
else:
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
return
|
||||
@ -507,10 +509,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
||||
generation: Optional[GenerationChunk] = None
|
||||
try:
|
||||
async for chunk in _stream_implementation(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager_, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
|
@ -312,6 +312,68 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||
]
|
||||
|
||||
|
||||
async def test_astream_events_from_model() -> None:
|
||||
"""Test the output of a model."""
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
|
||||
)
|
||||
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
||||
model = (
|
||||
GenericFakeChatModel(messages=infinite_cycle)
|
||||
.with_config(
|
||||
{
|
||||
"metadata": {"a": "b"},
|
||||
"tags": ["my_model"],
|
||||
"run_name": "my_model",
|
||||
}
|
||||
)
|
||||
.bind(stop="<stop_token>")
|
||||
)
|
||||
events = await _collect_events(model.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="hello world!")},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_event_stream_with_simple_chain() -> None:
|
||||
"""Test as event stream."""
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
|
Loading…
Reference in New Issue
Block a user