From 56ac94e014b9a4c691fdd388e8afe70682917e0f Mon Sep 17 00:00:00 2001 From: Shubham Pandey Date: Mon, 17 Jun 2024 22:17:05 +0530 Subject: [PATCH] community[minor]: add `ChatSnowflakeCortex` chat model (#21490) **Description:** This PR adds a chat model integration for [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions), which gives an instant access to industry-leading large language models (LLMs) trained by researchers at companies like Mistral, Reka, Meta, and Google, including [Snowflake Arctic](https://www.snowflake.com/en/data-cloud/arctic/), an open enterprise-grade model developed by Snowflake. **Dependencies:** Snowflake's [snowpark](https://pypi.org/project/snowflake-snowpark-python/) library is required for using this integration. **Twitter handle:** [@gethouseware](https://twitter.com/gethouseware) - [x] **Add tests and docs**: 1. integration tests: `libs/community/tests/integration_tests/chat_models/test_snowflake.py` 2. unit tests: `libs/community/tests/unit_tests/chat_models/test_snowflake.py` 3. example notebook: `docs/docs/integrations/chat/snowflake.ipynb` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --- docs/docs/integrations/chat/snowflake.ipynb | 180 ++++++++++++++ .../chat_models/__init__.py | 5 + .../chat_models/snowflake.py | 232 ++++++++++++++++++ .../chat_models/test_snowflake.py | 59 +++++ .../unit_tests/chat_models/test_imports.py | 1 + .../unit_tests/chat_models/test_snowflake.py | 24 ++ 6 files changed, 501 insertions(+) create mode 100644 docs/docs/integrations/chat/snowflake.ipynb create mode 100644 libs/community/langchain_community/chat_models/snowflake.py create mode 100644 libs/community/tests/integration_tests/chat_models/test_snowflake.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_snowflake.py diff --git a/docs/docs/integrations/chat/snowflake.ipynb b/docs/docs/integrations/chat/snowflake.ipynb new file mode 100644 index 0000000000..650648ffb7 --- /dev/null +++ b/docs/docs/integrations/chat/snowflake.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Snowflake Cortex\n", + "\n", + "[Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions) gives you instant access to industry-leading large language models (LLMs) trained by researchers at companies like Mistral, Reka, Meta, and Google, including [Snowflake Arctic](https://www.snowflake.com/en/data-cloud/arctic/), an open enterprise-grade model developed by Snowflake.\n", + "\n", + "This example goes over how to use LangChain to interact with Snowflake Cortex." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation and setup\n", + "\n", + "We start by installing the `snowflake-snowpark-python` library, using the command below. Then we configure the credentials for connecting to Snowflake, as environment variables or pass them directly." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --upgrade --quiet snowflake-snowpark-python" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# First step is to set up the environment variables, to connect to Snowflake,\n", + "# you can also pass these snowflake credentials while instantiating the model\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_ACCOUNT\") is None:\n", + " os.environ[\"SNOWFLAKE_ACCOUNT\"] = getpass.getpass(\"Account: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_USERNAME\") is None:\n", + " os.environ[\"SNOWFLAKE_USERNAME\"] = getpass.getpass(\"Username: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_PASSWORD\") is None:\n", + " os.environ[\"SNOWFLAKE_PASSWORD\"] = getpass.getpass(\"Password: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_DATABASE\") is None:\n", + " os.environ[\"SNOWFLAKE_DATABASE\"] = getpass.getpass(\"Database: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_SCHEMA\") is None:\n", + " os.environ[\"SNOWFLAKE_SCHEMA\"] = getpass.getpass(\"Schema: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_WAREHOUSE\") is None:\n", + " os.environ[\"SNOWFLAKE_WAREHOUSE\"] = getpass.getpass(\"Warehouse: \")\n", + "\n", + "if os.environ.get(\"SNOWFLAKE_ROLE\") is None:\n", + " os.environ[\"SNOWFLAKE_ROLE\"] = getpass.getpass(\"Role: \")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatSnowflakeCortex\n", + "from langchain_core.messages import HumanMessage, SystemMessage\n", + "\n", + "# By default, we'll be using the cortex provided model: `snowflake-arctic`, with function: `complete`\n", + "chat = ChatSnowflakeCortex()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above cell assumes that your Snowflake credentials are set in your environment variables. If you would rather manually specify them, use the following code:\n", + "\n", + "```python\n", + "chat = ChatSnowflakeCortex(\n", + " # change default cortex model and function\n", + " model=\"snowflake-arctic\",\n", + " cortex_function=\"complete\",\n", + "\n", + " # change default generation parameters\n", + " temperature=0,\n", + " max_tokens=10,\n", + " top_p=0.95,\n", + "\n", + " # specify snowflake credentials\n", + " account=\"YOUR_SNOWFLAKE_ACCOUNT\",\n", + " username=\"YOUR_SNOWFLAKE_USERNAME\",\n", + " password=\"YOUR_SNOWFLAKE_PASSWORD\",\n", + " database=\"YOUR_SNOWFLAKE_DATABASE\",\n", + " schema=\"YOUR_SNOWFLAKE_SCHEMA\",\n", + " role=\"YOUR_SNOWFLAKE_ROLE\",\n", + " warehouse=\"YOUR_SNOWFLAKE_WAREHOUSE\"\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calling the model\n", + "We can now call the model using the `invoke` or `generate` method.\n", + "\n", + "#### Generation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\" Large language models are artificial intelligence systems designed to understand, generate, and manipulate human language. These models are typically based on deep learning techniques and are trained on vast amounts of text data to learn patterns and structures in language. They can perform a wide range of language-related tasks, such as language translation, text generation, sentiment analysis, and answering questions. Some well-known large language models include Google's BERT, OpenAI's GPT series, and Facebook's RoBERTa. These models have shown remarkable performance in various natural language processing tasks, and their applications continue to expand as research in AI progresses.\", response_metadata={'completion_tokens': 131, 'prompt_tokens': 29, 'total_tokens': 160}, id='run-5435bd0a-83fd-4295-b237-66cbd1b5c0f3-0')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " SystemMessage(content=\"You are a friendly assistant.\"),\n", + " HumanMessage(content=\"What are large language models?\"),\n", + "]\n", + "chat.invoke(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming\n", + "`ChatSnowflakeCortex` doesn't support streaming as of now. Support for streaming will be coming in the later versions!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index c7632f5aad..7b942a26ca 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -140,6 +140,9 @@ if TYPE_CHECKING: from langchain_community.chat_models.promptlayer_openai import ( PromptLayerChatOpenAI, ) + from langchain_community.chat_models.snowflake import ( + ChatSnowflakeCortex, + ) from langchain_community.chat_models.solar import ( SolarChat, ) @@ -196,6 +199,7 @@ __all__ = [ "ChatPerplexity", "ChatPremAI", "ChatSparkLLM", + "ChatSnowflakeCortex", "ChatTongyi", "ChatVertexAI", "ChatYandexGPT", @@ -247,6 +251,7 @@ _module_lookup = { "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", "ChatPerplexity": "langchain_community.chat_models.perplexity", + "ChatSnowflakeCortex": "langchain_community.chat_models.snowflake", "ChatSparkLLM": "langchain_community.chat_models.sparkllm", "ChatTongyi": "langchain_community.chat_models.tongyi", "ChatVertexAI": "langchain_community.chat_models.vertexai", diff --git a/libs/community/langchain_community/chat_models/snowflake.py b/libs/community/langchain_community/chat_models/snowflake.py new file mode 100644 index 0000000000..c25d2254f9 --- /dev/null +++ b/libs/community/langchain_community/chat_models/snowflake.py @@ -0,0 +1,232 @@ +import json +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + convert_to_secret_str, + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs + +SUPPORTED_ROLES: List[str] = [ + "system", + "user", + "assistant", +] + + +class ChatSnowflakeCortexError(Exception): + """Error with Snowpark client.""" + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + """Convert a LangChain message to a dictionary. + + Args: + message: The LangChain message. + + Returns: + The dictionary. + """ + message_dict: Dict[str, Any] = { + "content": message.content, + } + + # populate role and additional message data + if isinstance(message, ChatMessage) and message.role in SUPPORTED_ROLES: + message_dict["role"] = message.role + elif isinstance(message, SystemMessage): + message_dict["role"] = "system" + elif isinstance(message, HumanMessage): + message_dict["role"] = "user" + elif isinstance(message, AIMessage): + message_dict["role"] = "assistant" + else: + raise TypeError(f"Got unknown type {message}") + return message_dict + + +def _truncate_at_stop_tokens( + text: str, + stop: Optional[List[str]], +) -> str: + """Truncates text at the earliest stop token found.""" + if stop is None: + return text + + for stop_token in stop: + stop_token_idx = text.find(stop_token) + if stop_token_idx != -1: + text = text[:stop_token_idx] + return text + + +class ChatSnowflakeCortex(BaseChatModel): + """Snowflake Cortex based Chat model + + To use you must have the ``snowflake-snowpark-python`` Python package installed and + either: + + 1. environment variables set with your snowflake credentials or + 2. directly passed in as kwargs to the ChatSnowflakeCortex constructor. + + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatSnowflakeCortex + chat = ChatSnowflakeCortex() + """ + + _sp_session: Any = None + """Snowpark session object.""" + + model: str = "snowflake-arctic" + """Snowflake cortex hosted LLM model name, defaulted to `snowflake-arctic`. + Refer to docs for more options.""" + + cortex_function: str = "complete" + """Cortex function to use, defaulted to `complete`. + Refer to docs for more options.""" + + temperature: float = 0.7 + """Model temperature. Value should be >= 0 and <= 1.0""" + + max_tokens: Optional[int] = None + """The maximum number of output tokens in the response.""" + + top_p: Optional[float] = None + """top_p adjusts the number of choices for each predicted tokens based on + cumulative probabilities. Value should be ranging between 0.0 and 1.0. + """ + + snowflake_username: Optional[str] = Field(default=None, alias="username") + """Automatically inferred from env var `SNOWFLAKE_USERNAME` if not provided.""" + snowflake_password: Optional[SecretStr] = Field(default=None, alias="password") + """Automatically inferred from env var `SNOWFLAKE_PASSWORD` if not provided.""" + snowflake_account: Optional[str] = Field(default=None, alias="account") + """Automatically inferred from env var `SNOWFLAKE_ACCOUNT` if not provided.""" + snowflake_database: Optional[str] = Field(default=None, alias="database") + """Automatically inferred from env var `SNOWFLAKE_DATABASE` if not provided.""" + snowflake_schema: Optional[str] = Field(default=None, alias="schema") + """Automatically inferred from env var `SNOWFLAKE_SCHEMA` if not provided.""" + snowflake_warehouse: Optional[str] = Field(default=None, alias="warehouse") + """Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided.""" + snowflake_role: Optional[str] = Field(default=None, alias="role") + """Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided.""" + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + try: + from snowflake.snowpark import Session + except ImportError: + raise ImportError( + "`snowflake-snowpark-python` package not found, please install it with " + "`pip install snowflake-snowpark-python`" + ) + + values["snowflake_username"] = get_from_dict_or_env( + values, "snowflake_username", "SNOWFLAKE_USERNAME" + ) + values["snowflake_password"] = convert_to_secret_str( + get_from_dict_or_env(values, "snowflake_password", "SNOWFLAKE_PASSWORD") + ) + values["snowflake_account"] = get_from_dict_or_env( + values, "snowflake_account", "SNOWFLAKE_ACCOUNT" + ) + values["snowflake_database"] = get_from_dict_or_env( + values, "snowflake_database", "SNOWFLAKE_DATABASE" + ) + values["snowflake_schema"] = get_from_dict_or_env( + values, "snowflake_schema", "SNOWFLAKE_SCHEMA" + ) + values["snowflake_warehouse"] = get_from_dict_or_env( + values, "snowflake_warehouse", "SNOWFLAKE_WAREHOUSE" + ) + values["snowflake_role"] = get_from_dict_or_env( + values, "snowflake_role", "SNOWFLAKE_ROLE" + ) + + connection_params = { + "account": values["snowflake_account"], + "user": values["snowflake_username"], + "password": values["snowflake_password"].get_secret_value(), + "database": values["snowflake_database"], + "schema": values["snowflake_schema"], + "warehouse": values["snowflake_warehouse"], + "role": values["snowflake_role"], + } + + try: + values["_sp_session"] = Session.builder.configs(connection_params).create() + except Exception as e: + raise ChatSnowflakeCortexError(f"Failed to create session: {e}") + + return values + + def __del__(self) -> None: + if getattr(self, "_sp_session", None) is not None: + self._sp_session.close() + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model.""" + return f"snowflake-cortex-{self.model}" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts = [_convert_message_to_dict(m) for m in messages] + message_str = str(message_dicts) + options = {"temperature": self.temperature} + if self.top_p is not None: + options["top_p"] = self.top_p + if self.max_tokens is not None: + options["max_tokens"] = self.max_tokens + options_str = str(options) + sql_stmt = f""" + select snowflake.cortex.{self.cortex_function}( + '{self.model}' + ,{message_str},{options_str}) as llm_response;""" + + try: + l_rows = self._sp_session.sql(sql_stmt).collect() + except Exception as e: + raise ChatSnowflakeCortexError( + f"Error while making request to Snowflake Cortex via Snowpark: {e}" + ) + + response = json.loads(l_rows[0]["LLM_RESPONSE"]) + ai_message_content = response["choices"][0]["messages"] + + content = _truncate_at_stop_tokens(ai_message_content, stop) + message = AIMessage( + content=content, + response_metadata=response["usage"], + ) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) diff --git a/libs/community/tests/integration_tests/chat_models/test_snowflake.py b/libs/community/tests/integration_tests/chat_models/test_snowflake.py new file mode 100644 index 0000000000..f3ba87fb35 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_snowflake.py @@ -0,0 +1,59 @@ +"""Test ChatSnowflakeCortex +Note: This test must be run with the following environment variables set: + SNOWFLAKE_ACCOUNT="YOUR_SNOWFLAKE_ACCOUNT", + SNOWFLAKE_USERNAME="YOUR_SNOWFLAKE_USERNAME", + SNOWFLAKE_PASSWORD="YOUR_SNOWFLAKE_PASSWORD", + SNOWFLAKE_DATABASE="YOUR_SNOWFLAKE_DATABASE", + SNOWFLAKE_SCHEMA="YOUR_SNOWFLAKE_SCHEMA", + SNOWFLAKE_WAREHOUSE="YOUR_SNOWFLAKE_WAREHOUSE" + SNOWFLAKE_ROLE="YOUR_SNOWFLAKE_ROLE", +""" + +import pytest +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.chat_models import ChatSnowflakeCortex + + +@pytest.fixture +def chat() -> ChatSnowflakeCortex: + return ChatSnowflakeCortex() + + +def test_chat_snowflake_cortex(chat: ChatSnowflakeCortex) -> None: + """Test ChatSnowflakeCortex.""" + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_snowflake_cortex_system_message(chat: ChatSnowflakeCortex) -> None: + """Test ChatSnowflakeCortex for system message""" + system_message = SystemMessage(content="You are to chat with the user.") + human_message = HumanMessage(content="Hello") + response = chat([system_message, human_message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_snowflake_cortex_model() -> None: + """Test ChatSnowflakeCortex handles model_name.""" + chat = ChatSnowflakeCortex( + model="foo", + ) + assert chat.model == "foo" + + +def test_chat_snowflake_cortex_generate(chat: ChatSnowflakeCortex) -> None: + """Test ChatSnowflakeCortex with generate.""" + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index bfa7d8c51e..a0e573068c 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -51,6 +51,7 @@ EXPECTED_ALL = [ "QianfanChatEndpoint", "VolcEngineMaasChat", "ChatOctoAI", + "ChatSnowflakeCortex", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_snowflake.py b/libs/community/tests/unit_tests/chat_models/test_snowflake.py new file mode 100644 index 0000000000..9e80179a89 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_snowflake.py @@ -0,0 +1,24 @@ +"""Test ChatSnowflakeCortex.""" + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models.snowflake import _convert_message_to_dict + + +def test_messages_to_prompt_dict_with_valid_messages() -> None: + messages = [ + SystemMessage(content="System Prompt"), + HumanMessage(content="User message #1"), + AIMessage(content="AI message #1"), + HumanMessage(content="User message #2"), + AIMessage(content="AI message #2"), + ] + result = [_convert_message_to_dict(m) for m in messages] + expected = [ + {"role": "system", "content": "System Prompt"}, + {"role": "user", "content": "User message #1"}, + {"role": "assistant", "content": "AI message #1"}, + {"role": "user", "content": "User message #2"}, + {"role": "assistant", "content": "AI message #2"}, + ] + assert result == expected