You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/integration_tests/chat_models/test_kinetica.py

202 lines
6.3 KiB
Python

"""Test Kinetica Chat API wrapper."""
import logging
from typing import TYPE_CHECKING, Generator
import pandas as pd
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models.kinetica import (
ChatKinetica,
KineticaSqlOutputParser,
KineticaSqlResponse,
KineticaUtil,
)
if TYPE_CHECKING:
import gpudb
LOG = logging.getLogger(__name__)
@pytest.fixture(scope="module")
def vcr_config() -> dict:
return {
# Replace the Authorization request header with "DUMMY" in cassettes
"filter_headers": [("authorization", "DUMMY")],
}
class TestChatKinetica:
"""Integration tests for `Kinetica` chat models.
You must have `gpudb`, `typeguard`, and `faker` packages installed to run these
tests. pytest-vcr cassettes are provided for offline testing.
For more information see https://docs.kinetica.com/7.1/sql-gpt/concepts/.
These integration tests follow a workflow:
1. The `test_setup()` will create a table with fake user profiles and and a related
LLM context for the table.
2. The LLM context is retrieved from the DB and used to create a chat prompt
template.
3. A chain is constructed from the chat prompt template.
4. The chain is executed to generate the SQL and execute the query.
"""
table_name = "demo.test_profiles"
context_name = "demo.test_llm_ctx"
num_records = 100
@classmethod
@pytest.mark.vcr()
def test_setup(cls) -> "gpudb.GPUdb":
"""Create the connection, test table, and LLM context."""
kdbc = KineticaUtil.create_kdbc()
cls._create_test_table(kdbc, cls.table_name, cls.num_records)
cls._create_llm_context(kdbc, cls.context_name)
return kdbc
@pytest.mark.vcr()
def test_create_llm(self) -> None:
"""Create an LLM instance."""
import gpudb
kinetica_llm = ChatKinetica()
LOG.info(kinetica_llm._identifying_params)
assert isinstance(kinetica_llm.kdbc, gpudb.GPUdb)
assert kinetica_llm._llm_type == "kinetica-sqlassist"
@pytest.mark.vcr()
def test_load_context(self) -> None:
"""Load the LLM context from the DB."""
kinetica_llm = ChatKinetica()
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
system_message = ctx_messages[0]
assert isinstance(system_message, SystemMessage)
last_question = ctx_messages[-2]
assert isinstance(last_question, HumanMessage)
assert last_question.content == "How many male users are there?"
@pytest.mark.vcr()
def test_generate(self) -> None:
"""Generate SQL from a chain."""
kinetica_llm = ChatKinetica()
# create chain
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
ctx_messages.append(("human", "{input}"))
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
chain = prompt_template | kinetica_llm
resp_message = chain.invoke(
{"input": "What are the female users ordered by username?"}
)
LOG.info(f"SQL Response: {resp_message.content}")
assert isinstance(resp_message, AIMessage)
@pytest.mark.vcr()
def test_full_chain(self) -> None:
"""Generate SQL from a chain and execute the query."""
kinetica_llm = ChatKinetica()
# create chain
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
ctx_messages.append(("human", "{input}"))
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
chain = (
prompt_template
| kinetica_llm
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
)
sql_response: KineticaSqlResponse = chain.invoke(
{"input": "What are the female users ordered by username?"}
)
assert isinstance(sql_response, KineticaSqlResponse)
LOG.info(f"SQL Response: {sql_response.sql}")
assert isinstance(sql_response.dataframe, pd.DataFrame)
users = sql_response.dataframe["username"]
assert users[0] == "alexander40"
@classmethod
def _create_fake_records(cls, count: int) -> Generator:
"""Generator for fake records."""
import faker
faker.Faker.seed(5467)
faker_inst = faker.Faker(locale="en-US")
for id in range(0, count):
rec = dict(id=id, **faker_inst.simple_profile())
rec["birthdate"] = pd.Timestamp(rec["birthdate"])
yield rec
@classmethod
def _create_test_table(
cls, kinetica_dbc: "gpudb.GPUdb", table_name: str, num_records: int
) -> "gpudb.GPUdbTable":
"""Create a table from the fake records generator."""
import gpudb
table_df = pd.DataFrame.from_records(
data=cls._create_fake_records(num_records), index="id"
)
LOG.info(f"Creating test table '{table_name}' with {num_records} records...")
gpudb_table = gpudb.GPUdbTable.from_df(
table_df,
db=kinetica_dbc,
table_name=table_name,
clear_table=True,
load_data=True,
column_types={},
)
return gpudb_table
@classmethod
def _check_error(cls, response: dict) -> None:
"""Convert a DB error into an exception."""
status = response["status_info"]["status"]
if status != "OK":
message = response["status_info"]["message"]
raise Exception("[%s]: %s" % (status, message))
@classmethod
def _create_llm_context(
cls, kinetica_dbc: "gpudb.GPUdb", context_name: str
) -> None:
"""Create an LLM context for the table."""
sql = f"""
CREATE OR REPLACE CONTEXT {context_name}
(
TABLE = {cls.table_name}
COMMENT = 'Contains user profiles.'
),
(
SAMPLES = (
'How many male users are there?' =
'select count(1) as num_users
from {cls.table_name}
where sex = ''M'';')
)
"""
LOG.info(f"Creating context: {context_name}")
response = kinetica_dbc.execute_sql(sql)
cls._check_error(response)