mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
25fbe356b4
This PR upgrades community to a recent version of mypy. It inserts type: ignore on all existing failures.
202 lines
6.4 KiB
Python
202 lines
6.4 KiB
Python
"""Test Kinetica Chat API wrapper."""
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Generator
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
from langchain_community.chat_models.kinetica import (
|
|
ChatKinetica,
|
|
KineticaSqlOutputParser,
|
|
KineticaSqlResponse,
|
|
KineticaUtil,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
import gpudb
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def vcr_config() -> dict:
|
|
return {
|
|
# Replace the Authorization request header with "DUMMY" in cassettes
|
|
"filter_headers": [("authorization", "DUMMY")],
|
|
}
|
|
|
|
|
|
class TestChatKinetica:
|
|
"""Integration tests for `Kinetica` chat models.
|
|
|
|
You must have `gpudb`, `typeguard`, and `faker` packages installed to run these
|
|
tests. pytest-vcr cassettes are provided for offline testing.
|
|
|
|
For more information see https://docs.kinetica.com/7.1/sql-gpt/concepts/.
|
|
|
|
These integration tests follow a workflow:
|
|
|
|
1. The `test_setup()` will create a table with fake user profiles and and a related
|
|
LLM context for the table.
|
|
|
|
2. The LLM context is retrieved from the DB and used to create a chat prompt
|
|
template.
|
|
|
|
3. A chain is constructed from the chat prompt template.
|
|
|
|
4. The chain is executed to generate the SQL and execute the query.
|
|
"""
|
|
|
|
table_name = "demo.test_profiles"
|
|
context_name = "demo.test_llm_ctx"
|
|
num_records = 100
|
|
|
|
@classmethod
|
|
@pytest.mark.vcr()
|
|
def test_setup(cls) -> "gpudb.GPUdb":
|
|
"""Create the connection, test table, and LLM context."""
|
|
|
|
kdbc = KineticaUtil.create_kdbc()
|
|
cls._create_test_table(kdbc, cls.table_name, cls.num_records)
|
|
cls._create_llm_context(kdbc, cls.context_name)
|
|
return kdbc
|
|
|
|
@pytest.mark.vcr()
|
|
def test_create_llm(self) -> None:
|
|
"""Create an LLM instance."""
|
|
import gpudb
|
|
|
|
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
|
LOG.info(kinetica_llm._identifying_params)
|
|
|
|
assert isinstance(kinetica_llm.kdbc, gpudb.GPUdb)
|
|
assert kinetica_llm._llm_type == "kinetica-sqlassist"
|
|
|
|
@pytest.mark.vcr()
|
|
def test_load_context(self) -> None:
|
|
"""Load the LLM context from the DB."""
|
|
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
|
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
|
|
|
system_message = ctx_messages[0]
|
|
assert isinstance(system_message, SystemMessage)
|
|
|
|
last_question = ctx_messages[-2]
|
|
assert isinstance(last_question, HumanMessage)
|
|
assert last_question.content == "How many male users are there?"
|
|
|
|
@pytest.mark.vcr()
|
|
def test_generate(self) -> None:
|
|
"""Generate SQL from a chain."""
|
|
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
|
|
|
# create chain
|
|
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
|
ctx_messages.append(("human", "{input}"))
|
|
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
|
|
chain = prompt_template | kinetica_llm
|
|
|
|
resp_message = chain.invoke(
|
|
{"input": "What are the female users ordered by username?"}
|
|
)
|
|
LOG.info(f"SQL Response: {resp_message.content}")
|
|
assert isinstance(resp_message, AIMessage)
|
|
|
|
@pytest.mark.vcr()
|
|
def test_full_chain(self) -> None:
|
|
"""Generate SQL from a chain and execute the query."""
|
|
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
|
|
|
# create chain
|
|
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
|
ctx_messages.append(("human", "{input}"))
|
|
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
|
|
chain = (
|
|
prompt_template
|
|
| kinetica_llm
|
|
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
|
|
)
|
|
sql_response: KineticaSqlResponse = chain.invoke(
|
|
{"input": "What are the female users ordered by username?"}
|
|
)
|
|
|
|
assert isinstance(sql_response, KineticaSqlResponse)
|
|
LOG.info(f"SQL Response: {sql_response.sql}")
|
|
assert isinstance(sql_response.dataframe, pd.DataFrame)
|
|
users = sql_response.dataframe["username"]
|
|
assert users[0] == "alexander40"
|
|
|
|
@classmethod
|
|
def _create_fake_records(cls, count: int) -> Generator:
|
|
"""Generator for fake records."""
|
|
import faker
|
|
|
|
faker.Faker.seed(5467)
|
|
faker_inst = faker.Faker(locale="en-US")
|
|
for id in range(0, count):
|
|
rec = dict(id=id, **faker_inst.simple_profile())
|
|
rec["birthdate"] = pd.Timestamp(rec["birthdate"])
|
|
yield rec
|
|
|
|
@classmethod
|
|
def _create_test_table(
|
|
cls, kinetica_dbc: "gpudb.GPUdb", table_name: str, num_records: int
|
|
) -> "gpudb.GPUdbTable":
|
|
"""Create a table from the fake records generator."""
|
|
import gpudb
|
|
|
|
table_df = pd.DataFrame.from_records(
|
|
data=cls._create_fake_records(num_records), index="id"
|
|
)
|
|
|
|
LOG.info(f"Creating test table '{table_name}' with {num_records} records...")
|
|
gpudb_table = gpudb.GPUdbTable.from_df(
|
|
table_df,
|
|
db=kinetica_dbc,
|
|
table_name=table_name,
|
|
clear_table=True,
|
|
load_data=True,
|
|
column_types={},
|
|
)
|
|
return gpudb_table
|
|
|
|
@classmethod
|
|
def _check_error(cls, response: dict) -> None:
|
|
"""Convert a DB error into an exception."""
|
|
status = response["status_info"]["status"]
|
|
if status != "OK":
|
|
message = response["status_info"]["message"]
|
|
raise Exception("[%s]: %s" % (status, message))
|
|
|
|
@classmethod
|
|
def _create_llm_context(
|
|
cls, kinetica_dbc: "gpudb.GPUdb", context_name: str
|
|
) -> None:
|
|
"""Create an LLM context for the table."""
|
|
|
|
sql = f"""
|
|
CREATE OR REPLACE CONTEXT {context_name}
|
|
(
|
|
TABLE = {cls.table_name}
|
|
COMMENT = 'Contains user profiles.'
|
|
),
|
|
(
|
|
SAMPLES = (
|
|
'How many male users are there?' =
|
|
'select count(1) as num_users
|
|
from {cls.table_name}
|
|
where sex = ''M'';')
|
|
)
|
|
"""
|
|
LOG.info(f"Creating context: {context_name}")
|
|
response = kinetica_dbc.execute_sql(sql)
|
|
cls._check_error(response)
|