From a47c8618ec067dad7af7ecfd0dad4cdc5aa46e4e Mon Sep 17 00:00:00 2001 From: felpigeon Date: Mon, 5 Jun 2023 19:10:12 -0400 Subject: [PATCH] Add class attribute "return_generated_question" to class "BaseConversationalRetrievalChain" (#5749) Adding a class attribute "return_generated_question" to class "BaseConversationalRetrievalChain". If set to `True`, the chain's output has a key "generated_question" with the question generated by the sub-chain `question_generator` as the value. This way the generated question can be logged. #### Who can review? @dev2049 @vowelparrot --- .../chains/conversational_retrieval/base.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 1832585c..baa3f6dd 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -56,6 +56,7 @@ class BaseConversationalRetrievalChain(Chain): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False + return_generated_question: bool = False get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None """Return the source documents.""" @@ -80,6 +81,8 @@ class BaseConversationalRetrievalChain(Chain): _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] + if self.return_generated_question: + _output_keys = _output_keys + ["generated_question"] return _output_keys @abstractmethod @@ -110,10 +113,12 @@ class BaseConversationalRetrievalChain(Chain): answer = self.combine_docs_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) + output: Dict[str, Any] = {self.output_key: answer} if self.return_source_documents: - return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + output["source_documents"] = docs + if self.return_generated_question: + output["generated_question"] = new_question + return output @abstractmethod async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: @@ -142,10 +147,12 @@ class BaseConversationalRetrievalChain(Chain): answer = await self.combine_docs_chain.arun( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) + output: Dict[str, Any] = {self.output_key: answer} if self.return_source_documents: - return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + output["source_documents"] = docs + if self.return_generated_question: + output["generated_question"] = new_question + return output def save(self, file_path: Union[Path, str]) -> None: if self.get_chat_history: