core[patch], openai[patch]: Chat openai stream logprobs (#16218)

pull/16287/head
Bagatur 6 months ago committed by GitHub
parent 6f7a414955
commit 84bf5787a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts
class ChatGeneration(Generation): class ChatGeneration(Generation):
@ -53,14 +54,13 @@ class ChatGenerationChunk(ChatGeneration):
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):
generation_info = ( generation_info = merge_dicts(
{**(self.generation_info or {}), **(other.generation_info or {})} self.generation_info or {},
if self.generation_info is not None or other.generation_info is not None other.generation_info or {},
else None
) )
return ChatGenerationChunk( return ChatGenerationChunk(
message=self.message + other.message, message=self.message + other.message,
generation_info=generation_info, generation_info=generation_info or None,
) )
else: else:
raise TypeError( raise TypeError(

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts
class Generation(Serializable): class Generation(Serializable):
@ -40,14 +41,13 @@ class GenerationChunk(Generation):
def __add__(self, other: GenerationChunk) -> GenerationChunk: def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk): if isinstance(other, GenerationChunk):
generation_info = ( generation_info = merge_dicts(
{**(self.generation_info or {}), **(other.generation_info or {})} self.generation_info or {},
if self.generation_info is not None or other.generation_info is not None other.generation_info or {},
else None
) )
return GenerationChunk( return GenerationChunk(
text=self.text + other.text, text=self.text + other.text,
generation_info=generation_info, generation_info=generation_info or None,
) )
else: else:
raise TypeError( raise TypeError(

@ -0,0 +1,44 @@
from __future__ import annotations
from typing import Any, Dict
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
then, after merging, for the key "function_call",
the value from 'right' is used,
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif merged[k] is None and v:
merged[k] = v
elif v is None:
continue
elif merged[k] == v:
continue
elif type(merged[k]) != type(v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = merge_dicts(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k] + v
else:
raise TypeError(
f"Additional kwargs key {k} already exists in left dict and value has "
f"unsupported type {type(merged[k])}."
)
return merged

@ -404,15 +404,19 @@ class ChatOpenAI(BaseChatModel):
chunk = _convert_delta_to_message_chunk( chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class choice["delta"], default_chunk_class
) )
finish_reason = choice.get("finish_reason") generation_info = {}
generation_info = ( if finish_reason := choice.get("finish_reason"):
dict(finish_reason=finish_reason) if finish_reason is not None else None generation_info["finish_reason"] = finish_reason
) logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
yield chunk yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
def _generate( def _generate(
self, self,
@ -492,15 +496,21 @@ class ChatOpenAI(BaseChatModel):
chunk = _convert_delta_to_message_chunk( chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class choice["delta"], default_chunk_class
) )
finish_reason = choice.get("finish_reason") generation_info = {}
generation_info = ( if finish_reason := choice.get("finish_reason"):
dict(finish_reason=finish_reason) if finish_reason is not None else None generation_info["finish_reason"] = finish_reason
) logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
yield chunk yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
async def _agenerate( async def _agenerate(
self, self,

@ -391,3 +391,37 @@ def test_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str) assert isinstance(result.content, str)
def test_logprobs() -> None:
llm = ChatOpenAI()
result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
async def test_async_logprobs() -> None:
llm = ChatOpenAI()
result = await llm.agenerate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
def test_logprobs_streaming() -> None:
llm = ChatOpenAI()
result = llm.generate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
async def test_async_logprobs_streaming() -> None:
llm = ChatOpenAI()
result = await llm.agenerate(
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
)
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]

Loading…
Cancel
Save