openai[patch]: wrap stream code in context manager blocks (#18013)

**Description:**
Use the `Stream` context managers in `ChatOpenAi` `stream` and `astream`
method.

Using the context manager returned by the OpenAI client makes it
possible to terminate the stream early since the response connection
will be closed when the context manager exists.

**Issue:** #5340
**Twitter handle:** @snopoke

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/20218/head
Simon Kelly 6 months ago committed by GitHub
parent 6c11c8dac6
commit a682f0d12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -457,30 +457,33 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params): with self.client.create(messages=message_dicts, **params) as response:
if not isinstance(chunk, dict): for chunk in response:
chunk = chunk.model_dump() if not isinstance(chunk, dict):
if len(chunk["choices"]) == 0: chunk = chunk.model_dump()
continue if len(chunk["choices"]) == 0:
choice = chunk["choices"][0] continue
if choice["delta"] is None: choice = chunk["choices"][0]
continue if choice["delta"] is None:
chunk = _convert_delta_to_message_chunk( continue
choice["delta"], default_chunk_class chunk = _convert_delta_to_message_chunk(
) choice["delta"], default_chunk_class
generation_info = {} )
if finish_reason := choice.get("finish_reason"): generation_info = {}
generation_info["finish_reason"] = finish_reason if finish_reason := choice.get("finish_reason"):
logprobs = choice.get("logprobs") generation_info["finish_reason"] = finish_reason
if logprobs: logprobs = choice.get("logprobs")
generation_info["logprobs"] = logprobs if logprobs:
default_chunk_class = chunk.__class__ generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk( default_chunk_class = chunk.__class__
message=chunk, generation_info=generation_info or None chunk = ChatGenerationChunk(
) message=chunk, generation_info=generation_info or None
if run_manager: )
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) if run_manager:
yield chunk run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
def _generate( def _generate(
self, self,
@ -553,34 +556,34 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
async for chunk in await self.async_client.create( response = await self.async_client.create(messages=message_dicts, **params)
messages=message_dicts, **params async with response:
): async for chunk in response:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
continue continue
choice = chunk["choices"][0] choice = chunk["choices"][0]
if choice["delta"] is None: if choice["delta"] is None:
continue continue
chunk = _convert_delta_to_message_chunk( chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class choice["delta"], default_chunk_class
) )
generation_info = {} generation_info = {}
if finish_reason := choice.get("finish_reason"): if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs") logprobs = choice.get("logprobs")
if logprobs: if logprobs:
generation_info["logprobs"] = logprobs generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
) )
yield chunk if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
async def _agenerate( async def _agenerate(
self, self,

Loading…
Cancel
Save