mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Harrison/sql query (#8370)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
a1a650c743
commit
a221a9ced0
@ -73,6 +73,7 @@ from langchain.chains.sql_database.base import (
|
||||
SQLDatabaseChain,
|
||||
SQLDatabaseSequentialChain,
|
||||
)
|
||||
from langchain.chains.sql_database.query import create_sql_query_chain
|
||||
from langchain.chains.transform import TransformChain
|
||||
|
||||
__all__ = [
|
||||
@ -132,4 +133,5 @@ __all__ = [
|
||||
"create_tagging_chain_pydantic",
|
||||
"generate_example",
|
||||
"load_chain",
|
||||
"create_sql_query_chain",
|
||||
]
|
||||
|
68
libs/langchain/langchain/chains/sql_database/query.py
Normal file
68
libs/langchain/langchain/chains/sql_database/query.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import List, Optional, TypedDict, Union
|
||||
|
||||
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output_parser import NoOpOutputParser
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.schema.runnable import RunnableMap, RunnableSequence
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
def _strip(text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
class SQLInput(TypedDict):
|
||||
"""Input for a SQL Chain."""
|
||||
|
||||
question: str
|
||||
|
||||
|
||||
class SQLInputWithTables(TypedDict):
|
||||
"""Input for a SQL Chain."""
|
||||
|
||||
question: str
|
||||
table_names_to_use: List[str]
|
||||
|
||||
|
||||
def create_sql_query_chain(
|
||||
llm: BaseLanguageModel,
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
k: int = 5,
|
||||
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
|
||||
"""Create a chain that generates SQL queries.
|
||||
|
||||
Args:
|
||||
llm: The language model to use
|
||||
db: The SQLDatabase to generate the query for
|
||||
prompt: The prompt to use. If none is provided, will choose one
|
||||
based on dialect. Defaults to None.
|
||||
k: The number of results per select statement to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
A chain that takes in a question and generates a SQL query that answers
|
||||
that question.
|
||||
"""
|
||||
if prompt is not None:
|
||||
prompt_to_use = prompt
|
||||
elif db.dialect in SQL_PROMPTS:
|
||||
prompt_to_use = SQL_PROMPTS[db.dialect]
|
||||
else:
|
||||
prompt_to_use = PROMPT
|
||||
inputs = {
|
||||
"input": lambda x: x["question"] + "\nSQLQuery: ",
|
||||
"top_k": lambda _: k,
|
||||
"table_info": lambda x: db.get_table_info(
|
||||
table_names=x.get("table_names_to_use")
|
||||
),
|
||||
}
|
||||
if "dialect" in prompt_to_use.input_variables:
|
||||
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
|
||||
return (
|
||||
RunnableMap(inputs)
|
||||
| prompt_to_use
|
||||
| llm.bind(stop=["\nSQLResult:"])
|
||||
| NoOpOutputParser()
|
||||
| _strip
|
||||
)
|
@ -72,6 +72,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
@ -81,6 +82,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
@ -201,6 +203,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
@ -221,6 +224,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
|
Loading…
Reference in New Issue
Block a user