Fix SQLAlchemy truncating text when it is too big (#5206)

# Fixes SQLAlchemy truncating the result if you have a big/text column
with many chars.

SQLAlchemy truncates columns if you try to convert a Row or Sequence to
a string directly

For comparison:

- Before:
```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ... (2 characters truncated) ... hat is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ')]```

- After:
```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is
my Bio That is my Bio That is my Bio That is my Bio That is my Bio That
is my Bio That is my Bio That is my Bio That is my Bio That is my Bio
That is my Bio That is my Bio That is my Bio That is my Bio That is my
Bio That is my Bio That is my Bio ')]```



## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

I'm not sure who to tag for chains, maybe @vowelparrot ?
This commit is contained in:
Waldecir Santos 2023-06-01 20:33:31 -05:00 committed by GitHub
parent 4c572ffe95
commit db45970a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 21 deletions

View File

@ -5,14 +5,7 @@ import warnings
from typing import Any, Iterable, List, Optional
import sqlalchemy
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
@ -27,6 +20,21 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
)
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
"""
Truncate a string to a certain number of words, based on the max string
length.
"""
if not isinstance(content, str) or length <= 0:
return content
if len(content) <= length:
return content
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
class SQLDatabase:
"""SQLAlchemy wrapper around a database."""
@ -41,6 +49,7 @@ class SQLDatabase:
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
):
"""Create engine from database URI."""
self._engine = engine
@ -95,6 +104,8 @@ class SQLDatabase:
if table in intersection
)
self._max_string_length = max_string_length
self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
@ -322,6 +333,7 @@ class SQLDatabase:
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
if self._schema is not None:
@ -338,10 +350,28 @@ class SQLDatabase:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result)
# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
return str(
[
tuple(
truncate_word(c, length=self._max_string_length)
for c in r
)
for r in result
]
)
return str(
tuple(
truncate_word(c, length=self._max_string_length) for c in result
)
)
return ""
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:

View File

@ -1,9 +1,18 @@
# flake8: noqa=E501
"""Test SQL database wrapper."""
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
from sqlalchemy import (
Column,
Integer,
MetaData,
String,
Table,
Text,
create_engine,
insert,
)
from langchain.sql_database import SQLDatabase
from langchain.sql_database import SQLDatabase, truncate_word
metadata_obj = MetaData()
@ -12,6 +21,7 @@ user = Table(
metadata_obj,
Column("user_id", Integer, primary_key=True),
Column("user_name", String(16), nullable=False),
Column("user_bio", Text, nullable=True),
)
company = Table(
@ -32,11 +42,12 @@ def test_table_info() -> None:
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
user_bio TEXT,
PRIMARY KEY (user_id)
)
/*
3 rows from user table:
user_id user_name
user_id user_name user_bio
/*
@ -59,8 +70,8 @@ def test_table_info_w_sample_rows() -> None:
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
values = [
{"user_id": 13, "user_name": "Harrison"},
{"user_id": 14, "user_name": "Chase"},
{"user_id": 13, "user_name": "Harrison", "user_bio": "bio"},
{"user_id": 14, "user_name": "Chase", "user_bio": "bio"},
]
stmt = insert(user).values(values)
with engine.begin() as conn:
@ -84,13 +95,14 @@ def test_table_info_w_sample_rows() -> None:
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
user_bio TEXT,
PRIMARY KEY (user_id)
)
/*
2 rows from user table:
user_id user_name
13 Harrison
14 Chase
user_id user_name user_bio
13 Harrison bio
14 Chase bio
*/
"""
@ -101,13 +113,16 @@ def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
stmt = insert(user).values(
user_id=13, user_name="Harrison", user_bio="That is my Bio " * 24
)
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "select user_name from user where user_id = 13"
command = "select user_id, user_name, user_bio from user where user_id = 13"
output = db.run(command)
expected_output = "[('Harrison',)]"
user_bio = "That is my Bio " * 19 + "That is my..."
expected_output = f"[(13, 'Harrison', '{user_bio}')]"
assert output == expected_output
@ -123,3 +138,11 @@ def test_sql_database_run_update() -> None:
output = db.run(command)
expected_output = ""
assert output == expected_output
def test_truncate_word() -> None:
assert truncate_word("Hello World", length=5) == "He..."
assert truncate_word("Hello World", length=0) == "Hello World"
assert truncate_word("Hello World", length=-10) == "Hello World"
assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!"
assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"