mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
9678797625
- Description: callback on_llm_new_token before yield chunk for _stream/_astream for some chat models, make all chat models in a consistent behaviour. - Issue: N/A - Dependencies: N/A
337 lines
11 KiB
Python
337 lines
11 KiB
Python
"""ZHIPU AI chat models wrapper."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from functools import partial
|
|
from typing import Any, Dict, Iterator, List, Optional, cast
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ref(BaseModel):
|
|
"""Reference used in CharacterGLM."""
|
|
|
|
enable: bool = Field(True)
|
|
search_query: str = Field("")
|
|
|
|
|
|
class meta(BaseModel):
|
|
"""Metadata used in CharacterGLM."""
|
|
|
|
user_info: str = Field("")
|
|
bot_info: str = Field("")
|
|
bot_name: str = Field("")
|
|
user_name: str = Field("User")
|
|
|
|
|
|
class ChatZhipuAI(BaseChatModel):
|
|
"""
|
|
`ZHIPU AI` large language chat models API.
|
|
|
|
To use, you should have the ``zhipuai`` python package installed.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models import ChatZhipuAI
|
|
|
|
zhipuai_chat = ChatZhipuAI(
|
|
temperature=0.5,
|
|
api_key="your-api-key",
|
|
model="chatglm_turbo",
|
|
)
|
|
|
|
"""
|
|
|
|
zhipuai: Any
|
|
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
|
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
|
|
|
model: str = Field("chatglm_turbo")
|
|
"""
|
|
Model name to use.
|
|
-chatglm_turbo:
|
|
According to the input of natural language instructions to complete a
|
|
variety of language tasks, it is recommended to use SSE or asynchronous
|
|
call request interface.
|
|
-characterglm:
|
|
It supports human-based role-playing, ultra-long multi-round memory,
|
|
and thousands of character dialogues. It is widely used in anthropomorphic
|
|
dialogues or game scenes such as emotional accompaniments, game intelligent
|
|
NPCS, Internet celebrities/stars/movie and TV series IP clones, digital
|
|
people/virtual anchors, and text adventure games.
|
|
"""
|
|
|
|
temperature: float = Field(0.95)
|
|
"""
|
|
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
|
be equal to 0.
|
|
The larger the value, the more random and creative the output; The smaller
|
|
the value, the more stable or certain the output will be.
|
|
You are advised to adjust top_p or temperature parameters based on application
|
|
scenarios, but do not adjust the two parameters at the same time.
|
|
"""
|
|
|
|
top_p: float = Field(0.7)
|
|
"""
|
|
Another method of sampling temperature is called nuclear sampling. The value
|
|
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
|
The model considers the results with top_p probability quality tokens.
|
|
For example, 0.1 means that the model decoder only considers tokens from the
|
|
top 10% probability of the candidate set.
|
|
You are advised to adjust top_p or temperature parameters based on application
|
|
scenarios, but do not adjust the two parameters at the same time.
|
|
"""
|
|
|
|
request_id: Optional[str] = Field(None)
|
|
"""
|
|
Parameter transmission by the client must ensure uniqueness; A unique
|
|
identifier used to distinguish each request, which is generated by default
|
|
by the platform when the client does not transmit it.
|
|
"""
|
|
|
|
streaming: bool = Field(False)
|
|
"""Whether to stream the results or not."""
|
|
|
|
incremental: bool = Field(True)
|
|
"""
|
|
When invoked by the SSE interface, it is used to control whether the content
|
|
is returned incremented or full each time.
|
|
If this parameter is not provided, the value is returned incremented by default.
|
|
"""
|
|
|
|
return_type: str = Field("json_string")
|
|
"""
|
|
This parameter is used to control the type of content returned each time.
|
|
- json_string Returns a standard JSON string.
|
|
- text Returns the original text content.
|
|
"""
|
|
|
|
ref: Optional[ref] = Field(None)
|
|
"""
|
|
This parameter is used to control the reference of external information
|
|
during the request.
|
|
Currently, this parameter is used to control whether to reference external
|
|
information.
|
|
If this field is empty or absent, the search and parameter passing format
|
|
is enabled by default.
|
|
{"enable": "true", "search_query": "history "}
|
|
"""
|
|
|
|
meta: Optional[meta] = Field(None)
|
|
"""Used in CharacterGLM"""
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
return {"model_name": self.model}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return the type of chat model."""
|
|
return "zhipuai"
|
|
|
|
@property
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
|
|
|
|
@classmethod
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
"""Get the namespace of the langchain object."""
|
|
return ["langchain", "chat_models", "zhipuai"]
|
|
|
|
@property
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
if self.model:
|
|
attributes["model"] = self.model
|
|
|
|
if self.streaming:
|
|
attributes["streaming"] = self.streaming
|
|
|
|
if self.return_type:
|
|
attributes["return_type"] = self.return_type
|
|
|
|
return attributes
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
try:
|
|
import zhipuai
|
|
|
|
self.zhipuai = zhipuai
|
|
self.zhipuai.api_key = self.zhipuai_api_key
|
|
except ImportError:
|
|
raise RuntimeError(
|
|
"Could not import zhipuai package. "
|
|
"Please install it via 'pip install zhipuai'"
|
|
)
|
|
|
|
def invoke(self, prompt: Any) -> Any: # type: ignore[override]
|
|
if self.model == "chatglm_turbo":
|
|
return self.zhipuai.model_api.invoke(
|
|
model=self.model,
|
|
prompt=prompt,
|
|
top_p=self.top_p,
|
|
temperature=self.temperature,
|
|
request_id=self.request_id,
|
|
return_type=self.return_type,
|
|
)
|
|
elif self.model == "characterglm":
|
|
_meta = cast(meta, self.meta).dict()
|
|
return self.zhipuai.model_api.invoke(
|
|
model=self.model,
|
|
meta=_meta,
|
|
prompt=prompt,
|
|
request_id=self.request_id,
|
|
return_type=self.return_type,
|
|
)
|
|
return None
|
|
|
|
def sse_invoke(self, prompt: Any) -> Any:
|
|
if self.model == "chatglm_turbo":
|
|
return self.zhipuai.model_api.sse_invoke(
|
|
model=self.model,
|
|
prompt=prompt,
|
|
top_p=self.top_p,
|
|
temperature=self.temperature,
|
|
request_id=self.request_id,
|
|
return_type=self.return_type,
|
|
incremental=self.incremental,
|
|
)
|
|
elif self.model == "characterglm":
|
|
_meta = cast(meta, self.meta).dict()
|
|
return self.zhipuai.model_api.sse_invoke(
|
|
model=self.model,
|
|
prompt=prompt,
|
|
meta=_meta,
|
|
request_id=self.request_id,
|
|
return_type=self.return_type,
|
|
incremental=self.incremental,
|
|
)
|
|
return None
|
|
|
|
async def async_invoke(self, prompt: Any) -> Any:
|
|
loop = asyncio.get_running_loop()
|
|
partial_func = partial(
|
|
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
|
)
|
|
response = await loop.run_in_executor(
|
|
None,
|
|
partial_func,
|
|
)
|
|
return response
|
|
|
|
async def async_invoke_result(self, task_id: Any) -> Any:
|
|
loop = asyncio.get_running_loop()
|
|
response = await loop.run_in_executor(
|
|
None,
|
|
self.zhipuai.model_api.query_async_invoke_result,
|
|
task_id,
|
|
)
|
|
return response
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
stream: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Generate a chat response."""
|
|
prompt: List = []
|
|
for message in messages:
|
|
if isinstance(message, AIMessage):
|
|
role = "assistant"
|
|
else: # For both HumanMessage and SystemMessage, role is 'user'
|
|
role = "user"
|
|
|
|
prompt.append({"role": role, "content": message.content})
|
|
|
|
should_stream = stream if stream is not None else self.streaming
|
|
if not should_stream:
|
|
response = self.invoke(prompt)
|
|
|
|
if response["code"] != 200:
|
|
raise RuntimeError(response)
|
|
|
|
content = response["data"]["choices"][0]["content"]
|
|
return ChatResult(
|
|
generations=[ChatGeneration(message=AIMessage(content=content))]
|
|
)
|
|
|
|
else:
|
|
stream_iter = self._stream(
|
|
prompt=prompt,
|
|
stop=stop,
|
|
run_manager=run_manager,
|
|
**kwargs,
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
async def _agenerate( # type: ignore[override]
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
stream: Optional[bool] = False,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Asynchronously generate a chat response."""
|
|
|
|
prompt = []
|
|
for message in messages:
|
|
if isinstance(message, AIMessage):
|
|
role = "assistant"
|
|
else: # For both HumanMessage and SystemMessage, role is 'user'
|
|
role = "user"
|
|
|
|
prompt.append({"role": role, "content": message.content})
|
|
|
|
invoke_response = await self.async_invoke(prompt)
|
|
task_id = invoke_response["data"]["task_id"]
|
|
|
|
response = await self.async_invoke_result(task_id)
|
|
while response["data"]["task_status"] != "SUCCESS":
|
|
await asyncio.sleep(1)
|
|
response = await self.async_invoke_result(task_id)
|
|
|
|
content = response["data"]["choices"][0]["content"]
|
|
content = json.loads(content)
|
|
return ChatResult(
|
|
generations=[ChatGeneration(message=AIMessage(content=content))]
|
|
)
|
|
|
|
def _stream( # type: ignore[override]
|
|
self,
|
|
prompt: List[Dict[str, str]],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
"""Stream the chat response in chunks."""
|
|
response = self.sse_invoke(prompt)
|
|
|
|
for r in response.events():
|
|
if r.event == "add":
|
|
delta = r.data
|
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(delta, chunk=chunk)
|
|
yield chunk
|
|
|
|
elif r.event == "error":
|
|
raise ValueError(f"Error from ZhipuAI API response: {r.data}")
|