community[patch]: fix qianfan chat stream calling caused exception (#13800)

- **Description:** 
`QianfanChatEndpoint` extends `BaseChatModel` as a super class, which
has a default stream implement might concat the MessageChunk with
`__add__`. When call stream(), a ValueError for duplicated key will be
raise.
  - **Issues:** 
     * #13546  
     * #13548
     * merge two single test file related to qianfan.
  - **Dependencies:** no
  - **Tag maintainer:**

---------

Co-authored-by: root <liujun45@baidu.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/15786/head
NuODaniel 9 months ago committed by GitHub
parent 656e87beb9
commit 70b6315b23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,17 +55,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO] [09-15 20:00:29] logging.py:55 [t:139698882193216]: requesting llm api endpoint: /chat/eb-instant\n"
]
}
],
"source": [ "source": [
"\"\"\"For basic init and call\"\"\"\n", "\"\"\"For basic init and call\"\"\"\n",
"import os\n", "import os\n",
@ -126,9 +118,7 @@
"from langchain.schema import HumanMessage\n", "from langchain.schema import HumanMessage\n",
"from langchain_community.chat_models import QianfanChatEndpoint\n", "from langchain_community.chat_models import QianfanChatEndpoint\n",
"\n", "\n",
"chatLLM = QianfanChatEndpoint(\n", "chatLLM = QianfanChatEndpoint()\n",
" streaming=True,\n",
")\n",
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n", "res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
"for r in res:\n", "for r in res:\n",
" print(\"chat resp:\", r)\n", " print(\"chat resp:\", r)\n",
@ -260,11 +250,11 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.11.5"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157" "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
} }
} }
}, },

@ -1,5 +1,3 @@
from __future__ import annotations
import logging import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
@ -244,7 +242,14 @@ class QianfanChatEndpoint(BaseChatModel):
""" """
if self.streaming: if self.streaming:
completion = "" completion = ""
token_usage = {}
chat_generation_info: Dict = {}
for chunk in self._stream(messages, stop, run_manager, **kwargs): for chunk in self._stream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={}) lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration( gen = ChatGeneration(
@ -253,7 +258,10 @@ class QianfanChatEndpoint(BaseChatModel):
) )
return ChatResult( return ChatResult(
generations=[gen], generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model}, llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"model_name": self.model,
},
) )
params = self._convert_prompt_msg_params(messages, **kwargs) params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params) response_payload = self.client.do(**params)
@ -279,7 +287,13 @@ class QianfanChatEndpoint(BaseChatModel):
if self.streaming: if self.streaming:
completion = "" completion = ""
token_usage = {} token_usage = {}
chat_generation_info: Dict = {}
async for chunk in self._astream(messages, stop, run_manager, **kwargs): async for chunk in self._astream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={}) lc_msg = AIMessage(content=completion, additional_kwargs={})
@ -289,7 +303,10 @@ class QianfanChatEndpoint(BaseChatModel):
) )
return ChatResult( return ChatResult(
generations=[gen], generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model}, llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"model_name": self.model,
},
) )
params = self._convert_prompt_msg_params(messages, **kwargs) params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = await self.client.ado(**params) response_payload = await self.client.ado(**params)
@ -315,16 +332,19 @@ class QianfanChatEndpoint(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs) params = self._convert_prompt_msg_params(messages, **kwargs)
params["stream"] = True
for res in self.client.do(**params): for res in self.client.do(**params):
if res: if res:
msg = _convert_dict_to_message(res) msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
text=res["result"], text=res["result"],
message=AIMessageChunk( message=AIMessageChunk(
content=msg.content, content=msg.content,
role="assistant", role="assistant",
additional_kwargs=msg.additional_kwargs, additional_kwargs=additional_kwargs,
), ),
generation_info=msg.additional_kwargs,
) )
yield chunk yield chunk
if run_manager: if run_manager:
@ -338,16 +358,19 @@ class QianfanChatEndpoint(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs) params = self._convert_prompt_msg_params(messages, **kwargs)
params["stream"] = True
async for res in await self.client.ado(**params): async for res in await self.client.ado(**params):
if res: if res:
msg = _convert_dict_to_message(res) msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
text=res["result"], text=res["result"],
message=AIMessageChunk( message=AIMessageChunk(
content=msg.content, content=msg.content,
role="assistant", role="assistant",
additional_kwargs=msg.additional_kwargs, additional_kwargs=additional_kwargs,
), ),
generation_info=msg.additional_kwargs,
) )
yield chunk yield chunk
if run_manager: if run_manager:

@ -1,53 +0,0 @@
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)
def test_qianfan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
chat = QianfanChatEndpoint()
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_qianfan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"

@ -1,18 +1,24 @@
"""Test Baidu Qianfan Chat Endpoint.""" """Test Baidu Qianfan Chat Endpoint."""
from typing import Any from typing import Any, cast
import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
BaseMessageChunk,
FunctionMessage, FunctionMessage,
HumanMessage, HumanMessage,
) )
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
_FUNCTIONS: Any = [ _FUNCTIONS: Any = [
@ -139,6 +145,25 @@ def test_multiple_history() -> None:
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_chat_generate() -> None:
"""Tests chat generate works."""
chat = QianfanChatEndpoint()
response = chat.generate(
[
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
]
)
assert isinstance(response, LLMResult)
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
def test_stream() -> None: def test_stream() -> None:
"""Test that stream works.""" """Test that stream works."""
chat = QianfanChatEndpoint(streaming=True) chat = QianfanChatEndpoint(streaming=True)
@ -156,6 +181,57 @@ def test_stream() -> None:
assert callback_handler.llm_streams > 0 assert callback_handler.llm_streams > 0
assert isinstance(response.content, str) assert isinstance(response.content, str)
res = chat.stream(
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
]
)
assert len(list(res)) >= 1
@pytest.mark.asyncio
async def test_async_invoke() -> None:
chat = QianfanChatEndpoint()
res = await chat.ainvoke([HumanMessage(content="Hello")])
assert isinstance(res, BaseMessage)
assert res.content != ""
@pytest.mark.asyncio
async def test_async_generate() -> None:
"""Tests chat agenerate works."""
chat = QianfanChatEndpoint()
response = await chat.agenerate(
[
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
]
)
assert isinstance(response, LLMResult)
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
@pytest.mark.asyncio
async def test_async_stream() -> None:
chat = QianfanChatEndpoint(streaming=True)
async for token in chat.astream(
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
]
):
assert isinstance(token, BaseMessageChunk)
def test_multiple_messages() -> None: def test_multiple_messages() -> None:
"""Tests multiple messages works.""" """Tests multiple messages works."""
@ -232,3 +308,48 @@ def test_rate_limit() -> None:
for res in responses: for res in responses:
assert isinstance(res, BaseMessage) assert isinstance(res, BaseMessage)
assert isinstance(res.content, str) assert isinstance(res.content, str)
def test_qianfan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
chat = QianfanChatEndpoint()
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_qianfan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"

Loading…
Cancel
Save