add optional collapse prompt (#358)

This commit is contained in:
Harrison Chase 2022-12-16 06:25:29 -08:00 committed by GitHub
parent 2dd895d98c
commit 750edfb440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 6 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
@ -56,9 +56,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Combining documents by mapping a chain over them, then combining results."""
llm_chain: LLMChain
"""Chain to apply to each document individually.."""
"""Chain to apply to each document individually."""
combine_document_chain: BaseCombineDocumentsChain
"""Chain to use to combine results of applying llm_chain to documents."""
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
"""Chain to use to collapse intermediary results if needed.
If None, will use the combine_document_chain."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
@ -90,6 +93,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
)
return values
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_document_chain is not None:
return self.collapse_document_chain
else:
return self.combine_document_chain
def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> str:
@ -117,7 +127,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(
docs, self.combine_document_chain.combine_docs, **kwargs
docs, self._collapse_chain.combine_docs, **kwargs
)
result_docs.append(new_doc)
num_tokens = self.combine_document_chain.prompt_length(

View File

@ -1,5 +1,5 @@
"""Load question answering with sources chains."""
from typing import Any, Mapping, Protocol
from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -44,6 +44,7 @@ def _load_map_reduce_chain(
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt)
@ -53,10 +54,19 @@ def _load_map_reduce_chain(
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs,
)

View File

@ -1,5 +1,5 @@
"""Load question answering chains."""
from typing import Any, Mapping, Protocol
from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -41,6 +41,7 @@ def _load_map_reduce_chain(
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt)
@ -49,10 +50,18 @@ def _load_map_reduce_chain(
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
)
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs,
)

View File

@ -1,5 +1,5 @@
"""Load summarizing chains."""
from typing import Any, Mapping, Protocol
from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -37,6 +37,7 @@ def _load_map_reduce_chain(
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text",
map_reduce_document_variable_name: str = "text",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt)
@ -45,10 +46,18 @@ def _load_map_reduce_chain(
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
)
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs,
)