From 84bf5787a7036a33745e44de44aa89a074dcec55 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 19 Jan 2024 09:16:09 -0800 Subject: [PATCH] core[patch], openai[patch]: Chat openai stream logprobs (#16218) --- .../langchain_core/outputs/chat_generation.py | 10 ++--- .../core/langchain_core/outputs/generation.py | 10 ++--- libs/core/langchain_core/utils/_merge.py | 44 +++++++++++++++++++ .../langchain_openai/chat_models/base.py | 34 +++++++++----- .../chat_models/test_base.py | 34 ++++++++++++++ 5 files changed, 110 insertions(+), 22 deletions(-) create mode 100644 libs/core/langchain_core/utils/_merge.py diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index fa5041c348..b7bd6042a2 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.outputs.generation import Generation from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils._merge import merge_dicts class ChatGeneration(Generation): @@ -53,14 +54,13 @@ class ChatGenerationChunk(ChatGeneration): def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: if isinstance(other, ChatGenerationChunk): - generation_info = ( - {**(self.generation_info or {}), **(other.generation_info or {})} - if self.generation_info is not None or other.generation_info is not None - else None + generation_info = merge_dicts( + self.generation_info or {}, + other.generation_info or {}, ) return ChatGenerationChunk( message=self.message + other.message, - generation_info=generation_info, + generation_info=generation_info or None, ) else: raise TypeError( diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index 3ede28f9fc..3f0a79ecb1 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional from langchain_core.load import Serializable +from langchain_core.utils._merge import merge_dicts class Generation(Serializable): @@ -40,14 +41,13 @@ class GenerationChunk(Generation): def __add__(self, other: GenerationChunk) -> GenerationChunk: if isinstance(other, GenerationChunk): - generation_info = ( - {**(self.generation_info or {}), **(other.generation_info or {})} - if self.generation_info is not None or other.generation_info is not None - else None + generation_info = merge_dicts( + self.generation_info or {}, + other.generation_info or {}, ) return GenerationChunk( text=self.text + other.text, - generation_info=generation_info, + generation_info=generation_info or None, ) else: raise TypeError( diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py new file mode 100644 index 0000000000..e21fdd9662 --- /dev/null +++ b/libs/core/langchain_core/utils/_merge.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Any, Dict + + +def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: + """Merge two dicts, handling specific scenarios where a key exists in both + dictionaries but has a value of None in 'left'. In such cases, the method uses the + value from 'right' for that key in the merged dictionary. + + Example: + If left = {"function_call": {"arguments": None}} and + right = {"function_call": {"arguments": "{\n"}} + then, after merging, for the key "function_call", + the value from 'right' is used, + resulting in merged = {"function_call": {"arguments": "{\n"}}. + """ + merged = left.copy() + for k, v in right.items(): + if k not in merged: + merged[k] = v + elif merged[k] is None and v: + merged[k] = v + elif v is None: + continue + elif merged[k] == v: + continue + elif type(merged[k]) != type(v): + raise TypeError( + f'additional_kwargs["{k}"] already exists in this message,' + " but with a different type." + ) + elif isinstance(merged[k], str): + merged[k] += v + elif isinstance(merged[k], dict): + merged[k] = merge_dicts(merged[k], v) + elif isinstance(merged[k], list): + merged[k] = merged[k] + v + else: + raise TypeError( + f"Additional kwargs key {k} already exists in left dict and value has " + f"unsupported type {type(merged[k])}." + ) + return merged diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 06bd22485c..10d1dadf22 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -404,15 +404,19 @@ class ChatOpenAI(BaseChatModel): chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) yield chunk if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) def _generate( self, @@ -492,15 +496,21 @@ class ChatOpenAI(BaseChatModel): chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) yield chunk if run_manager: - await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) + await run_manager.on_llm_new_token( + token=chunk.text, chunk=chunk, logprobs=logprobs + ) async def _agenerate( self, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index c86112891f..33006a624c 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -391,3 +391,37 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +def test_logprobs() -> None: + llm = ChatOpenAI() + result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True) + assert result.generations[0][0].generation_info + assert "content" in result.generations[0][0].generation_info["logprobs"] + + +async def test_async_logprobs() -> None: + llm = ChatOpenAI() + result = await llm.agenerate( + [[HumanMessage(content="I'm PickleRick")]], logprobs=True + ) + assert result.generations[0][0].generation_info + assert "content" in result.generations[0][0].generation_info["logprobs"] + + +def test_logprobs_streaming() -> None: + llm = ChatOpenAI() + result = llm.generate( + [[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True + ) + assert result.generations[0][0].generation_info + assert "content" in result.generations[0][0].generation_info["logprobs"] + + +async def test_async_logprobs_streaming() -> None: + llm = ChatOpenAI() + result = await llm.agenerate( + [[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True + ) + assert result.generations[0][0].generation_info + assert "content" in result.generations[0][0].generation_info["logprobs"]