diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 24bf475b..088e0aea 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -5,14 +5,7 @@ import warnings from typing import Any, Iterable, List, Optional import sqlalchemy -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) +from sqlalchemy import MetaData, Table, create_engine, inspect, select, text from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable @@ -27,6 +20,21 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: ) +def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: + """ + Truncate a string to a certain number of words, based on the max string + length. + """ + + if not isinstance(content, str) or length <= 0: + return content + + if len(content) <= length: + return content + + return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix + + class SQLDatabase: """SQLAlchemy wrapper around a database.""" @@ -41,6 +49,7 @@ class SQLDatabase: indexes_in_table_info: bool = False, custom_table_info: Optional[dict] = None, view_support: bool = False, + max_string_length: int = 300, ): """Create engine from database URI.""" self._engine = engine @@ -95,6 +104,8 @@ class SQLDatabase: if table in intersection ) + self._max_string_length = max_string_length + self._metadata = metadata or MetaData() # including view support if view_support = true self._metadata.reflect( @@ -322,6 +333,7 @@ class SQLDatabase: If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. + """ with self._engine.begin() as connection: if self._schema is not None: @@ -338,10 +350,28 @@ class SQLDatabase: if fetch == "all": result = cursor.fetchall() elif fetch == "one": - result = cursor.fetchone()[0] # type: ignore + result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") - return str(result) + + # Convert columns values to string to avoid issues with sqlalchmey + # trunacating text + if isinstance(result, list): + return str( + [ + tuple( + truncate_word(c, length=self._max_string_length) + for c in r + ) + for r in result + ] + ) + + return str( + tuple( + truncate_word(c, length=self._max_string_length) for c in result + ) + ) return "" def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index 3da40c5a..80f00950 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -1,9 +1,18 @@ # flake8: noqa=E501 """Test SQL database wrapper.""" -from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert +from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + Text, + create_engine, + insert, +) -from langchain.sql_database import SQLDatabase +from langchain.sql_database import SQLDatabase, truncate_word metadata_obj = MetaData() @@ -12,6 +21,7 @@ user = Table( metadata_obj, Column("user_id", Integer, primary_key=True), Column("user_name", String(16), nullable=False), + Column("user_bio", Text, nullable=True), ) company = Table( @@ -32,11 +42,12 @@ def test_table_info() -> None: CREATE TABLE user ( user_id INTEGER NOT NULL, user_name VARCHAR(16) NOT NULL, + user_bio TEXT, PRIMARY KEY (user_id) ) /* 3 rows from user table: - user_id user_name + user_id user_name user_bio /* @@ -59,8 +70,8 @@ def test_table_info_w_sample_rows() -> None: engine = create_engine("sqlite:///:memory:") metadata_obj.create_all(engine) values = [ - {"user_id": 13, "user_name": "Harrison"}, - {"user_id": 14, "user_name": "Chase"}, + {"user_id": 13, "user_name": "Harrison", "user_bio": "bio"}, + {"user_id": 14, "user_name": "Chase", "user_bio": "bio"}, ] stmt = insert(user).values(values) with engine.begin() as conn: @@ -84,13 +95,14 @@ def test_table_info_w_sample_rows() -> None: CREATE TABLE user ( user_id INTEGER NOT NULL, user_name VARCHAR(16) NOT NULL, + user_bio TEXT, PRIMARY KEY (user_id) ) /* 2 rows from user table: - user_id user_name - 13 Harrison - 14 Chase + user_id user_name user_bio + 13 Harrison bio + 14 Chase bio */ """ @@ -101,13 +113,16 @@ def test_sql_database_run() -> None: """Test that commands can be run successfully and returned in correct format.""" engine = create_engine("sqlite:///:memory:") metadata_obj.create_all(engine) - stmt = insert(user).values(user_id=13, user_name="Harrison") + stmt = insert(user).values( + user_id=13, user_name="Harrison", user_bio="That is my Bio " * 24 + ) with engine.begin() as conn: conn.execute(stmt) db = SQLDatabase(engine) - command = "select user_name from user where user_id = 13" + command = "select user_id, user_name, user_bio from user where user_id = 13" output = db.run(command) - expected_output = "[('Harrison',)]" + user_bio = "That is my Bio " * 19 + "That is my..." + expected_output = f"[(13, 'Harrison', '{user_bio}')]" assert output == expected_output @@ -123,3 +138,11 @@ def test_sql_database_run_update() -> None: output = db.run(command) expected_output = "" assert output == expected_output + + +def test_truncate_word() -> None: + assert truncate_word("Hello World", length=5) == "He..." + assert truncate_word("Hello World", length=0) == "Hello World" + assert truncate_word("Hello World", length=-10) == "Hello World" + assert truncate_word("Hello World", length=5, suffix="!!!") == "He!!!" + assert truncate_word("Hello World", length=12, suffix="!!!") == "Hello World"