add apply functionality (#150)

harrison/save_metadatas
Harrison Chase 2 years ago committed by GitHub
parent 47e35d7d0e
commit d775ddd749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -49,6 +49,10 @@ class Chain(BaseModel, ABC):
self._validate_outputs(outputs)
return {**inputs, **outputs}
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
def run(self, text: str) -> str:
"""Run text in, text out (if applicable)."""
if len(self.input_keys) != 1:

@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Split the larger text into smaller chunks.
docs = self.text_splitter.split_text(inputs[self.input_key])
# Now that we have the chunks, we send them to the LLM and track results.
# This is the "map" part.
summaries = []
for d in docs:
inputs = {self.map_llm.prompt.input_variables[0]: d}
res = self.map_llm.predict(**inputs)
summaries.append(res)
input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
summary_results = self.map_llm.apply(input_list)
summaries = [res[self.map_llm.output_key] for res in summary_results]
# We then need to combine these individual parts into one.
# This is the reduce part.

Loading…
Cancel
Save