forked from Archives/langchain
add optional collapse prompt (#358)
This commit is contained in:
parent
2dd895d98c
commit
750edfb440
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
@ -56,9 +56,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
"""Combining documents by mapping a chain over them, then combining results."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually.."""
|
||||
"""Chain to apply to each document individually."""
|
||||
combine_document_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine results of applying llm_chain to documents."""
|
||||
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
||||
"""Chain to use to collapse intermediary results if needed.
|
||||
If None, will use the combine_document_chain."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
@ -90,6 +93,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||
if self.collapse_document_chain is not None:
|
||||
return self.collapse_document_chain
|
||||
else:
|
||||
return self.combine_document_chain
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||
) -> str:
|
||||
@ -117,7 +127,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = _collapse_docs(
|
||||
docs, self.combine_document_chain.combine_docs, **kwargs
|
||||
docs, self._collapse_chain.combine_docs, **kwargs
|
||||
)
|
||||
result_docs.append(new_doc)
|
||||
num_tokens = self.combine_document_chain.prompt_length(
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Protocol
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -44,6 +44,7 @@ def _load_map_reduce_chain(
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
||||
@ -53,10 +54,19 @@ def _load_map_reduce_chain(
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
else:
|
||||
collapse_chain = StuffDocumentsChain(
|
||||
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Protocol
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -41,6 +41,7 @@ def _load_map_reduce_chain(
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
||||
@ -49,10 +50,18 @@ def _load_map_reduce_chain(
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
else:
|
||||
collapse_chain = StuffDocumentsChain(
|
||||
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Load summarizing chains."""
|
||||
from typing import Any, Mapping, Protocol
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@ -37,6 +37,7 @@ def _load_map_reduce_chain(
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_document_variable_name: str = "text",
|
||||
map_reduce_document_variable_name: str = "text",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
||||
@ -45,10 +46,18 @@ def _load_map_reduce_chain(
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
else:
|
||||
collapse_chain = StuffDocumentsChain(
|
||||
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user