diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index c806db1d04..ac5297af25 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -73,6 +73,7 @@ from langchain.chains.sql_database.base import ( SQLDatabaseChain, SQLDatabaseSequentialChain, ) +from langchain.chains.sql_database.query import create_sql_query_chain from langchain.chains.transform import TransformChain __all__ = [ @@ -132,4 +133,5 @@ __all__ = [ "create_tagging_chain_pydantic", "generate_example", "load_chain", + "create_sql_query_chain", ] diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py new file mode 100644 index 0000000000..e874555e0e --- /dev/null +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -0,0 +1,68 @@ +from typing import List, Optional, TypedDict, Union + +from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.output_parser import NoOpOutputParser +from langchain.schema.prompt_template import BasePromptTemplate +from langchain.schema.runnable import RunnableMap, RunnableSequence +from langchain.utilities.sql_database import SQLDatabase + + +def _strip(text: str) -> str: + return text.strip() + + +class SQLInput(TypedDict): + """Input for a SQL Chain.""" + + question: str + + +class SQLInputWithTables(TypedDict): + """Input for a SQL Chain.""" + + question: str + table_names_to_use: List[str] + + +def create_sql_query_chain( + llm: BaseLanguageModel, + db: SQLDatabase, + prompt: Optional[BasePromptTemplate] = None, + k: int = 5, +) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]: + """Create a chain that generates SQL queries. + + Args: + llm: The language model to use + db: The SQLDatabase to generate the query for + prompt: The prompt to use. If none is provided, will choose one + based on dialect. Defaults to None. + k: The number of results per select statement to return. Defaults to 5. + + Returns: + A chain that takes in a question and generates a SQL query that answers + that question. + """ + if prompt is not None: + prompt_to_use = prompt + elif db.dialect in SQL_PROMPTS: + prompt_to_use = SQL_PROMPTS[db.dialect] + else: + prompt_to_use = PROMPT + inputs = { + "input": lambda x: x["question"] + "\nSQLQuery: ", + "top_k": lambda _: k, + "table_info": lambda x: db.get_table_info( + table_names=x.get("table_names_to_use") + ), + } + if "dialect" in prompt_to_use.input_variables: + inputs["dialect"] = lambda _: (db.dialect, prompt_to_use) + return ( + RunnableMap(inputs) + | prompt_to_use + | llm.bind(stop=["\nSQLResult:"]) + | NoOpOutputParser() + | _strip + ) diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 19a3be143f..0864fe55d0 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -72,6 +72,7 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Any, Other], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], ], ) -> RunnableSequence[Input, Other]: @@ -81,6 +82,7 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Other, Any], + Callable[[Any], Other], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], ], ) -> RunnableSequence[Other, Output]: @@ -201,6 +203,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, other: Union[ Runnable[Any, Other], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], ], ) -> RunnableSequence[Input, Other]: @@ -221,6 +224,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, other: Union[ Runnable[Other, Any], + Callable[[Any], Other], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], ], ) -> RunnableSequence[Other, Output]: