Add memory to sql chain (#8597)

continuation of PR #8550

@hwchase17 please see and merge. And also close the PR #8550.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
pull/11361/head
Mohammad Mohtashim 9 months ago committed by GitHub
parent feabf2e0d5
commit 3bddd708f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -191,6 +191,112 @@ result["intermediate_steps"]
</CodeOutputBlock>
## Adding Memory
How to add memory to a SQLDatabaseChain:
```python
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
```
Set up the SQLDatabase and LLM
```python
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")
llm = OpenAI(temperature=0, verbose=True)
```
Set up the memory
```python
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory()
```
Now we need to add a place for memory in the prompt template
```python
from langchain.prompts import PromptTemplate
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Previous Conversation:
{history}
Question: {input}"""
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
"""
PROMPT = PromptTemplate.from_template(
_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
)
```
Now let's create and run out chain
```python
db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, memory=memory)
db_chain.run("name one employee")
```
<CodeOutputBlock lang="python">
```
> Entering new SQLDatabaseChain chain...
name one employee
SQLQuery:SELECT FirstName, LastName FROM Employee LIMIT 1
SQLResult: [('Andrew', 'Adams')]
Answer:Andrew Adams
> Finished chain.
'Andrew Adams'
```
</CodeOutputBlock>
```python
db_chain.run("how many letters in their name?")
```
<CodeOutputBlock lang="python">
```
> Entering new SQLDatabaseChain chain...
how many letters in their name?
SQLQuery:SELECT LENGTH(FirstName) + LENGTH(LastName) AS 'NameLength' FROM Employee WHERE FirstName = 'Andrew' AND LastName = 'Adams'
SQLResult: [(11,)]
Answer:Andrew Adams has 11 letters in their name.
> Finished chain.
'Andrew Adams has 11 letters in their name.'
```
</CodeOutputBlock>
## Choosing how to limit the number of rows returned
If you are querying for several rows of a table you can select the maximum number of results you want to get by using the 'top_k' parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.

@ -122,6 +122,9 @@ class SQLDatabaseChain(Chain):
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
if self.memory is not None:
for k in self.memory.memory_variables:
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs) # input: sql generation

@ -0,0 +1,128 @@
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseChain, SQLDatabaseSequentialChain
from tests.unit_tests.fake_llm import FakeLLM
# Fake db to test SQL-Chain
db = SQLDatabase.from_uri("sqlite:///:memory:")
def create_fake_db(db: SQLDatabase) -> SQLDatabase:
"""Create a table in fake db to test SQL-Chain"""
db.run(
"""
CREATE TABLE foo (baaz TEXT);
"""
)
db.run(
"""
INSERT INTO foo (baaz)
VALUES ('baaz');
"""
)
return db
db = create_fake_db(db)
def test_sql_chain_without_memory() -> None:
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
llm = FakeLLM(queries=queries, sequential_responses=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_sequential_without_memory() -> None:
queries = {
"foo": "SELECT baaz from foo",
"foo2": "SELECT baaz from foo",
"foo3": "SELECT baaz from foo",
}
llm = FakeLLM(queries=queries, sequential_responses=True)
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_with_memory() -> None:
valid_prompt_with_history = """
Only use the following tables:
{table_info}
Question: {input}
Given an input question, first create a syntactically correct
{dialect} query to run.
Always limit your query to at most {top_k} results.
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
"""
prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k", "history"],
template=valid_prompt_with_history,
)
queries = {"foo": "SELECT baaz from foo", "foo2": "SELECT baaz from foo"}
llm = FakeLLM(queries=queries, sequential_responses=True)
memory = ConversationBufferMemory()
db_chain = SQLDatabaseChain.from_llm(
llm, db, memory=memory, prompt=prompt, verbose=True
)
assert db_chain.run("hello") == "SELECT baaz from foo"
def test_sql_chain_sequential_with_memory() -> None:
valid_query_prompt_str = """
Only use the following tables:
{table_info}
Question: {input}
Given an input question, first create a syntactically correct
{dialect} query to run.
Always limit your query to at most {top_k} results.
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information
if not relevant)
"""
valid_decider_prompt_str = """Given the below input question and list of potential
tables, output a comma separated list of the
table names that may be necessary to answer this question.
Question: {query}
Table Names: {table_names}
Relevant Table Names:"""
valid_query_prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k", "history"],
template=valid_query_prompt_str,
)
valid_decider_prompt = PromptTemplate(
input_variables=["query", "table_names"],
template=valid_decider_prompt_str,
output_parser=CommaSeparatedListOutputParser(),
)
queries = {
"foo": "SELECT baaz from foo",
"foo2": "SELECT baaz from foo",
"foo3": "SELECT baaz from foo",
}
llm = FakeLLM(queries=queries, sequential_responses=True)
memory = ConversationBufferMemory(memory_key="history", input_key="query")
db_chain = SQLDatabaseSequentialChain.from_llm(
llm,
db,
memory=memory,
decider_prompt=valid_decider_prompt,
query_prompt=valid_query_prompt,
verbose=True,
)
assert db_chain.run("hello") == "SELECT baaz from foo"
Loading…
Cancel
Save