mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
openai:compatible with other llm usage meta data (#24500)
- [ ] **PR message**: - **Description:** Compatible with other llm (eg: deepseek-chat, glm-4) usage meta data - **Issue:** N/A - **Dependencies:** no new dependencies added - [ ] **Add tests and docs**: libs/partners/openai/tests/unit_tests/chat_models/test_base.py ```shell cd libs/partners/openai poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_openai_stream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_deepseek_stream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_astream poetry run pytest tests/unit_tests/chat_models/test_base.py::test_glm4_stream ``` --------- Co-authored-by: hyman <hyman@xiaozancloud.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
3dc7d447aa
commit
58e72febeb
@ -284,6 +284,57 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content, id=id_) # type: ignore
|
||||
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", [])
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
if token_usage
|
||||
else None
|
||||
)
|
||||
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
@ -561,43 +612,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
if token_usage := chunk.get("usage"):
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class( # type: ignore[call-arg]
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
)
|
||||
if generation_chunk is None:
|
||||
continue
|
||||
default_chunk_class = generation_chunk.message.__class__
|
||||
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
@ -744,51 +767,18 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
async for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
if token_usage := chunk.get("usage"):
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class( # type: ignore[call-arg]
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
message_chunk = await run_in_executor(
|
||||
None,
|
||||
_convert_delta_to_message_chunk,
|
||||
choice["delta"],
|
||||
default_chunk_class,
|
||||
)
|
||||
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
generation_chunk = _convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
)
|
||||
if generation_chunk is None:
|
||||
continue
|
||||
default_chunk_class = generation_chunk.message.__class__
|
||||
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=generation_chunk.text,
|
||||
chunk=generation_chunk,
|
||||
logprobs=logprobs,
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
|
@ -1,12 +1,14 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
InvalidToolCall,
|
||||
@ -14,6 +16,7 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
@ -172,6 +175,284 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
assert reverted_message_dict == message
|
||||
|
||||
|
||||
class MockAsyncContextManager:
|
||||
def __init__(self, chunk_list: list):
|
||||
self.current_chunk = 0
|
||||
self.chunk_list = chunk_list
|
||||
self.chunk_num = len(chunk_list)
|
||||
|
||||
async def __aenter__(self) -> "MockAsyncContextManager":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __aiter__(self) -> "MockAsyncContextManager":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> dict:
|
||||
if self.current_chunk < self.chunk_num:
|
||||
chunk = self.chunk_list[self.current_chunk]
|
||||
self.current_chunk += 1
|
||||
return chunk
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class MockSyncContextManager:
|
||||
def __init__(self, chunk_list: list):
|
||||
self.current_chunk = 0
|
||||
self.chunk_list = chunk_list
|
||||
self.chunk_num = len(chunk_list)
|
||||
|
||||
def __enter__(self) -> "MockSyncContextManager":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __iter__(self) -> "MockSyncContextManager":
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
if self.current_chunk < self.chunk_num:
|
||||
chunk = self.chunk_list[self.current_chunk]
|
||||
self.current_chunk += 1
|
||||
return chunk
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
||||
GLM4_STREAM_META = """{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4eba\u5de5\u667a\u80fd"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":","}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4f60\u53ef\u4ee5"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u53eb\u6211"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"AI"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"\u52a9\u624b"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"。"}}]}
|
||||
{"id":"20240722102053e7277a4f94e848248ff9588ed37fb6e6","created":1721614853,"model":"glm-4","choices":[{"index":0,"finish_reason":"stop","delta":{"role":"assistant","content":""}}],"usage":{"prompt_tokens":13,"completion_tokens":10,"total_tokens":23}}
|
||||
[DONE]""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_glm4_completion() -> list:
|
||||
list_chunk_data = GLM4_STREAM_META.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
async def test_glm4_astream(mock_glm4_completion: list) -> None:
|
||||
llm_name = "glm-4"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
|
||||
return MockAsyncContextManager(mock_glm4_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_glm4_completion[-1]
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
def test_glm4_stream(mock_glm4_completion: list) -> None:
|
||||
llm_name = "glm-4"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = MagicMock()
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
|
||||
return MockSyncContextManager(mock_glm4_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_glm4_completion[-1]
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
DEEPSEEK_STREAM_DATA = """{"id":"d3610c24e6b42518a7883ea57c3ea2c3","choices":[{"index":0,"delta":{"content":"","role":"assistant"},"finish_reason":null,"logprobs":null}],"created":1721630271,"model":"deepseek-chat","system_fingerprint":"fp_7e0991cad4","object":"chat.completion.chunk","usage":null}
|
||||
{"choices":[{"delta":{"content":"我是","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"Deep","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"Seek","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":" Chat","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":",","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"一个","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"由","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"深度","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"求","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"索","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"公司","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"开发的","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"智能","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"助手","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"。","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":null}
|
||||
{"choices":[{"delta":{"content":"","role":null},"finish_reason":"stop","index":0,"logprobs":null}],"created":1721630271,"id":"d3610c24e6b42518a7883ea57c3ea2c3","model":"deepseek-chat","object":"chat.completion.chunk","system_fingerprint":"fp_7e0991cad4","usage":{"completion_tokens":15,"prompt_tokens":11,"total_tokens":26}}
|
||||
[DONE]""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepseek_completion() -> List[Dict]:
|
||||
list_chunk_data = DEEPSEEK_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
async def test_deepseek_astream(mock_deepseek_completion: list) -> None:
|
||||
llm_name = "deepseek-chat"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
|
||||
return MockAsyncContextManager(mock_deepseek_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
def test_deepseek_stream(mock_deepseek_completion: list) -> None:
|
||||
llm_name = "deepseek-chat"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = MagicMock()
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
|
||||
return MockSyncContextManager(mock_deepseek_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_deepseek_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
OPENAI_STREAM_DATA = """{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
|
||||
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"我是"},"logprobs":null,"finish_reason":null}],"usage":null}
|
||||
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"助手"},"logprobs":null,"finish_reason":null}],"usage":null}
|
||||
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{"content":"。"},"logprobs":null,"finish_reason":null}],"usage":null}
|
||||
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
|
||||
{"id":"chatcmpl-9nhARrdUiJWEMd5plwV1Gc9NCjb9M","object":"chat.completion.chunk","created":1721631035,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_18cc0f1fa0","choices":[],"usage":{"prompt_tokens":14,"completion_tokens":3,"total_tokens":17}}
|
||||
[DONE]""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_completion() -> List[Dict]:
|
||||
list_chunk_data = OPENAI_STREAM_DATA.split("\n")
|
||||
result_list = []
|
||||
for msg in list_chunk_data:
|
||||
if msg != "[DONE]":
|
||||
result_list.append(json.loads(msg))
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
async def test_openai_astream(mock_openai_completion: list) -> None:
|
||||
llm_name = "gpt-4o"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
|
||||
return MockAsyncContextManager(mock_openai_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "async_client", mock_client):
|
||||
async for chunk in llm.astream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
def test_openai_stream(mock_openai_completion: list) -> None:
|
||||
llm_name = "gpt-4o"
|
||||
llm = ChatOpenAI(model=llm_name, stream_usage=True)
|
||||
mock_client = MagicMock()
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
|
||||
return MockSyncContextManager(mock_openai_completion)
|
||||
|
||||
mock_client.create = mock_create
|
||||
usage_chunk = mock_openai_completion[-1]
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
with patch.object(llm, "client", mock_client):
|
||||
for chunk in llm.stream("你的名字叫什么?只回答名字"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
|
||||
assert usage_metadata is not None
|
||||
|
||||
assert usage_metadata["input_tokens"] == usage_chunk["usage"]["prompt_tokens"]
|
||||
assert usage_metadata["output_tokens"] == usage_chunk["usage"]["completion_tokens"]
|
||||
assert usage_metadata["total_tokens"] == usage_chunk["usage"]["total_tokens"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
|
Loading…
Reference in New Issue
Block a user