community[minor]: add coze chat model (#20770)

add coze chat model, to call coze.com apis
pull/15890/head^2
chyroc 2 months ago committed by GitHub
parent 29493bb598
commit 3e241956d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,181 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Coze Chat\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chat with Coze Bot\n",
"\n",
"ChatCoze chat models API by coze.com. For more information, see [https://www.coze.com/open/docs/chat](https://www.coze.com/open/docs/chat)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-25T15:14:24.186131Z",
"start_time": "2024-04-25T15:14:23.831767Z"
}
},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatCoze\n",
"from langchain_core.messages import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-25T15:14:24.191123Z",
"start_time": "2024-04-25T15:14:24.186330Z"
}
},
"outputs": [],
"source": [
"chat = ChatCoze(\n",
" coze_api_base=\"YOUR_API_BASE\",\n",
" coze_api_key=\"YOUR_API_KEY\",\n",
" bot_id=\"YOUR_BOT_ID\",\n",
" user=\"YOUR_USER_ID\",\n",
" conversation_id=\"YOUR_CONVERSATION_ID\",\n",
" streaming=False,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, you can set your API key and API base with:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"COZE_API_KEY\"] = \"YOUR_API_KEY\"\n",
"os.environ[\"COZE_API_BASE\"] = \"YOUR_API_BASE\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-25T15:14:25.853218Z",
"start_time": "2024-04-25T15:14:24.192408Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='为你找到关于coze的信息如下\n\nCoze是一个由字节跳动推出的AI聊天机器人和应用程序编辑开发平台。\n\n用户无论是否有编程经验都可以通过该平台快速创建各种类型的聊天机器人、智能体、AI应用和插件并将其部署在社交平台和即时聊天应用程序中。\n\n国际版使用的模型比国内版更强大。')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat([HumanMessage(content=\"什么是扣子(coze)\")])"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Chat with Coze Streaming"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-25T15:14:25.870044Z",
"start_time": "2024-04-25T15:14:25.863381Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"chat = ChatCoze(\n",
" coze_api_base=\"YOUR_API_BASE\",\n",
" coze_api_key=\"YOUR_API_KEY\",\n",
" bot_id=\"YOUR_BOT_ID\",\n",
" user=\"YOUR_USER_ID\",\n",
" conversation_id=\"YOUR_CONVERSATION_ID\",\n",
" streaming=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-25T15:14:27.153546Z",
"start_time": "2024-04-25T15:14:25.868470Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessageChunk(content='为你查询到Coze是一个由字节跳动推出的AI聊天机器人和应用程序编辑开发平台。')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat([HumanMessage(content=\"什么是扣子(coze)\")])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -42,6 +42,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.cohere import (
ChatCohere, # noqa: F401
)
from langchain_community.chat_models.coze import (
ChatCoze, # noqa: F401
)
from langchain_community.chat_models.databricks import (
ChatDatabricks, # noqa: F401
)
@ -167,6 +170,7 @@ __all__ = [
"ChatAnyscale",
"ChatBaichuan",
"ChatCohere",
"ChatCoze",
"ChatDatabricks",
"ChatDeepInfra",
"ChatEverlyAI",
@ -217,6 +221,7 @@ _module_lookup = {
"ChatAnyscale": "langchain_community.chat_models.anyscale",
"ChatBaichuan": "langchain_community.chat_models.baichuan",
"ChatCohere": "langchain_community.chat_models.cohere",
"ChatCoze": "langchain_community.chat_models.coze",
"ChatDatabricks": "langchain_community.chat_models.databricks",
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
"ChatEverlyAI": "langchain_community.chat_models.everlyai",

@ -0,0 +1,255 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
)
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.coze.com"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, HumanMessage):
message_dict = {
"role": "user",
"content": message.content,
"content_type": "text",
}
else:
message_dict = {
"role": "assistant",
"content": message.content,
"content_type": "text",
}
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
msg_type = _dict["type"]
if msg_type != "answer":
return None
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
return AIMessageChunk(content=content)
else:
return ChatMessageChunk(content=content, role=role)
class ChatCoze(BaseChatModel):
"""ChatCoze chat models API by coze.com
For more information, see https://www.coze.com/open/docs/chat
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"coze_api_key": "COZE_API_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
coze_api_base: str = Field(default=DEFAULT_API_BASE)
"""Coze custom endpoints"""
coze_api_key: Optional[SecretStr] = None
"""Coze API Key"""
request_timeout: int = Field(default=60, alias="timeout")
"""request timeout for chat http requests"""
bot_id: str = Field(default="")
"""The ID of the bot that the API interacts with."""
conversation_id: str = Field(default="")
"""Indicate which conversation the dialog is taking place in. If there is no need to
distinguish the context of the conversation(just a question and answer), skip this
parameter. It will be generated by the system."""
user: str = Field(default="")
"""The user who calls the API to chat with the bot."""
streaming: bool = False
"""Whether to stream the response to the client.
false: if no value is specified or set to false, a non-streaming response is
returned. "Non-streaming response" means that all responses will be returned at once
after they are all ready, and the client does not need to concatenate the content.
true: set to true, partial message deltas will be sent .
"Streaming response" will provide real-time response of the model to the client, and
the client needs to assemble the final reply based on the type of message. """
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["coze_api_base"] = get_from_dict_or_env(
values,
"coze_api_base",
"COZE_API_BASE",
DEFAULT_API_BASE,
)
values["coze_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"coze_api_key",
"COZE_API_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Coze API."""
return {
"bot_id": self.bot_id,
"conversation_id": self.conversation_id,
"user": self.user,
"streaming": self.streaming,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
r = self._chat(messages, **kwargs)
res = r.json()
if res["code"] != 0:
raise ValueError(
f"Error from Coze api response: {res['code']}: {res['msg']}, "
f"logid: {r.headers.get('X-Tt-Logid')}"
)
return self._create_chat_result(res.get("messages") or [])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
for chunk in res.iter_lines():
chunk = chunk.decode("utf-8").strip("\r\n")
parts = chunk.split("data:", 1)
chunk = parts[1] if len(parts) > 1 else None
if chunk is None:
continue
response = json.loads(chunk)
if response["event"] == "done":
break
elif (
response["event"] != "message"
or response["message"]["type"] != "answer"
):
continue
chunk = _convert_delta_to_message_chunk(response["message"])
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}
query = ""
chat_history = []
for msg in messages:
if isinstance(msg, HumanMessage):
query = f"{msg.content}" # overwrite, to get last user message as query
chat_history.append(_convert_message_to_dict(msg))
conversation_id = parameters.pop("conversation_id")
bot_id = parameters.pop("bot_id")
user = parameters.pop("user")
streaming = parameters.pop("streaming")
payload = {
"conversation_id": conversation_id,
"bot_id": bot_id,
"user": user,
"query": query,
"stream": streaming,
}
if chat_history:
payload["chat_history"] = chat_history
url = self.coze_api_base + "/open_api/v2/chat"
api_key = ""
if self.coze_api_key:
api_key = self.coze_api_key.get_secret_value()
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
json=payload,
stream=streaming,
)
if res.status_code != 200:
logid = res.headers.get("X-Tt-Logid")
raise ValueError(f"Error from Coze api response: {res}, logid: {logid}")
return res
def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult:
generations = []
for c in messages:
msg = _convert_dict_to_message(c)
if msg:
generations.append(ChatGeneration(message=msg))
llm_output = {"token_usage": "", "model": ""}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "coze-chat"

@ -0,0 +1,36 @@
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_models.coze import ChatCoze
# For testing, run:
# TEST_FILE=tests/integration_tests/chat_models/test_coze.py make test
def test_chat_coze_default() -> None:
chat = ChatCoze(
coze_api_base="https://api.coze.com",
coze_api_key="pat_...",
bot_id="7....",
user="123",
conversation_id="",
streaming=True,
)
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_coze_default_non_streaming() -> None:
chat = ChatCoze(
coze_api_base="https://api.coze.com",
coze_api_key="pat_...",
bot_id="7....",
user="123",
conversation_id="",
streaming=False,
)
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

@ -7,6 +7,7 @@ EXPECTED_ALL = [
"ChatAnyscale",
"ChatBaichuan",
"ChatCohere",
"ChatCoze",
"ChatDatabricks",
"ChatDeepInfra",
"ChatEverlyAI",

Loading…
Cancel
Save