mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add ChatGLM for llm and chat_model by using ChatGLM API (#9797)
using sample: ``` endpoint_url = API URL ChatGLM_llm = ChatGLM( endpoint_url=endpoint_url, api_key=Your API Key by ChatGLM ) print(ChatGLM_llm("hello")) ``` ``` model = ChatChatGLM( chatglm_api_key="api_key", chatglm_api_base="api_base_url", model_name="model_name" ) chain = LLMChain(llm=model) ``` Description: The call of ChatGLM has been adapted. Issue: The call of ChatGLM has been adapted. Dependencies: Need python package `zhipuai` and `aiostream` Tag maintainer: @baskaryan Twitter handle: None I remove the compatibility test for pydantic version 2, because pydantic v2 can't not pickle classmethod,but BaseModel use @root_validator is a classmethod decorator. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d60145229b
commit
d6df288380
@ -22,6 +22,7 @@ from langchain.chat_models.anyscale import ChatAnyscale
|
|||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||||
from langchain.chat_models.bedrock import BedrockChat
|
from langchain.chat_models.bedrock import BedrockChat
|
||||||
|
from langchain.chat_models.chatchatglm import ChatChatGLM
|
||||||
from langchain.chat_models.ernie import ErnieBotChat
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
@ -51,6 +52,7 @@ __all__ = [
|
|||||||
"ChatAnyscale",
|
"ChatAnyscale",
|
||||||
"ChatLiteLLM",
|
"ChatLiteLLM",
|
||||||
"ErnieBotChat",
|
"ErnieBotChat",
|
||||||
|
"ChatChatGLM",
|
||||||
"ChatKonko",
|
"ChatKonko",
|
||||||
"QianfanChatEndpoint",
|
"QianfanChatEndpoint",
|
||||||
]
|
]
|
||||||
|
595
libs/langchain/langchain/chat_models/chatchatglm.py
Normal file
595
libs/langchain/langchain/chat_models/chatchatglm.py
Normal file
@ -0,0 +1,595 @@
|
|||||||
|
"""ChatGLM chat wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.base import create_base_retry_decorator
|
||||||
|
from langchain.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain.schema import ChatGeneration, ChatResult
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_tiktoken() -> Any:
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_token_ids. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
return tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
llm: ChatChatGLM,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
errors = [
|
||||||
|
BaseException,
|
||||||
|
]
|
||||||
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
llm: ChatChatGLM,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
# Use ChatGLM's async api
|
||||||
|
# https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_pro/invoke
|
||||||
|
m_kwargs = copy.deepcopy(kwargs)
|
||||||
|
m_kwargs["prompt"] = kwargs["messages"]
|
||||||
|
if len(m_kwargs["prompt"]) // 2 == 0:
|
||||||
|
raise ValueError("The length of the Prompt must be an odd number.")
|
||||||
|
if m_kwargs.get("streaming") or m_kwargs.get("stream"):
|
||||||
|
try:
|
||||||
|
from aiostream.stream import list as alist
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Streaming with ChatChatGLMrequires optional dependency aiostream. "
|
||||||
|
"To install please run `pip install aiostream`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def async_gen(**m_kwargs: Any) -> Any:
|
||||||
|
for event in llm.client.sse_invoke(**m_kwargs).events():
|
||||||
|
yield event.data
|
||||||
|
|
||||||
|
return alist(async_gen(**m_kwargs))
|
||||||
|
else:
|
||||||
|
return llm.client.invoke(**m_kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
# Fix for azure
|
||||||
|
# Also ChatGLM returns None for tool invocations
|
||||||
|
content = _dict.get("content", "") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
return AIMessage(content=eval(content), additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
elif role == "function":
|
||||||
|
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_chatglm_messages(messages: List[dict]) -> List[BaseMessage]:
|
||||||
|
"""Convert dictionaries representing ChatGLM messages to LangChain format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of dictionaries representing ChatGLM messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LangChain BaseMessage objects.
|
||||||
|
"""
|
||||||
|
return [_convert_dict_to_message(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.name,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
if "name" in message.additional_kwargs:
|
||||||
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChatGLM(BaseChatModel):
|
||||||
|
"""Wrapper around ChatGLM Chat large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``zhipuai`` python package installed, and the
|
||||||
|
environment variable ``CHATGLM_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the chatglm.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatChatGLM
|
||||||
|
chatglm = ChatChatGLM(model_name="gpt-3.5-turbo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"chatglm_api_key": "CHATGLM_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
client: Any = None #: :meta private:
|
||||||
|
model_name: str = Field(default="chatglm_pro", alias="model")
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
chatglm_api_key: Optional[str] = None
|
||||||
|
"""Base URL path for API requests,
|
||||||
|
leave blank if not using a proxy or service emulator."""
|
||||||
|
chatglm_api_base: Optional[str] = None
|
||||||
|
chatglm_organization: Optional[str] = None
|
||||||
|
# to support explicit proxy for ChatGLM
|
||||||
|
chatglm_proxy: Optional[str] = None
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout for requests to ChatGLM completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
n: int = 1
|
||||||
|
"""Number of chat completions to generate for each prompt."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
|
tiktoken_model_name: Optional[str] = None
|
||||||
|
"""The model name to pass to tiktoken when using this class.
|
||||||
|
Tiktoken is used to count the number of tokens in documents to constrain
|
||||||
|
them to be under a certain limit. By default, when set to None, this will
|
||||||
|
be the same as the embedding model name. However, there are some cases
|
||||||
|
where you may want to use this Embedding class with a model name not
|
||||||
|
supported by tiktoken. This can include when using Azure embeddings or
|
||||||
|
when using one of the many model providers that expose an ChatGLM-like
|
||||||
|
API but with different models. In those cases, in order to avoid erroring
|
||||||
|
when tiktoken is called, you can specify a model name to use here."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["chatglm_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "chatglm_api_key", "CHATGLM_API_KEY"
|
||||||
|
)
|
||||||
|
values["chatglm_organization"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"chatglm_organization",
|
||||||
|
"CHATGLM_ORGANIZATION",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
values["chatglm_api_base"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"chatglm_api_base",
|
||||||
|
"CHATGLM_API_BASE",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
values["chatglm_proxy"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"chatglm_proxy",
|
||||||
|
"CHATGLM_PROXY",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import zhipuai
|
||||||
|
|
||||||
|
zhipuai.api_key = values["chatglm_api_key"]
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import zhipuai python package. "
|
||||||
|
"Please install it with `pip install zhipuai`."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
values["client"] = zhipuai.model_api
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
"`zhipuai` has no `model_api` attribute, this is likely "
|
||||||
|
"due to an old version of the zhipuai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade zhipuai`."
|
||||||
|
)
|
||||||
|
if values["n"] < 1:
|
||||||
|
raise ValueError("n must be at least 1.")
|
||||||
|
if values["n"] > 1 and values["streaming"]:
|
||||||
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling ChatGLM API."""
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
"request_timeout": self.request_timeout,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"n": self.n,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
m_kwargs = copy.deepcopy(kwargs)
|
||||||
|
m_kwargs["prompt"] = kwargs["messages"]
|
||||||
|
if len(m_kwargs["prompt"]) // 2 == 0:
|
||||||
|
raise ValueError("The length of the Prompt must be an odd number.")
|
||||||
|
if m_kwargs.get("streaming") or m_kwargs.get("stream"):
|
||||||
|
return self.client.sse_invoke(**m_kwargs)
|
||||||
|
else:
|
||||||
|
return self.client.invoke(**m_kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for event in self.completion_with_retry(
|
||||||
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
).events():
|
||||||
|
delta = {"role": "assistant", "content": event.data}
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if stream if stream is not None else self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
response = self.completion_with_retry(
|
||||||
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = self._client_params
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response["data"]["choices"]:
|
||||||
|
message = _convert_dict_to_message(res)
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
token_usage = response.get("usage", {})
|
||||||
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for event in await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
|
if len(event):
|
||||||
|
delta = {"role": "assistant", "content": event[-1]}
|
||||||
|
else:
|
||||||
|
delta = {"role": "assistant", "content": ""}
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if stream if stream is not None else self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
async for chunk in self._astream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
response = await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _client_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used for the chatglm client."""
|
||||||
|
chatglm_creds: Dict[str, Any] = {
|
||||||
|
"api_key": self.chatglm_api_key,
|
||||||
|
"api_base": self.chatglm_api_base,
|
||||||
|
"organization": self.chatglm_organization,
|
||||||
|
"model": self.model_name,
|
||||||
|
}
|
||||||
|
if self.chatglm_proxy:
|
||||||
|
import zhipuai
|
||||||
|
|
||||||
|
zhipuai.api_key = self.chatglm_api_key
|
||||||
|
# zhipuai.proxy = {"http": self.chatglm_proxy, "https": self.chatglm_proxy}
|
||||||
|
# type: ignore[assignment] # noqa: E501
|
||||||
|
return {**self._default_params, **chatglm_creds}
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
**super()._get_invocation_params(stop=stop),
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "chatglm-chat"
|
||||||
|
|
||||||
|
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||||
|
tiktoken_ = _import_tiktoken()
|
||||||
|
if self.tiktoken_model_name is not None:
|
||||||
|
model = self.tiktoken_model_name
|
||||||
|
else:
|
||||||
|
model = self.model_name
|
||||||
|
# Returns the number of tokens used by a list of messages.
|
||||||
|
try:
|
||||||
|
encoding = tiktoken_.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken_.get_encoding(model)
|
||||||
|
return model, encoding
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
"""Get the tokens present in the text with tiktoken package."""
|
||||||
|
# tiktoken NOT supported for Python 3.7 or below
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_token_ids(text)
|
||||||
|
_, encoding_model = self._get_encoding_model()
|
||||||
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
|
Official documentation: https://open.bigmodel.cn/dev/api
|
||||||
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_num_tokens_from_messages(messages)
|
||||||
|
model, encoding = self._get_encoding_model()
|
||||||
|
if model.startswith("chatglm_pro"):
|
||||||
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
tokens_per_message = 4
|
||||||
|
# if there's a name, the role is omitted
|
||||||
|
tokens_per_name = -1
|
||||||
|
elif model.startswith("chatglm_std") or model.startswith("chatglm_lite"):
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"get_num_tokens_from_messages() is not presently implemented "
|
||||||
|
f"for model {model}."
|
||||||
|
"See https://open.bigmodel.cn/dev/api for "
|
||||||
|
"information on how messages are converted to tokens."
|
||||||
|
)
|
||||||
|
num_tokens = 0
|
||||||
|
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
# Cast str(value) in case the message value is not a string
|
||||||
|
# This occurs with function messages
|
||||||
|
num_tokens += len(encoding.encode(str(value)))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
# every reply is primed with <im_start>assistant
|
||||||
|
num_tokens += 3
|
||||||
|
return num_tokens
|
@ -26,6 +26,7 @@ class ChatGLM(LLM):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
endpoint_url: str = "http://127.0.0.1:8000/"
|
endpoint_url: str = "http://127.0.0.1:8000/"
|
||||||
|
api_key: str = ""
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
@ -78,8 +79,21 @@ class ChatGLM(LLM):
|
|||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
# HTTP headers for authorization
|
# HTTP headers for authorization
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json; charset=UTF-8",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
from zhipuai.utils import jwt_token
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Must install zhipuai, use`pip install zhipuai`", e)
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception(
|
||||||
|
"api_key not provided, you could provide it with "
|
||||||
|
"`shell: export API_KEY=xxx` or `code: zhipuai.api_key=xxx`"
|
||||||
|
)
|
||||||
|
jwt_api_key_ = jwt_token.generate_token(self.api_key)
|
||||||
|
headers.update({"Authorization": jwt_api_key_})
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
@ -105,12 +119,11 @@ class ChatGLM(LLM):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_response = response.json()
|
parsed_response = response.json()
|
||||||
|
|
||||||
# Check if response content does exists
|
# Check if response content does exists
|
||||||
if isinstance(parsed_response, dict):
|
if isinstance(parsed_response, dict):
|
||||||
content_keys = "response"
|
content_keys = "data"
|
||||||
if content_keys in parsed_response:
|
if content_keys in parsed_response:
|
||||||
text = parsed_response[content_keys]
|
text = eval(parsed_response[content_keys]["choices"][0]["content"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"No content in response : {parsed_response}")
|
raise ValueError(f"No content in response : {parsed_response}")
|
||||||
else:
|
else:
|
||||||
@ -125,5 +138,6 @@ class ChatGLM(LLM):
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
if self.with_history:
|
if self.with_history:
|
||||||
self.history = self.history + [[None, parsed_response["response"]]]
|
self.history = self.history + [[None, parsed_response["data"]["choices"]]]
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
Loading…
Reference in New Issue
Block a user