add optional collapse prompt (#358)

This commit is contained in:
Harrison Chase 2022-12-16 06:25:29 -08:00 committed by GitHub
parent 2dd895d98c
commit 750edfb440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 6 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -56,9 +56,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Combining documents by mapping a chain over them, then combining results.""" """Combining documents by mapping a chain over them, then combining results."""
llm_chain: LLMChain llm_chain: LLMChain
"""Chain to apply to each document individually..""" """Chain to apply to each document individually."""
combine_document_chain: BaseCombineDocumentsChain combine_document_chain: BaseCombineDocumentsChain
"""Chain to use to combine results of applying llm_chain to documents.""" """Chain to use to combine results of applying llm_chain to documents."""
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
"""Chain to use to collapse intermediary results if needed.
If None, will use the combine_document_chain."""
document_variable_name: str document_variable_name: str
"""The variable name in the llm_chain to put the documents in. """The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided.""" If only one variable in the llm_chain, this need not be provided."""
@ -90,6 +93,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
) )
return values return values
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_document_chain is not None:
return self.collapse_document_chain
else:
return self.combine_document_chain
def combine_docs( def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> str: ) -> str:
@ -117,7 +127,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
result_docs = [] result_docs = []
for docs in new_result_doc_list: for docs in new_result_doc_list:
new_doc = _collapse_docs( new_doc = _collapse_docs(
docs, self.combine_document_chain.combine_docs, **kwargs docs, self._collapse_chain.combine_docs, **kwargs
) )
result_docs.append(new_doc) result_docs.append(new_doc)
num_tokens = self.combine_document_chain.prompt_length( num_tokens = self.combine_document_chain.prompt_length(

View File

@ -1,5 +1,5 @@
"""Load question answering with sources chains.""" """Load question answering with sources chains."""
from typing import Any, Mapping, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -44,6 +44,7 @@ def _load_map_reduce_chain(
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
combine_document_variable_name: str = "summaries", combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context", map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt) map_chain = LLMChain(llm=llm, prompt=question_prompt)
@ -53,10 +54,19 @@ def _load_map_reduce_chain(
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
) )
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain, combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs, **kwargs,
) )

View File

@ -1,5 +1,5 @@
"""Load question answering chains.""" """Load question answering chains."""
from typing import Any, Mapping, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -41,6 +41,7 @@ def _load_map_reduce_chain(
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
combine_document_variable_name: str = "summaries", combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context", map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt) map_chain = LLMChain(llm=llm, prompt=question_prompt)
@ -49,10 +50,18 @@ def _load_map_reduce_chain(
combine_document_chain = StuffDocumentsChain( combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
) )
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain, combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs, **kwargs,
) )

View File

@ -1,5 +1,5 @@
"""Load summarizing chains.""" """Load summarizing chains."""
from typing import Any, Mapping, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -37,6 +37,7 @@ def _load_map_reduce_chain(
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text", combine_document_variable_name: str = "text",
map_reduce_document_variable_name: str = "text", map_reduce_document_variable_name: str = "text",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt) map_chain = LLMChain(llm=llm, prompt=map_prompt)
@ -45,10 +46,18 @@ def _load_map_reduce_chain(
combine_document_chain = StuffDocumentsChain( combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
) )
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain, combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs, **kwargs,
) )