mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Update BaseChatModel.astream to respect generation_info (#9430)
Currently, generation_info is not respected by only reflecting messages in chunks. Change it to add generations so that generation chunks are merged properly. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
f29312eb84
commit
ca8232a3c1
@ -176,22 +176,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
message: Optional[BaseMessageChunk] = None
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if message is None:
|
||||
message = chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
message += chunk.message
|
||||
assert message is not None
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
async def astream(
|
||||
@ -223,22 +223,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
message: Optional[BaseMessageChunk] = None
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if message is None:
|
||||
message = chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
message += chunk.message
|
||||
assert message is not None
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@ -89,6 +91,34 @@ def test_chat_openai_streaming() -> None:
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming_generation_info() -> None:
|
||||
"""Test that generation info is preserved when streaming."""
|
||||
|
||||
class _FakeCallback(FakeCallbackHandler):
|
||||
saved_things: dict = {}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Save the generation
|
||||
self.saved_things["generation"] = args[0]
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
assert generation.generations[0][0].text == "Hello!"
|
||||
|
||||
|
||||
def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
|
Loading…
Reference in New Issue
Block a user