From 57ade13b2b9c75d2967eb13c91417c356e2c805d Mon Sep 17 00:00:00 2001 From: Qihui Xie Date: Fri, 6 Oct 2023 12:32:08 +0800 Subject: [PATCH] fix llm_inputs duplication problem in intermediate_steps in SQLDatabaseChain (#10279) Use `.copy()` to fix the bug that the first `llm_inputs` element is overwritten by the second `llm_inputs` element in `intermediate_steps`. ***Problem description:*** In [line 127]( https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L127C17-L127C17), the `llm_inputs` of the sql generation step is appended as the first element of `intermediate_steps`: ``` intermediate_steps.append(llm_inputs) # input: sql generation ``` However, `llm_inputs` is a mutable dict, it is updated in [line 179](https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/sql/base.py#L179) for the final answer step: ``` llm_inputs["input"] = input_text ``` Then, the updated `llm_inputs` is appended as another element of `intermediate_steps` in [line 180](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L180): ``` intermediate_steps.append(llm_inputs) # input: final answer ``` As a result, the final `intermediate_steps` returned in [line 189](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L189C43-L189C43) actually contains two same `llm_inputs` elements, i.e., the `llm_inputs` for the sql generation step overwritten by the one for final answer step by mistake. Users are not able to get the actual `llm_inputs` for the sql generation step from `intermediate_steps` Simply calling `.copy()` when appending `llm_inputs` to `intermediate_steps` can solve this problem. --- libs/experimental/langchain_experimental/sql/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index 5b220b5eb0..e14989db83 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -127,7 +127,7 @@ class SQLDatabaseChain(Chain): llm_inputs[k] = inputs[k] intermediate_steps: List = [] try: - intermediate_steps.append(llm_inputs) # input: sql generation + intermediate_steps.append(llm_inputs.copy()) # input: sql generation sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, @@ -180,7 +180,7 @@ class SQLDatabaseChain(Chain): _run_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text - intermediate_steps.append(llm_inputs) # input: final answer + intermediate_steps.append(llm_inputs.copy()) # input: final answer final_result = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs,