fix llm_inputs duplication problem in intermediate_steps in SQLDatabaseChain (#10279)

Use `.copy()` to fix the bug that the first `llm_inputs` element is
overwritten by the second `llm_inputs` element in `intermediate_steps`.

***Problem description:***
In [line 127](

c732d8fffd/libs/experimental/langchain_experimental/sql/base.py (L127C17-L127C17)),
the `llm_inputs` of the sql generation step is appended as the first
element of `intermediate_steps`:
```
            intermediate_steps.append(llm_inputs)  # input: sql generation
```

However, `llm_inputs` is a mutable dict, it is updated in [line
179](https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/sql/base.py#L179)
for the final answer step:
```
                llm_inputs["input"] = input_text
```
Then, the updated `llm_inputs` is appended as another element of
`intermediate_steps` in [line
180](c732d8fffd/libs/experimental/langchain_experimental/sql/base.py (L180)):
```
                intermediate_steps.append(llm_inputs)  # input: final answer
```

As a result, the final `intermediate_steps` returned in [line
189](c732d8fffd/libs/experimental/langchain_experimental/sql/base.py (L189C43-L189C43))
actually contains two same `llm_inputs` elements, i.e., the `llm_inputs`
for the sql generation step overwritten by the one for final answer step
by mistake. Users are not able to get the actual `llm_inputs` for the
sql generation step from `intermediate_steps`

Simply calling `.copy()` when appending `llm_inputs` to
`intermediate_steps` can solve this problem.
pull/11473/head
Qihui Xie 9 months ago committed by GitHub
parent d78f418c0d
commit 57ade13b2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -127,7 +127,7 @@ class SQLDatabaseChain(Chain):
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs) # input: sql generation
intermediate_steps.append(llm_inputs.copy()) # input: sql generation
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
@ -180,7 +180,7 @@ class SQLDatabaseChain(Chain):
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
intermediate_steps.append(llm_inputs) # input: final answer
intermediate_steps.append(llm_inputs.copy()) # input: final answer
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,

Loading…
Cancel
Save