From 019b6ebe8d6d97227ec48af3d1fe7220b22f5b0e Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 24 Jan 2024 11:23:46 +0800 Subject: [PATCH] community[minor]: Add iFlyTek Spark LLM chat model support (#13389) - **Description:** This PR enables LangChain to access the iFlyTek's Spark LLM via the chat_models wrapper. - **Dependencies:** websocket-client ^1.6.1 - **Tag maintainer:** @baskaryan ### SparkLLM chat model usage Get SparkLLM's app_id, api_key and api_secret from [iFlyTek SparkLLM API Console](https://console.xfyun.cn/services/bm3) (for more info, see [iFlyTek SparkLLM Intro](https://xinghuo.xfyun.cn/sparkapi) ), then set environment variables `IFLYTEK_SPARK_APP_ID`, `IFLYTEK_SPARK_API_KEY` and `IFLYTEK_SPARK_API_SECRET` or pass parameters when using it like the demo below: ```python3 from langchain.chat_models.sparkllm import ChatSparkLLM client = ChatSparkLLM( spark_app_id="", spark_api_key="", spark_api_secret="" ) ``` --- docs/docs/integrations/chat/sparkllm.ipynb | 99 ++++ .../chat_models/__init__.py | 2 + .../chat_models/sparkllm.py | 473 ++++++++++++++++++ .../chat_models/test_sparkllm.py | 36 ++ .../unit_tests/chat_models/test_imports.py | 1 + 5 files changed, 611 insertions(+) create mode 100644 docs/docs/integrations/chat/sparkllm.ipynb create mode 100644 libs/community/langchain_community/chat_models/sparkllm.py create mode 100644 libs/community/tests/integration_tests/chat_models/test_sparkllm.py diff --git a/docs/docs/integrations/chat/sparkllm.ipynb b/docs/docs/integrations/chat/sparkllm.ipynb new file mode 100644 index 0000000000..4fe68f9a2a --- /dev/null +++ b/docs/docs/integrations/chat/sparkllm.ipynb @@ -0,0 +1,99 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3ddface67cd10a87", + "metadata": { + "collapsed": false + }, + "source": [ + "# SparkLLM Chat\n", + "\n", + "SparkLLM chat models API by iFlyTek. For more information, see [iFlyTek Open Platform](https://www.xfyun.cn/)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic use" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43daa39972d4c533", + "metadata": { + "collapsed": false, + "is_executing": true + }, + "outputs": [], + "source": [ + "\"\"\"For basic init and call\"\"\"\n", + "from langchain.chat_models import ChatSparkLLM\n", + "from langchain.schema import HumanMessage\n", + "\n", + "chat = ChatSparkLLM(\n", + " spark_app_id=\"\", spark_api_key=\"\", spark_api_secret=\"\"\n", + ")\n", + "message = HumanMessage(content=\"Hello\")\n", + "chat([message])" + ] + }, + { + "cell_type": "markdown", + "id": "df755f4c5689510", + "metadata": { + "collapsed": false + }, + "source": [ + "- Get SparkLLM's app_id, api_key and api_secret from [iFlyTek SparkLLM API Console](https://console.xfyun.cn/services/bm3) (for more info, see [iFlyTek SparkLLM Intro](https://xinghuo.xfyun.cn/sparkapi) ), then set environment variables `IFLYTEK_SPARK_APP_ID`, `IFLYTEK_SPARK_API_KEY` and `IFLYTEK_SPARK_API_SECRET` or pass parameters when creating `ChatSparkLLM` as the demo above." + ] + }, + { + "cell_type": "markdown", + "id": "984e32ee47bc6772", + "metadata": { + "collapsed": false + }, + "source": [ + "## For ChatSparkLLM with Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7dc162bd65fec08f", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "chat = ChatSparkLLM(streaming=True)\n", + "for chunk in chat.stream(\"Hello!\"):\n", + " print(chunk.content, end=\"\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 067a57038b..bc35a3c4ef 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -48,6 +48,7 @@ from langchain_community.chat_models.ollama import ChatOllama from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.chat_models.pai_eas_endpoint import PaiEasChatEndpoint from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI +from langchain_community.chat_models.sparkllm import ChatSparkLLM from langchain_community.chat_models.tongyi import ChatTongyi from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat @@ -88,6 +89,7 @@ __all__ = [ "ChatBaichuan", "ChatHunyuan", "GigaChat", + "ChatSparkLLM", "VolcEngineMaasChat", "GPTRouter", "ChatZhipuAI", diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py new file mode 100644 index 0000000000..7e84c2e98c --- /dev/null +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -0,0 +1,473 @@ +import base64 +import hashlib +import hmac +import json +import logging +import queue +import threading +from datetime import datetime +from queue import Queue +from time import mktime +from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type +from urllib.parse import urlencode, urlparse, urlunparse +from wsgiref.handlers import format_date_time + +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, +) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) + +logger = logging.getLogger(__name__) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + msg_role = _dict["role"] + msg_content = _dict["content"] + if msg_role == "user": + return HumanMessage(content=msg_content) + elif msg_role == "assistant": + content = msg_content or "" + return AIMessage(content=content) + elif msg_role == "system": + return SystemMessage(content=msg_content) + else: + return ChatMessage(content=msg_content, role=msg_role) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + msg_role = _dict["role"] + msg_content = _dict.get("content", "") + if msg_role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=msg_content) + elif msg_role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=msg_content) + elif msg_role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=msg_content, role=msg_role) + else: + return default_class(content=msg_content) + + +class ChatSparkLLM(BaseChatModel): + """Wrapper around iFlyTek's Spark large language model. + + To use, you should pass `app_id`, `api_key`, `api_secret` + as a named parameter to the constructor OR set environment + variables ``IFLYTEK_SPARK_APP_ID``, ``IFLYTEK_SPARK_API_KEY`` and + ``IFLYTEK_SPARK_API_SECRET`` + + Example: + .. code-block:: python + + client = ChatSparkLLM( + spark_app_id="", + spark_api_key="", + spark_api_secret="" + ) + """ + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return False + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "spark_app_id": "IFLYTEK_SPARK_APP_ID", + "spark_api_key": "IFLYTEK_SPARK_API_KEY", + "spark_api_secret": "IFLYTEK_SPARK_API_SECRET", + "spark_api_url": "IFLYTEK_SPARK_API_URL", + "spark_llm_domain": "IFLYTEK_SPARK_LLM_DOMAIN", + } + + client: Any = None #: :meta private: + spark_app_id: Optional[str] = None + spark_api_key: Optional[str] = None + spark_api_secret: Optional[str] = None + spark_api_url: Optional[str] = None + spark_llm_domain: Optional[str] = None + spark_user_id: str = "lc_user" + streaming: bool = False + request_timeout: int = 30 + temperature: float = 0.5 + top_k: int = 4 + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["spark_app_id"] = get_from_dict_or_env( + values, + "spark_app_id", + "IFLYTEK_SPARK_APP_ID", + ) + values["spark_api_key"] = get_from_dict_or_env( + values, + "spark_api_key", + "IFLYTEK_SPARK_API_KEY", + ) + values["spark_api_secret"] = get_from_dict_or_env( + values, + "spark_api_secret", + "IFLYTEK_SPARK_API_SECRET", + ) + values["spark_app_url"] = get_from_dict_or_env( + values, + "spark_app_url", + "IFLYTEK_SPARK_APP_URL", + "wss://spark-api.xf-yun.com/v3.1/chat", + ) + values["spark_llm_domain"] = get_from_dict_or_env( + values, + "spark_llm_domain", + "IFLYTEK_SPARK_LLM_DOMAIN", + "generalv3", + ) + # put extra params into model_kwargs + values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature + values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k + + values["client"] = _SparkLLMClient( + app_id=values["spark_app_id"], + api_key=values["spark_api_key"], + api_secret=values["spark_api_secret"], + api_url=values["spark_api_url"], + spark_domain=values["spark_llm_domain"], + model_kwargs=values["model_kwargs"], + ) + return values + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + default_chunk_class = AIMessageChunk + + self.client.arun( + [_convert_message_to_dict(m) for m in messages], + self.spark_user_id, + self.model_kwargs, + self.streaming, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" not in content: + continue + delta = content["data"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + yield ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(str(chunk.content)) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + self.client.arun( + [_convert_message_to_dict(m) for m in messages], + self.spark_user_id, + self.model_kwargs, + False, + ) + completion = {} + llm_output = {} + for content in self.client.subscribe(timeout=self.request_timeout): + if "usage" in content: + llm_output["token_usage"] = content["usage"] + if "data" not in content: + continue + completion = content["data"] + message = _convert_dict_to_message(completion) + generations = [ChatGeneration(message=message)] + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + return "spark-llm-chat" + + +class _SparkLLMClient: + """ + Use websocket-client to call the SparkLLM interface provided by Xfyun, + which is the iFlyTek's open platform for AI capabilities + """ + + def __init__( + self, + app_id: str, + api_key: str, + api_secret: str, + api_url: Optional[str] = None, + spark_domain: Optional[str] = None, + model_kwargs: Optional[dict] = None, + ): + try: + import websocket + + self.websocket_client = websocket + except ImportError: + raise ImportError( + "Could not import websocket client python package. " + "Please install it with `pip install websocket-client`." + ) + + self.api_url = ( + "wss://spark-api.xf-yun.com/v3.1/chat" if not api_url else api_url + ) + self.app_id = app_id + self.ws_url = _SparkLLMClient._create_url( + self.api_url, + api_key, + api_secret, + ) + self.model_kwargs = model_kwargs + self.spark_domain = spark_domain or "generalv3" + self.queue: Queue[Dict] = Queue() + self.blocking_message = {"content": "", "role": "assistant"} + + @staticmethod + def _create_url(api_url: str, api_key: str, api_secret: str) -> str: + """ + Generate a request url with an api key and an api secret. + """ + # generate timestamp by RFC1123 + date = format_date_time(mktime(datetime.now().timetuple())) + + # urlparse + parsed_url = urlparse(api_url) + host = parsed_url.netloc + path = parsed_url.path + + signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" + + # encrypt using hmac-sha256 + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") + + authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", \ + headers="host date request-line", signature="{signature_sha_base64}"' + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( + encoding="utf-8" + ) + + # generate url + params_dict = {"authorization": authorization, "date": date, "host": host} + encoded_params = urlencode(params_dict) + url = urlunparse( + ( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + parsed_url.params, + encoded_params, + parsed_url.fragment, + ) + ) + return url + + def run( + self, + messages: List[Dict], + user_id: str, + model_kwargs: Optional[dict] = None, + streaming: bool = False, + ) -> None: + self.websocket_client.enableTrace(False) + ws = self.websocket_client.WebSocketApp( + self.ws_url, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close, + on_open=self.on_open, + ) + ws.messages = messages + ws.user_id = user_id + ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs + ws.streaming = streaming + ws.run_forever() + + def arun( + self, + messages: List[Dict], + user_id: str, + model_kwargs: Optional[dict] = None, + streaming: bool = False, + ) -> threading.Thread: + ws_thread = threading.Thread( + target=self.run, + args=( + messages, + user_id, + model_kwargs, + streaming, + ), + ) + ws_thread.start() + return ws_thread + + def on_error(self, ws: Any, error: Optional[Any]) -> None: + self.queue.put({"error": error}) + ws.close() + + def on_close(self, ws: Any, close_status_code: int, close_reason: str) -> None: + logger.debug( + { + "log": { + "close_status_code": close_status_code, + "close_reason": close_reason, + } + } + ) + self.queue.put({"done": True}) + + def on_open(self, ws: Any) -> None: + self.blocking_message = {"content": "", "role": "assistant"} + data = json.dumps( + self.gen_params( + messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs + ) + ) + ws.send(data) + + def on_message(self, ws: Any, message: str) -> None: + data = json.loads(message) + code = data["header"]["code"] + if code != 0: + self.queue.put( + {"error": f"Code: {code}, Error: {data['header']['message']}"} + ) + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + if ws.streaming: + self.queue.put({"data": choices["text"][0]}) + else: + self.blocking_message["content"] += content + if status == 2: + if not ws.streaming: + self.queue.put({"data": self.blocking_message}) + usage_data = ( + data.get("payload", {}).get("usage", {}).get("text", {}) + if data + else {} + ) + self.queue.put({"usage": usage_data}) + ws.close() + + def gen_params( + self, messages: list, user_id: str, model_kwargs: Optional[dict] = None + ) -> dict: + data: Dict = { + "header": {"app_id": self.app_id, "uid": user_id}, + "parameter": {"chat": {"domain": self.spark_domain}}, + "payload": {"message": {"text": messages}}, + } + + if model_kwargs: + data["parameter"]["chat"].update(model_kwargs) + logger.debug(f"Spark Request Parameters: {data}") + return data + + def subscribe(self, timeout: Optional[int] = 30) -> Generator[Dict, None, None]: + while True: + try: + content = self.queue.get(timeout=timeout) + except queue.Empty as _: + raise TimeoutError( + f"SparkLLMClient wait LLM api response timeout {timeout} seconds" + ) + if "error" in content: + raise ConnectionError(content["error"]) + if "usage" in content: + yield content + continue + if "done" in content: + break + if "data" not in content: + break + yield content diff --git a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py new file mode 100644 index 0000000000..a219b88574 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py @@ -0,0 +1,36 @@ +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage + +from langchain_community.chat_models.sparkllm import ChatSparkLLM + + +def test_chat_spark_llm() -> None: + chat = ChatSparkLLM() + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_spark_llm_streaming() -> None: + chat = ChatSparkLLM(streaming=True) + for chunk in chat.stream("Hello!"): + assert isinstance(chunk, AIMessageChunk) + assert isinstance(chunk.content, str) + + +def test_chat_spark_llm_with_domain() -> None: + chat = ChatSparkLLM(spark_llm_domain="generalv3") + message = HumanMessage(content="Hello") + response = chat([message]) + print(response) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_spark_llm_with_temperature() -> None: + chat = ChatSparkLLM(temperature=0.9, top_k=2) + message = HumanMessage(content="Hello") + response = chat([message]) + print(response) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 187459afd5..29d3529f75 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -33,6 +33,7 @@ EXPECTED_ALL = [ "ChatBaichuan", "ChatHunyuan", "GigaChat", + "ChatSparkLLM", "VolcEngineMaasChat", "LlamaEdgeChatService", "GPTRouter",