forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.0 KiB
Python
152 lines
5.0 KiB
Python
"""Question answering with sources over documents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
|
from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
|
COMBINE_PROMPT,
|
|
EXAMPLE_PROMPT,
|
|
QUESTION_PROMPT,
|
|
)
|
|
from langchain.docstore.document import Document
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.schema import BaseLanguageModel
|
|
|
|
|
|
class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|
"""Question answering with sources over documents."""
|
|
|
|
combine_documents_chain: BaseCombineDocumentsChain
|
|
"""Chain to use to combine documents."""
|
|
question_key: str = "question" #: :meta private:
|
|
input_docs_key: str = "docs" #: :meta private:
|
|
answer_key: str = "answer" #: :meta private:
|
|
sources_answer_key: str = "sources" #: :meta private:
|
|
return_source_documents: bool = False
|
|
"""Return the source documents."""
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
|
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
|
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
|
**kwargs: Any,
|
|
) -> BaseQAWithSourcesChain:
|
|
"""Construct the chain from an LLM."""
|
|
llm_question_chain = LLMChain(llm=llm, prompt=question_prompt)
|
|
llm_combine_chain = LLMChain(llm=llm, prompt=combine_prompt)
|
|
combine_results_chain = StuffDocumentsChain(
|
|
llm_chain=llm_combine_chain,
|
|
document_prompt=document_prompt,
|
|
document_variable_name="summaries",
|
|
)
|
|
combine_document_chain = MapReduceDocumentsChain(
|
|
llm_chain=llm_question_chain,
|
|
combine_document_chain=combine_results_chain,
|
|
document_variable_name="context",
|
|
)
|
|
return cls(
|
|
combine_documents_chain=combine_document_chain,
|
|
**kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def from_chain_type(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
chain_type: str = "stuff",
|
|
chain_type_kwargs: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> BaseQAWithSourcesChain:
|
|
"""Load chain from chain type."""
|
|
_chain_kwargs = chain_type_kwargs or {}
|
|
combine_document_chain = load_qa_with_sources_chain(
|
|
llm, chain_type=chain_type, **_chain_kwargs
|
|
)
|
|
return cls(combine_documents_chain=combine_document_chain, **kwargs)
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Expect input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.question_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return output key.
|
|
|
|
:meta private:
|
|
"""
|
|
_output_keys = [self.answer_key, self.sources_answer_key]
|
|
if self.return_source_documents:
|
|
_output_keys = _output_keys + ["source_documents"]
|
|
return _output_keys
|
|
|
|
@root_validator(pre=True)
|
|
def validate_naming(cls, values: Dict) -> Dict:
|
|
"""Fix backwards compatability in naming."""
|
|
if "combine_document_chain" in values:
|
|
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
|
return values
|
|
|
|
@abstractmethod
|
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
|
"""Get docs to run questioning over."""
|
|
|
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
docs = self._get_docs(inputs)
|
|
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
|
if re.search(r"SOURCES:\s", answer):
|
|
answer, sources = re.split(r"SOURCES:\s", answer)
|
|
else:
|
|
sources = ""
|
|
result: Dict[str, Any] = {
|
|
self.answer_key: answer,
|
|
self.sources_answer_key: sources,
|
|
}
|
|
if self.return_source_documents:
|
|
result["source_documents"] = docs
|
|
return result
|
|
|
|
|
|
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|
"""Question answering with sources over documents."""
|
|
|
|
input_docs_key: str = "docs" #: :meta private:
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Expect input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_docs_key, self.question_key]
|
|
|
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
|
return inputs.pop(self.input_docs_key)
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
return "qa_with_sources_chain"
|