You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/mapreduce.py

75 lines
2.5 KiB
Python

"""Map-reduce chain.
Splits up a document, sends the smaller parts to the LLM with one prompt,
then combines the results with another one.
"""
from __future__ import annotations
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter
class MapReduceChain(Chain, BaseModel):
"""Map-reduce chain."""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine documents."""
text_splitter: TextSplitter
"""Text splitter to use."""
input_key: str = "input_text" #: :meta private:
output_key: str = "output_text" #: :meta private:
@classmethod
def from_params(
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain, combine_document_chain=reduce_chain
)
return cls(
combine_documents_chain=combine_documents_chain, text_splitter=text_splitter
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
docs = [Document(page_content=text) for text in texts]
outputs = self.combine_documents_chain.combine_docs(docs)
return {self.output_key: outputs}