diff --git a/docs/modules/agents/agents/examples/mrkl.ipynb b/docs/modules/agents/agents/examples/mrkl.ipynb index 3a099c38..1b02e163 100644 --- a/docs/modules/agents/agents/examples/mrkl.ipynb +++ b/docs/modules/agents/agents/examples/mrkl.ipynb @@ -42,7 +42,7 @@ "search = SerpAPIWrapper()\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", "tools = [\n", " Tool(\n", " name = \"Search\",\n", diff --git a/docs/modules/agents/agents/examples/mrkl_chat.ipynb b/docs/modules/agents/agents/examples/mrkl_chat.ipynb index 44b53b43..669e061f 100644 --- a/docs/modules/agents/agents/examples/mrkl_chat.ipynb +++ b/docs/modules/agents/agents/examples/mrkl_chat.ipynb @@ -44,7 +44,7 @@ "search = SerpAPIWrapper()\n", "llm_math_chain = LLMMathChain(llm=llm1, verbose=True)\n", "db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n", - "db_chain = SQLDatabaseChain(llm=llm1, database=db, verbose=True)\n", + "db_chain = SQLDatabaseChain.from_llm(llm1, db, verbose=True)\n", "tools = [\n", " Tool(\n", " name = \"Search\",\n", diff --git a/docs/use_cases/evaluation/sql_qa_benchmarking_chinook.ipynb b/docs/use_cases/evaluation/sql_qa_benchmarking_chinook.ipynb index 21a05540..317bc2f2 100644 --- a/docs/use_cases/evaluation/sql_qa_benchmarking_chinook.ipynb +++ b/docs/use_cases/evaluation/sql_qa_benchmarking_chinook.ipynb @@ -213,7 +213,7 @@ "metadata": {}, "outputs": [], "source": [ - "chain = SQLDatabaseChain(llm=llm, database=db, input_key=\"question\")" + "chain = SQLDatabaseChain.from_llm(llm, db, input_key=\"question\")" ] }, { @@ -415,7 +415,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index 5b9b78f9..b7e2c2a6 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -307,7 +307,9 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> SQLDatabaseChain: if "prompt" in config: prompt_config = config.pop("prompt") prompt = load_prompt_from_config(prompt_config) - return SQLDatabaseChain(database=database, llm=llm, prompt=prompt, **config) + else: + prompt = None + return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config) def _load_vector_db_qa_with_sources_chain( diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index fe08b82c..843a041b 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -223,8 +223,8 @@ class SQLDatabaseSequentialChain(Chain): **kwargs: Any, ) -> SQLDatabaseSequentialChain: """Load the necessary chains.""" - sql_chain = SQLDatabaseChain( - llm=llm, database=database, prompt=query_prompt, **kwargs + sql_chain = SQLDatabaseChain.from_llm( + llm, database, prompt=query_prompt, **kwargs ) decider_chain = LLMChain( llm=llm, prompt=decider_prompt, output_key="table_names"