add return source docs (#1515)

fix-searx
Harrison Chase 1 year ago committed by GitHub
parent 064741db58
commit 8f21605d71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,6 +32,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
input_docs_key: str = "docs" #: :meta private:
answer_key: str = "answer" #: :meta private:
sources_answer_key: str = "sources" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
@classmethod
def from_llm(
@ -95,7 +97,10 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
:meta private:
"""
return [self.answer_key, self.sources_answer_key]
_output_keys = [self.answer_key, self.sources_answer_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
return _output_keys
@root_validator(pre=True)
def validate_naming(cls, values: Dict) -> Dict:
@ -108,14 +113,20 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs to run questioning over."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
docs = self._get_docs(inputs)
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
if "SOURCES: " in answer:
answer, sources = answer.split("SOURCES: ")
else:
sources = ""
return {self.answer_key: answer, self.sources_answer_key: sources}
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):

Loading…
Cancel
Save