return intermediate steps in combine document chains

harrison/combine-docs-parse
Harrison Chase 1 year ago
parent 9ae1d75318
commit f97db8cc7b

@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "8dff4f43",
"metadata": {},
"outputs": [],
@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "27989fc4",
"metadata": {},
"outputs": [],
@ -131,33 +131,57 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "ef28e1d4",
"metadata": {},
"outputs": [],
"source": [
"chain = load_summarize_chain(llm, chain_type=\"map_reduce\")"
"chain = load_summarize_chain(llm, chain_type=\"map_reduce\", verbose=True, return_map_steps=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "f82c5f9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new MapReduceDocumentsChain chain...\u001b[0m\n",
"\n",
"\u001b[1m> Finished MapReduceDocumentsChain chain.\u001b[0m\n"
]
}
],
"source": [
"res = chain(docs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f5a2b653",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" In response to Vladimir Putin's aggression in Ukraine, the US and its allies have taken action to hold him accountable, including economic sanctions, cutting off access to technology, and seizing the assets of Russian oligarchs. They are also providing military, economic, and humanitarian assistance to the Ukrainians, and releasing 60 million barrels of oil from reserves around the world. President Biden has passed several laws to provide economic relief to Americans and create jobs, and is making sure taxpayer dollars support American jobs and businesses.\""
"[{'text': \" In response to Russia's aggression in Ukraine, the United States has united with other freedom-loving nations to impose economic sanctions and hold Putin accountable. The U.S. Department of Justice is also assembling a task force to go after the crimes of Russian oligarchs and seize their ill-gotten gains.\"},\n",
" {'text': ' The United States and its European allies are taking action to punish Russia for its invasion of Ukraine, including seizing assets, closing off airspace, and providing economic and military assistance to Ukraine. The US is also mobilizing forces to protect NATO countries and has released 30 million barrels of oil from its Strategic Petroleum Reserve to help blunt gas prices. The world is uniting in support of Ukraine and democracy, and the US stands with its Ukrainian American citizens.'},\n",
" {'text': ' President Biden and Vice President Harris ran for office with a new economic vision for America, and have since passed the American Rescue Plan and the Bipartisan Infrastructure Law to help working people and rebuild America. These plans will create jobs, modernize roads, airports, ports, and waterways, and provide clean water and high-speed internet for all Americans. The government will also be investing in American products to support American jobs.'}]"
]
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(docs)"
"res['map_steps']"
]
},
{

@ -1,7 +1,7 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel
@ -39,12 +39,13 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
return None
@abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output = self.combine_docs(docs, **other_keys)
return {self.output_key: output}
output, extra_return_dict = self.combine_docs(docs, **other_keys)
extra_return_dict[self.output_key] = output
return extra_return_dict

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, root_validator
@ -65,6 +65,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
return_map_steps: bool = False
"""Return the results of the map steps in the output."""
@property
def output_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
_output_keys = super().output_keys
if self.return_map_steps:
_output_keys = _output_keys + ["map_steps"]
return _output_keys
class Config:
"""Configuration for this pydantic object."""
@ -102,7 +115,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> str:
) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.
@ -133,5 +146,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
num_tokens = self.combine_document_chain.prompt_length(
result_docs, **kwargs
)
output = self.combine_document_chain.combine_docs(result_docs, **kwargs)
return output
if self.return_map_steps:
extra_return_dict = {"map_steps": results}
else:
extra_return_dict = {}
output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs)
return output, extra_return_dict

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
@ -33,6 +33,19 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
return_refine_steps: bool = False
"""Return the results of the refine steps in the output."""
@property
def output_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
_output_keys = super().output_keys
if self.return_refine_steps:
_output_keys = _output_keys + ["refine_steps"]
return _output_keys
class Config:
"""Configuration for this pydantic object."""
@ -61,7 +74,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata)
@ -71,6 +84,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
}
inputs = {**base_inputs, **kwargs}
res = self.initial_llm_chain.predict(**inputs)
refine_steps = [res]
for doc in docs[1:]:
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
@ -85,4 +99,9 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
}
inputs = {**base_inputs, **kwargs}
res = self.refine_llm_chain.predict(**inputs)
return res
refine_steps.append(res)
if self.return_refine_steps:
extra_return_dict = {"refine_steps": refine_steps}
else:
extra_return_dict = {}
return res, extra_return_dict

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
@ -78,8 +78,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(**inputs)
return self.llm_chain.predict(**inputs), {}

@ -70,5 +70,5 @@ class MapReduceChain(Chain, BaseModel):
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
docs = [Document(page_content=text) for text in texts]
outputs = self.combine_documents_chain.combine_docs(docs)
outputs, _ = self.combine_documents_chain.combine_docs(docs)
return {self.output_key: outputs}

@ -106,7 +106,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = self._get_docs(inputs)
answer = self.combine_document_chain.combine_docs(docs, **inputs)
answer, _ = self.combine_document_chain.combine_docs(docs, **inputs)
if "\nSOURCES: " in answer:
answer, sources = answer.split("\nSOURCES: ")
else:

@ -101,5 +101,5 @@ class VectorDBQA(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key]
docs = self.vectorstore.similarity_search(question, k=self.k)
answer = self.combine_documents_chain.combine_docs(docs, question=question)
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
return {self.output_key: answer}

Loading…
Cancel
Save