feat(llms): support ernie chat (#9114)

Description: support ernie (文心一言) chat model
Related issue: #7990
Dependencies: None
Tag maintainer: @baskaryan
pull/9253/head
axiangcoding 1 year ago committed by GitHub
parent 08a8363fc6
commit 664ff28cba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,87 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ERNIE-Bot Chat\n",
"\n",
"This notebook covers how to get started with Ernie chat models."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ErnieBotChat\n",
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"chat = ErnieBotChat(ernie_client_id='YOUR_CLIENT_ID', ernie_client_secret='YOUR_CLIENT_SECRET')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"or you can set `client_id` and `client_secret` in your environment variables\n",
"```bash\n",
"export ERNIE_CLIENT_ID=YOUR_CLIENT_ID\n",
"export ERNIE_CLIENT_SECRET=YOUR_CLIENT_SECRET\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Hello, I am an artificial intelligence language model. My purpose is to help users answer questions or provide information. What can I do for you?', additional_kwargs={}, example=False)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat([\n",
" HumanMessage(content='hello there, who are you?')\n",
"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.human import HumanInputChatModel
@ -43,4 +44,5 @@ __all__ = [
"HumanInputChatModel",
"ChatAnyscale",
"ChatLiteLLM",
"ErnieBotChat",
]

@ -0,0 +1,156 @@
import logging
import threading
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
)
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
class ErnieBotChat(BaseChatModel):
"""ErnieBot Chat large language model.
ERNIE-Bot is a large language model developed by Baidu,
covering a huge amount of Chinese data.
To use, you should have the `ernie_client_id` and `ernie_client_secret` set.
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
access_token will be automatically generated based on client_id and client_secret,
and will be regenerated after expiration.
Example:
.. code-block:: python
from langchain.chat_models import ErnieBotChat
chat = ErnieBotChat()
"""
ernie_client_id: Optional[str] = None
ernie_client_secret: Optional[str] = None
access_token: Optional[str] = None
model_name: str = "ERNIE-Bot-turbo"
streaming: Optional[bool] = False
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
_lock = threading.Lock()
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["ernie_client_id"] = get_from_dict_or_env(
values,
"ernie_client_id",
"ERNIE_CLIENT_ID",
)
values["ernie_client_secret"] = get_from_dict_or_env(
values,
"ernie_client_secret",
"ERNIE_CLIENT_SECRET",
)
return values
def _chat(self, payload: object) -> dict:
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
if self.model_name == "ERNIE-Bot-turbo":
url = f"{base_url}/eb-instant"
elif self.model_name == "ERNIE-Bot":
url = f"{base_url}/completions"
else:
raise ValueError(f"Got unknown model_name {self.model_name}")
resp = requests.post(
url,
headers={
"Content-Type": "application/json",
},
params={"access_token": self.access_token},
json=payload,
)
return resp.json()
def _refresh_access_token_with_lock(self) -> None:
with self._lock:
logger.debug("Refreshing access token")
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
resp = requests.post(
base_url,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
params={
"grant_type": "client_credentials",
"client_id": self.ernie_client_id,
"client_secret": self.ernie_client_secret,
},
)
self.access_token = str(resp.json().get("access_token"))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
raise ValueError("`streaming` option currently unsupported.")
if not self.access_token:
self._refresh_access_token_with_lock()
payload = {
"messages": [_convert_message_to_dict(m) for m in messages],
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
resp = self._chat(payload)
if resp.get("error_code"):
if resp.get("error_code") == 111:
self._refresh_access_token_with_lock()
resp = self._chat(payload)
else:
raise ValueError(f"Error from ErnieChat api response: {resp}")
return self._create_chat_result(resp)
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = [
ChatGeneration(message=AIMessage(content=response.get("result")))
]
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "ernie-chat"

@ -0,0 +1,26 @@
from langchain.chat_models.ernie import ErnieBotChat
from langchain.schema.messages import AIMessage, HumanMessage
def test_chat_ernie_bot() -> None:
chat = ErnieBotChat()
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_ernie_bot_with_model_name() -> None:
chat = ErnieBotChat(model_name="ERNIE-Bot")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_ernie_bot_with_temperature() -> None:
chat = ErnieBotChat(model_name="ERNIE-Bot", temperature=1.0)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

@ -0,0 +1,16 @@
from langchain.chat_models.ernie import _convert_message_to_dict
from langchain.schema.messages import AIMessage, HumanMessage
def test__convert_dict_to_message_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
Loading…
Cancel
Save