mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
256 lines
8.3 KiB
Python
256 lines
8.3 KiB
Python
|
import json
|
||
|
import logging
|
||
|
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
|
||
|
|
||
|
import requests
|
||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
|
from langchain_core.language_models.chat_models import (
|
||
|
BaseChatModel,
|
||
|
generate_from_stream,
|
||
|
)
|
||
|
from langchain_core.messages import (
|
||
|
AIMessage,
|
||
|
AIMessageChunk,
|
||
|
BaseMessage,
|
||
|
BaseMessageChunk,
|
||
|
ChatMessage,
|
||
|
ChatMessageChunk,
|
||
|
HumanMessage,
|
||
|
HumanMessageChunk,
|
||
|
)
|
||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||
|
from langchain_core.utils import (
|
||
|
convert_to_secret_str,
|
||
|
get_from_dict_or_env,
|
||
|
)
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
DEFAULT_API_BASE = "https://api.coze.com"
|
||
|
|
||
|
|
||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||
|
message_dict: Dict[str, Any]
|
||
|
if isinstance(message, HumanMessage):
|
||
|
message_dict = {
|
||
|
"role": "user",
|
||
|
"content": message.content,
|
||
|
"content_type": "text",
|
||
|
}
|
||
|
else:
|
||
|
message_dict = {
|
||
|
"role": "assistant",
|
||
|
"content": message.content,
|
||
|
"content_type": "text",
|
||
|
}
|
||
|
return message_dict
|
||
|
|
||
|
|
||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
|
||
|
msg_type = _dict["type"]
|
||
|
if msg_type != "answer":
|
||
|
return None
|
||
|
role = _dict["role"]
|
||
|
if role == "user":
|
||
|
return HumanMessage(content=_dict["content"])
|
||
|
elif role == "assistant":
|
||
|
return AIMessage(content=_dict.get("content", "") or "")
|
||
|
else:
|
||
|
return ChatMessage(content=_dict["content"], role=role)
|
||
|
|
||
|
|
||
|
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
|
||
|
role = _dict.get("role")
|
||
|
content = _dict.get("content") or ""
|
||
|
|
||
|
if role == "user":
|
||
|
return HumanMessageChunk(content=content)
|
||
|
elif role == "assistant":
|
||
|
return AIMessageChunk(content=content)
|
||
|
else:
|
||
|
return ChatMessageChunk(content=content, role=role)
|
||
|
|
||
|
|
||
|
class ChatCoze(BaseChatModel):
|
||
|
"""ChatCoze chat models API by coze.com
|
||
|
|
||
|
For more information, see https://www.coze.com/open/docs/chat
|
||
|
"""
|
||
|
|
||
|
@property
|
||
|
def lc_secrets(self) -> Dict[str, str]:
|
||
|
return {
|
||
|
"coze_api_key": "COZE_API_KEY",
|
||
|
}
|
||
|
|
||
|
@property
|
||
|
def lc_serializable(self) -> bool:
|
||
|
return True
|
||
|
|
||
|
coze_api_base: str = Field(default=DEFAULT_API_BASE)
|
||
|
"""Coze custom endpoints"""
|
||
|
coze_api_key: Optional[SecretStr] = None
|
||
|
"""Coze API Key"""
|
||
|
request_timeout: int = Field(default=60, alias="timeout")
|
||
|
"""request timeout for chat http requests"""
|
||
|
bot_id: str = Field(default="")
|
||
|
"""The ID of the bot that the API interacts with."""
|
||
|
conversation_id: str = Field(default="")
|
||
|
"""Indicate which conversation the dialog is taking place in. If there is no need to
|
||
|
distinguish the context of the conversation(just a question and answer), skip this
|
||
|
parameter. It will be generated by the system."""
|
||
|
user: str = Field(default="")
|
||
|
"""The user who calls the API to chat with the bot."""
|
||
|
streaming: bool = False
|
||
|
"""Whether to stream the response to the client.
|
||
|
false: if no value is specified or set to false, a non-streaming response is
|
||
|
returned. "Non-streaming response" means that all responses will be returned at once
|
||
|
after they are all ready, and the client does not need to concatenate the content.
|
||
|
true: set to true, partial message deltas will be sent .
|
||
|
"Streaming response" will provide real-time response of the model to the client, and
|
||
|
the client needs to assemble the final reply based on the type of message. """
|
||
|
|
||
|
class Config:
|
||
|
"""Configuration for this pydantic object."""
|
||
|
|
||
|
allow_population_by_field_name = True
|
||
|
|
||
|
@root_validator()
|
||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||
|
values["coze_api_base"] = get_from_dict_or_env(
|
||
|
values,
|
||
|
"coze_api_base",
|
||
|
"COZE_API_BASE",
|
||
|
DEFAULT_API_BASE,
|
||
|
)
|
||
|
values["coze_api_key"] = convert_to_secret_str(
|
||
|
get_from_dict_or_env(
|
||
|
values,
|
||
|
"coze_api_key",
|
||
|
"COZE_API_KEY",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return values
|
||
|
|
||
|
@property
|
||
|
def _default_params(self) -> Dict[str, Any]:
|
||
|
"""Get the default parameters for calling Coze API."""
|
||
|
return {
|
||
|
"bot_id": self.bot_id,
|
||
|
"conversation_id": self.conversation_id,
|
||
|
"user": self.user,
|
||
|
"streaming": self.streaming,
|
||
|
}
|
||
|
|
||
|
def _generate(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> ChatResult:
|
||
|
if self.streaming:
|
||
|
stream_iter = self._stream(
|
||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||
|
)
|
||
|
return generate_from_stream(stream_iter)
|
||
|
|
||
|
r = self._chat(messages, **kwargs)
|
||
|
res = r.json()
|
||
|
if res["code"] != 0:
|
||
|
raise ValueError(
|
||
|
f"Error from Coze api response: {res['code']}: {res['msg']}, "
|
||
|
f"logid: {r.headers.get('X-Tt-Logid')}"
|
||
|
)
|
||
|
|
||
|
return self._create_chat_result(res.get("messages") or [])
|
||
|
|
||
|
def _stream(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Iterator[ChatGenerationChunk]:
|
||
|
res = self._chat(messages, **kwargs)
|
||
|
for chunk in res.iter_lines():
|
||
|
chunk = chunk.decode("utf-8").strip("\r\n")
|
||
|
parts = chunk.split("data:", 1)
|
||
|
chunk = parts[1] if len(parts) > 1 else None
|
||
|
if chunk is None:
|
||
|
continue
|
||
|
response = json.loads(chunk)
|
||
|
if response["event"] == "done":
|
||
|
break
|
||
|
elif (
|
||
|
response["event"] != "message"
|
||
|
or response["message"]["type"] != "answer"
|
||
|
):
|
||
|
continue
|
||
|
chunk = _convert_delta_to_message_chunk(response["message"])
|
||
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||
|
if run_manager:
|
||
|
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||
|
yield cg_chunk
|
||
|
|
||
|
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||
|
parameters = {**self._default_params, **kwargs}
|
||
|
|
||
|
query = ""
|
||
|
chat_history = []
|
||
|
for msg in messages:
|
||
|
if isinstance(msg, HumanMessage):
|
||
|
query = f"{msg.content}" # overwrite, to get last user message as query
|
||
|
chat_history.append(_convert_message_to_dict(msg))
|
||
|
|
||
|
conversation_id = parameters.pop("conversation_id")
|
||
|
bot_id = parameters.pop("bot_id")
|
||
|
user = parameters.pop("user")
|
||
|
streaming = parameters.pop("streaming")
|
||
|
|
||
|
payload = {
|
||
|
"conversation_id": conversation_id,
|
||
|
"bot_id": bot_id,
|
||
|
"user": user,
|
||
|
"query": query,
|
||
|
"stream": streaming,
|
||
|
}
|
||
|
if chat_history:
|
||
|
payload["chat_history"] = chat_history
|
||
|
|
||
|
url = self.coze_api_base + "/open_api/v2/chat"
|
||
|
api_key = ""
|
||
|
if self.coze_api_key:
|
||
|
api_key = self.coze_api_key.get_secret_value()
|
||
|
|
||
|
res = requests.post(
|
||
|
url=url,
|
||
|
timeout=self.request_timeout,
|
||
|
headers={
|
||
|
"Content-Type": "application/json",
|
||
|
"Authorization": f"Bearer {api_key}",
|
||
|
},
|
||
|
json=payload,
|
||
|
stream=streaming,
|
||
|
)
|
||
|
if res.status_code != 200:
|
||
|
logid = res.headers.get("X-Tt-Logid")
|
||
|
raise ValueError(f"Error from Coze api response: {res}, logid: {logid}")
|
||
|
return res
|
||
|
|
||
|
def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult:
|
||
|
generations = []
|
||
|
for c in messages:
|
||
|
msg = _convert_dict_to_message(c)
|
||
|
if msg:
|
||
|
generations.append(ChatGeneration(message=msg))
|
||
|
|
||
|
llm_output = {"token_usage": "", "model": ""}
|
||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
return "coze-chat"
|