from langchain.memory import ConversationBufferMemory from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.sql_database import SQLDatabase from langchain_core.prompts import PromptTemplate from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain from tests.unit_tests.fake_llm import FakeLLM # Fake db to test SQL-Chain db = SQLDatabase.from_uri("sqlite:///:memory:") def create_fake_db(db: SQLDatabase) -> SQLDatabase: """Create a table in fake db to test SQL-Chain""" db.run( """ CREATE TABLE foo (baaz TEXT); """ ) db.run( """ INSERT INTO foo (baaz) VALUES ('baaz'); """ ) return db db = create_fake_db(db) def test_sql_chain_without_memory() -> None: queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"} llm = FakeLLM(queries=queries, sequential_responses=True) db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) assert db_chain.run("hello") == "SELECT baaz from foo" def test_sql_chain_sequential_without_memory() -> None: queries = { "foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo", "foo3": "SELECT baaz from foo", } llm = FakeLLM(queries=queries, sequential_responses=True) db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True) assert db_chain.run("hello") == "SELECT baaz from foo" def test_sql_chain_with_memory() -> None: valid_prompt_with_history = """ Only use the following tables: {table_info} Question: {input} Given an input question, first create a syntactically correct {dialect} query to run. Always limit your query to at most {top_k} results. Relevant pieces of previous conversation: {history} (You do not need to use these pieces of information if not relevant) """ prompt = PromptTemplate( input_variables=["input", "table_info", "dialect", "top_k", "history"], template=valid_prompt_with_history, ) queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"} llm = FakeLLM(queries=queries, sequential_responses=True) memory = ConversationBufferMemory() db_chain = SQLDatabaseChain.from_llm( llm, db, memory=memory, prompt=prompt, verbose=True ) assert db_chain.run("hello") == "SELECT baaz from foo" def test_sql_chain_sequential_with_memory() -> None: valid_query_prompt_str = """ Only use the following tables: {table_info} Question: {input} Given an input question, first create a syntactically correct {dialect} query to run. Always limit your query to at most {top_k} results. Relevant pieces of previous conversation: {history} (You do not need to use these pieces of information if not relevant) """ valid_decider_prompt_str = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question. Question: {query} Table Names: {table_names} Relevant Table Names:""" valid_query_prompt = PromptTemplate( input_variables=["input", "table_info", "dialect", "top_k", "history"], template=valid_query_prompt_str, ) valid_decider_prompt = PromptTemplate( input_variables=["query", "table_names"], template=valid_decider_prompt_str, output_parser=CommaSeparatedListOutputParser(), ) queries = { "foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo", "foo3": "SELECT baaz from foo", } llm = FakeLLM(queries=queries, sequential_responses=True) memory = ConversationBufferMemory(memory_key="history", input_key="query") db_chain = SQLDatabaseSequentialChain.from_llm( llm, db, memory=memory, decider_prompt=valid_decider_prompt, query_prompt=valid_query_prompt, verbose=True, ) assert db_chain.run("hello") == "SELECT baaz from foo"