Harrison/sql query (#8370)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/8402/head
Harrison Chase 1 year ago committed by GitHub
parent a1a650c743
commit a221a9ced0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,6 +73,7 @@ from langchain.chains.sql_database.base import (
SQLDatabaseChain,
SQLDatabaseSequentialChain,
)
from langchain.chains.sql_database.query import create_sql_query_chain
from langchain.chains.transform import TransformChain
__all__ = [
@ -132,4 +133,5 @@ __all__ = [
"create_tagging_chain_pydantic",
"generate_example",
"load_chain",
"create_sql_query_chain",
]

@ -0,0 +1,68 @@
from typing import List, Optional, TypedDict, Union
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import NoOpOutputParser
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableSequence
from langchain.utilities.sql_database import SQLDatabase
def _strip(text: str) -> str:
return text.strip()
class SQLInput(TypedDict):
"""Input for a SQL Chain."""
question: str
class SQLInputWithTables(TypedDict):
"""Input for a SQL Chain."""
question: str
table_names_to_use: List[str]
def create_sql_query_chain(
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
"""Create a chain that generates SQL queries.
Args:
llm: The language model to use
db: The SQLDatabase to generate the query for
prompt: The prompt to use. If none is provided, will choose one
based on dialect. Defaults to None.
k: The number of results per select statement to return. Defaults to 5.
Returns:
A chain that takes in a question and generates a SQL query that answers
that question.
"""
if prompt is not None:
prompt_to_use = prompt
elif db.dialect in SQL_PROMPTS:
prompt_to_use = SQL_PROMPTS[db.dialect]
else:
prompt_to_use = PROMPT
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"top_k": lambda _: k,
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
),
}
if "dialect" in prompt_to_use.input_variables:
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
return (
RunnableMap(inputs)
| prompt_to_use
| llm.bind(stop=["\nSQLResult:"])
| NoOpOutputParser()
| _strip
)

@ -72,6 +72,7 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
@ -81,6 +82,7 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
@ -201,6 +203,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
@ -221,6 +224,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:

Loading…
Cancel
Save