mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Docs combine document chain (#6994)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
81eebc4070
commit
0ad984fa27
@ -4,6 +4,7 @@ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
||||
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
||||
@ -111,4 +112,5 @@ __all__ = [
|
||||
"MapRerankDocumentsChain",
|
||||
"MapReduceDocumentsChain",
|
||||
"RefineDocumentsChain",
|
||||
"ReduceDocumentsChain",
|
||||
]
|
||||
|
@ -11,30 +11,20 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
"""Base interface for chains combining documents.
|
||||
|
||||
Subclasses of this chain deal with combining documents in a variety of
|
||||
ways. This base class exists to add some uniformity in the interface these types
|
||||
of chains should expose. Namely, they expect an input key related to the documents
|
||||
to use (default `input_documents`), and then also expose a method to calculate
|
||||
the length of a prompt from documents (useful for outside callers to use to
|
||||
determine whether it's safe to pass a list of documents into this chain or whether
|
||||
that will longer than the context length).
|
||||
"""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
@ -58,25 +48,57 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
Returns None if the method does not depend on the prompt length.
|
||||
This can be used by a caller to determine whether passing in a list
|
||||
of documents would exceed a certain prompt length. This useful when
|
||||
trying to ensure that the size of a prompt remains below a certain
|
||||
context limit.
|
||||
|
||||
Args:
|
||||
docs: List[Document], a list of documents to use to calculate the
|
||||
total prompt length.
|
||||
|
||||
Returns:
|
||||
Returns None if the method does not depend on the prompt length,
|
||||
otherwise the length of the prompt in tokens.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
"""Combine documents into a single string.
|
||||
|
||||
Args:
|
||||
docs: List[Document], the documents to combine
|
||||
**kwargs: Other parameters to use in combining documents, often
|
||||
other inputs to the prompt.
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string asynchronously."""
|
||||
"""Combine documents into a single string.
|
||||
|
||||
Args:
|
||||
docs: List[Document], the documents to combine
|
||||
**kwargs: Other parameters to use in combining documents, often
|
||||
other inputs to the prompt.
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
@ -92,6 +114,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
@ -104,7 +127,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
||||
|
||||
class AnalyzeDocumentChain(Chain):
|
||||
"""Chain that splits documents, then analyzes it in pieces."""
|
||||
"""Chain that splits documents, then analyzes it in pieces.
|
||||
|
||||
This chain is parameterized by a TextSplitter and a CombineDocumentsChain.
|
||||
This chain takes a single document as input, and then splits it up into chunks
|
||||
and then passes those chucks to the CombineDocumentsChain.
|
||||
"""
|
||||
|
||||
input_key: str = "input_document" #: :meta private:
|
||||
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
|
||||
@ -131,6 +159,7 @@ class AnalyzeDocumentChain(Chain):
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Split document into chunks and pass to CombineDocumentsChain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
document = inputs[self.input_key]
|
||||
docs = self.text_splitter.create_documents([document])
|
||||
|
@ -2,74 +2,97 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class CombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
|
||||
def _split_list_of_docs(
|
||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
) -> List[List[Document]]:
|
||||
new_result_doc_list = []
|
||||
_sub_result_docs = []
|
||||
for doc in docs:
|
||||
_sub_result_docs.append(doc)
|
||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||
if _num_tokens > token_max:
|
||||
if len(_sub_result_docs) == 1:
|
||||
raise ValueError(
|
||||
"A single document was longer than the context length,"
|
||||
" we cannot handle this."
|
||||
)
|
||||
if len(_sub_result_docs) == 2:
|
||||
raise ValueError(
|
||||
"A single document was so long it could not be combined "
|
||||
"with another document, we cannot handle this."
|
||||
)
|
||||
new_result_doc_list.append(_sub_result_docs[:-1])
|
||||
_sub_result_docs = _sub_result_docs[-1:]
|
||||
new_result_doc_list.append(_sub_result_docs)
|
||||
return new_result_doc_list
|
||||
|
||||
|
||||
def _collapse_docs(
|
||||
docs: List[Document],
|
||||
combine_document_func: CombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
result = combine_document_func(docs, **kwargs)
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
if k in combined_metadata:
|
||||
combined_metadata[k] += f", {v}"
|
||||
else:
|
||||
combined_metadata[k] = str(v)
|
||||
return Document(page_content=result, metadata=combined_metadata)
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by mapping a chain over them, then combining results."""
|
||||
"""Combining documents by mapping a chain over them, then combining results.
|
||||
|
||||
We first call `llm_chain` on each document individually, passing in the
|
||||
`page_content` and any other kwargs. This is the `map` step.
|
||||
|
||||
We then process the results of that `map` step in a `reduce` step. This should
|
||||
likely be a ReduceDocumentsChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import (
|
||||
StuffDocumentsChain,
|
||||
LLMChain,
|
||||
ReduceDocumentsChain,
|
||||
MapReduceDocumentsChain,
|
||||
)
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
# We now define how to combine these summaries
|
||||
reduce_prompt = PromptTemplate.from_template(
|
||||
"Combine these summaries: {context}"
|
||||
)
|
||||
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
)
|
||||
chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
)
|
||||
# If we wanted to, we could also pass in collapse_documents_chain
|
||||
# which is specifically aimed at collapsing documents BEFORE
|
||||
# the final call.
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Collapse this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
collapse_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_documents_chain,
|
||||
)
|
||||
chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
)
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually."""
|
||||
combine_document_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine results of applying llm_chain to documents."""
|
||||
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
||||
"""Chain to use to collapse intermediary results if needed.
|
||||
If None, will use the combine_document_chain."""
|
||||
reduce_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to reduce the results of applying `llm_chain` to each doc.
|
||||
This typically either a ReduceDocumentChain or StuffDocumentChain."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
@ -93,6 +116,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_reduce_chain(cls, values: Dict) -> Dict:
|
||||
"""For backwards compatibility."""
|
||||
if "combine_document_chain" in values:
|
||||
if "reduce_documents_chain" in values:
|
||||
raise ValueError(
|
||||
"Both `reduce_documents_chain` and `combine_document_chain` "
|
||||
"cannot be provided at the same time. `combine_document_chain` "
|
||||
"is deprecated, please only provide `reduce_documents_chain`"
|
||||
)
|
||||
combine_chain = values["combine_document_chain"]
|
||||
collapse_chain = values.get("collapse_document_chain")
|
||||
reduce_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
)
|
||||
values["reduce_documents_chain"] = reduce_chain
|
||||
del values["combine_document_chain"]
|
||||
if "collapse_document_chain" in values:
|
||||
del values["collapse_document_chain"]
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
|
||||
"""For backwards compatibility."""
|
||||
@ -123,11 +169,31 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||
if self.collapse_document_chain is not None:
|
||||
return self.collapse_document_chain
|
||||
def collapse_document_chain(self) -> BaseCombineDocumentsChain:
|
||||
"""Kept for backward compatibility."""
|
||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||
if self.reduce_documents_chain.collapse_documents_chain:
|
||||
return self.reduce_documents_chain.collapse_documents_chain
|
||||
else:
|
||||
return self.reduce_documents_chain.combine_documents_chain
|
||||
else:
|
||||
return self.combine_document_chain
|
||||
raise ValueError(
|
||||
f"`reduce_documents_chain` is of type "
|
||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||
f"this attribute."
|
||||
)
|
||||
|
||||
@property
|
||||
def combine_document_chain(self) -> BaseCombineDocumentsChain:
|
||||
"""Kept for backward compatibility."""
|
||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||
return self.reduce_documents_chain.combine_documents_chain
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`reduce_documents_chain` is of type "
|
||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||
f"this attribute."
|
||||
)
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
@ -141,14 +207,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
This reducing can be done recursively if needed (if there are many documents).
|
||||
"""
|
||||
results = self.llm_chain.apply(
|
||||
map_results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
question_result_key = self.llm_chain.output_key
|
||||
result_docs = [
|
||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||
# This uses metadata from the docs, and the textual results from `results`
|
||||
for i, r in enumerate(map_results)
|
||||
]
|
||||
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||
result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
if self.return_intermediate_steps:
|
||||
intermediate_steps = [r[question_result_key] for r in map_results]
|
||||
extra_return_dict["intermediate_steps"] = intermediate_steps
|
||||
return result, extra_return_dict
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
@ -158,83 +234,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
This reducing can be done recursively if needed (if there are many documents).
|
||||
"""
|
||||
results = await self.llm_chain.aapply(
|
||||
map_results = await self.llm_chain.aapply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return await self._aprocess_results(
|
||||
results, docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _process_results_common(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
question_result_key = self.llm_chain.output_key
|
||||
result_docs = [
|
||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||
# This uses metadata from the docs, and the textual results from `results`
|
||||
for i, r in enumerate(results)
|
||||
for i, r in enumerate(map_results)
|
||||
]
|
||||
length_func = self.combine_document_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
|
||||
result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
if self.return_intermediate_steps:
|
||||
_results = [r[self.llm_chain.output_key] for r in results]
|
||||
extra_return_dict = {"intermediate_steps": _results}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
return result_docs, extra_return_dict
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
result_docs, extra_return_dict = self._process_results_common(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
output = self.combine_document_chain.run(
|
||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return output, extra_return_dict
|
||||
|
||||
async def _aprocess_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
result_docs, extra_return_dict = self._process_results_common(
|
||||
results, docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
output = await self.combine_document_chain.arun(
|
||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return output, extra_return_dict
|
||||
intermediate_steps = [r[question_result_key] for r in map_results]
|
||||
extra_return_dict["intermediate_steps"] = intermediate_steps
|
||||
return result, extra_return_dict
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
@ -14,7 +14,48 @@ from langchain.output_parsers.regex import RegexParser
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by mapping a chain over them, then reranking results."""
|
||||
"""Combining documents by mapping a chain over them, then reranking results.
|
||||
|
||||
This algorithm calls an LLMChain on each input document. The LLMChain is expected
|
||||
to have an OutputParser that parses the result into both an answer (`answer_key`)
|
||||
and a score (`rank_key`). The answer with the highest score is then returned.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
# The actual prompt will need to be a lot more complex, this is just
|
||||
# an example.
|
||||
prompt_template = (
|
||||
"Use the following context to tell me the chemical formula "
|
||||
"for water. Output both your answer and a score of how confident "
|
||||
"you are. Context: {content}"
|
||||
)
|
||||
output_parser = RegexParser(
|
||||
regex=r"(.*?)\nScore: (.*)",
|
||||
output_keys=["answer", "score"],
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["context"],
|
||||
output_parser=output_parser,
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain = MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
rank_key="score",
|
||||
answer_key="answer",
|
||||
)
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually."""
|
||||
@ -26,7 +67,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
answer_key: str
|
||||
"""Key in output of llm_chain to return as answer."""
|
||||
metadata_keys: Optional[List[str]] = None
|
||||
"""Additional metadata from the chosen document to return."""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return intermediate steps.
|
||||
Intermediate steps include the results of calling llm_chain on each document."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -96,6 +140,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
results = self.llm_chain.apply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
@ -110,6 +164,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
results = await self.llm_chain.aapply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
|
277
langchain/chains/combine_documents/reduce.py
Normal file
277
langchain/chains/combine_documents/reduce.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""Combine many documents together by recursively reducing them."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class CombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
|
||||
class AsyncCombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
"""Async nterface for the combine_docs method."""
|
||||
|
||||
|
||||
def _split_list_of_docs(
|
||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
) -> List[List[Document]]:
|
||||
new_result_doc_list = []
|
||||
_sub_result_docs = []
|
||||
for doc in docs:
|
||||
_sub_result_docs.append(doc)
|
||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||
if _num_tokens > token_max:
|
||||
if len(_sub_result_docs) == 1:
|
||||
raise ValueError(
|
||||
"A single document was longer than the context length,"
|
||||
" we cannot handle this."
|
||||
)
|
||||
new_result_doc_list.append(_sub_result_docs[:-1])
|
||||
_sub_result_docs = _sub_result_docs[-1:]
|
||||
new_result_doc_list.append(_sub_result_docs)
|
||||
return new_result_doc_list
|
||||
|
||||
|
||||
def _collapse_docs(
|
||||
docs: List[Document],
|
||||
combine_document_func: CombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
result = combine_document_func(docs, **kwargs)
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
if k in combined_metadata:
|
||||
combined_metadata[k] += f", {v}"
|
||||
else:
|
||||
combined_metadata[k] = str(v)
|
||||
return Document(page_content=result, metadata=combined_metadata)
|
||||
|
||||
|
||||
async def _acollapse_docs(
|
||||
docs: List[Document],
|
||||
combine_document_func: AsyncCombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
result = await combine_document_func(docs, **kwargs)
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
if k in combined_metadata:
|
||||
combined_metadata[k] += f", {v}"
|
||||
else:
|
||||
combined_metadata[k] = str(v)
|
||||
return Document(page_content=result, metadata=combined_metadata)
|
||||
|
||||
|
||||
class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by recursively reducing them.
|
||||
|
||||
This involves
|
||||
- combine_documents_chain
|
||||
- collapse_documents_chain
|
||||
|
||||
`combine_documents_chain` is ALWAYS provided. This is final chain that is called.
|
||||
We pass all previous results to this chain, and the output of this chain is
|
||||
returned as a final result.
|
||||
|
||||
`collapse_documents_chain` is used if the documents passed in are too many to all
|
||||
be passed to `combine_documents_chain` in one go. In this case,
|
||||
`collapse_documents_chain` is called recursively on as big of groups of documents
|
||||
as are allowed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import (
|
||||
StuffDocumentsChain, LLMChain, ReduceDocumentsChain
|
||||
)
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
)
|
||||
# If we wanted to, we could also pass in collapse_documents_chain
|
||||
# which is specifically aimed at collapsing documents BEFORE
|
||||
# the final call.
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Collapse this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
collapse_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_documents_chain,
|
||||
)
|
||||
"""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Final chain to call to combine documents.
|
||||
This is typically a StuffDocumentsChain."""
|
||||
collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
|
||||
"""Chain to use to collapse documents if needed until they can all fit.
|
||||
If None, will use the combine_documents_chain.
|
||||
This is typically a StuffDocumentsChain."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||
if self.collapse_documents_chain is not None:
|
||||
return self.collapse_documents_chain
|
||||
else:
|
||||
return self.combine_documents_chain
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine multiple documents recursively.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine, assumed that each one is less than
|
||||
`token_max`.
|
||||
token_max: Recursively creates groups of documents less than this number
|
||||
of tokens.
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
result_docs, extra_return_dict = self._collapse(
|
||||
docs, token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return self.combine_documents_chain.combine_docs(
|
||||
docs=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine multiple documents recursively.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine, assumed that each one is less than
|
||||
`token_max`.
|
||||
token_max: Recursively creates groups of documents less than this number
|
||||
of tokens.
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
result_docs, extra_return_dict = await self._acollapse(
|
||||
docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return await self.combine_documents_chain.acombine_docs(
|
||||
docs=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _collapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
result_docs = docs
|
||||
length_func = self.combine_documents_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
return result_docs, {}
|
||||
|
||||
async def _acollapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
result_docs = docs
|
||||
length_func = self.combine_documents_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return await self._collapse_chain.arun(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
return result_docs, {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "reduce_documents_chain"
|
@ -9,12 +9,11 @@ from pydantic import Extra, Field, root_validator
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate, format_document
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
@ -22,7 +21,55 @@ def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
||||
|
||||
class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combine documents by doing a first pass and then refining on more documents."""
|
||||
"""Combine documents by doing a first pass and then refining on more documents.
|
||||
|
||||
This algorithm first calls `initial_llm_chain` on the first document, passing
|
||||
that first document in with the variable name `document_variable_name`, and
|
||||
produces a new variable with the variable name `initial_response_name`.
|
||||
|
||||
Then, it loops over every remaining document. This is called the "refine" step.
|
||||
It calls `refine_llm_chain`,
|
||||
passing in that document with the variable name `document_variable_name`
|
||||
as well as the previous response with the variable name `initial_response_name`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import RefineDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
initial_response_name = "prev_response"
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name` as well as `initial_response_name`
|
||||
prompt_refine = PromptTemplate.from_template(
|
||||
"Here's your first summary: {prev_response}. "
|
||||
"Now add to it based on the following context: {context}"
|
||||
)
|
||||
llm_chain_refine = LLMChain(llm=llm, prompt=prompt_refine)
|
||||
chain = RefineDocumentsChain(
|
||||
initial_llm_chain=initial_llm_chain,
|
||||
refine_llm_chain=refine_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
)
|
||||
"""
|
||||
|
||||
initial_llm_chain: LLMChain
|
||||
"""LLM chain to use on initial document."""
|
||||
@ -36,7 +83,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return the results of the refine steps in the output."""
|
||||
|
||||
@ -89,7 +136,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
@ -103,7 +161,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine
|
||||
callbacks: Callbacks to be passed through
|
||||
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||
input variables besides the documents)
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
|
@ -7,12 +7,11 @@ from pydantic import Extra, Field, root_validator
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema import BasePromptTemplate, format_document
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
@ -20,14 +19,50 @@ def _get_default_document_prompt() -> PromptTemplate:
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
"""Chain that combines documents by stuffing into context.
|
||||
|
||||
This chain takes a list of documents and first combines them into a single string.
|
||||
It does this by formatting each document into a string with the `document_prompt`
|
||||
and then joining them together with `document_separator`. It then adds that new
|
||||
string to the inputs with the variable name set by `document_variable_name`.
|
||||
Those inputs are then passed to the `llm_chain`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
"""LLM chain which is called with the formatted document string,
|
||||
along with any other inputs."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
@ -42,7 +77,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
"""Get default document variable name, if not provided.
|
||||
|
||||
If only one variable is present in the llm_chain.prompt,
|
||||
we can infer that the formatted documents should be passed in
|
||||
with this variable name.
|
||||
"""
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if "document_variable_name" not in values:
|
||||
if len(llm_chain_variables) == 1:
|
||||
@ -61,6 +101,20 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
"""Construct inputs from kwargs and docs.
|
||||
|
||||
Format and the join all the documents together into one input with name
|
||||
`self.document_variable_name`. The pluck any additional variables
|
||||
from **kwargs.
|
||||
|
||||
Args:
|
||||
docs: List of documents to format and then join into single input
|
||||
**kwargs: additional inputs to chain, will pluck any other required
|
||||
arguments from here.
|
||||
|
||||
Returns:
|
||||
dictionary of inputs to LLMChain
|
||||
"""
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||
# Join the documents together to put them in the prompt.
|
||||
@ -73,7 +127,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return inputs
|
||||
|
||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||
"""Get the prompt length by formatting the prompt."""
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
This can be used by a caller to determine whether passing in a list
|
||||
of documents would exceed a certain prompt length. This useful when
|
||||
trying to ensure that the size of a prompt remains below a certain
|
||||
context limit.
|
||||
|
||||
Args:
|
||||
docs: List[Document], a list of documents to use to calculate the
|
||||
total prompt length.
|
||||
|
||||
Returns:
|
||||
Returns None if the method does not depend on the prompt length,
|
||||
otherwise the length of the prompt in tokens.
|
||||
"""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
@ -81,7 +149,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
"""Stuff all documents into one prompt and pass to LLM.
|
||||
|
||||
Args:
|
||||
docs: List of documents to join together into one variable
|
||||
callbacks: Optional callbacks to pass along
|
||||
**kwargs: additional parameters to use to get inputs to LLMChain.
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
@ -89,7 +167,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
"""Stuff all documents into one prompt and pass to LLM.
|
||||
|
||||
Args:
|
||||
docs: List of documents to join together into one variable
|
||||
callbacks: Optional callbacks to pass along
|
||||
**kwargs: additional parameters to use to get inputs to LLMChain.
|
||||
|
||||
Returns:
|
||||
The first element returned is the single string output. The second
|
||||
element returned is a dictionary of other keys to return.
|
||||
"""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
|
||||
|
@ -5,6 +5,7 @@ from typing import Any, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -117,9 +118,9 @@ def _load_map_reduce_documents_chain(
|
||||
|
||||
if "combine_document_chain" in config:
|
||||
combine_document_chain_config = config.pop("combine_document_chain")
|
||||
combine_document_chain = load_chain_from_config(combine_document_chain_config)
|
||||
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
|
||||
elif "combine_document_chain_path" in config:
|
||||
combine_document_chain = load_chain(config.pop("combine_document_chain_path"))
|
||||
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of `combine_document_chain` or "
|
||||
@ -128,17 +129,24 @@ def _load_map_reduce_documents_chain(
|
||||
if "collapse_document_chain" in config:
|
||||
collapse_document_chain_config = config.pop("collapse_document_chain")
|
||||
if collapse_document_chain_config is None:
|
||||
collapse_document_chain = None
|
||||
collapse_documents_chain = None
|
||||
else:
|
||||
collapse_document_chain = load_chain_from_config(
|
||||
collapse_documents_chain = load_chain_from_config(
|
||||
collapse_document_chain_config
|
||||
)
|
||||
elif "collapse_document_chain_path" in config:
|
||||
collapse_document_chain = load_chain(config.pop("collapse_document_chain_path"))
|
||||
collapse_documents_chain = load_chain(
|
||||
config.pop("collapse_document_chain_path")
|
||||
)
|
||||
else:
|
||||
collapse_documents_chain = None
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_documents_chain,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
collapse_document_chain=collapse_document_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
@ -11,6 +11,7 @@ from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -44,14 +45,17 @@ class MapReduceChain(Chain):
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||
reduce_chain = StuffDocumentsChain(
|
||||
stuff_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
callbacks=callbacks,
|
||||
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=stuff_chain
|
||||
)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
combine_document_chain=reduce_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
callbacks=callbacks,
|
||||
**(combine_chain_kwargs if combine_chain_kwargs else {}),
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -58,13 +59,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name="summaries",
|
||||
)
|
||||
combine_document_chain = MapReduceDocumentsChain(
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_results_chain
|
||||
)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_question_chain,
|
||||
combine_document_chain=combine_results_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name="context",
|
||||
)
|
||||
return cls(
|
||||
combine_documents_chain=combine_document_chain,
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -78,10 +82,10 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
) -> BaseQAWithSourcesChain:
|
||||
"""Load chain from chain type."""
|
||||
_chain_kwargs = chain_type_kwargs or {}
|
||||
combine_document_chain = load_qa_with_sources_chain(
|
||||
combine_documents_chain = load_qa_with_sources_chain(
|
||||
llm, chain_type=chain_type, **_chain_kwargs
|
||||
)
|
||||
return cls(combine_documents_chain=combine_document_chain, **kwargs)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -110,7 +114,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_naming(cls, values: Dict) -> Dict:
|
||||
"""Fix backwards compatability in naming."""
|
||||
"""Fix backwards compatibility in naming."""
|
||||
if "combine_document_chain" in values:
|
||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||
return values
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@ -83,7 +84,7 @@ def _load_map_reduce_chain(
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
@ -107,11 +108,14 @@ def _load_map_reduce_chain(
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Mapping, Optional, Protocol
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@ -122,7 +123,7 @@ def _load_map_reduce_chain(
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
@ -150,11 +151,14 @@ def _load_map_reduce_chain(
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
|
@ -2,6 +2,7 @@
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
@ -53,7 +54,7 @@ def _load_map_reduce_chain(
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
# TODO: document prompt
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
verbose=verbose,
|
||||
@ -75,11 +76,14 @@ def _load_map_reduce_chain(
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ from langchain.schema.output_parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.schema.prompt_template import BasePromptTemplate, format_document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
RUN_KEY = "__run"
|
||||
@ -66,4 +66,5 @@ __all__ = [
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BasePromptTemplate",
|
||||
"format_document",
|
||||
]
|
||||
|
@ -9,7 +9,9 @@ import yaml
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseOutputParser, PromptValue
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output_parser import BaseOutputParser
|
||||
from langchain.schema.prompt import PromptValue
|
||||
|
||||
|
||||
class BasePromptTemplate(Serializable, ABC):
|
||||
@ -137,3 +139,48 @@ class BasePromptTemplate(Serializable, ABC):
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
|
||||
1. `page_content`: this takes the information from the `document.page_content`
|
||||
and assigns it to a variable named `page_content`.
|
||||
2. metadata: this takes information from `document.metadata` and assigns
|
||||
it to variables of the same name.
|
||||
|
||||
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||
|
||||
Args:
|
||||
doc: Document, the page_content and metadata will be used to create
|
||||
the final string.
|
||||
prompt: BasePromptTemplate, will be used to format the page_content
|
||||
and metadata into the final string.
|
||||
|
||||
Returns:
|
||||
string of the document formatted.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.schema import Document
|
||||
from langchain.prompts import PromptTemplate
|
||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
||||
format_document(doc, prompt)
|
||||
>>> "Page 1: This is a joke"
|
||||
"""
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
@ -5,12 +5,12 @@ from typing import Any, List
|
||||
import pytest
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chains.combine_documents.base import format_document
|
||||
from langchain.chains.combine_documents.map_reduce import (
|
||||
from langchain.chains.combine_documents.reduce import (
|
||||
_collapse_docs,
|
||||
_split_list_of_docs,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import format_document
|
||||
|
||||
|
||||
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||
@ -28,13 +28,6 @@ def test__split_list_long_single_doc() -> None:
|
||||
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||
|
||||
|
||||
def test__split_list_long_pair_doc() -> None:
|
||||
"""Test splitting of a list with two medium docs."""
|
||||
docs = [Document(page_content="foo" * 30)] * 2
|
||||
with pytest.raises(ValueError):
|
||||
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||
|
||||
|
||||
def test__split_list_single_doc() -> None:
|
||||
"""Test splitting works with just a single doc."""
|
||||
docs = [Document(page_content="foo")]
|
||||
|
@ -86,8 +86,8 @@ def test_imports() -> None:
|
||||
from langchain.document_loaders import BSHTMLLoader # noqa: F401
|
||||
from langchain.embeddings import OpenAIEmbeddings # noqa: F401
|
||||
from langchain.llms import OpenAI # noqa: F401
|
||||
from langchain.prompts import BasePromptTemplate # noqa: F401
|
||||
from langchain.retrievers import VespaRetriever # noqa: F401
|
||||
from langchain.schema import BasePromptTemplate # noqa: F401
|
||||
from langchain.tools import DuckDuckGoSearchResults # noqa: F401
|
||||
from langchain.utilities import SerpAPIWrapper # noqa: F401
|
||||
from langchain.vectorstores import FAISS # noqa: F401
|
||||
|
Loading…
Reference in New Issue
Block a user