forked from Archives/langchain
Fix SQLAlchemy truncating text when it is too big (#5206)
# Fixes SQLAlchemy truncating the result if you have a big/text column with many chars. SQLAlchemy truncates columns if you try to convert a Row or Sequence to a string directly For comparison: - Before: ```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ... (2 characters truncated) ... hat is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ')]``` - After: ```[('Harrison', 'That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio That is my Bio ')]``` ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: I'm not sure who to tag for chains, maybe @vowelparrot ?
This commit is contained in:
parent
4c572ffe95
commit
db45970a66
@ -5,14 +5,7 @@ import warnings
|
|||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||||
MetaData,
|
|
||||||
Table,
|
|
||||||
create_engine,
|
|
||||||
inspect,
|
|
||||||
select,
|
|
||||||
text,
|
|
||||||
)
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
@ -27,6 +20,21 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
||||||
|
"""
|
||||||
|
Truncate a string to a certain number of words, based on the max string
|
||||||
|
length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(content, str) or length <= 0:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if len(content) <= length:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabase:
|
class SQLDatabase:
|
||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
@ -41,6 +49,7 @@ class SQLDatabase:
|
|||||||
indexes_in_table_info: bool = False,
|
indexes_in_table_info: bool = False,
|
||||||
custom_table_info: Optional[dict] = None,
|
custom_table_info: Optional[dict] = None,
|
||||||
view_support: bool = False,
|
view_support: bool = False,
|
||||||
|
max_string_length: int = 300,
|
||||||
):
|
):
|
||||||
"""Create engine from database URI."""
|
"""Create engine from database URI."""
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
@ -95,6 +104,8 @@ class SQLDatabase:
|
|||||||
if table in intersection
|
if table in intersection
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._max_string_length = max_string_length
|
||||||
|
|
||||||
self._metadata = metadata or MetaData()
|
self._metadata = metadata or MetaData()
|
||||||
# including view support if view_support = true
|
# including view support if view_support = true
|
||||||
self._metadata.reflect(
|
self._metadata.reflect(
|
||||||
@ -322,6 +333,7 @@ class SQLDatabase:
|
|||||||
|
|
||||||
If the statement returns rows, a string of the results is returned.
|
If the statement returns rows, a string of the results is returned.
|
||||||
If the statement returns no rows, an empty string is returned.
|
If the statement returns no rows, an empty string is returned.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with self._engine.begin() as connection:
|
with self._engine.begin() as connection:
|
||||||
if self._schema is not None:
|
if self._schema is not None:
|
||||||
@ -338,10 +350,28 @@ class SQLDatabase:
|
|||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
elif fetch == "one":
|
elif fetch == "one":
|
||||||
result = cursor.fetchone()[0] # type: ignore
|
result = cursor.fetchone() # type: ignore
|
||||||
else:
|
else:
|
||||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
return str(result)
|
|
||||||
|
# Convert columns values to string to avoid issues with sqlalchmey
|
||||||
|
# trunacating text
|
||||||
|
if isinstance(result, list):
|
||||||
|
return str(
|
||||||
|
[
|
||||||
|
tuple(
|
||||||
|
truncate_word(c, length=self._max_string_length)
|
||||||
|
for c in r
|
||||||
|
)
|
||||||
|
for r in result
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return str(
|
||||||
|
tuple(
|
||||||
|
truncate_word(c, length=self._max_string_length) for c in result
|
||||||
|
)
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
@ -1,9 +1,18 @@
|
|||||||
# flake8: noqa=E501
|
# flake8: noqa=E501
|
||||||
"""Test SQL database wrapper."""
|
"""Test SQL database wrapper."""
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Text,
|
||||||
|
create_engine,
|
||||||
|
insert,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase, truncate_word
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
@ -12,6 +21,7 @@ user = Table(
|
|||||||
metadata_obj,
|
metadata_obj,
|
||||||
Column("user_id", Integer, primary_key=True),
|
Column("user_id", Integer, primary_key=True),
|
||||||
Column("user_name", String(16), nullable=False),
|
Column("user_name", String(16), nullable=False),
|
||||||
|
Column("user_bio", Text, nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
company = Table(
|
company = Table(
|
||||||
@ -32,11 +42,12 @@ def test_table_info() -> None:
|
|||||||
CREATE TABLE user (
|
CREATE TABLE user (
|
||||||
user_id INTEGER NOT NULL,
|
user_id INTEGER NOT NULL,
|
||||||
user_name VARCHAR(16) NOT NULL,
|
user_name VARCHAR(16) NOT NULL,
|
||||||
|
user_bio TEXT,
|
||||||
PRIMARY KEY (user_id)
|
PRIMARY KEY (user_id)
|
||||||
)
|
)
|
||||||
/*
|
/*
|
||||||
3 rows from user table:
|
3 rows from user table:
|
||||||
user_id user_name
|
user_id user_name user_bio
|
||||||
/*
|
/*
|
||||||
|
|
||||||
|
|
||||||
@ -59,8 +70,8 @@ def test_table_info_w_sample_rows() -> None:
|
|||||||
engine = create_engine("sqlite:///:memory:")
|
engine = create_engine("sqlite:///:memory:")
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
values = [
|
values = [
|
||||||
{"user_id": 13, "user_name": "Harrison"},
|
{"user_id": 13, "user_name": "Harrison", "user_bio": "bio"},
|
||||||
{"user_id": 14, "user_name": "Chase"},
|
{"user_id": 14, "user_name": "Chase", "user_bio": "bio"},
|
||||||
]
|
]
|
||||||
stmt = insert(user).values(values)
|
stmt = insert(user).values(values)
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
@ -84,13 +95,14 @@ def test_table_info_w_sample_rows() -> None:
|
|||||||
CREATE TABLE user (
|
CREATE TABLE user (
|
||||||
user_id INTEGER NOT NULL,
|
user_id INTEGER NOT NULL,
|
||||||
user_name VARCHAR(16) NOT NULL,
|
user_name VARCHAR(16) NOT NULL,
|
||||||
|
user_bio TEXT,
|
||||||
PRIMARY KEY (user_id)
|
PRIMARY KEY (user_id)
|
||||||
)
|
)
|
||||||
/*
|
/*
|
||||||
2 rows from user table:
|
2 rows from user table:
|
||||||
user_id user_name
|
user_id user_name user_bio
|
||||||
13 Harrison
|
13 Harrison bio
|
||||||
14 Chase
|
14 Chase bio
|
||||||
*/
|
*/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -101,13 +113,16 @@ def test_sql_database_run() -> None:
|
|||||||
"""Test that commands can be run successfully and returned in correct format."""
|
"""Test that commands can be run successfully and returned in correct format."""
|
||||||
engine = create_engine("sqlite:///:memory:")
|
engine = create_engine("sqlite:///:memory:")
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
stmt = insert(user).values(
|
||||||
|
user_id=13, user_name="Harrison", user_bio="That is my Bio " * 24
|
||||||
|
)
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
conn.execute(stmt)
|
conn.execute(stmt)
|
||||||
db = SQLDatabase(engine)
|
db = SQLDatabase(engine)
|
||||||
command = "select user_name from user where user_id = 13"
|
command = "select user_id, user_name, user_bio from user where user_id = 13"
|
||||||
output = db.run(command)
|
output = db.run(command)
|
||||||
expected_output = "[('Harrison',)]"
|
user_bio = "That is my Bio " * 19 + "That is my..."
|
||||||
|
expected_output = f"[(13, 'Harrison', '{user_bio}')]"
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
@ -123,3 +138,11 @@ def test_sql_database_run_update() -> None:
|
|||||||
output = db.run(command)
|
output = db.run(command)
|
||||||
expected_output = ""
|
expected_output = ""
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_word() -> None:
|
||||||
|
assert truncate_word("Hello World", length=5) == "He..."
|
||||||
|
assert truncate_word("Hello World", length=0) == "Hello World"
|
||||||
|
assert truncate_word("Hello World", length=-10) == "Hello World"
|
||||||
|
assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!"
|
||||||
|
assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"
|
||||||
|
Loading…
Reference in New Issue
Block a user