diff --git a/docs/snippets/modules/chains/popular/sqlite.mdx b/docs/snippets/modules/chains/popular/sqlite.mdx index 6319a0917c..fcd58eb5e7 100644 --- a/docs/snippets/modules/chains/popular/sqlite.mdx +++ b/docs/snippets/modules/chains/popular/sqlite.mdx @@ -191,6 +191,112 @@ result["intermediate_steps"] +## Adding Memory + +How to add memory to a SQLDatabaseChain: + +```python +from langchain.llms import OpenAI +from langchain.utilities import SQLDatabase +from langchain_experimental.sql import SQLDatabaseChain +``` + +Set up the SQLDatabase and LLM + +```python +db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db") +llm = OpenAI(temperature=0, verbose=True) +``` + +Set up the memory + +```python +from langchain.memory import ConversationBufferMemory +memory = ConversationBufferMemory() +``` + +Now we need to add a place for memory in the prompt template + +```python +from langchain.prompts import PromptTemplate +PROMPT_SUFFIX = """Only use the following tables: +{table_info} + +Previous Conversation: +{history} + +Question: {input}""" + +_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. + +Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. + +Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +Use the following format: + +Question: Question here +SQLQuery: SQL Query to run +SQLResult: Result of the SQLQuery +Answer: Final answer here + +""" + +PROMPT = PromptTemplate.from_template( + _DEFAULT_TEMPLATE + PROMPT_SUFFIX, +) +``` + +Now let's create and run out chain + +```python +db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, memory=memory) +db_chain.run("name one employee") +``` + + + +``` + > Entering new SQLDatabaseChain chain... + name one employee + SQLQuery:SELECT FirstName, LastName FROM Employee LIMIT 1 + SQLResult: [('Andrew', 'Adams')] + Answer:Andrew Adams + > Finished chain. + + + + + + 'Andrew Adams' +``` + + + +```python +db_chain.run("how many letters in their name?") +``` + + + +``` + > Entering new SQLDatabaseChain chain... + how many letters in their name? + SQLQuery:SELECT LENGTH(FirstName) + LENGTH(LastName) AS 'NameLength' FROM Employee WHERE FirstName = 'Andrew' AND LastName = 'Adams' + SQLResult: [(11,)] + Answer:Andrew Adams has 11 letters in their name. + > Finished chain. + + + + + + 'Andrew Adams has 11 letters in their name.' +``` + + + + ## Choosing how to limit the number of rows returned If you are querying for several rows of a table you can select the maximum number of results you want to get by using the 'top_k' parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily. diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index 9144a705e3..5b220b5eb0 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -122,6 +122,9 @@ class SQLDatabaseChain(Chain): "table_info": table_info, "stop": ["\nSQLResult:"], } + if self.memory is not None: + for k in self.memory.memory_variables: + llm_inputs[k] = inputs[k] intermediate_steps: List = [] try: intermediate_steps.append(llm_inputs) # input: sql generation diff --git a/libs/experimental/tests/unit_tests/test_sql.py b/libs/experimental/tests/unit_tests/test_sql.py new file mode 100644 index 0000000000..06589804bb --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_sql.py @@ -0,0 +1,128 @@ +from langchain.memory import ConversationBufferMemory +from langchain.output_parsers.list import CommaSeparatedListOutputParser +from langchain.prompts import PromptTemplate +from langchain.sql_database import SQLDatabase + +from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain +from tests.unit_tests.fake_llm import FakeLLM + +# Fake db to test SQL-Chain +db = SQLDatabase.from_uri("sqlite:///:memory:") + + +def create_fake_db(db: SQLDatabase) -> SQLDatabase: + """Create a table in fake db to test SQL-Chain""" + db.run( + """ + CREATE TABLE foo (baaz TEXT); + """ + ) + db.run( + """ + INSERT INTO foo (baaz) + VALUES ('baaz'); + """ + ) + return db + + +db = create_fake_db(db) + + +def test_sql_chain_without_memory() -> None: + queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"} + llm = FakeLLM(queries=queries, sequential_responses=True) + db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) + assert db_chain.run("hello") == "SELECT baaz from foo" + + +def test_sql_chain_sequential_without_memory() -> None: + queries = { + "foo": "SELECT baaz from foo", + "foo2": "SELECT baaz from foo", + "foo3": "SELECT baaz from foo", + } + llm = FakeLLM(queries=queries, sequential_responses=True) + db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True) + assert db_chain.run("hello") == "SELECT baaz from foo" + + +def test_sql_chain_with_memory() -> None: + valid_prompt_with_history = """ + Only use the following tables: + {table_info} + Question: {input} + + Given an input question, first create a syntactically correct + {dialect} query to run. + Always limit your query to at most {top_k} results. + + Relevant pieces of previous conversation: + {history} + + (You do not need to use these pieces of information if not relevant) + """ + prompt = PromptTemplate( + input_variables=["input", "table_info", "dialect", "top_k", "history"], + template=valid_prompt_with_history, + ) + queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"} + llm = FakeLLM(queries=queries, sequential_responses=True) + memory = ConversationBufferMemory() + db_chain = SQLDatabaseChain.from_llm( + llm, db, memory=memory, prompt=prompt, verbose=True + ) + assert db_chain.run("hello") == "SELECT baaz from foo" + + +def test_sql_chain_sequential_with_memory() -> None: + valid_query_prompt_str = """ + Only use the following tables: + {table_info} + Question: {input} + + Given an input question, first create a syntactically correct + {dialect} query to run. + Always limit your query to at most {top_k} results. + + Relevant pieces of previous conversation: + {history} + + (You do not need to use these pieces of information + if not relevant) + """ + valid_decider_prompt_str = """Given the below input question and list of potential + tables, output a comma separated list of the + table names that may be necessary to answer this question. + + Question: {query} + + Table Names: {table_names} + + Relevant Table Names:""" + + valid_query_prompt = PromptTemplate( + input_variables=["input", "table_info", "dialect", "top_k", "history"], + template=valid_query_prompt_str, + ) + valid_decider_prompt = PromptTemplate( + input_variables=["query", "table_names"], + template=valid_decider_prompt_str, + output_parser=CommaSeparatedListOutputParser(), + ) + queries = { + "foo": "SELECT baaz from foo", + "foo2": "SELECT baaz from foo", + "foo3": "SELECT baaz from foo", + } + llm = FakeLLM(queries=queries, sequential_responses=True) + memory = ConversationBufferMemory(memory_key="history", input_key="query") + db_chain = SQLDatabaseSequentialChain.from_llm( + llm, + db, + memory=memory, + decider_prompt=valid_decider_prompt, + query_prompt=valid_query_prompt, + verbose=True, + ) + assert db_chain.run("hello") == "SELECT baaz from foo"