Docs combine document chain (#6994)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harrison Chase 2023-07-04 11:51:04 -07:00 committed by GitHub
parent 81eebc4070
commit 0ad984fa27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 820 additions and 205 deletions

View File

@ -4,6 +4,7 @@ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain
@ -111,4 +112,5 @@ __all__ = [
"MapRerankDocumentsChain",
"MapReduceDocumentsChain",
"RefineDocumentsChain",
"ReduceDocumentsChain",
]

View File

@ -11,30 +11,20 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.schema import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
"""Base interface for chains combining documents.
Subclasses of this chain deal with combining documents in a variety of
ways. This base class exists to add some uniformity in the interface these types
of chains should expose. Namely, they expect an input key related to the documents
to use (default `input_documents`), and then also expose a method to calculate
the length of a prompt from documents (useful for outside callers to use to
determine whether it's safe to pass a list of documents into this chain or whether
that will longer than the context length).
"""
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@ -58,25 +48,57 @@ class BaseCombineDocumentsChain(Chain, ABC):
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length.
This can be used by a caller to determine whether passing in a list
of documents would exceed a certain prompt length. This useful when
trying to ensure that the size of a prompt remains below a certain
context limit.
Args:
docs: List[Document], a list of documents to use to calculate the
total prompt length.
Returns:
Returns None if the method does not depend on the prompt length,
otherwise the length of the prompt in tokens.
"""
return None
@abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string."""
"""Combine documents into a single string.
Args:
docs: List[Document], the documents to combine
**kwargs: Other parameters to use in combining documents, often
other inputs to the prompt.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
@abstractmethod
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string asynchronously."""
"""Combine documents into a single string.
Args:
docs: List[Document], the documents to combine
**kwargs: Other parameters to use in combining documents, often
other inputs to the prompt.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
@ -92,6 +114,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
inputs: Dict[str, List[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
@ -104,7 +127,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
class AnalyzeDocumentChain(Chain):
"""Chain that splits documents, then analyzes it in pieces."""
"""Chain that splits documents, then analyzes it in pieces.
This chain is parameterized by a TextSplitter and a CombineDocumentsChain.
This chain takes a single document as input, and then splits it up into chunks
and then passes those chucks to the CombineDocumentsChain.
"""
input_key: str = "input_document" #: :meta private:
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
@ -131,6 +159,7 @@ class AnalyzeDocumentChain(Chain):
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Split document into chunks and pass to CombineDocumentsChain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document])

View File

@ -2,74 +2,97 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from typing import Any, Dict, List, Tuple
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
if len(_sub_result_docs) == 2:
raise ValueError(
"A single document was so long it could not be combined "
"with another document, we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
def _collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then combining results."""
"""Combining documents by mapping a chain over them, then combining results.
We first call `llm_chain` on each document individually, passing in the
`page_content` and any other kwargs. This is the `map` step.
We then process the results of that `map` step in a `reduce` step. This should
likely be a ReduceDocumentsChain.
Example:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain,
LLMChain,
ReduceDocumentsChain,
MapReduceDocumentsChain,
)
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)
# If we wanted to, we could also pass in collapse_documents_chain
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template(
"Collapse this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)
"""
llm_chain: LLMChain
"""Chain to apply to each document individually."""
combine_document_chain: BaseCombineDocumentsChain
"""Chain to use to combine results of applying llm_chain to documents."""
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
"""Chain to use to collapse intermediary results if needed.
If None, will use the combine_document_chain."""
reduce_documents_chain: BaseCombineDocumentsChain
"""Chain to use to reduce the results of applying `llm_chain` to each doc.
This typically either a ReduceDocumentChain or StuffDocumentChain."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
@ -93,6 +116,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_reduce_chain(cls, values: Dict) -> Dict:
"""For backwards compatibility."""
if "combine_document_chain" in values:
if "reduce_documents_chain" in values:
raise ValueError(
"Both `reduce_documents_chain` and `combine_document_chain` "
"cannot be provided at the same time. `combine_document_chain` "
"is deprecated, please only provide `reduce_documents_chain`"
)
combine_chain = values["combine_document_chain"]
collapse_chain = values.get("collapse_document_chain")
reduce_chain = ReduceDocumentsChain(
combine_documents_chain=combine_chain,
collapse_documents_chain=collapse_chain,
)
values["reduce_documents_chain"] = reduce_chain
del values["combine_document_chain"]
if "collapse_document_chain" in values:
del values["collapse_document_chain"]
return values
@root_validator(pre=True)
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
"""For backwards compatibility."""
@ -123,11 +169,31 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return values
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_document_chain is not None:
return self.collapse_document_chain
def collapse_document_chain(self) -> BaseCombineDocumentsChain:
"""Kept for backward compatibility."""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
if self.reduce_documents_chain.collapse_documents_chain:
return self.reduce_documents_chain.collapse_documents_chain
else:
return self.reduce_documents_chain.combine_documents_chain
else:
return self.combine_document_chain
raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
@property
def combine_document_chain(self) -> BaseCombineDocumentsChain:
"""Kept for backward compatibility."""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
return self.reduce_documents_chain.combine_documents_chain
else:
raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
def combine_docs(
self,
@ -141,14 +207,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = self.llm_chain.apply(
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(
results, docs, token_max, callbacks=callbacks, **kwargs
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(map_results)
]
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, callbacks=callbacks, **kwargs
)
if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict["intermediate_steps"] = intermediate_steps
return result, extra_return_dict
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
@ -158,83 +234,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = await self.llm_chain.aapply(
map_results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return await self._aprocess_results(
results, docs, callbacks=callbacks, **kwargs
)
def _process_results_common(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(results)
for i, r in enumerate(map_results)
]
length_func = self.combine_document_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
result_docs, callbacks=callbacks, **kwargs
)
if self.return_intermediate_steps:
_results = [r[self.llm_chain.output_key] for r in results]
extra_return_dict = {"intermediate_steps": _results}
else:
extra_return_dict = {}
return result_docs, extra_return_dict
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, token_max, callbacks=callbacks, **kwargs
)
output = self.combine_document_chain.run(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
async def _aprocess_results(
self,
results: List[Dict],
docs: List[Document],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, callbacks=callbacks, **kwargs
)
output = await self.combine_document_chain.arun(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict["intermediate_steps"] = intermediate_steps
return result, extra_return_dict
@property
def _chain_type(self) -> str:

View File

@ -14,7 +14,48 @@ from langchain.output_parsers.regex import RegexParser
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then reranking results."""
"""Combining documents by mapping a chain over them, then reranking results.
This algorithm calls an LLMChain on each input document. The LLMChain is expected
to have an OutputParser that parses the result into both an answer (`answer_key`)
and a score (`rank_key`). The answer with the highest score is then returned.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.output_parsers.regex import RegexParser
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
# The actual prompt will need to be a lot more complex, this is just
# an example.
prompt_template = (
"Use the following context to tell me the chemical formula "
"for water. Output both your answer and a score of how confident "
"you are. Context: {content}"
)
output_parser = RegexParser(
regex=r"(.*?)\nScore: (.*)",
output_keys=["answer", "score"],
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context"],
output_parser=output_parser,
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = MapRerankDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
rank_key="score",
answer_key="answer",
)
"""
llm_chain: LLMChain
"""Chain to apply to each document individually."""
@ -26,7 +67,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
answer_key: str
"""Key in output of llm_chain to return as answer."""
metadata_keys: Optional[List[str]] = None
"""Additional metadata from the chosen document to return."""
return_intermediate_steps: bool = False
"""Return intermediate steps.
Intermediate steps include the results of calling llm_chain on each document."""
class Config:
"""Configuration for this pydantic object."""
@ -96,6 +140,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast.
@ -110,6 +164,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast.

View File

@ -0,0 +1,277 @@
"""Combine many documents together by recursively reducing them."""
from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple
from pydantic import Extra
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Async nterface for the combine_docs method."""
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
def _collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
async def _acollapse_docs(
docs: List[Document],
combine_document_func: AsyncCombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = await combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by recursively reducing them.
This involves
- combine_documents_chain
- collapse_documents_chain
`combine_documents_chain` is ALWAYS provided. This is final chain that is called.
We pass all previous results to this chain, and the output of this chain is
returned as a final result.
`collapse_documents_chain` is used if the documents passed in are too many to all
be passed to `combine_documents_chain` in one go. In this case,
`collapse_documents_chain` is called recursively on as big of groups of documents
as are allowed.
Example:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain, LLMChain, ReduceDocumentsChain
)
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
# If we wanted to, we could also pass in collapse_documents_chain
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template(
"Collapse this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
"""
combine_documents_chain: BaseCombineDocumentsChain
"""Final chain to call to combine documents.
This is typically a StuffDocumentsChain."""
collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
"""Chain to use to collapse documents if needed until they can all fit.
If None, will use the combine_documents_chain.
This is typically a StuffDocumentsChain."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_documents_chain is not None:
return self.collapse_documents_chain
else:
return self.combine_documents_chain
def combine_docs(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""Combine multiple documents recursively.
Args:
docs: List of documents to combine, assumed that each one is less than
`token_max`.
token_max: Recursively creates groups of documents less than this number
of tokens.
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
result_docs, extra_return_dict = self._collapse(
docs, token_max, callbacks=callbacks, **kwargs
)
return self.combine_documents_chain.combine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine multiple documents recursively.
Args:
docs: List of documents to combine, assumed that each one is less than
`token_max`.
token_max: Recursively creates groups of documents less than this number
of tokens.
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
result_docs, extra_return_dict = await self._acollapse(
docs, callbacks=callbacks, **kwargs
)
return await self.combine_documents_chain.acombine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
def _collapse(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
return result_docs, {}
async def _acollapse(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return await self._collapse_chain.arun(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
return result_docs, {}
@property
def _chain_type(self) -> str:
return "reduce_documents_chain"

View File

@ -9,12 +9,11 @@ from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema import BasePromptTemplate, format_document
def _get_default_document_prompt() -> PromptTemplate:
@ -22,7 +21,55 @@ def _get_default_document_prompt() -> PromptTemplate:
class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents by doing a first pass and then refining on more documents."""
"""Combine documents by doing a first pass and then refining on more documents.
This algorithm first calls `initial_llm_chain` on the first document, passing
that first document in with the variable name `document_variable_name`, and
produces a new variable with the variable name `initial_response_name`.
Then, it loops over every remaining document. This is called the "refine" step.
It calls `refine_llm_chain`,
passing in that document with the variable name `document_variable_name`
as well as the previous response with the variable name `initial_response_name`.
Example:
.. code-block:: python
from langchain.chains import RefineDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
initial_response_name = "prev_response"
# The prompt here should take as an input variable the
# `document_variable_name` as well as `initial_response_name`
prompt_refine = PromptTemplate.from_template(
"Here's your first summary: {prev_response}. "
"Now add to it based on the following context: {context}"
)
llm_chain_refine = LLMChain(llm=llm, prompt=prompt_refine)
chain = RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
)
"""
initial_llm_chain: LLMChain
"""LLM chain to use on initial document."""
@ -36,7 +83,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
"""Prompt to use to format each document, gets passed to `format_document`."""
return_intermediate_steps: bool = False
"""Return the results of the refine steps in the output."""
@ -89,7 +136,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
"""Combine by mapping first chain over all, then stuffing into final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps = [res]
@ -103,7 +161,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
"""Combine by mapping first chain over all, then stuffing into final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps = [res]

View File

@ -7,12 +7,11 @@ from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema import BasePromptTemplate, format_document
def _get_default_document_prompt() -> PromptTemplate:
@ -20,14 +19,50 @@ def _get_default_document_prompt() -> PromptTemplate:
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context."""
"""Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
llm_chain: LLMChain
"""LLM wrapper to use after formatting documents."""
"""LLM chain which is called with the formatted document string,
along with any other inputs."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
"""Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
@ -42,7 +77,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
"""Get default document variable name, if not provided.
If only one variable is present in the llm_chain.prompt,
we can infer that the formatted documents should be passed in
with this variable name.
"""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
@ -61,6 +101,20 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
@ -73,7 +127,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Get the prompt length by formatting the prompt."""
"""Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list
of documents would exceed a certain prompt length. This useful when
trying to ensure that the size of a prompt remains below a certain
context limit.
Args:
docs: List[Document], a list of documents to use to calculate the
total prompt length.
Returns:
Returns None if the method does not depend on the prompt length,
otherwise the length of the prompt in tokens.
"""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
@ -81,7 +149,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@ -89,7 +167,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}

View File

@ -5,6 +5,7 @@ from typing import Any, Union
import yaml
from langchain.chains import ReduceDocumentsChain
from langchain.chains.api.base import APIChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -117,9 +118,9 @@ def _load_map_reduce_documents_chain(
if "combine_document_chain" in config:
combine_document_chain_config = config.pop("combine_document_chain")
combine_document_chain = load_chain_from_config(combine_document_chain_config)
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
elif "combine_document_chain_path" in config:
combine_document_chain = load_chain(config.pop("combine_document_chain_path"))
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
else:
raise ValueError(
"One of `combine_document_chain` or "
@ -128,17 +129,24 @@ def _load_map_reduce_documents_chain(
if "collapse_document_chain" in config:
collapse_document_chain_config = config.pop("collapse_document_chain")
if collapse_document_chain_config is None:
collapse_document_chain = None
collapse_documents_chain = None
else:
collapse_document_chain = load_chain_from_config(
collapse_documents_chain = load_chain_from_config(
collapse_document_chain_config
)
elif "collapse_document_chain_path" in config:
collapse_document_chain = load_chain(config.pop("collapse_document_chain_path"))
collapse_documents_chain = load_chain(
config.pop("collapse_document_chain_path")
)
else:
collapse_documents_chain = None
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
return MapReduceDocumentsChain(
llm_chain=llm_chain,
combine_document_chain=combine_document_chain,
collapse_document_chain=collapse_document_chain,
reduce_documents_chain=reduce_documents_chain,
**config,
)

View File

@ -11,6 +11,7 @@ from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -44,14 +45,17 @@ class MapReduceChain(Chain):
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(
stuff_chain = StuffDocumentsChain(
llm_chain=llm_chain,
callbacks=callbacks,
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=stuff_chain
)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
combine_document_chain=reduce_chain,
reduce_documents_chain=reduce_documents_chain,
callbacks=callbacks,
**(combine_chain_kwargs if combine_chain_kwargs else {}),
)

View File

@ -14,6 +14,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -58,13 +59,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
document_prompt=document_prompt,
document_variable_name="summaries",
)
combine_document_chain = MapReduceDocumentsChain(
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_results_chain
)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_question_chain,
combine_document_chain=combine_results_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="context",
)
return cls(
combine_documents_chain=combine_document_chain,
combine_documents_chain=combine_documents_chain,
**kwargs,
)
@ -78,10 +82,10 @@ class BaseQAWithSourcesChain(Chain, ABC):
) -> BaseQAWithSourcesChain:
"""Load chain from chain type."""
_chain_kwargs = chain_type_kwargs or {}
combine_document_chain = load_qa_with_sources_chain(
combine_documents_chain = load_qa_with_sources_chain(
llm, chain_type=chain_type, **_chain_kwargs
)
return cls(combine_documents_chain=combine_document_chain, **kwargs)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
class Config:
"""Configuration for this pydantic object."""
@ -110,7 +114,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
@root_validator(pre=True)
def validate_naming(cls, values: Dict) -> Dict:
"""Fix backwards compatability in naming."""
"""Fix backwards compatibility in naming."""
if "combine_document_chain" in values:
values["combine_documents_chain"] = values.pop("combine_document_chain")
return values

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@ -83,7 +84,7 @@ def _load_map_reduce_chain(
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
combine_document_chain = StuffDocumentsChain(
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
@ -107,11 +108,14 @@ def _load_map_reduce_chain(
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
verbose=verbose,
**kwargs,
)

View File

@ -4,6 +4,7 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@ -122,7 +123,7 @@ def _load_map_reduce_chain(
callbacks=callbacks,
)
# TODO: document prompt
combine_document_chain = StuffDocumentsChain(
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose,
@ -150,11 +151,14 @@ def _load_map_reduce_chain(
verbose=verbose,
callback_manager=callback_manager,
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
reduce_documents_chain=reduce_documents_chain,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,

View File

@ -2,6 +2,7 @@
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
@ -53,7 +54,7 @@ def _load_map_reduce_chain(
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
# TODO: document prompt
combine_document_chain = StuffDocumentsChain(
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose,
@ -75,11 +76,14 @@ def _load_map_reduce_chain(
),
document_variable_name=combine_document_variable_name,
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
verbose=verbose,
**kwargs,
)

View File

@ -28,7 +28,7 @@ from langchain.schema.output_parser import (
OutputParserException,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate, format_document
from langchain.schema.retriever import BaseRetriever
RUN_KEY = "__run"
@ -66,4 +66,5 @@ __all__ = [
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"format_document",
]

View File

@ -9,7 +9,9 @@ import yaml
from pydantic import Field, root_validator
from langchain.load.serializable import Serializable
from langchain.schema import BaseOutputParser, PromptValue
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
class BasePromptTemplate(Serializable, ABC):
@ -137,3 +139,48 @@ class BasePromptTemplate(Serializable, ABC):
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`: this takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata: this takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain.schema import Document
from langchain.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

View File

@ -5,12 +5,12 @@ from typing import Any, List
import pytest
from langchain import PromptTemplate
from langchain.chains.combine_documents.base import format_document
from langchain.chains.combine_documents.map_reduce import (
from langchain.chains.combine_documents.reduce import (
_collapse_docs,
_split_list_of_docs,
)
from langchain.docstore.document import Document
from langchain.schema import format_document
def _fake_docs_len_func(docs: List[Document]) -> int:
@ -28,13 +28,6 @@ def test__split_list_long_single_doc() -> None:
_split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_long_pair_doc() -> None:
"""Test splitting of a list with two medium docs."""
docs = [Document(page_content="foo" * 30)] * 2
with pytest.raises(ValueError):
_split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_single_doc() -> None:
"""Test splitting works with just a single doc."""
docs = [Document(page_content="foo")]

View File

@ -86,8 +86,8 @@ def test_imports() -> None:
from langchain.document_loaders import BSHTMLLoader # noqa: F401
from langchain.embeddings import OpenAIEmbeddings # noqa: F401
from langchain.llms import OpenAI # noqa: F401
from langchain.prompts import BasePromptTemplate # noqa: F401
from langchain.retrievers import VespaRetriever # noqa: F401
from langchain.schema import BasePromptTemplate # noqa: F401
from langchain.tools import DuckDuckGoSearchResults # noqa: F401
from langchain.utilities import SerpAPIWrapper # noqa: F401
from langchain.vectorstores import FAISS # noqa: F401