|
|
|
@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel):
|
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
|
|
|
# Split the larger text into smaller chunks.
|
|
|
|
|
docs = self.text_splitter.split_text(inputs[self.input_key])
|
|
|
|
|
|
|
|
|
|
# Now that we have the chunks, we send them to the LLM and track results.
|
|
|
|
|
# This is the "map" part.
|
|
|
|
|
summaries = []
|
|
|
|
|
for d in docs:
|
|
|
|
|
inputs = {self.map_llm.prompt.input_variables[0]: d}
|
|
|
|
|
res = self.map_llm.predict(**inputs)
|
|
|
|
|
summaries.append(res)
|
|
|
|
|
input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
|
|
|
|
|
summary_results = self.map_llm.apply(input_list)
|
|
|
|
|
summaries = [res[self.map_llm.output_key] for res in summary_results]
|
|
|
|
|
|
|
|
|
|
# We then need to combine these individual parts into one.
|
|
|
|
|
# This is the reduce part.
|
|
|
|
|