diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index c2bbd4df..d030c6ac 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -101,7 +101,7 @@ class ChatVectorDBChain(Chain, BaseModel): else: return {self.output_key: answer} - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] chat_history_str = _get_chat_history(inputs["chat_history"]) vectordbkwargs = inputs.get("vectordbkwargs", {}) @@ -119,4 +119,7 @@ class ChatVectorDBChain(Chain, BaseModel): new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs) - return {self.output_key: answer} + if self.return_source_documents: + return {self.output_key: answer, "source_documents": docs} + else: + return {self.output_key: answer}