diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index e5f8b7b23b..90e0b9225a 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -4,6 +4,7 @@ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain from langchain.chains.combine_documents.base import AnalyzeDocumentChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain +from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.constitutional_ai.base import ConstitutionalChain @@ -111,4 +112,5 @@ __all__ = [ "MapRerankDocumentsChain", "MapReduceDocumentsChain", "RefineDocumentsChain", + "ReduceDocumentsChain", ] diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 92e7838ff6..0693aac44a 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -11,30 +11,20 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.docstore.document import Document -from langchain.schema import BasePromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter -def format_document(doc: Document, prompt: BasePromptTemplate) -> str: - """Format a document into a string based on a prompt template.""" - base_info = {"page_content": doc.page_content} - base_info.update(doc.metadata) - missing_metadata = set(prompt.input_variables).difference(base_info) - if len(missing_metadata) > 0: - required_metadata = [ - iv for iv in prompt.input_variables if iv != "page_content" - ] - raise ValueError( - f"Document prompt requires documents to have metadata variables: " - f"{required_metadata}. Received document with missing metadata: " - f"{list(missing_metadata)}." - ) - document_info = {k: base_info[k] for k in prompt.input_variables} - return prompt.format(**document_info) - - class BaseCombineDocumentsChain(Chain, ABC): - """Base interface for chains combining documents.""" + """Base interface for chains combining documents. + + Subclasses of this chain deal with combining documents in a variety of + ways. This base class exists to add some uniformity in the interface these types + of chains should expose. Namely, they expect an input key related to the documents + to use (default `input_documents`), and then also expose a method to calculate + the length of a prompt from documents (useful for outside callers to use to + determine whether it's safe to pass a list of documents into this chain or whether + that will longer than the context length). + """ input_key: str = "input_documents" #: :meta private: output_key: str = "output_text" #: :meta private: @@ -58,25 +48,57 @@ class BaseCombineDocumentsChain(Chain, ABC): def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: """Return the prompt length given the documents passed in. - Returns None if the method does not depend on the prompt length. + This can be used by a caller to determine whether passing in a list + of documents would exceed a certain prompt length. This useful when + trying to ensure that the size of a prompt remains below a certain + context limit. + + Args: + docs: List[Document], a list of documents to use to calculate the + total prompt length. + + Returns: + Returns None if the method does not depend on the prompt length, + otherwise the length of the prompt in tokens. """ return None @abstractmethod def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: - """Combine documents into a single string.""" + """Combine documents into a single string. + + Args: + docs: List[Document], the documents to combine + **kwargs: Other parameters to use in combining documents, often + other inputs to the prompt. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ @abstractmethod async def acombine_docs( self, docs: List[Document], **kwargs: Any ) -> Tuple[str, dict]: - """Combine documents into a single string asynchronously.""" + """Combine documents into a single string. + + Args: + docs: List[Document], the documents to combine + **kwargs: Other parameters to use in combining documents, often + other inputs to the prompt. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ def _call( self, inputs: Dict[str, List[Document]], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: + """Prepare inputs, call combine docs, prepare outputs.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction @@ -92,6 +114,7 @@ class BaseCombineDocumentsChain(Chain, ABC): inputs: Dict[str, List[Document]], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, str]: + """Prepare inputs, call combine docs, prepare outputs.""" _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction @@ -104,7 +127,12 @@ class BaseCombineDocumentsChain(Chain, ABC): class AnalyzeDocumentChain(Chain): - """Chain that splits documents, then analyzes it in pieces.""" + """Chain that splits documents, then analyzes it in pieces. + + This chain is parameterized by a TextSplitter and a CombineDocumentsChain. + This chain takes a single document as input, and then splits it up into chunks + and then passes those chucks to the CombineDocumentsChain. + """ input_key: str = "input_document" #: :meta private: text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter) @@ -131,6 +159,7 @@ class AnalyzeDocumentChain(Chain): inputs: Dict[str, str], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: + """Split document into chunks and pass to CombineDocumentsChain.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() document = inputs[self.input_key] docs = self.text_splitter.create_documents([document]) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 84e49296a5..6fc30415eb 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,74 +2,97 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Dict, List, Tuple from pydantic import Extra, root_validator from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -class CombineDocsProtocol(Protocol): - """Interface for the combine_docs method.""" +class MapReduceDocumentsChain(BaseCombineDocumentsChain): + """Combining documents by mapping a chain over them, then combining results. - def __call__(self, docs: List[Document], **kwargs: Any) -> str: - """Interface for the combine_docs method.""" + We first call `llm_chain` on each document individually, passing in the + `page_content` and any other kwargs. This is the `map` step. + We then process the results of that `map` step in a `reduce` step. This should + likely be a ReduceDocumentsChain. -def _split_list_of_docs( - docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any -) -> List[List[Document]]: - new_result_doc_list = [] - _sub_result_docs = [] - for doc in docs: - _sub_result_docs.append(doc) - _num_tokens = length_func(_sub_result_docs, **kwargs) - if _num_tokens > token_max: - if len(_sub_result_docs) == 1: - raise ValueError( - "A single document was longer than the context length," - " we cannot handle this." - ) - if len(_sub_result_docs) == 2: - raise ValueError( - "A single document was so long it could not be combined " - "with another document, we cannot handle this." - ) - new_result_doc_list.append(_sub_result_docs[:-1]) - _sub_result_docs = _sub_result_docs[-1:] - new_result_doc_list.append(_sub_result_docs) - return new_result_doc_list - - -def _collapse_docs( - docs: List[Document], - combine_document_func: CombineDocsProtocol, - **kwargs: Any, -) -> Document: - result = combine_document_func(docs, **kwargs) - combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} - for doc in docs[1:]: - for k, v in doc.metadata.items(): - if k in combined_metadata: - combined_metadata[k] += f", {v}" - else: - combined_metadata[k] = str(v) - return Document(page_content=result, metadata=combined_metadata) + Example: + .. code-block:: python + from langchain.chains import ( + StuffDocumentsChain, + LLMChain, + ReduceDocumentsChain, + MapReduceDocumentsChain, + ) + from langchain.prompts import PromptTemplate + from langchain.llms import OpenAI -class MapReduceDocumentsChain(BaseCombineDocumentsChain): - """Combining documents by mapping a chain over them, then combining results.""" + # This controls how each document will be formatted. Specifically, + # it will be passed to `format_document` - see that function for more + # details. + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + document_variable_name = "context" + llm = OpenAI() + # The prompt here should take as an input variable the + # `document_variable_name` + prompt = PromptTemplate.from_template( + "Summarize this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + # We now define how to combine these summaries + reduce_prompt = PromptTemplate.from_template( + "Combine these summaries: {context}" + ) + reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt) + combine_documents_chain = StuffDocumentsChain( + llm_chain=reduce_llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + ) + chain = MapReduceDocumentsChain( + llm_chain=llm_chain, + reduce_documents_chain=reduce_documents_chain, + ) + # If we wanted to, we could also pass in collapse_documents_chain + # which is specifically aimed at collapsing documents BEFORE + # the final call. + prompt = PromptTemplate.from_template( + "Collapse this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + collapse_documents_chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_documents_chain, + ) + chain = MapReduceDocumentsChain( + llm_chain=llm_chain, + reduce_documents_chain=reduce_documents_chain, + ) + """ llm_chain: LLMChain """Chain to apply to each document individually.""" - combine_document_chain: BaseCombineDocumentsChain - """Chain to use to combine results of applying llm_chain to documents.""" - collapse_document_chain: Optional[BaseCombineDocumentsChain] = None - """Chain to use to collapse intermediary results if needed. - If None, will use the combine_document_chain.""" + reduce_documents_chain: BaseCombineDocumentsChain + """Chain to use to reduce the results of applying `llm_chain` to each doc. + This typically either a ReduceDocumentChain or StuffDocumentChain.""" document_variable_name: str """The variable name in the llm_chain to put the documents in. If only one variable in the llm_chain, this need not be provided.""" @@ -93,6 +116,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def get_reduce_chain(cls, values: Dict) -> Dict: + """For backwards compatibility.""" + if "combine_document_chain" in values: + if "reduce_documents_chain" in values: + raise ValueError( + "Both `reduce_documents_chain` and `combine_document_chain` " + "cannot be provided at the same time. `combine_document_chain` " + "is deprecated, please only provide `reduce_documents_chain`" + ) + combine_chain = values["combine_document_chain"] + collapse_chain = values.get("collapse_document_chain") + reduce_chain = ReduceDocumentsChain( + combine_documents_chain=combine_chain, + collapse_documents_chain=collapse_chain, + ) + values["reduce_documents_chain"] = reduce_chain + del values["combine_document_chain"] + if "collapse_document_chain" in values: + del values["collapse_document_chain"] + + return values + @root_validator(pre=True) def get_return_intermediate_steps(cls, values: Dict) -> Dict: """For backwards compatibility.""" @@ -123,11 +169,31 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return values @property - def _collapse_chain(self) -> BaseCombineDocumentsChain: - if self.collapse_document_chain is not None: - return self.collapse_document_chain + def collapse_document_chain(self) -> BaseCombineDocumentsChain: + """Kept for backward compatibility.""" + if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): + if self.reduce_documents_chain.collapse_documents_chain: + return self.reduce_documents_chain.collapse_documents_chain + else: + return self.reduce_documents_chain.combine_documents_chain + else: + raise ValueError( + f"`reduce_documents_chain` is of type " + f"{type(self.reduce_documents_chain)} so it does not have " + f"this attribute." + ) + + @property + def combine_document_chain(self) -> BaseCombineDocumentsChain: + """Kept for backward compatibility.""" + if isinstance(self.reduce_documents_chain, ReduceDocumentsChain): + return self.reduce_documents_chain.combine_documents_chain else: - return self.combine_document_chain + raise ValueError( + f"`reduce_documents_chain` is of type " + f"{type(self.reduce_documents_chain)} so it does not have " + f"this attribute." + ) def combine_docs( self, @@ -141,14 +207,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): Combine by mapping first chain over all documents, then reducing the results. This reducing can be done recursively if needed (if there are many documents). """ - results = self.llm_chain.apply( + map_results = self.llm_chain.apply( # FYI - this is parallelized and so it is fast. [{self.document_variable_name: d.page_content, **kwargs} for d in docs], callbacks=callbacks, ) - return self._process_results( - results, docs, token_max, callbacks=callbacks, **kwargs + question_result_key = self.llm_chain.output_key + result_docs = [ + Document(page_content=r[question_result_key], metadata=docs[i].metadata) + # This uses metadata from the docs, and the textual results from `results` + for i, r in enumerate(map_results) + ] + result, extra_return_dict = self.reduce_documents_chain.combine_docs( + result_docs, callbacks=callbacks, **kwargs ) + if self.return_intermediate_steps: + intermediate_steps = [r[question_result_key] for r in map_results] + extra_return_dict["intermediate_steps"] = intermediate_steps + return result, extra_return_dict async def acombine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any @@ -158,83 +234,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): Combine by mapping first chain over all documents, then reducing the results. This reducing can be done recursively if needed (if there are many documents). """ - results = await self.llm_chain.aapply( + map_results = await self.llm_chain.aapply( # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], callbacks=callbacks, ) - return await self._aprocess_results( - results, docs, callbacks=callbacks, **kwargs - ) - - def _process_results_common( - self, - results: List[Dict], - docs: List[Document], - token_max: int = 3000, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Tuple[List[Document], dict]: question_result_key = self.llm_chain.output_key result_docs = [ Document(page_content=r[question_result_key], metadata=docs[i].metadata) # This uses metadata from the docs, and the textual results from `results` - for i, r in enumerate(results) + for i, r in enumerate(map_results) ] - length_func = self.combine_document_chain.prompt_length - num_tokens = length_func(result_docs, **kwargs) - - def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: - return self._collapse_chain.run( - input_documents=docs, callbacks=callbacks, **kwargs - ) - - while num_tokens is not None and num_tokens > token_max: - new_result_doc_list = _split_list_of_docs( - result_docs, length_func, token_max, **kwargs - ) - result_docs = [] - for docs in new_result_doc_list: - new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs) - result_docs.append(new_doc) - num_tokens = length_func(result_docs, **kwargs) - if self.return_intermediate_steps: - _results = [r[self.llm_chain.output_key] for r in results] - extra_return_dict = {"intermediate_steps": _results} - else: - extra_return_dict = {} - return result_docs, extra_return_dict - - def _process_results( - self, - results: List[Dict], - docs: List[Document], - token_max: int = 3000, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Tuple[str, dict]: - result_docs, extra_return_dict = self._process_results_common( - results, docs, token_max, callbacks=callbacks, **kwargs + result, extra_return_dict = await self.reduce_documents_chain.acombine_docs( + result_docs, callbacks=callbacks, **kwargs ) - output = self.combine_document_chain.run( - input_documents=result_docs, callbacks=callbacks, **kwargs - ) - return output, extra_return_dict - - async def _aprocess_results( - self, - results: List[Dict], - docs: List[Document], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Tuple[str, dict]: - result_docs, extra_return_dict = self._process_results_common( - results, docs, callbacks=callbacks, **kwargs - ) - output = await self.combine_document_chain.arun( - input_documents=result_docs, callbacks=callbacks, **kwargs - ) - return output, extra_return_dict + if self.return_intermediate_steps: + intermediate_steps = [r[question_result_key] for r in map_results] + extra_return_dict["intermediate_steps"] = intermediate_steps + return result, extra_return_dict @property def _chain_type(self) -> str: diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index ad8409c343..e2d656d07e 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -14,7 +14,48 @@ from langchain.output_parsers.regex import RegexParser class MapRerankDocumentsChain(BaseCombineDocumentsChain): - """Combining documents by mapping a chain over them, then reranking results.""" + """Combining documents by mapping a chain over them, then reranking results. + + This algorithm calls an LLMChain on each input document. The LLMChain is expected + to have an OutputParser that parses the result into both an answer (`answer_key`) + and a score (`rank_key`). The answer with the highest score is then returned. + + Example: + .. code-block:: python + + from langchain.chains import StuffDocumentsChain, LLMChain + from langchain.prompts import PromptTemplate + from langchain.llms import OpenAI + from langchain.output_parsers.regex import RegexParser + + document_variable_name = "context" + llm = OpenAI() + # The prompt here should take as an input variable the + # `document_variable_name` + # The actual prompt will need to be a lot more complex, this is just + # an example. + prompt_template = ( + "Use the following context to tell me the chemical formula " + "for water. Output both your answer and a score of how confident " + "you are. Context: {content}" + ) + output_parser = RegexParser( + regex=r"(.*?)\nScore: (.*)", + output_keys=["answer", "score"], + ) + prompt = PromptTemplate( + template=prompt_template, + input_variables=["context"], + output_parser=output_parser, + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + chain = MapRerankDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + rank_key="score", + answer_key="answer", + ) + """ llm_chain: LLMChain """Chain to apply to each document individually.""" @@ -26,7 +67,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): answer_key: str """Key in output of llm_chain to return as answer.""" metadata_keys: Optional[List[str]] = None + """Additional metadata from the chosen document to return.""" return_intermediate_steps: bool = False + """Return intermediate steps. + Intermediate steps include the results of calling llm_chain on each document.""" class Config: """Configuration for this pydantic object.""" @@ -96,6 +140,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): """Combine documents in a map rerank manner. Combine by mapping first chain over all documents, then reranking the results. + + Args: + docs: List of documents to combine + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. """ results = self.llm_chain.apply_and_parse( # FYI - this is parallelized and so it is fast. @@ -110,6 +164,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): """Combine documents in a map rerank manner. Combine by mapping first chain over all documents, then reranking the results. + + Args: + docs: List of documents to combine + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. """ results = await self.llm_chain.aapply_and_parse( # FYI - this is parallelized and so it is fast. diff --git a/langchain/chains/combine_documents/reduce.py b/langchain/chains/combine_documents/reduce.py new file mode 100644 index 0000000000..e4f1e7719f --- /dev/null +++ b/langchain/chains/combine_documents/reduce.py @@ -0,0 +1,277 @@ +"""Combine many documents together by recursively reducing them.""" + +from __future__ import annotations + +from typing import Any, Callable, List, Optional, Protocol, Tuple + +from pydantic import Extra + +from langchain.callbacks.manager import Callbacks +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.docstore.document import Document + + +class CombineDocsProtocol(Protocol): + """Interface for the combine_docs method.""" + + def __call__(self, docs: List[Document], **kwargs: Any) -> str: + """Interface for the combine_docs method.""" + + +class AsyncCombineDocsProtocol(Protocol): + """Interface for the combine_docs method.""" + + async def __call__(self, docs: List[Document], **kwargs: Any) -> str: + """Async nterface for the combine_docs method.""" + + +def _split_list_of_docs( + docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any +) -> List[List[Document]]: + new_result_doc_list = [] + _sub_result_docs = [] + for doc in docs: + _sub_result_docs.append(doc) + _num_tokens = length_func(_sub_result_docs, **kwargs) + if _num_tokens > token_max: + if len(_sub_result_docs) == 1: + raise ValueError( + "A single document was longer than the context length," + " we cannot handle this." + ) + new_result_doc_list.append(_sub_result_docs[:-1]) + _sub_result_docs = _sub_result_docs[-1:] + new_result_doc_list.append(_sub_result_docs) + return new_result_doc_list + + +def _collapse_docs( + docs: List[Document], + combine_document_func: CombineDocsProtocol, + **kwargs: Any, +) -> Document: + result = combine_document_func(docs, **kwargs) + combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} + for doc in docs[1:]: + for k, v in doc.metadata.items(): + if k in combined_metadata: + combined_metadata[k] += f", {v}" + else: + combined_metadata[k] = str(v) + return Document(page_content=result, metadata=combined_metadata) + + +async def _acollapse_docs( + docs: List[Document], + combine_document_func: AsyncCombineDocsProtocol, + **kwargs: Any, +) -> Document: + result = await combine_document_func(docs, **kwargs) + combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} + for doc in docs[1:]: + for k, v in doc.metadata.items(): + if k in combined_metadata: + combined_metadata[k] += f", {v}" + else: + combined_metadata[k] = str(v) + return Document(page_content=result, metadata=combined_metadata) + + +class ReduceDocumentsChain(BaseCombineDocumentsChain): + """Combining documents by recursively reducing them. + + This involves + - combine_documents_chain + - collapse_documents_chain + + `combine_documents_chain` is ALWAYS provided. This is final chain that is called. + We pass all previous results to this chain, and the output of this chain is + returned as a final result. + + `collapse_documents_chain` is used if the documents passed in are too many to all + be passed to `combine_documents_chain` in one go. In this case, + `collapse_documents_chain` is called recursively on as big of groups of documents + as are allowed. + + Example: + .. code-block:: python + + from langchain.chains import ( + StuffDocumentsChain, LLMChain, ReduceDocumentsChain + ) + from langchain.prompts import PromptTemplate + from langchain.llms import OpenAI + + # This controls how each document will be formatted. Specifically, + # it will be passed to `format_document` - see that function for more + # details. + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + document_variable_name = "context" + llm = OpenAI() + # The prompt here should take as an input variable the + # `document_variable_name` + prompt = PromptTemplate.from_template( + "Summarize this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + combine_documents_chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + ) + # If we wanted to, we could also pass in collapse_documents_chain + # which is specifically aimed at collapsing documents BEFORE + # the final call. + prompt = PromptTemplate.from_template( + "Collapse this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + collapse_documents_chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_documents_chain, + ) + """ + + combine_documents_chain: BaseCombineDocumentsChain + """Final chain to call to combine documents. + This is typically a StuffDocumentsChain.""" + collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None + """Chain to use to collapse documents if needed until they can all fit. + If None, will use the combine_documents_chain. + This is typically a StuffDocumentsChain.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def _collapse_chain(self) -> BaseCombineDocumentsChain: + if self.collapse_documents_chain is not None: + return self.collapse_documents_chain + else: + return self.combine_documents_chain + + def combine_docs( + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[str, dict]: + """Combine multiple documents recursively. + + Args: + docs: List of documents to combine, assumed that each one is less than + `token_max`. + token_max: Recursively creates groups of documents less than this number + of tokens. + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ + result_docs, extra_return_dict = self._collapse( + docs, token_max, callbacks=callbacks, **kwargs + ) + return self.combine_documents_chain.combine_docs( + docs=result_docs, callbacks=callbacks, **kwargs + ) + + async def acombine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: + """Combine multiple documents recursively. + + Args: + docs: List of documents to combine, assumed that each one is less than + `token_max`. + token_max: Recursively creates groups of documents less than this number + of tokens. + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ + result_docs, extra_return_dict = await self._acollapse( + docs, callbacks=callbacks, **kwargs + ) + return await self.combine_documents_chain.acombine_docs( + docs=result_docs, callbacks=callbacks, **kwargs + ) + + def _collapse( + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[List[Document], dict]: + result_docs = docs + length_func = self.combine_documents_chain.prompt_length + num_tokens = length_func(result_docs, **kwargs) + + def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: + return self._collapse_chain.run( + input_documents=docs, callbacks=callbacks, **kwargs + ) + + while num_tokens is not None and num_tokens > token_max: + new_result_doc_list = _split_list_of_docs( + result_docs, length_func, token_max, **kwargs + ) + result_docs = [] + for docs in new_result_doc_list: + new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs) + result_docs.append(new_doc) + num_tokens = length_func(result_docs, **kwargs) + return result_docs, {} + + async def _acollapse( + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[List[Document], dict]: + result_docs = docs + length_func = self.combine_documents_chain.prompt_length + num_tokens = length_func(result_docs, **kwargs) + + async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: + return await self._collapse_chain.arun( + input_documents=docs, callbacks=callbacks, **kwargs + ) + + while num_tokens is not None and num_tokens > token_max: + new_result_doc_list = _split_list_of_docs( + result_docs, length_func, token_max, **kwargs + ) + result_docs = [] + for docs in new_result_doc_list: + new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs) + result_docs.append(new_doc) + num_tokens = length_func(result_docs, **kwargs) + return result_docs, {} + + @property + def _chain_type(self) -> str: + return "reduce_documents_chain" diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index fac234719c..18fb0ed097 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -9,12 +9,11 @@ from pydantic import Extra, Field, root_validator from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, - format_document, ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BasePromptTemplate +from langchain.schema import BasePromptTemplate, format_document def _get_default_document_prompt() -> PromptTemplate: @@ -22,7 +21,55 @@ def _get_default_document_prompt() -> PromptTemplate: class RefineDocumentsChain(BaseCombineDocumentsChain): - """Combine documents by doing a first pass and then refining on more documents.""" + """Combine documents by doing a first pass and then refining on more documents. + + This algorithm first calls `initial_llm_chain` on the first document, passing + that first document in with the variable name `document_variable_name`, and + produces a new variable with the variable name `initial_response_name`. + + Then, it loops over every remaining document. This is called the "refine" step. + It calls `refine_llm_chain`, + passing in that document with the variable name `document_variable_name` + as well as the previous response with the variable name `initial_response_name`. + + Example: + .. code-block:: python + + from langchain.chains import RefineDocumentsChain, LLMChain + from langchain.prompts import PromptTemplate + from langchain.llms import OpenAI + + # This controls how each document will be formatted. Specifically, + # it will be passed to `format_document` - see that function for more + # details. + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + document_variable_name = "context" + llm = OpenAI() + # The prompt here should take as an input variable the + # `document_variable_name` + prompt = PromptTemplate.from_template( + "Summarize this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + initial_response_name = "prev_response" + # The prompt here should take as an input variable the + # `document_variable_name` as well as `initial_response_name` + prompt_refine = PromptTemplate.from_template( + "Here's your first summary: {prev_response}. " + "Now add to it based on the following context: {context}" + ) + llm_chain_refine = LLMChain(llm=llm, prompt=prompt_refine) + chain = RefineDocumentsChain( + initial_llm_chain=initial_llm_chain, + refine_llm_chain=refine_llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + ) + """ initial_llm_chain: LLMChain """LLM chain to use on initial document.""" @@ -36,7 +83,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): document_prompt: BasePromptTemplate = Field( default_factory=_get_default_document_prompt ) - """Prompt to use to format each document.""" + """Prompt to use to format each document, gets passed to `format_document`.""" return_intermediate_steps: bool = False """Return the results of the refine steps in the output.""" @@ -89,7 +136,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: - """Combine by mapping first chain over all, then stuffing into final chain.""" + """Combine by mapping first chain over all, then stuffing into final chain. + + Args: + docs: List of documents to combine + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ inputs = self._construct_initial_inputs(docs, **kwargs) res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs) refine_steps = [res] @@ -103,7 +161,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): async def acombine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: - """Combine by mapping first chain over all, then stuffing into final chain.""" + """Combine by mapping first chain over all, then stuffing into final chain. + + Args: + docs: List of documents to combine + callbacks: Callbacks to be passed through + **kwargs: additional parameters to be passed to LLM calls (like other + input variables besides the documents) + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ inputs = self._construct_initial_inputs(docs, **kwargs) res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs) refine_steps = [res] diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index e4859ff6ca..6712734231 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -7,12 +7,11 @@ from pydantic import Extra, Field, root_validator from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, - format_document, ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BasePromptTemplate +from langchain.schema import BasePromptTemplate, format_document def _get_default_document_prompt() -> PromptTemplate: @@ -20,14 +19,50 @@ def _get_default_document_prompt() -> PromptTemplate: class StuffDocumentsChain(BaseCombineDocumentsChain): - """Chain that combines documents by stuffing into context.""" + """Chain that combines documents by stuffing into context. + + This chain takes a list of documents and first combines them into a single string. + It does this by formatting each document into a string with the `document_prompt` + and then joining them together with `document_separator`. It then adds that new + string to the inputs with the variable name set by `document_variable_name`. + Those inputs are then passed to the `llm_chain`. + + Example: + .. code-block:: python + + from langchain.chains import StuffDocumentsChain, LLMChain + from langchain.prompts import PromptTemplate + from langchain.llms import OpenAI + + # This controls how each document will be formatted. Specifically, + # it will be passed to `format_document` - see that function for more + # details. + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + document_variable_name = "context" + llm = OpenAI() + # The prompt here should take as an input variable the + # `document_variable_name` + prompt = PromptTemplate.from_template( + "Summarize this content: {context}" + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + """ llm_chain: LLMChain - """LLM wrapper to use after formatting documents.""" + """LLM chain which is called with the formatted document string, + along with any other inputs.""" document_prompt: BasePromptTemplate = Field( default_factory=_get_default_document_prompt ) - """Prompt to use to format each document.""" + """Prompt to use to format each document, gets passed to `format_document`.""" document_variable_name: str """The variable name in the llm_chain to put the documents in. If only one variable in the llm_chain, this need not be provided.""" @@ -42,7 +77,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): @root_validator(pre=True) def get_default_document_variable_name(cls, values: Dict) -> Dict: - """Get default document variable name, if not provided.""" + """Get default document variable name, if not provided. + + If only one variable is present in the llm_chain.prompt, + we can infer that the formatted documents should be passed in + with this variable name. + """ llm_chain_variables = values["llm_chain"].prompt.input_variables if "document_variable_name" not in values: if len(llm_chain_variables) == 1: @@ -61,6 +101,20 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return values def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: + """Construct inputs from kwargs and docs. + + Format and the join all the documents together into one input with name + `self.document_variable_name`. The pluck any additional variables + from **kwargs. + + Args: + docs: List of documents to format and then join into single input + **kwargs: additional inputs to chain, will pluck any other required + arguments from here. + + Returns: + dictionary of inputs to LLMChain + """ # Format each document according to the prompt doc_strings = [format_document(doc, self.document_prompt) for doc in docs] # Join the documents together to put them in the prompt. @@ -73,7 +127,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return inputs def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: - """Get the prompt length by formatting the prompt.""" + """Return the prompt length given the documents passed in. + + This can be used by a caller to determine whether passing in a list + of documents would exceed a certain prompt length. This useful when + trying to ensure that the size of a prompt remains below a certain + context limit. + + Args: + docs: List[Document], a list of documents to use to calculate the + total prompt length. + + Returns: + Returns None if the method does not depend on the prompt length, + otherwise the length of the prompt in tokens. + """ inputs = self._get_inputs(docs, **kwargs) prompt = self.llm_chain.prompt.format(**inputs) return self.llm_chain.llm.get_num_tokens(prompt) @@ -81,7 +149,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: - """Stuff all documents into one prompt and pass to LLM.""" + """Stuff all documents into one prompt and pass to LLM. + + Args: + docs: List of documents to join together into one variable + callbacks: Optional callbacks to pass along + **kwargs: additional parameters to use to get inputs to LLMChain. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. return self.llm_chain.predict(callbacks=callbacks, **inputs), {} @@ -89,7 +167,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): async def acombine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: - """Stuff all documents into one prompt and pass to LLM.""" + """Stuff all documents into one prompt and pass to LLM. + + Args: + docs: List of documents to join together into one variable + callbacks: Optional callbacks to pass along + **kwargs: additional parameters to use to get inputs to LLMChain. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index a01872fbe8..75b2eaaf8b 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -5,6 +5,7 @@ from typing import Any, Union import yaml +from langchain.chains import ReduceDocumentsChain from langchain.chains.api.base import APIChain from langchain.chains.base import Chain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -117,9 +118,9 @@ def _load_map_reduce_documents_chain( if "combine_document_chain" in config: combine_document_chain_config = config.pop("combine_document_chain") - combine_document_chain = load_chain_from_config(combine_document_chain_config) + combine_documents_chain = load_chain_from_config(combine_document_chain_config) elif "combine_document_chain_path" in config: - combine_document_chain = load_chain(config.pop("combine_document_chain_path")) + combine_documents_chain = load_chain(config.pop("combine_document_chain_path")) else: raise ValueError( "One of `combine_document_chain` or " @@ -128,17 +129,24 @@ def _load_map_reduce_documents_chain( if "collapse_document_chain" in config: collapse_document_chain_config = config.pop("collapse_document_chain") if collapse_document_chain_config is None: - collapse_document_chain = None + collapse_documents_chain = None else: - collapse_document_chain = load_chain_from_config( + collapse_documents_chain = load_chain_from_config( collapse_document_chain_config ) elif "collapse_document_chain_path" in config: - collapse_document_chain = load_chain(config.pop("collapse_document_chain_path")) + collapse_documents_chain = load_chain( + config.pop("collapse_document_chain_path") + ) + else: + collapse_documents_chain = None + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_documents_chain, + ) return MapReduceDocumentsChain( llm_chain=llm_chain, - combine_document_chain=combine_document_chain, - collapse_document_chain=collapse_document_chain, + reduce_documents_chain=reduce_documents_chain, **config, ) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 732d489e87..8b217bf023 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -11,6 +11,7 @@ from pydantic import Extra from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks +from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -44,14 +45,17 @@ class MapReduceChain(Chain): ) -> MapReduceChain: """Construct a map-reduce chain that uses the chain for map and reduce.""" llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) - reduce_chain = StuffDocumentsChain( + stuff_chain = StuffDocumentsChain( llm_chain=llm_chain, callbacks=callbacks, **(reduce_chain_kwargs if reduce_chain_kwargs else {}), ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=stuff_chain + ) combine_documents_chain = MapReduceDocumentsChain( llm_chain=llm_chain, - combine_document_chain=reduce_chain, + reduce_documents_chain=reduce_documents_chain, callbacks=callbacks, **(combine_chain_kwargs if combine_chain_kwargs else {}), ) diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index ba882310b6..5e1799cc29 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -14,6 +14,7 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) +from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -58,13 +59,16 @@ class BaseQAWithSourcesChain(Chain, ABC): document_prompt=document_prompt, document_variable_name="summaries", ) - combine_document_chain = MapReduceDocumentsChain( + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_results_chain + ) + combine_documents_chain = MapReduceDocumentsChain( llm_chain=llm_question_chain, - combine_document_chain=combine_results_chain, + reduce_documents_chain=reduce_documents_chain, document_variable_name="context", ) return cls( - combine_documents_chain=combine_document_chain, + combine_documents_chain=combine_documents_chain, **kwargs, ) @@ -78,10 +82,10 @@ class BaseQAWithSourcesChain(Chain, ABC): ) -> BaseQAWithSourcesChain: """Load chain from chain type.""" _chain_kwargs = chain_type_kwargs or {} - combine_document_chain = load_qa_with_sources_chain( + combine_documents_chain = load_qa_with_sources_chain( llm, chain_type=chain_type, **_chain_kwargs ) - return cls(combine_documents_chain=combine_document_chain, **kwargs) + return cls(combine_documents_chain=combine_documents_chain, **kwargs) class Config: """Configuration for this pydantic object.""" @@ -110,7 +114,7 @@ class BaseQAWithSourcesChain(Chain, ABC): @root_validator(pre=True) def validate_naming(cls, values: Dict) -> Dict: - """Fix backwards compatability in naming.""" + """Fix backwards compatibility in naming.""" if "combine_document_chain" in values: values["combine_documents_chain"] = values.pop("combine_document_chain") return values diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index 57ea76a763..9a5058938c 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Any, Mapping, Optional, Protocol from langchain.base_language import BaseLanguageModel +from langchain.chains import ReduceDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -83,7 +84,7 @@ def _load_map_reduce_chain( map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _reduce_llm = reduce_llm or llm reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) - combine_document_chain = StuffDocumentsChain( + combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, document_prompt=document_prompt, @@ -107,11 +108,14 @@ def _load_map_reduce_chain( document_variable_name=combine_document_variable_name, document_prompt=document_prompt, ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_chain, + ) return MapReduceDocumentsChain( llm_chain=map_chain, - combine_document_chain=combine_document_chain, + reduce_documents_chain=reduce_documents_chain, document_variable_name=map_reduce_document_variable_name, - collapse_document_chain=collapse_chain, verbose=verbose, **kwargs, ) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index a6eb09e232..9fa2d2bffd 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -4,6 +4,7 @@ from typing import Any, Mapping, Optional, Protocol from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks +from langchain.chains import ReduceDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -122,7 +123,7 @@ def _load_map_reduce_chain( callbacks=callbacks, ) # TODO: document prompt - combine_document_chain = StuffDocumentsChain( + combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, verbose=verbose, @@ -150,11 +151,14 @@ def _load_map_reduce_chain( verbose=verbose, callback_manager=callback_manager, ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_chain, + ) return MapReduceDocumentsChain( llm_chain=map_chain, - combine_document_chain=combine_document_chain, document_variable_name=map_reduce_document_variable_name, - collapse_document_chain=collapse_chain, + reduce_documents_chain=reduce_documents_chain, verbose=verbose, callback_manager=callback_manager, callbacks=callbacks, diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index fa88cea1ff..211bf701d9 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -2,6 +2,7 @@ from typing import Any, Mapping, Optional, Protocol from langchain.base_language import BaseLanguageModel +from langchain.chains import ReduceDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain @@ -53,7 +54,7 @@ def _load_map_reduce_chain( _reduce_llm = reduce_llm or llm reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # TODO: document prompt - combine_document_chain = StuffDocumentsChain( + combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, verbose=verbose, @@ -75,11 +76,14 @@ def _load_map_reduce_chain( ), document_variable_name=combine_document_variable_name, ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_chain, + ) return MapReduceDocumentsChain( llm_chain=map_chain, - combine_document_chain=combine_document_chain, + reduce_documents_chain=reduce_documents_chain, document_variable_name=map_reduce_document_variable_name, - collapse_document_chain=collapse_chain, verbose=verbose, **kwargs, ) diff --git a/langchain/schema/__init__.py b/langchain/schema/__init__.py index 13983206bf..818ff4113d 100644 --- a/langchain/schema/__init__.py +++ b/langchain/schema/__init__.py @@ -28,7 +28,7 @@ from langchain.schema.output_parser import ( OutputParserException, ) from langchain.schema.prompt import PromptValue -from langchain.schema.prompt_template import BasePromptTemplate +from langchain.schema.prompt_template import BasePromptTemplate, format_document from langchain.schema.retriever import BaseRetriever RUN_KEY = "__run" @@ -66,4 +66,5 @@ __all__ = [ "BaseOutputParser", "BaseLLMOutputParser", "BasePromptTemplate", + "format_document", ] diff --git a/langchain/schema/prompt_template.py b/langchain/schema/prompt_template.py index 6ed048df4a..7b340887eb 100644 --- a/langchain/schema/prompt_template.py +++ b/langchain/schema/prompt_template.py @@ -9,7 +9,9 @@ import yaml from pydantic import Field, root_validator from langchain.load.serializable import Serializable -from langchain.schema import BaseOutputParser, PromptValue +from langchain.schema.document import Document +from langchain.schema.output_parser import BaseOutputParser +from langchain.schema.prompt import PromptValue class BasePromptTemplate(Serializable, ABC): @@ -137,3 +139,48 @@ class BasePromptTemplate(Serializable, ABC): yaml.dump(prompt_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml") + + +def format_document(doc: Document, prompt: BasePromptTemplate) -> str: + """Format a document into a string based on a prompt template. + + First, this pulls information from the document from two sources: + + 1. `page_content`: this takes the information from the `document.page_content` + and assigns it to a variable named `page_content`. + 2. metadata: this takes information from `document.metadata` and assigns + it to variables of the same name. + + Those variables are then passed into the `prompt` to produce a formatted string. + + Args: + doc: Document, the page_content and metadata will be used to create + the final string. + prompt: BasePromptTemplate, will be used to format the page_content + and metadata into the final string. + + Returns: + string of the document formatted. + + Example: + .. code-block:: python + from langchain.schema import Document + from langchain.prompts import PromptTemplate + doc = Document(page_content="This is a joke", metadata={"page": "1"}) + prompt = PromptTemplate.from_template("Page {page}: {page_content}") + format_document(doc, prompt) + >>> "Page 1: This is a joke" + """ + base_info = {"page_content": doc.page_content, **doc.metadata} + missing_metadata = set(prompt.input_variables).difference(base_info) + if len(missing_metadata) > 0: + required_metadata = [ + iv for iv in prompt.input_variables if iv != "page_content" + ] + raise ValueError( + f"Document prompt requires documents to have metadata variables: " + f"{required_metadata}. Received document with missing metadata: " + f"{list(missing_metadata)}." + ) + document_info = {k: base_info[k] for k in prompt.input_variables} + return prompt.format(**document_info) diff --git a/tests/unit_tests/chains/test_combine_documents.py b/tests/unit_tests/chains/test_combine_documents.py index 7377250352..df2212588e 100644 --- a/tests/unit_tests/chains/test_combine_documents.py +++ b/tests/unit_tests/chains/test_combine_documents.py @@ -5,12 +5,12 @@ from typing import Any, List import pytest from langchain import PromptTemplate -from langchain.chains.combine_documents.base import format_document -from langchain.chains.combine_documents.map_reduce import ( +from langchain.chains.combine_documents.reduce import ( _collapse_docs, _split_list_of_docs, ) from langchain.docstore.document import Document +from langchain.schema import format_document def _fake_docs_len_func(docs: List[Document]) -> int: @@ -28,13 +28,6 @@ def test__split_list_long_single_doc() -> None: _split_list_of_docs(docs, _fake_docs_len_func, 100) -def test__split_list_long_pair_doc() -> None: - """Test splitting of a list with two medium docs.""" - docs = [Document(page_content="foo" * 30)] * 2 - with pytest.raises(ValueError): - _split_list_of_docs(docs, _fake_docs_len_func, 100) - - def test__split_list_single_doc() -> None: """Test splitting works with just a single doc.""" docs = [Document(page_content="foo")] diff --git a/tests/unit_tests/test_dependencies.py b/tests/unit_tests/test_dependencies.py index 72ca1b793f..b200bf30b4 100644 --- a/tests/unit_tests/test_dependencies.py +++ b/tests/unit_tests/test_dependencies.py @@ -86,8 +86,8 @@ def test_imports() -> None: from langchain.document_loaders import BSHTMLLoader # noqa: F401 from langchain.embeddings import OpenAIEmbeddings # noqa: F401 from langchain.llms import OpenAI # noqa: F401 - from langchain.prompts import BasePromptTemplate # noqa: F401 from langchain.retrievers import VespaRetriever # noqa: F401 + from langchain.schema import BasePromptTemplate # noqa: F401 from langchain.tools import DuckDuckGoSearchResults # noqa: F401 from langchain.utilities import SerpAPIWrapper # noqa: F401 from langchain.vectorstores import FAISS # noqa: F401