openai:compatible with other llm usage meta data (#24500)

- [ ] **PR message**:
- **Description:** Compatible with other llm (eg: deepseek-chat, glm-4)
usage meta data
    - **Issue:** N/A
    - **Dependencies:** no new dependencies added


- [ ] **Add tests and docs**: 
libs/partners/openai/tests/unit_tests/chat_models/test_base.py
```shell
cd libs/partners/openai
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_astream
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_stream
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_astream
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_stream
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_astream
poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_stream
```

---------

Co-authored-by: hyman <hyman@xiaozancloud.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Hyman 2024-08-24 07:59:14 +08:00 committed by GitHub
parent 3dc7d447aa
commit 58e72febeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 351 additions and 80 deletions

View File

@ -284,6 +284,57 @@ def _convert_delta_to_message_chunk(
return default_class(content=content, id=id_) # type: ignore
def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
) -> Optional[ChatGenerationChunk]:
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])
usage_metadata: Optional[UsageMetadata] = (
UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
if token_usage
else None
)
if len(choices) == 0:
# logprobs is implicitly None
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk
class _FunctionCall(TypedDict):
name: str
@ -561,43 +612,15 @@ class BaseChatOpenAI(BaseChatModel):
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
generation_chunk = ChatGenerationChunk(
message=default_chunk_class( # type: ignore[call-arg]
content="", usage_metadata=usage_metadata
)
)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
generation_chunk = _convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue
default_chunk_class = generation_chunk.message.__class__
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
@ -744,51 +767,18 @@ class BaseChatOpenAI(BaseChatModel):
async for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
generation_chunk = ChatGenerationChunk(
message=default_chunk_class( # type: ignore[call-arg]
content="", usage_metadata=usage_metadata
)
)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = await run_in_executor(
None,
_convert_delta_to_message_chunk,
choice["delta"],
default_chunk_class,
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
generation_chunk = _convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue
default_chunk_class = generation_chunk.message.__class__
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk

View File

@ -1,12 +1,14 @@
"""Test OpenAI Chat API wrapper."""
import json
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
FunctionMessage,
HumanMessage,
InvalidToolCall,
@ -14,6 +16,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI
@ -172,6 +175,284 @@ def test__convert_dict_to_message_tool_call() -> None:
assert reverted_message_dict == message
class MockAsyncContextManager:
def __init__(self, chunk_list: list):
self.current_chunk = 0
self.chunk_list = chunk_list
self.chunk_num = len(chunk_list)
async def __aenter__(self) -> "MockAsyncContextManager":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass
def __aiter__(self) -> "MockAsyncContextManager":
return self
async def __anext__(self) -> dict:
if self.current_chunk < self.chunk_num:
chunk = self.chunk_list[self.current_chunk]
self.current_chunk += 1
return chunk
else:
raise StopAsyncIteration
class MockSyncContextManager:
def __init__(self, chunk_list: list):
self.current_chunk = 0
self.chunk_list = chunk_list
self.chunk_num = len(chunk_list)
def __enter__(self) -> "MockSyncContextManager":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass
def __iter__(self) -> "MockSyncContextManager":
return self
def __next__(self) -> dict:
if self.current_chunk < self.chunk_num:
chunk = self.chunk_list[self.current_chunk]
self.current_chunk += 1
return chunk
else:
raise StopIteration
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":""}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":""}}]}
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"finish_reason":"stop","delta":{"role":"assistant","content":""}}],"usage":{"prompt_tokens":13,"completion_tokens":10,"total_tokens":23}}
[DONE]""" # noqa: E501
@pytest.fixture
def mock_glm4_completion() -> list:
list_chunk_data = GLM4_STREAM_META.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_glm4_astream(mock_glm4_completion: list) -> None:
llm_name = "glm-4"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = AsyncMock()
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
return MockAsyncContextManager(mock_glm4_completion)
mock_client.create = mock_create
usage_chunk = mock_glm4_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
def test_glm4_stream(mock_glm4_completion: list) -> None:
llm_name = "glm-4"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
return MockSyncContextManager(mock_glm4_completion)
mock_client.create = mock_create
usage_chunk = mock_glm4_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"index":0,"delta":{"content":"","role":"assistant"},"finish_reason":null,"logprobs":null}],"created":1721630271,"model":"deepseek-chat","system_fingerprint":"fp_7e0991cad4","object":"chat.completion.chunk","usage":null}
{"choices":[{"delta":{"content":"我是","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"公司","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"开发的","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"智能","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"助手","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
{"choices":[{"delta":{"content":"","role":null},"finish_reason":"stop","index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":{"completion_tokens":15,"prompt_tokens":11,"total_tokens":26}}
[DONE]""" # noqa: E501
@pytest.fixture
def mock_deepseek_completion() -> List[Dict]:
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
llm_name = "deepseek-chat"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = AsyncMock()
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
return MockAsyncContextManager(mock_deepseek_completion)
mock_client.create = mock_create
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
def test_deepseek_stream(mock_deepseek_completion: list) -> None:
llm_name = "deepseek-chat"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
return MockSyncContextManager(mock_deepseek_completion)
mock_client.create = mock_create
usage_chunk = mock_deepseek_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"我是"},"logprobs":null,"finish_reason":null}],"usage":null}
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"助手"},"logprobs":null,"finish_reason":null}],"usage":null}
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":null}],"usage":null}
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[],"usage":{"prompt_tokens":14,"completion_tokens":3,"total_tokens":17}}
[DONE]""" # noqa: E501
@pytest.fixture
def mock_openai_completion() -> List[Dict]:
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
result_list = []
for msg in list_chunk_data:
if msg != "[DONE]":
result_list.append(json.loads(msg))
return result_list
async def test_openai_astream(mock_openai_completion: list) -> None:
llm_name = "gpt-4o"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = AsyncMock()
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
return MockAsyncContextManager(mock_openai_completion)
mock_client.create = mock_create
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "async_client", mock_client):
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
def test_openai_stream(mock_openai_completion: list) -> None:
llm_name = "gpt-4o"
llm = ChatOpenAI(model=llm_name, stream_usage=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
return MockSyncContextManager(mock_openai_completion)
mock_client.create = mock_create
usage_chunk = mock_openai_completion[-1]
usage_metadata: Optional[UsageMetadata] = None
with patch.object(llm, "client", mock_client):
for chunk in llm.stream("你的名字叫什么?只回答名字"):
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
usage_metadata = chunk.usage_metadata
assert usage_metadata is not None
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
@pytest.fixture
def mock_completion() -> dict:
return {