feature: chat model for snowflake cortex

pull/21490/head
Shubham Pandey 3 weeks ago
parent aafaf3e193
commit 64563a0d87

@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChatSnowflakeCortex\n",
"\n",
"[Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions) gives you instant access to industry-leading large language models (LLMs) trained by researchers at companies like Mistral, Reka, Meta, and Google, including [Snowflake Arctic](https://www.snowflake.com/en/data-cloud/arctic/), an open enterprise-grade model developed by Snowflake.\n",
"\n",
"This example goes over how to use LangChain to interact with ChatSnowflakeCortex."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation and setup\n",
"\n",
"We start by installing the `snowflake-snowpark-python` library, using the command below. Then we configure the credentials for connecting to Snowflake, as environment variables or pass them directly."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet snowflake-snowpark-python"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"# First step is to set up the environment variables, to connect to Snowflake,\n",
"# you can also pass these snowflake credentials while instantiating the model\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_ACCOUNT\") is None:\n",
" os.environ[\"SNOWFLAKE_ACCOUNT\"] = getpass.getpass(\"Account: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_USERNAME\") is None:\n",
" os.environ[\"SNOWFLAKE_USERNAME\"] = getpass.getpass(\"Username: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_PASSWORD\") is None:\n",
" os.environ[\"SNOWFLAKE_PASSWORD\"] = getpass.getpass(\"Password: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_DATABASE\") is None:\n",
" os.environ[\"SNOWFLAKE_DATABASE\"] = getpass.getpass(\"Database: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_SCHEMA\") is None:\n",
" os.environ[\"SNOWFLAKE_SCHEMA\"] = getpass.getpass(\"Schema: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_WAREHOUSE\") is None:\n",
" os.environ[\"SNOWFLAKE_WAREHOUSE\"] = getpass.getpass(\"Warehouse: \")\n",
"\n",
"if os.environ.get(\"SNOWFLAKE_ROLE\") is None:\n",
" os.environ[\"SNOWFLAKE_ROLE\"] = getpass.getpass(\"Role: \")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatSnowflakeCortex\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"\n",
"# By default, we'll be using the cortex provided model: `snowflake-arctic`, with function: `complete`\n",
"chat = ChatSnowflakeCortex()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell assumes that your Snowflake credentials are set in your environment variables. If you would rather manually specify them, use the following code:\n",
"\n",
"```python\n",
"chat = ChatSnowflakeCortex(\n",
" # change default cortex model and function\n",
" model=\"snowflake-arctic\",\n",
" cortex_function=\"complete\",\n",
"\n",
" # change default generation parameters\n",
" temperature=0,\n",
" max_tokens=10,\n",
" top_p=0.95,\n",
"\n",
" # specify snowflake credentials\n",
" account=\"YOUR_SNOWFLAKE_ACCOUNT\",\n",
" username=\"YOUR_SNOWFLAKE_USERNAME\",\n",
" password=\"YOUR_SNOWFLAKE_PASSWORD\",\n",
" database=\"YOUR_SNOWFLAKE_DATABASE\",\n",
" schema=\"YOUR_SNOWFLAKE_SCHEMA\",\n",
" role=\"YOUR_SNOWFLAKE_ROLE\",\n",
" warehouse=\"YOUR_SNOWFLAKE_WAREHOUSE\"\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calling the model\n",
"We can now call the model using the `invoke` or `generate` method.\n",
"\n",
"#### Generation"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Large language models are artificial intelligence systems designed to understand, generate, and manipulate human language. These models are typically based on deep learning techniques and are trained on vast amounts of text data to learn patterns and structures in language. They can perform a wide range of language-related tasks, such as language translation, text generation, sentiment analysis, and answering questions. Some well-known large language models include Google's BERT, OpenAI's GPT series, and Facebook's RoBERTa. These models have shown remarkable performance in various natural language processing tasks, and their applications continue to expand as research in AI progresses.\", response_metadata={'completion_tokens': 131, 'prompt_tokens': 29, 'total_tokens': 160}, id='run-5435bd0a-83fd-4295-b237-66cbd1b5c0f3-0')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(\n",
" content=\"You are a friendly assistant.\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"What are large language models?\"\n",
" )\n",
"]\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming\n",
"`ChatSnowflakeCortex` doesn't support streaming as of now. Support for streaming will be coming in the later versions!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -138,6 +138,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.promptlayer_openai import (
PromptLayerChatOpenAI,
)
from langchain_community.chat_models.snowflake import (
ChatSnowflakeCortex,
)
from langchain_community.chat_models.solar import (
SolarChat,
)
@ -193,6 +196,7 @@ __all__ = [
"ChatPerplexity",
"ChatPremAI",
"ChatSparkLLM",
"ChatSnowflakeCortex",
"ChatTongyi",
"ChatVertexAI",
"ChatYandexGPT",
@ -243,6 +247,7 @@ _module_lookup = {
"ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai",
"ChatPerplexity": "langchain_community.chat_models.perplexity",
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
"ChatTongyi": "langchain_community.chat_models.tongyi",
"ChatVertexAI": "langchain_community.chat_models.vertexai",

@ -0,0 +1,232 @@
import json
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
SUPPORTED_ROLES: List[str] = [
"system",
"user",
"assistant",
]
class ChatSnowflakeCortexError(Exception):
"""Error with Snowpark client."""
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any] = {
"content": message.content,
}
# populate role and additional message data
if isinstance(message, ChatMessage) and message.role in SUPPORTED_ROLES:
message_dict["role"] = message.role
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, HumanMessage):
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict["role"] = "assistant"
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
class ChatSnowflakeCortex(BaseChatModel):
"""Snowflake Cortex based Chat model
To use you must have the ``snowflake-snowpark-python`` Python package installed and
either:
1. environment variables set with your snowflake credentials or
2. directly passed in as kwargs to the ChatSnowflakeCortex constructor.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatSnowflakeCortex
chat = ChatSnowflakeCortex()
"""
_sp_session: Any
"""Snowpark session object."""
model: str = "snowflake-arctic"
"""Snowflake cortex hosted LLM model name, defaulted to `snowflake-arctic`.
Refer to docs for more options."""
cortex_function: str = "complete"
"""Cortex function to use, defaulted to `complete`.
Refer to docs for more options."""
temperature: float = 0.7
"""Model temperature. Value should be >= 0 and <= 1.0"""
max_tokens: Optional[int] = None
"""The maximum number of output tokens in the response."""
top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""
snowflake_username: Optional[str] = Field(default=None, alias="username")
"""Automatically inferred from env var `SNOWFLAKE_USERNAME` if not provided."""
snowflake_password: Optional[SecretStr] = Field(default=None, alias="password")
"""Automatically inferred from env var `SNOWFLAKE_PASSWORD` if not provided."""
snowflake_account: Optional[str] = Field(default=None, alias="account")
"""Automatically inferred from env var `SNOWFLAKE_ACCOUNT` if not provided."""
snowflake_database: Optional[str] = Field(default=None, alias="database")
"""Automatically inferred from env var `SNOWFLAKE_DATABASE` if not provided."""
snowflake_schema: Optional[str] = Field(default=None, alias="schema")
"""Automatically inferred from env var `SNOWFLAKE_SCHEMA` if not provided."""
snowflake_warehouse: Optional[str] = Field(default=None, alias="warehouse")
"""Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided."""
snowflake_role: Optional[str] = Field(default=None, alias="role")
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"`snowflake-snowpark-python` package not found, please install it with "
"`pip install snowflake-snowpark-python`"
)
values["snowflake_username"] = get_from_dict_or_env(
values, "snowflake_username", "SNOWFLAKE_USERNAME"
)
values["snowflake_password"] = convert_to_secret_str(
get_from_dict_or_env(values, "snowflake_password", "SNOWFLAKE_PASSWORD")
)
values["snowflake_account"] = get_from_dict_or_env(
values, "snowflake_account", "SNOWFLAKE_ACCOUNT"
)
values["snowflake_database"] = get_from_dict_or_env(
values, "snowflake_database", "SNOWFLAKE_DATABASE"
)
values["snowflake_schema"] = get_from_dict_or_env(
values, "snowflake_schema", "SNOWFLAKE_SCHEMA"
)
values["snowflake_warehouse"] = get_from_dict_or_env(
values, "snowflake_warehouse", "SNOWFLAKE_WAREHOUSE"
)
values["snowflake_role"] = get_from_dict_or_env(
values, "snowflake_role", "SNOWFLAKE_ROLE"
)
connection_params = {
"account": values["snowflake_account"],
"user": values["snowflake_username"],
"password": values["snowflake_password"].get_secret_value(),
"database": values["snowflake_database"],
"schema": values["snowflake_schema"],
"warehouse": values["snowflake_warehouse"],
"role": values["snowflake_role"],
}
try:
values["_sp_session"] = Session.builder.configs(connection_params).create()
except Exception as e:
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
return values
def __del__(self) -> None:
if getattr(self, "_sp_session", None) is not None:
self._sp_session.close()
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return f"snowflake-cortex-{self.model}"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_str = str(message_dicts)
options = {"temperature": self.temperature}
if self.top_p is not None:
options["top_p"] = self.top_p
if self.max_tokens is not None:
options["max_tokens"] = self.max_tokens
options_str = str(options)
sql_stmt = f"""
select snowflake.cortex.{self.cortex_function}(
'{self.model}'
,{message_str},{options_str}) as llm_response;"""
try:
l_rows = self._sp_session.sql(sql_stmt).collect()
except Exception as e:
raise ChatSnowflakeCortexError(
f"Error while making request to Snowflake Cortex via Snowpark: {e}"
)
response = json.loads(l_rows[0]["LLM_RESPONSE"])
ai_message_content = response["choices"][0]["messages"]
content = _truncate_at_stop_tokens(ai_message_content, stop)
message = AIMessage(
content=content,
response_metadata=response["usage"],
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

@ -0,0 +1,59 @@
"""Test ChatSnowflakeCortex
Note: This test must be run with the following environment variables set:
SNOWFLAKE_ACCOUNT="YOUR_SNOWFLAKE_ACCOUNT",
SNOWFLAKE_USERNAME="YOUR_SNOWFLAKE_USERNAME",
SNOWFLAKE_PASSWORD="YOUR_SNOWFLAKE_PASSWORD",
SNOWFLAKE_DATABASE="YOUR_SNOWFLAKE_DATABASE",
SNOWFLAKE_SCHEMA="YOUR_SNOWFLAKE_SCHEMA",
SNOWFLAKE_WAREHOUSE="YOUR_SNOWFLAKE_WAREHOUSE"
SNOWFLAKE_ROLE="YOUR_SNOWFLAKE_ROLE",
"""
import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models import ChatSnowflakeCortex
@pytest.fixture
def chat() -> ChatSnowflakeCortex:
return ChatSnowflakeCortex()
def test_chat_snowflake_cortex(chat: ChatSnowflakeCortex) -> None:
"""Test ChatSnowflakeCortex."""
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_snowflake_cortex_system_message(chat: ChatSnowflakeCortex) -> None:
"""Test ChatSnowflakeCortex for system message"""
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_snowflake_cortex_model() -> None:
"""Test ChatSnowflakeCortex handles model_name."""
chat = ChatSnowflakeCortex(
model="foo",
)
assert chat.model == "foo"
def test_chat_snowflake_cortex_generate(chat: ChatSnowflakeCortex) -> None:
"""Test ChatSnowflakeCortex with generate."""
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

@ -50,6 +50,7 @@ EXPECTED_ALL = [
"QianfanChatEndpoint",
"VolcEngineMaasChat",
"ChatOctoAI",
"ChatSnowflakeCortex",
]

@ -0,0 +1,24 @@
"""Test ChatSnowflakeCortex."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models.snowflake import _convert_message_to_dict
def test_messages_to_prompt_dict_with_valid_messages() -> None:
messages = [
SystemMessage(content="System Prompt"),
HumanMessage(content="User message #1"),
AIMessage(content="AI message #1"),
HumanMessage(content="User message #2"),
AIMessage(content="AI message #2"),
]
result = [_convert_message_to_dict(m) for m in messages]
expected = [
{"role": "system", "content": "System Prompt"},
{"role": "user", "content": "User message #1"},
{"role": "assistant", "content": "AI message #1"},
{"role": "user", "content": "User message #2"},
{"role": "assistant", "content": "AI message #2"},
]
assert result == expected
Loading…
Cancel
Save