From cb04ba013602a6ef6aa924c33e6c7dd084323856 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 11 Mar 2023 15:44:41 -0800 Subject: [PATCH] Add support for intermediate steps to SQLDatabaseSequentialChain (#1583) (#1601) for https://github.com/hwchase17/langchain/issues/1582 I simply added the `return_intermediate_steps` and changed the `output_keys` function. I added 2 simple tests, 1 for SQLDatabaseSequentialChain without the intermediate steps and 1 with Co-authored-by: brad-nemetski <115185478+brad-nemetski@users.noreply.github.com> --- langchain/chains/sql_database/base.py | 7 ++- .../chains/test_sql_database.py | 49 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index ccb3416f..37a3bd17 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -117,6 +117,8 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): This is useful in cases where the number of tables in the database is large. """ + return_intermediate_steps: bool = False + @classmethod def from_llm( cls, @@ -154,7 +156,10 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): :meta private: """ - return [self.output_key] + if not self.return_intermediate_steps: + return [self.output_key] + else: + return [self.output_key, "intermediate_steps"] def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: _table_names = self.sql_chain.database.get_table_names() diff --git a/tests/integration_tests/chains/test_sql_database.py b/tests/integration_tests/chains/test_sql_database.py index 67d82f02..3518866c 100644 --- a/tests/integration_tests/chains/test_sql_database.py +++ b/tests/integration_tests/chains/test_sql_database.py @@ -1,7 +1,10 @@ """Test SQL Database Chain.""" from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert -from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.chains.sql_database.base import ( + SQLDatabaseChain, + SQLDatabaseSequentialChain, +) from langchain.llms.openai import OpenAI from langchain.sql_database import SQLDatabase @@ -45,3 +48,47 @@ def test_sql_database_run_update() -> None: output = db_chain.run("What company does Harrison work at?") expected_output = " Harrison works at Bar." assert output == expected_output + + +def test_sql_database_sequential_chain_run() -> None: + """Test that commands can be run successfully SEQUENTIALLY + and returned in correct format.""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo") + with engine.connect() as conn: + conn.execute(stmt) + db = SQLDatabase(engine) + db_chain = SQLDatabaseSequentialChain.from_llm( + llm=OpenAI(temperature=0), database=db + ) + output = db_chain.run("What company does Harrison work at?") + expected_output = " Harrison works at Foo." + assert output == expected_output + + +def test_sql_database_sequential_chain_intermediate_steps() -> None: + """Test that commands can be run successfully SEQUENTIALLY and returned + in correct format. sWith Intermediate steps""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo") + with engine.connect() as conn: + conn.execute(stmt) + db = SQLDatabase(engine) + db_chain = SQLDatabaseSequentialChain.from_llm( + llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True + ) + output = db_chain("What company does Harrison work at?") + expected_output = " Harrison works at Foo." + assert output["result"] == expected_output + + query = output["intermediate_steps"][0] + expected_query = ( + " SELECT user_company FROM user WHERE user_name = 'Harrison' LIMIT 1;" + ) + assert query == expected_query + + query_results = output["intermediate_steps"][1] + expected_query_results = "[('Foo',)]" + assert query_results == expected_query_results