fix(llms): improve the ernie chat model (#9289)

- Description: improve the ernie chat model.
   - fix missing kwargs to payload
   - new test cases
   - add some debug level log
   - improve description
- Issue: None
- Dependencies: None
- Tag maintainer: @baskaryan
pull/9296/head
axiangcoding 11 months ago committed by GitHub
parent 1d55141c50
commit 63601551b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,8 @@
"source": [
"# ERNIE-Bot Chat\n",
"\n",
"This notebook covers how to get started with Ernie chat models."
"[ERNIE-Bot](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11) is a large language model developed by Baidu, covering a huge amount of Chinese data.\n",
"This notebook covers how to get started with ErnieBot chat models."
]
},
{
@ -16,7 +17,7 @@
"outputs": [],
"source": [
"from langchain.chat_models import ErnieBotChat\n",
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
"from langchain.schema import HumanMessage"
]
},
{

@ -33,23 +33,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
class ErnieBotChat(BaseChatModel):
"""ErnieBot Chat large language model.
"""`ERNIE-Bot` large language model.
ERNIE-Bot is a large language model developed by Baidu,
covering a huge amount of Chinese data.
To use, you should have the `ernie_client_id` and `ernie_client_secret` set.
To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
Note:
access_token will be automatically generated based on client_id and client_secret,
and will be regenerated after expiration.
and will be regenerated after expiration (30 days).
Default model is `ERNIE-Bot-turbo`,
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`
Example:
.. code-block:: python
from langchain.chat_models import ErnieBotChat
chat = ErnieBotChat()
chat = ErnieBotChat(model_name='ERNIE-Bot')
"""
@ -133,10 +136,13 @@ class ErnieBotChat(BaseChatModel):
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
**kwargs,
}
logger.debug(f"Payload for ernie api is {payload}")
resp = self._chat(payload)
if resp.get("error_code"):
if resp.get("error_code") == 111:
logger.debug("access_token expired, refresh it")
self._refresh_access_token_with_lock()
resp = self._chat(payload)
else:
@ -153,4 +159,4 @@ class ErnieBotChat(BaseChatModel):
@property
def _llm_type(self) -> str:
return "ernie-chat"
return "ernie-bot-chat"

@ -1,3 +1,5 @@
import pytest
from langchain.chat_models.ernie import ErnieBotChat
from langchain.schema.messages import AIMessage, HumanMessage
@ -24,3 +26,31 @@ def test_chat_ernie_bot_with_temperature() -> None:
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_ernie_bot_with_kwargs() -> None:
chat = ErnieBotChat()
message = HumanMessage(content="Hello")
response = chat([message], temperature=0.88, top_p=0.7)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_extra_kwargs() -> None:
chat = ErnieBotChat(temperature=0.88, top_p=0.7)
assert chat.temperature == 0.88
assert chat.top_p == 0.7
def test_wrong_temperature_1() -> None:
chat = ErnieBotChat()
message = HumanMessage(content="Hello")
with pytest.raises(ValueError):
chat([message], temperature=1.2)
def test_wrong_temperature_2() -> None:
chat = ErnieBotChat()
message = HumanMessage(content="Hello")
with pytest.raises(ValueError):
chat([message], temperature=0)

Loading…
Cancel
Save