mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
7d962278f6
Related to #17048
602 lines
20 KiB
Python
602 lines
20 KiB
Python
##
|
|
# Copyright (c) 2024, Chad Juliano, Kinetica DB Inc.
|
|
##
|
|
"""Kinetica SQL generation LLM API."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from importlib.metadata import version
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
|
|
|
if TYPE_CHECKING:
|
|
import gpudb
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.output_parsers.transform import BaseOutputParser
|
|
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
# Kinetica pydantic API datatypes
|
|
|
|
|
|
class _KdtSuggestContext(BaseModel):
|
|
"""pydantic API request type"""
|
|
|
|
table: Optional[str] = Field(default=None, title="Name of table")
|
|
description: Optional[str] = Field(default=None, title="Table description")
|
|
columns: List[str] = Field(default=None, title="Table columns list")
|
|
rules: Optional[List[str]] = Field(
|
|
default=None, title="Rules that apply to the table."
|
|
)
|
|
samples: Optional[Dict] = Field(
|
|
default=None, title="Samples that apply to the entire context."
|
|
)
|
|
|
|
def to_system_str(self) -> str:
|
|
lines = []
|
|
lines.append(f"CREATE TABLE {self.table} AS")
|
|
lines.append("(")
|
|
|
|
if not self.columns or len(self.columns) == 0:
|
|
ValueError("columns list can't be null.")
|
|
|
|
columns = []
|
|
for column in self.columns:
|
|
column = column.replace('"', "").strip()
|
|
columns.append(f" {column}")
|
|
lines.append(",\n".join(columns))
|
|
lines.append(");")
|
|
|
|
if self.description:
|
|
lines.append(f"COMMENT ON TABLE {self.table} IS '{self.description}';")
|
|
|
|
if self.rules and len(self.rules) > 0:
|
|
lines.append(
|
|
f"-- When querying table {self.table} the following rules apply:"
|
|
)
|
|
for rule in self.rules:
|
|
lines.append(f"-- * {rule}")
|
|
|
|
result = "\n".join(lines)
|
|
return result
|
|
|
|
|
|
class _KdtSuggestPayload(BaseModel):
|
|
"""pydantic API request type"""
|
|
|
|
question: Optional[str]
|
|
context: List[_KdtSuggestContext]
|
|
|
|
def get_system_str(self) -> str:
|
|
lines = []
|
|
for table_context in self.context:
|
|
if table_context.table is None:
|
|
continue
|
|
context_str = table_context.to_system_str()
|
|
lines.append(context_str)
|
|
return "\n\n".join(lines)
|
|
|
|
def get_messages(self) -> List[Dict]:
|
|
messages = []
|
|
for context in self.context:
|
|
if context.samples is None:
|
|
continue
|
|
for question, answer in context.samples.items():
|
|
# unescape double quotes
|
|
answer = answer.replace("''", "'")
|
|
|
|
messages.append(dict(role="user", content=question or ""))
|
|
messages.append(dict(role="assistant", content=answer))
|
|
return messages
|
|
|
|
def to_completion(self) -> Dict:
|
|
messages = []
|
|
messages.append(dict(role="system", content=self.get_system_str()))
|
|
messages.extend(self.get_messages())
|
|
messages.append(dict(role="user", content=self.question or ""))
|
|
response = dict(messages=messages)
|
|
return response
|
|
|
|
|
|
class _KdtoSuggestRequest(BaseModel):
|
|
"""pydantic API request type"""
|
|
|
|
payload: _KdtSuggestPayload
|
|
|
|
|
|
class _KdtMessage(BaseModel):
|
|
"""pydantic API response type"""
|
|
|
|
role: str = Field(default=None, title="One of [user|assistant|system]")
|
|
content: str
|
|
|
|
|
|
class _KdtChoice(BaseModel):
|
|
"""pydantic API response type"""
|
|
|
|
index: int
|
|
message: _KdtMessage = Field(default=None, title="The generated SQL")
|
|
finish_reason: str
|
|
|
|
|
|
class _KdtUsage(BaseModel):
|
|
"""pydantic API response type"""
|
|
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class _KdtSqlResponse(BaseModel):
|
|
"""pydantic API response type"""
|
|
|
|
id: str
|
|
object: str
|
|
created: int
|
|
model: str
|
|
choices: List[_KdtChoice]
|
|
usage: _KdtUsage
|
|
prompt: str = Field(default=None, title="The input question")
|
|
|
|
|
|
class _KdtCompletionResponse(BaseModel):
|
|
"""pydantic API response type"""
|
|
|
|
status: str
|
|
data: _KdtSqlResponse
|
|
|
|
|
|
class _KineticaLlmFileContextParser:
|
|
"""Parser for Kinetica LLM context datafiles."""
|
|
|
|
# parse line into a dict containing role and content
|
|
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
|
|
|
@classmethod
|
|
def _removesuffix(cls, text: str, suffix: str) -> str:
|
|
if suffix and text.endswith(suffix):
|
|
return text[: -len(suffix)]
|
|
return text
|
|
|
|
@classmethod
|
|
def parse_dialogue_file(cls, input_file: os.PathLike) -> Dict:
|
|
path = Path(input_file)
|
|
# schema = path.name.removesuffix(".txt") python 3.9
|
|
schema = cls._removesuffix(path.name, ".txt")
|
|
|
|
lines = open(input_file).read()
|
|
return cls.parse_dialogue(lines, schema)
|
|
|
|
@classmethod
|
|
def parse_dialogue(cls, text: str, schema: str) -> Dict:
|
|
messages = []
|
|
system = None
|
|
|
|
lines = text.split("<|end|>")
|
|
user_message = None
|
|
|
|
for idx, line in enumerate(lines):
|
|
line = line.strip()
|
|
|
|
if len(line) == 0:
|
|
continue
|
|
|
|
match = cls.PARSER.match(line)
|
|
if match is None:
|
|
raise ValueError(f"Could not find starting token in: {line}")
|
|
|
|
groupdict = match.groupdict()
|
|
role = groupdict["role"]
|
|
|
|
if role == "system":
|
|
if system is not None:
|
|
raise ValueError(f"Only one system token allowed in: {line}")
|
|
system = groupdict["content"]
|
|
elif role == "user":
|
|
if user_message is not None:
|
|
raise ValueError(
|
|
f"Found user token without assistant token: {line}"
|
|
)
|
|
user_message = groupdict
|
|
elif role == "assistant":
|
|
if user_message is None:
|
|
raise Exception(f"Found assistant token without user token: {line}")
|
|
messages.append(user_message)
|
|
messages.append(groupdict)
|
|
user_message = None
|
|
else:
|
|
raise ValueError(f"Unknown token: {role}")
|
|
|
|
return {"schema": schema, "system": system, "messages": messages}
|
|
|
|
|
|
class KineticaUtil:
|
|
"""Kinetica utility functions."""
|
|
|
|
@classmethod
|
|
def create_kdbc(
|
|
cls,
|
|
url: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
passwd: Optional[str] = None,
|
|
) -> "gpudb.GPUdb":
|
|
"""Create a connectica connection object and verify connectivity.
|
|
|
|
If None is passed for one or more of the parameters then an attempt will be made
|
|
to retrieve the value from the related environment variable.
|
|
|
|
Args:
|
|
url: The Kinetica URL or ``KINETICA_URL`` if None.
|
|
user: The Kinetica user or ``KINETICA_USER`` if None.
|
|
passwd: The Kinetica password or ``KINETICA_PASSWD`` if None.
|
|
|
|
Returns:
|
|
The Kinetica connection object.
|
|
"""
|
|
|
|
try:
|
|
import gpudb
|
|
except ModuleNotFoundError:
|
|
raise ImportError(
|
|
"Could not import Kinetica python package. "
|
|
"Please install it with `pip install gpudb`."
|
|
)
|
|
|
|
url = cls._get_env("KINETICA_URL", url)
|
|
user = cls._get_env("KINETICA_USER", user)
|
|
passwd = cls._get_env("KINETICA_PASSWD", passwd)
|
|
|
|
options = gpudb.GPUdb.Options()
|
|
options.username = user
|
|
options.password = passwd
|
|
options.skip_ssl_cert_verification = True
|
|
options.disable_failover = True
|
|
options.logging_level = "INFO"
|
|
kdbc = gpudb.GPUdb(host=url, options=options)
|
|
|
|
LOG.info(
|
|
"Connected to Kinetica: {}. (api={}, server={})".format(
|
|
kdbc.get_url(), version("gpudb"), kdbc.server_version
|
|
)
|
|
)
|
|
|
|
return kdbc
|
|
|
|
@classmethod
|
|
def _get_env(cls, name: str, default: Optional[str]) -> str:
|
|
"""Get an environment variable or use a default."""
|
|
if default is not None:
|
|
return default
|
|
|
|
result = os.getenv(name)
|
|
if result is not None:
|
|
return result
|
|
|
|
raise ValueError(
|
|
f"Parameter was not passed and not found in the environment: {name}"
|
|
)
|
|
|
|
|
|
class ChatKinetica(BaseChatModel):
|
|
"""Kinetica LLM Chat Model API.
|
|
|
|
Prerequisites for using this API:
|
|
|
|
* The ``gpudb`` and ``typeguard`` packages installed.
|
|
* A Kinetica DB instance.
|
|
* Kinetica host specified in ``KINETICA_URL``
|
|
* Kinetica login specified ``KINETICA_USER``, and ``KINETICA_PASSWD``.
|
|
* An LLM context that specifies the tables and samples to use for inferencing.
|
|
|
|
This API is intended to interact with the Kinetica SqlAssist LLM that supports
|
|
generation of SQL from natural language.
|
|
|
|
In the Kinetica LLM workflow you create an LLM context in the database that provides
|
|
information needed for infefencing that includes tables, annotations, rules, and
|
|
samples. Invoking ``load_messages_from_context()`` will retrieve the contxt
|
|
information from the database so that it can be used to create a chat prompt.
|
|
|
|
The chat prompt consists of a ``SystemMessage`` and pairs of
|
|
``HumanMessage``/``AIMessage`` that contain the samples which are question/SQL
|
|
pairs. You can append pairs samples to this list but it is not intended to
|
|
facilitate a typical natural language conversation.
|
|
|
|
When you create a chain from the chat prompt and execute it, the Kinetica LLM will
|
|
generate SQL from the input. Optionally you can use ``KineticaSqlOutputParser`` to
|
|
execute the SQL and return the result as a dataframe.
|
|
|
|
The following example creates an LLM using the environment variables for the
|
|
Kinetica connection. This will fail if the API is unable to connect to the database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models.kinetica import KineticaChatLLM
|
|
kinetica_llm = KineticaChatLLM()
|
|
|
|
If you prefer to pass connection information directly then you can create a
|
|
connection using ``KineticaUtil.create_kdbc()``.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models.kinetica import (
|
|
KineticaChatLLM, KineticaUtil)
|
|
kdbc = KineticaUtil._create_kdbc(url=url, user=user, passwd=passwd)
|
|
kinetica_llm = KineticaChatLLM(kdbc=kdbc)
|
|
"""
|
|
|
|
kdbc: Any = Field(exclude=True)
|
|
""" Kinetica DB connection. """
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Pydantic object validator."""
|
|
|
|
kdbc = values.get("kdbc", None)
|
|
if kdbc is None:
|
|
kdbc = KineticaUtil.create_kdbc()
|
|
values["kdbc"] = kdbc
|
|
return values
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "kinetica-sqlassist"
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
return dict(
|
|
kinetica_version=str(self.kdbc.server_version), api_version=version("gpudb")
|
|
)
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if stop is not None:
|
|
raise ValueError("stop kwargs are not permitted.")
|
|
|
|
dict_messages = [self._convert_message_to_dict(m) for m in messages]
|
|
sql_response = self._submit_completion(dict_messages)
|
|
|
|
response_message = sql_response.choices[0].message
|
|
# generated_dict = response_message.model_dump() # pydantic v2
|
|
generated_dict = response_message.dict()
|
|
|
|
generated_message = self._convert_message_from_dict(generated_dict)
|
|
|
|
llm_output = dict(
|
|
input_tokens=sql_response.usage.prompt_tokens,
|
|
output_tokens=sql_response.usage.completion_tokens,
|
|
model_name=sql_response.model,
|
|
)
|
|
return ChatResult(
|
|
generations=[ChatGeneration(message=generated_message)],
|
|
llm_output=llm_output,
|
|
)
|
|
|
|
def load_messages_from_context(self, context_name: str) -> List:
|
|
"""Load a lanchain prompt from a Kinetica context.
|
|
|
|
A Kinetica Context is an object created with the Kinetica Workbench UI or with
|
|
SQL syntax. This function will convert the data in the context to a list of
|
|
messages that can be used as a prompt. The messages will contain a
|
|
``SystemMessage`` followed by pairs of ``HumanMessage``/``AIMessage`` that
|
|
contain the samples.
|
|
|
|
Args:
|
|
context_name: The name of an LLM context in the database.
|
|
|
|
Returns:
|
|
A list of messages containing the information from the context.
|
|
"""
|
|
|
|
# query kinetica for the prompt
|
|
sql = f"GENERATE PROMPT WITH OPTIONS (CONTEXT_NAMES = '{context_name}')"
|
|
result = self._execute_sql(sql)
|
|
prompt = result["Prompt"]
|
|
prompt_json = json.loads(prompt)
|
|
|
|
# convert the prompt to messages
|
|
# request = SuggestRequest.model_validate(prompt_json) # pydantic v2
|
|
request = _KdtoSuggestRequest.parse_obj(prompt_json)
|
|
payload = request.payload
|
|
|
|
dict_messages = []
|
|
dict_messages.append(dict(role="system", content=payload.get_system_str()))
|
|
dict_messages.extend(payload.get_messages())
|
|
messages = [self._convert_message_from_dict(m) for m in dict_messages]
|
|
return messages
|
|
|
|
def _submit_completion(self, messages: List[Dict]) -> _KdtSqlResponse:
|
|
"""Submit a /chat/completions request to Kinetica."""
|
|
|
|
request = dict(messages=messages)
|
|
request_json = json.dumps(request)
|
|
response_raw = self.kdbc._GPUdb__submit_request_json(
|
|
"/chat/completions", request_json
|
|
)
|
|
response_json = json.loads(response_raw)
|
|
|
|
status = response_json["status"]
|
|
if status != "OK":
|
|
message = response_json["message"]
|
|
match_resp = re.compile(r"response:({.*})")
|
|
result = match_resp.search(message)
|
|
if result is not None:
|
|
response = result.group(1)
|
|
response_json = json.loads(response)
|
|
message = response_json["message"]
|
|
raise ValueError(message)
|
|
|
|
data = response_json["data"]
|
|
# response = CompletionResponse.model_validate(data) # pydantic v2
|
|
response = _KdtCompletionResponse.parse_obj(data)
|
|
if response.status != "OK":
|
|
raise ValueError("SQL Generation failed")
|
|
return response.data
|
|
|
|
def _execute_sql(self, sql: str) -> Dict:
|
|
"""Execute an SQL query and return the result."""
|
|
|
|
response = self.kdbc.execute_sql_and_decode(
|
|
sql, limit=1, get_column_major=False
|
|
)
|
|
|
|
status_info = response["status_info"]
|
|
if status_info["status"] != "OK":
|
|
message = status_info["message"]
|
|
raise ValueError(message)
|
|
|
|
records = response["records"]
|
|
if len(records) != 1:
|
|
raise ValueError("No records returned.")
|
|
|
|
record = records[0]
|
|
response_dict = {}
|
|
for col, val in record.items():
|
|
response_dict[col] = val
|
|
return response_dict
|
|
|
|
@classmethod
|
|
def load_messages_from_datafile(cls, sa_datafile: Path) -> List[BaseMessage]:
|
|
"""Load a lanchain prompt from a Kinetica context datafile."""
|
|
datafile_dict = _KineticaLlmFileContextParser.parse_dialogue_file(sa_datafile)
|
|
messages = cls._convert_dict_to_messages(datafile_dict)
|
|
return messages
|
|
|
|
@classmethod
|
|
def _convert_message_to_dict(cls, message: BaseMessage) -> Dict:
|
|
"""Convert a single message to a BaseMessage."""
|
|
|
|
content = cast(str, message.content)
|
|
if isinstance(message, HumanMessage):
|
|
role = "user"
|
|
elif isinstance(message, AIMessage):
|
|
role = "assistant"
|
|
elif isinstance(message, SystemMessage):
|
|
role = "system"
|
|
else:
|
|
raise ValueError(f"Got unsupported message type: {message}")
|
|
|
|
result_message = dict(role=role, content=content)
|
|
return result_message
|
|
|
|
@classmethod
|
|
def _convert_message_from_dict(cls, message: Dict) -> BaseMessage:
|
|
"""Convert a single message from a BaseMessage."""
|
|
|
|
role = message["role"]
|
|
content = message["content"]
|
|
if role == "user":
|
|
return HumanMessage(content=content)
|
|
elif role == "assistant":
|
|
return AIMessage(content=content)
|
|
elif role == "system":
|
|
return SystemMessage(content=content)
|
|
else:
|
|
raise ValueError(f"Got unsupported role: {role}")
|
|
|
|
@classmethod
|
|
def _convert_dict_to_messages(cls, sa_data: Dict) -> List[BaseMessage]:
|
|
"""Convert a dict to a list of BaseMessages."""
|
|
|
|
schema = sa_data["schema"]
|
|
system = sa_data["system"]
|
|
messages = sa_data["messages"]
|
|
LOG.info(f"Importing prompt for schema: {schema}")
|
|
|
|
result_list: List[BaseMessage] = []
|
|
result_list.append(SystemMessage(content=system))
|
|
result_list.extend([cls._convert_message_from_dict(m) for m in messages])
|
|
return result_list
|
|
|
|
|
|
class KineticaSqlResponse(BaseModel):
|
|
"""Response containing SQL and the fetched data.
|
|
|
|
This object is returned by a chain with ``KineticaSqlOutputParser`` and it contains
|
|
the generated SQL and related Pandas Dataframe fetched from the database.
|
|
"""
|
|
|
|
sql: str = Field(default=None)
|
|
"""The generated SQL."""
|
|
|
|
# dataframe: "pd.DataFrame" = Field(default=None)
|
|
dataframe: Any = Field(default=None)
|
|
"""The Pandas dataframe containing the fetched data."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
|
|
"""Fetch and return data from the Kinetica LLM.
|
|
|
|
This object is used as the last element of a chain to execute generated SQL and it
|
|
will output a ``KineticaSqlResponse`` containing the SQL and a pandas dataframe with
|
|
the fetched data.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models.kinetica import (
|
|
KineticaChatLLM, KineticaSqlOutputParser)
|
|
kinetica_llm = KineticaChatLLM()
|
|
|
|
# create chain
|
|
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
|
ctx_messages.append(("human", "{input}"))
|
|
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
|
|
chain = (
|
|
prompt_template
|
|
| kinetica_llm
|
|
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
|
|
)
|
|
sql_response: KineticaSqlResponse = chain.invoke(
|
|
{"input": "What are the female users ordered by username?"}
|
|
)
|
|
|
|
assert isinstance(sql_response, KineticaSqlResponse)
|
|
LOG.info(f"SQL Response: {sql_response.sql}")
|
|
assert isinstance(sql_response.dataframe, pd.DataFrame)
|
|
"""
|
|
|
|
kdbc: Any = Field(exclude=True)
|
|
""" Kinetica DB connection. """
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
def parse(self, text: str) -> KineticaSqlResponse:
|
|
df = self.kdbc.to_df(text)
|
|
return KineticaSqlResponse(sql=text, dataframe=df)
|
|
|
|
def parse_result(
|
|
self, result: List[Generation], *, partial: bool = False
|
|
) -> KineticaSqlResponse:
|
|
return self.parse(result[0].text)
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
return "kinetica_sql_output_parser"
|