diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index d573c8c9d9..1eda15375a 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -427,6 +427,7 @@ class SQLDatabase: self, command: str, fetch: Union[Literal["all"], Literal["one"]] = "all", + include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. @@ -434,12 +435,18 @@ class SQLDatabase: If the statement returns no rows, an empty string is returned. """ result = self._execute(command, fetch) - # Convert columns values to string to avoid issues with sqlalchemy - # truncating text + res = [ - tuple(truncate_word(c, length=self._max_string_length) for c in r.values()) + { + column: truncate_word(value, length=self._max_string_length) + for column, value in r.items() + } for r in result ] + + if not include_columns: + res = [tuple(row.values()) for row in res] + if not res: return "" else: @@ -465,6 +472,7 @@ class SQLDatabase: self, command: str, fetch: Union[Literal["all"], Literal["one"]] = "all", + include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. @@ -474,7 +482,7 @@ class SQLDatabase: If the statement throws an error, the error message is returned. """ try: - return self.run(command, fetch) + return self.run(command, fetch, include_columns) except SQLAlchemyError as e: """Format the error message""" return f"Error: {e}" diff --git a/libs/langchain/tests/unit_tests/test_sql_database.py b/libs/langchain/tests/unit_tests/test_sql_database.py index a5513be2d0..cd1d049d50 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database.py +++ b/libs/langchain/tests/unit_tests/test_sql_database.py @@ -120,10 +120,16 @@ def test_sql_database_run() -> None: conn.execute(stmt) db = SQLDatabase(engine) command = "select user_id, user_name, user_bio from user where user_id = 13" - output = db.run(command) + partial_output = db.run(command) user_bio = "That is my Bio " * 19 + "That is my..." - expected_output = f"[(13, 'Harrison', '{user_bio}')]" - assert output == expected_output + expected_partial_output = f"[(13, 'Harrison', '{user_bio}')]" + assert partial_output == expected_partial_output + + full_output = db.run(command, include_columns=True) + expected_full_output = ( + "[{'user_id': 13, 'user_name': 'Harrison', 'user_bio': '%s'}]" % user_bio + ) + assert full_output == expected_full_output def test_sql_database_run_update() -> None: