|
|
|
@ -32,6 +32,8 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|
|
|
|
input_docs_key: str = "docs" #: :meta private:
|
|
|
|
|
answer_key: str = "answer" #: :meta private:
|
|
|
|
|
sources_answer_key: str = "sources" #: :meta private:
|
|
|
|
|
return_source_documents: bool = False
|
|
|
|
|
"""Return the source documents."""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_llm(
|
|
|
|
@ -95,7 +97,10 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|
|
|
|
|
|
|
|
|
:meta private:
|
|
|
|
|
"""
|
|
|
|
|
return [self.answer_key, self.sources_answer_key]
|
|
|
|
|
_output_keys = [self.answer_key, self.sources_answer_key]
|
|
|
|
|
if self.return_source_documents:
|
|
|
|
|
_output_keys = _output_keys + ["source_documents"]
|
|
|
|
|
return _output_keys
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def validate_naming(cls, values: Dict) -> Dict:
|
|
|
|
@ -108,14 +113,20 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|
|
|
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
|
|
|
|
"""Get docs to run questioning over."""
|
|
|
|
|
|
|
|
|
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
docs = self._get_docs(inputs)
|
|
|
|
|
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
|
|
|
|
if "SOURCES: " in answer:
|
|
|
|
|
answer, sources = answer.split("SOURCES: ")
|
|
|
|
|
else:
|
|
|
|
|
sources = ""
|
|
|
|
|
return {self.answer_key: answer, self.sources_answer_key: sources}
|
|
|
|
|
result: Dict[str, Any] = {
|
|
|
|
|
self.answer_key: answer,
|
|
|
|
|
self.sources_answer_key: sources,
|
|
|
|
|
}
|
|
|
|
|
if self.return_source_documents:
|
|
|
|
|
result["source_documents"] = docs
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|
|
|
|