forked from Archives/langchain
db45970a66
# Fixes SQLAlchemy truncating the result if you have a big/text column with many chars. SQLAlchemy truncates columns if you try to convert a Row or Sequence to a string directly For comparison: - Before: ```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ... (2 characters truncated) ... hat is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ')]``` - After: ```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ')]``` ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: I'm not sure who to tag for chains, maybe @vowelparrot ?
149 lines
4.1 KiB
Python
149 lines
4.1 KiB
Python
# flake8: noqa=E501
|
|
"""Test SQL database wrapper."""
|
|
|
|
from sqlalchemy import (
|
|
Column,
|
|
Integer,
|
|
MetaData,
|
|
String,
|
|
Table,
|
|
Text,
|
|
create_engine,
|
|
insert,
|
|
)
|
|
|
|
from langchain.sql_database import SQLDatabase, truncate_word
|
|
|
|
metadata_obj = MetaData()
|
|
|
|
user = Table(
|
|
"user",
|
|
metadata_obj,
|
|
Column("user_id", Integer, primary_key=True),
|
|
Column("user_name", String(16), nullable=False),
|
|
Column("user_bio", Text, nullable=True),
|
|
)
|
|
|
|
company = Table(
|
|
"company",
|
|
metadata_obj,
|
|
Column("company_id", Integer, primary_key=True),
|
|
Column("company_location", String, nullable=False),
|
|
)
|
|
|
|
|
|
def test_table_info() -> None:
|
|
"""Test that table info is constructed properly."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
metadata_obj.create_all(engine)
|
|
db = SQLDatabase(engine)
|
|
output = db.table_info
|
|
expected_output = """
|
|
CREATE TABLE user (
|
|
user_id INTEGER NOT NULL,
|
|
user_name VARCHAR(16) NOT NULL,
|
|
user_bio TEXT,
|
|
PRIMARY KEY (user_id)
|
|
)
|
|
/*
|
|
3 rows from user table:
|
|
user_id user_name user_bio
|
|
/*
|
|
|
|
|
|
CREATE TABLE company (
|
|
company_id INTEGER NOT NULL,
|
|
company_location VARCHAR NOT NULL,
|
|
PRIMARY KEY (company_id)
|
|
)
|
|
/*
|
|
3 rows from company table:
|
|
company_id company_location
|
|
*/
|
|
"""
|
|
|
|
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
|
|
|
|
|
def test_table_info_w_sample_rows() -> None:
|
|
"""Test that table info is constructed properly."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
metadata_obj.create_all(engine)
|
|
values = [
|
|
{"user_id": 13, "user_name": "Harrison", "user_bio": "bio"},
|
|
{"user_id": 14, "user_name": "Chase", "user_bio": "bio"},
|
|
]
|
|
stmt = insert(user).values(values)
|
|
with engine.begin() as conn:
|
|
conn.execute(stmt)
|
|
|
|
db = SQLDatabase(engine, sample_rows_in_table_info=2)
|
|
|
|
output = db.table_info
|
|
|
|
expected_output = """
|
|
CREATE TABLE company (
|
|
company_id INTEGER NOT NULL,
|
|
company_location VARCHAR NOT NULL,
|
|
PRIMARY KEY (company_id)
|
|
)
|
|
/*
|
|
2 rows from company table:
|
|
company_id company_location
|
|
*/
|
|
|
|
CREATE TABLE user (
|
|
user_id INTEGER NOT NULL,
|
|
user_name VARCHAR(16) NOT NULL,
|
|
user_bio TEXT,
|
|
PRIMARY KEY (user_id)
|
|
)
|
|
/*
|
|
2 rows from user table:
|
|
user_id user_name user_bio
|
|
13 Harrison bio
|
|
14 Chase bio
|
|
*/
|
|
"""
|
|
|
|
assert sorted(output.split()) == sorted(expected_output.split())
|
|
|
|
|
|
def test_sql_database_run() -> None:
|
|
"""Test that commands can be run successfully and returned in correct format."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
metadata_obj.create_all(engine)
|
|
stmt = insert(user).values(
|
|
user_id=13, user_name="Harrison", user_bio="That is my Bio " * 24
|
|
)
|
|
with engine.begin() as conn:
|
|
conn.execute(stmt)
|
|
db = SQLDatabase(engine)
|
|
command = "select user_id, user_name, user_bio from user where user_id = 13"
|
|
output = db.run(command)
|
|
user_bio = "That is my Bio " * 19 + "That is my..."
|
|
expected_output = f"[(13, 'Harrison', '{user_bio}')]"
|
|
assert output == expected_output
|
|
|
|
|
|
def test_sql_database_run_update() -> None:
|
|
"""Test commands which return no rows return an empty string."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
metadata_obj.create_all(engine)
|
|
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
|
with engine.begin() as conn:
|
|
conn.execute(stmt)
|
|
db = SQLDatabase(engine)
|
|
command = "update user set user_name='Updated' where user_id = 13"
|
|
output = db.run(command)
|
|
expected_output = ""
|
|
assert output == expected_output
|
|
|
|
|
|
def test_truncate_word() -> None:
|
|
assert truncate_word("Hello World", length=5) == "He..."
|
|
assert truncate_word("Hello World", length=0) == "Hello World"
|
|
assert truncate_word("Hello World", length=-10) == "Hello World"
|
|
assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!"
|
|
assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"
|