forked from Archives/langchain
add optional collapse prompt (#358)
This commit is contained in:
parent
2dd895d98c
commit
750edfb440
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
@ -56,9 +56,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
"""Combining documents by mapping a chain over them, then combining results."""
|
"""Combining documents by mapping a chain over them, then combining results."""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
"""Chain to apply to each document individually.."""
|
"""Chain to apply to each document individually."""
|
||||||
combine_document_chain: BaseCombineDocumentsChain
|
combine_document_chain: BaseCombineDocumentsChain
|
||||||
"""Chain to use to combine results of applying llm_chain to documents."""
|
"""Chain to use to combine results of applying llm_chain to documents."""
|
||||||
|
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
||||||
|
"""Chain to use to collapse intermediary results if needed.
|
||||||
|
If None, will use the combine_document_chain."""
|
||||||
document_variable_name: str
|
document_variable_name: str
|
||||||
"""The variable name in the llm_chain to put the documents in.
|
"""The variable name in the llm_chain to put the documents in.
|
||||||
If only one variable in the llm_chain, this need not be provided."""
|
If only one variable in the llm_chain, this need not be provided."""
|
||||||
@ -90,6 +93,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||||
|
if self.collapse_document_chain is not None:
|
||||||
|
return self.collapse_document_chain
|
||||||
|
else:
|
||||||
|
return self.combine_document_chain
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -117,7 +127,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
result_docs = []
|
result_docs = []
|
||||||
for docs in new_result_doc_list:
|
for docs in new_result_doc_list:
|
||||||
new_doc = _collapse_docs(
|
new_doc = _collapse_docs(
|
||||||
docs, self.combine_document_chain.combine_docs, **kwargs
|
docs, self._collapse_chain.combine_docs, **kwargs
|
||||||
)
|
)
|
||||||
result_docs.append(new_doc)
|
result_docs.append(new_doc)
|
||||||
num_tokens = self.combine_document_chain.prompt_length(
|
num_tokens = self.combine_document_chain.prompt_length(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Load question answering with sources chains."""
|
"""Load question answering with sources chains."""
|
||||||
from typing import Any, Mapping, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@ -44,6 +44,7 @@ def _load_map_reduce_chain(
|
|||||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||||
combine_document_variable_name: str = "summaries",
|
combine_document_variable_name: str = "summaries",
|
||||||
map_reduce_document_variable_name: str = "context",
|
map_reduce_document_variable_name: str = "context",
|
||||||
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
||||||
@ -53,10 +54,19 @@ def _load_map_reduce_chain(
|
|||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
)
|
)
|
||||||
|
if collapse_prompt is None:
|
||||||
|
collapse_chain = None
|
||||||
|
else:
|
||||||
|
collapse_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||||
|
document_variable_name=combine_document_variable_name,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
combine_document_chain=combine_document_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
|
collapse_document_chain=collapse_chain,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Load question answering chains."""
|
"""Load question answering chains."""
|
||||||
from typing import Any, Mapping, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@ -41,6 +41,7 @@ def _load_map_reduce_chain(
|
|||||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||||
combine_document_variable_name: str = "summaries",
|
combine_document_variable_name: str = "summaries",
|
||||||
map_reduce_document_variable_name: str = "context",
|
map_reduce_document_variable_name: str = "context",
|
||||||
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt)
|
||||||
@ -49,10 +50,18 @@ def _load_map_reduce_chain(
|
|||||||
combine_document_chain = StuffDocumentsChain(
|
combine_document_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
||||||
)
|
)
|
||||||
|
if collapse_prompt is None:
|
||||||
|
collapse_chain = None
|
||||||
|
else:
|
||||||
|
collapse_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||||
|
document_variable_name=combine_document_variable_name,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
combine_document_chain=combine_document_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
|
collapse_document_chain=collapse_chain,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Load summarizing chains."""
|
"""Load summarizing chains."""
|
||||||
from typing import Any, Mapping, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@ -37,6 +37,7 @@ def _load_map_reduce_chain(
|
|||||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||||
combine_document_variable_name: str = "text",
|
combine_document_variable_name: str = "text",
|
||||||
map_reduce_document_variable_name: str = "text",
|
map_reduce_document_variable_name: str = "text",
|
||||||
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
||||||
@ -45,10 +46,18 @@ def _load_map_reduce_chain(
|
|||||||
combine_document_chain = StuffDocumentsChain(
|
combine_document_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
||||||
)
|
)
|
||||||
|
if collapse_prompt is None:
|
||||||
|
collapse_chain = None
|
||||||
|
else:
|
||||||
|
collapse_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
||||||
|
document_variable_name=combine_document_variable_name,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
combine_document_chain=combine_document_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
|
collapse_document_chain=collapse_chain,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user