From 63601551b136d2f13831249afc90005ec884f965 Mon Sep 17 00:00:00 2001 From: axiangcoding <49201354+axiangcoding@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:48:42 +0800 Subject: [PATCH] fix(llms): improve the ernie chat model (#9289) - Description: improve the ernie chat model. - fix missing kwargs to payload - new test cases - add some debug level log - improve description - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --- docs/extras/integrations/chat/ernie.ipynb | 5 ++-- libs/langchain/langchain/chat_models/ernie.py | 18 +++++++---- .../chat_models/test_ernie.py | 30 +++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/docs/extras/integrations/chat/ernie.ipynb b/docs/extras/integrations/chat/ernie.ipynb index 4b1c4c6db8..b887991f57 100644 --- a/docs/extras/integrations/chat/ernie.ipynb +++ b/docs/extras/integrations/chat/ernie.ipynb @@ -6,7 +6,8 @@ "source": [ "# ERNIE-Bot Chat\n", "\n", - "This notebook covers how to get started with Ernie chat models." + "[ERNIE-Bot](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11) is a large language model developed by Baidu, covering a huge amount of Chinese data.\n", + "This notebook covers how to get started with ErnieBot chat models." ] }, { @@ -16,7 +17,7 @@ "outputs": [], "source": [ "from langchain.chat_models import ErnieBotChat\n", - "from langchain.schema import AIMessage, HumanMessage, SystemMessage" + "from langchain.schema import HumanMessage" ] }, { diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 50eb751444..463a26e15c 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -33,23 +33,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: class ErnieBotChat(BaseChatModel): - """ErnieBot Chat large language model. + """`ERNIE-Bot` large language model. ERNIE-Bot is a large language model developed by Baidu, covering a huge amount of Chinese data. - To use, you should have the `ernie_client_id` and `ernie_client_secret` set. - + To use, you should have the `ernie_client_id` and `ernie_client_secret` set, or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`. + Note: access_token will be automatically generated based on client_id and client_secret, - and will be regenerated after expiration. + and will be regenerated after expiration (30 days). + + Default model is `ERNIE-Bot-turbo`, + currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot` Example: .. code-block:: python from langchain.chat_models import ErnieBotChat - chat = ErnieBotChat() + chat = ErnieBotChat(model_name='ERNIE-Bot') """ @@ -133,10 +136,13 @@ class ErnieBotChat(BaseChatModel): "top_p": self.top_p, "temperature": self.temperature, "penalty_score": self.penalty_score, + **kwargs, } + logger.debug(f"Payload for ernie api is {payload}") resp = self._chat(payload) if resp.get("error_code"): if resp.get("error_code") == 111: + logger.debug("access_token expired, refresh it") self._refresh_access_token_with_lock() resp = self._chat(payload) else: @@ -153,4 +159,4 @@ class ErnieBotChat(BaseChatModel): @property def _llm_type(self) -> str: - return "ernie-chat" + return "ernie-bot-chat" diff --git a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py index cbaba6debb..a8a80ed9bb 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py @@ -1,3 +1,5 @@ +import pytest + from langchain.chat_models.ernie import ErnieBotChat from langchain.schema.messages import AIMessage, HumanMessage @@ -24,3 +26,31 @@ def test_chat_ernie_bot_with_temperature() -> None: response = chat([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_chat_ernie_bot_with_kwargs() -> None: + chat = ErnieBotChat() + message = HumanMessage(content="Hello") + response = chat([message], temperature=0.88, top_p=0.7) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_extra_kwargs() -> None: + chat = ErnieBotChat(temperature=0.88, top_p=0.7) + assert chat.temperature == 0.88 + assert chat.top_p == 0.7 + + +def test_wrong_temperature_1() -> None: + chat = ErnieBotChat() + message = HumanMessage(content="Hello") + with pytest.raises(ValueError): + chat([message], temperature=1.2) + + +def test_wrong_temperature_2() -> None: + chat = ErnieBotChat() + message = HumanMessage(content="Hello") + with pytest.raises(ValueError): + chat([message], temperature=0)