diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 237ecc2d..9d0a141c 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -40,8 +40,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): @root_validator(pre=True) def get_default_document_variable_name(cls, values: Dict) -> Dict: """Get default document variable name, if not provided.""" + llm_chain_variables = values["llm_chain"].prompt.input_variables if "document_variable_name" not in values: - llm_chain_variables = values["llm_chain"].prompt.input_variables if len(llm_chain_variables) == 1: values["document_variable_name"] = llm_chain_variables[0] else: @@ -50,7 +50,6 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): "multiple llm_chain_variables" ) else: - llm_chain_variables = values["llm_chain"].prompt.input_variables if values["document_variable_name"] not in llm_chain_variables: raise ValueError( f"document_variable_name {values['document_variable_name']} was "