feat(llms): improve ERNIE-Bot chat model (#9833)

- Description: improve ERNIE-Bot chat model, add request timeout and
more testcases.
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
axiangcoding 2023-08-30 09:20:06 +08:00 committed by GitHub
parent bdccb1215a
commit ffa5625134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 3 deletions

View File

@ -57,12 +57,25 @@ class ErnieBotChat(BaseChatModel):
""" """
ernie_client_id: Optional[str] = None ernie_client_id: Optional[str] = None
"""Baidu application client id"""
ernie_client_secret: Optional[str] = None ernie_client_secret: Optional[str] = None
"""Baidu application client secret"""
access_token: Optional[str] = None access_token: Optional[str] = None
"""access token is generated by client id and client secret,
setting this value directly will cause an error"""
model_name: str = "ERNIE-Bot-turbo" model_name: str = "ERNIE-Bot-turbo"
"""model name of ernie, default is `ERNIE-Bot-turbo`.
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
streaming: Optional[bool] = False streaming: Optional[bool] = False
"""streaming mode. not supported yet."""
top_p: Optional[float] = 0.8 top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95 temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1 penalty_score: Optional[float] = 1
@ -93,6 +106,7 @@ class ErnieBotChat(BaseChatModel):
raise ValueError(f"Got unknown model_name {self.model_name}") raise ValueError(f"Got unknown model_name {self.model_name}")
resp = requests.post( resp = requests.post(
url, url,
timeout=self.request_timeout,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
@ -107,6 +121,7 @@ class ErnieBotChat(BaseChatModel):
base_url: str = "https://aip.baidubce.com/oauth/2.0/token" base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
resp = requests.post( resp = requests.post(
base_url, base_url,
timeout=10,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",

View File

@ -45,12 +45,14 @@ def test_extra_kwargs() -> None:
def test_wrong_temperature_1() -> None: def test_wrong_temperature_1() -> None:
chat = ErnieBotChat() chat = ErnieBotChat()
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
with pytest.raises(ValueError): with pytest.raises(ValueError) as e:
chat([message], temperature=1.2) chat([message], temperature=1.2)
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)
def test_wrong_temperature_2() -> None: def test_wrong_temperature_2() -> None:
chat = ErnieBotChat() chat = ErnieBotChat()
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
with pytest.raises(ValueError): with pytest.raises(ValueError) as e:
chat([message], temperature=0) chat([message], temperature=0)
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)

View File

@ -1,5 +1,12 @@
import pytest
from langchain.chat_models.ernie import _convert_message_to_dict from langchain.chat_models.ernie import _convert_message_to_dict
from langchain.schema.messages import AIMessage, HumanMessage from langchain.schema.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
def test__convert_dict_to_message_human() -> None: def test__convert_dict_to_message_human() -> None:
@ -14,3 +21,17 @@ def test__convert_dict_to_message_ai() -> None:
result = _convert_message_to_dict(message) result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"} expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(ValueError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
def test__convert_dict_to_message_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(ValueError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)