From a682f0d12bdffa95db0ed424172db9815d76374d Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Tue, 9 Apr 2024 19:40:16 +0200 Subject: [PATCH] openai[patch]: wrap stream code in context manager blocks (#18013) **Description:** Use the `Stream` context managers in `ChatOpenAi` `stream` and `astream` method. Using the context manager returned by the OpenAI client makes it possible to terminate the stream early since the response connection will be closed when the context manager exists. **Issue:** #5340 **Twitter handle:** @snopoke --------- Co-authored-by: Bagatur Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_openai/chat_models/base.py | 105 +++++++++--------- 1 file changed, 54 insertions(+), 51 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b7e103d9f1..7a108e7b64 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -457,30 +457,33 @@ class ChatOpenAI(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk - for chunk in self.client.create(messages=message_dicts, **params): - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) - yield chunk + with self.client.create(messages=message_dicts, **params) as response: + for chunk in response: + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) + if run_manager: + run_manager.on_llm_new_token( + chunk.text, chunk=chunk, logprobs=logprobs + ) + yield chunk def _generate( self, @@ -553,34 +556,34 @@ class ChatOpenAI(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk - async for chunk in await self.async_client.create( - messages=message_dicts, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) - if run_manager: - await run_manager.on_llm_new_token( - token=chunk.text, chunk=chunk, logprobs=logprobs + response = await self.async_client.create(messages=message_dicts, **params) + async with response: + async for chunk in response: + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None ) - yield chunk + if run_manager: + await run_manager.on_llm_new_token( + token=chunk.text, chunk=chunk, logprobs=logprobs + ) + yield chunk async def _agenerate( self,