Update BaseChatModel.astream to respect generation_info (#9430)

Currently, generation_info is not respected by only reflecting messages
in chunks. Change it to add generations so that generation chunks are
merged properly.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Kim Minjong 2023-08-23 05:18:24 +07:00 committed by GitHub
parent f29312eb84
commit ca8232a3c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 12 deletions

View File

@ -176,22 +176,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
if generation is None:
generation = chunk
else:
message += chunk.message
assert message is not None
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
LLMResult(generations=[[generation]]),
)
async def astream(
@ -223,22 +223,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
if generation is None:
generation = chunk
else:
message += chunk.message
assert message is not None
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
LLMResult(generations=[[generation]]),
)
# --- Custom methods ---

View File

@ -1,6 +1,8 @@
"""Test ChatOpenAI wrapper."""
from typing import Any
import pytest
from langchain.callbacks.manager import CallbackManager
@ -89,6 +91,34 @@ def test_chat_openai_streaming() -> None:
assert isinstance(response, BaseMessage)
@pytest.mark.scheduled
def test_chat_openai_streaming_generation_info() -> None:
"""Test that generation info is preserved when streaming."""
class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {}
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
# Save the generation
self.saved_things["generation"] = args[0]
callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = ChatOpenAI(
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
assert generation.generations[0][0].text == "Hello!"
def test_chat_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10)