Add support for intermediate steps to SQLDatabaseSequentialChain (#1583) (#1601)

for https://github.com/hwchase17/langchain/issues/1582

I simply added the `return_intermediate_steps` and changed the
`output_keys` function.

I added 2 simple tests, 1 for SQLDatabaseSequentialChain without the
intermediate steps and 1 with

Co-authored-by: brad-nemetski <115185478+brad-nemetski@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-11 15:44:41 -08:00 committed by GitHub
parent 5903a93f3d
commit cb04ba0136
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 2 deletions

View File

@ -117,6 +117,8 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
This is useful in cases where the number of tables in the database is large. This is useful in cases where the number of tables in the database is large.
""" """
return_intermediate_steps: bool = False
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
@ -154,7 +156,10 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
:meta private: :meta private:
""" """
return [self.output_key] if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names() _table_names = self.sql_chain.database.get_table_names()

View File

@ -1,7 +1,10 @@
"""Test SQL Database Chain.""" """Test SQL Database Chain."""
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.sql_database.base import (
SQLDatabaseChain,
SQLDatabaseSequentialChain,
)
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
@ -45,3 +48,47 @@ def test_sql_database_run_update() -> None:
output = db_chain.run("What company does Harrison work at?") output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Bar." expected_output = " Harrison works at Bar."
assert output == expected_output assert output == expected_output
def test_sql_database_sequential_chain_run() -> None:
"""Test that commands can be run successfully SEQUENTIALLY
and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseSequentialChain.from_llm(
llm=OpenAI(temperature=0), database=db
)
output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output == expected_output
def test_sql_database_sequential_chain_intermediate_steps() -> None:
"""Test that commands can be run successfully SEQUENTIALLY and returned
in correct format. sWith Intermediate steps"""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseSequentialChain.from_llm(
llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True
)
output = db_chain("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output["result"] == expected_output
query = output["intermediate_steps"][0]
expected_query = (
" SELECT user_company FROM user WHERE user_name = 'Harrison' LIMIT 1;"
)
assert query == expected_query
query_results = output["intermediate_steps"][1]
expected_query_results = "[('Foo',)]"
assert query_results == expected_query_results