mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
feat(llms): improve ERNIE-Bot chat model (#9833)
- Description: improve ERNIE-Bot chat model, add request timeout and more testcases. - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bdccb1215a
commit
ffa5625134
@ -57,12 +57,25 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
ernie_client_id: Optional[str] = None
|
ernie_client_id: Optional[str] = None
|
||||||
|
"""Baidu application client id"""
|
||||||
|
|
||||||
ernie_client_secret: Optional[str] = None
|
ernie_client_secret: Optional[str] = None
|
||||||
|
"""Baidu application client secret"""
|
||||||
|
|
||||||
access_token: Optional[str] = None
|
access_token: Optional[str] = None
|
||||||
|
"""access token is generated by client id and client secret,
|
||||||
|
setting this value directly will cause an error"""
|
||||||
|
|
||||||
model_name: str = "ERNIE-Bot-turbo"
|
model_name: str = "ERNIE-Bot-turbo"
|
||||||
|
"""model name of ernie, default is `ERNIE-Bot-turbo`.
|
||||||
|
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
|
||||||
|
|
||||||
|
request_timeout: Optional[int] = 60
|
||||||
|
"""request timeout for chat http requests"""
|
||||||
|
|
||||||
streaming: Optional[bool] = False
|
streaming: Optional[bool] = False
|
||||||
|
"""streaming mode. not supported yet."""
|
||||||
|
|
||||||
top_p: Optional[float] = 0.8
|
top_p: Optional[float] = 0.8
|
||||||
temperature: Optional[float] = 0.95
|
temperature: Optional[float] = 0.95
|
||||||
penalty_score: Optional[float] = 1
|
penalty_score: Optional[float] = 1
|
||||||
@ -93,6 +106,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
raise ValueError(f"Got unknown model_name {self.model_name}")
|
raise ValueError(f"Got unknown model_name {self.model_name}")
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
url,
|
url,
|
||||||
|
timeout=self.request_timeout,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
@ -107,6 +121,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
base_url,
|
base_url,
|
||||||
|
timeout=10,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
@ -45,12 +45,14 @@ def test_extra_kwargs() -> None:
|
|||||||
def test_wrong_temperature_1() -> None:
|
def test_wrong_temperature_1() -> None:
|
||||||
chat = ErnieBotChat()
|
chat = ErnieBotChat()
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError) as e:
|
||||||
chat([message], temperature=1.2)
|
chat([message], temperature=1.2)
|
||||||
|
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_wrong_temperature_2() -> None:
|
def test_wrong_temperature_2() -> None:
|
||||||
chat = ErnieBotChat()
|
chat = ErnieBotChat()
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError) as e:
|
||||||
chat([message], temperature=0)
|
chat([message], temperature=0)
|
||||||
|
assert "parameter check failed, temperature range is (0, 1.0]" in str(e)
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models.ernie import _convert_message_to_dict
|
from langchain.chat_models.ernie import _convert_message_to_dict
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test__convert_dict_to_message_human() -> None:
|
def test__convert_dict_to_message_human() -> None:
|
||||||
@ -14,3 +21,17 @@ def test__convert_dict_to_message_ai() -> None:
|
|||||||
result = _convert_message_to_dict(message)
|
result = _convert_message_to_dict(message)
|
||||||
expected_output = {"role": "assistant", "content": "foo"}
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_system() -> None:
|
||||||
|
message = SystemMessage(content="foo")
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
_convert_message_to_dict(message)
|
||||||
|
assert "Got unknown type" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_function() -> None:
|
||||||
|
message = FunctionMessage(name="foo", content="bar")
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
_convert_message_to_dict(message)
|
||||||
|
assert "Got unknown type" in str(e)
|
||||||
|
Loading…
Reference in New Issue
Block a user