diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 1832585c..baa3f6dd 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -56,6 +56,7 @@ class BaseConversationalRetrievalChain(Chain): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False + return_generated_question: bool = False get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None """Return the source documents.""" @@ -80,6 +81,8 @@ class BaseConversationalRetrievalChain(Chain): _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] + if self.return_generated_question: + _output_keys = _output_keys + ["generated_question"] return _output_keys @abstractmethod @@ -110,10 +113,12 @@ class BaseConversationalRetrievalChain(Chain): answer = self.combine_docs_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) + output: Dict[str, Any] = {self.output_key: answer} if self.return_source_documents: - return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + output["source_documents"] = docs + if self.return_generated_question: + output["generated_question"] = new_question + return output @abstractmethod async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: @@ -142,10 +147,12 @@ class BaseConversationalRetrievalChain(Chain): answer = await self.combine_docs_chain.arun( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) + output: Dict[str, Any] = {self.output_key: answer} if self.return_source_documents: - return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + output["source_documents"] = docs + if self.return_generated_question: + output["generated_question"] = new_question + return output def save(self, file_path: Union[Path, str]) -> None: if self.get_chat_history: