From 8f21605d71c447c50d484302cde833734ad92761 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 7 Mar 2023 21:09:36 -0800 Subject: [PATCH] add return source docs (#1515) --- langchain/chains/qa_with_sources/base.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index d87db00b..3cc15147 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -32,6 +32,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): input_docs_key: str = "docs" #: :meta private: answer_key: str = "answer" #: :meta private: sources_answer_key: str = "sources" #: :meta private: + return_source_documents: bool = False + """Return the source documents.""" @classmethod def from_llm( @@ -95,7 +97,10 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): :meta private: """ - return [self.answer_key, self.sources_answer_key] + _output_keys = [self.answer_key, self.sources_answer_key] + if self.return_source_documents: + _output_keys = _output_keys + ["source_documents"] + return _output_keys @root_validator(pre=True) def validate_naming(cls, values: Dict) -> Dict: @@ -108,14 +113,20 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: """Get docs to run questioning over.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: docs = self._get_docs(inputs) answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs) if "SOURCES: " in answer: answer, sources = answer.split("SOURCES: ") else: sources = "" - return {self.answer_key: answer, self.sources_answer_key: sources} + result: Dict[str, Any] = { + self.answer_key: answer, + self.sources_answer_key: sources, + } + if self.return_source_documents: + result["source_documents"] = docs + return result class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):