mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add class attribute "return_generated_question" to class "BaseConversationalRetrievalChain" (#5749)
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> Adding a class attribute "return_generated_question" to class "BaseConversationalRetrievalChain". If set to `True`, the chain's output has a key "generated_question" with the question generated by the sub-chain `question_generator` as the value. This way the generated question can be logged. <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> @dev2049 @vowelparrot
This commit is contained in:
parent
87ad4fc4b2
commit
a47c8618ec
@ -56,6 +56,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
question_generator: LLMChain
|
||||
output_key: str = "answer"
|
||||
return_source_documents: bool = False
|
||||
return_generated_question: bool = False
|
||||
get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None
|
||||
"""Return the source documents."""
|
||||
|
||||
@ -80,6 +81,8 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
_output_keys = [self.output_key]
|
||||
if self.return_source_documents:
|
||||
_output_keys = _output_keys + ["source_documents"]
|
||||
if self.return_generated_question:
|
||||
_output_keys = _output_keys + ["generated_question"]
|
||||
return _output_keys
|
||||
|
||||
@abstractmethod
|
||||
@ -110,10 +113,12 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
answer = self.combine_docs_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
output: Dict[str, Any] = {self.output_key: answer}
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
output["source_documents"] = docs
|
||||
if self.return_generated_question:
|
||||
output["generated_question"] = new_question
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
@ -142,10 +147,12 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
answer = await self.combine_docs_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
output: Dict[str, Any] = {self.output_key: answer}
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
output["source_documents"] = docs
|
||||
if self.return_generated_question:
|
||||
output["generated_question"] = new_question
|
||||
return output
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.get_chat_history:
|
||||
|
Loading…
Reference in New Issue
Block a user