From c9da300e4d4602e4faa283d0878b416bb109406e Mon Sep 17 00:00:00 2001 From: Vic Cao <133215350+viccao95@users.noreply.github.com> Date: Mon, 7 Aug 2023 17:18:30 +0800 Subject: [PATCH] fix: overwrite stream for ChatOpenAI in runtime (#8288) @hwchase17, @baskaryan --------- Co-authored-by: Bagatur Co-authored-by: Nuno Campos --- libs/langchain/langchain/chat_models/openai.py | 6 ++++-- .../langchain/chat_models/promptlayer_openai.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 7e1932165b..7e321c8962 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -381,9 +381,10 @@ class ChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: + if stream if stream is not None else self.streaming: generation: Optional[ChatGenerationChunk] = None for chunk in self._stream( messages=messages, stop=stop, run_manager=run_manager, **kwargs @@ -454,9 +455,10 @@ class ChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: + if stream if stream is not None else self.streaming: generation: Optional[ChatGenerationChunk] = None async for chunk in self._astream( messages=messages, stop=stop, run_manager=run_manager, **kwargs diff --git a/libs/langchain/langchain/chat_models/promptlayer_openai.py b/libs/langchain/langchain/chat_models/promptlayer_openai.py index 2780888d85..092fdaee52 100644 --- a/libs/langchain/langchain/chat_models/promptlayer_openai.py +++ b/libs/langchain/langchain/chat_models/promptlayer_openai.py @@ -43,13 +43,16 @@ class PromptLayerChatOpenAI(ChatOpenAI): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any ) -> ChatResult: """Call ChatOpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(messages, stop, run_manager, **kwargs) + generated_responses = super()._generate( + messages, stop, run_manager, stream=stream, **kwargs + ) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations): @@ -82,13 +85,16 @@ class PromptLayerChatOpenAI(ChatOpenAI): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any ) -> ChatResult: """Call ChatOpenAI agenerate and then call PromptLayer to log.""" from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(messages, stop, run_manager) + generated_responses = await super()._agenerate( + messages, stop, run_manager, stream=stream, **kwargs + ) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations):