map rerank chain (#516)

add a chain that applies a prompt to all inputs and then returns not
only an answer but scores it

add examples for question answering and question answering with sources
pull/561/head
Harrison Chase 1 year ago committed by GitHub
parent 948eee9fe1
commit 8dfad874a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@ For more information on specific use cases as well as different methods for **fe
This documentation now picks up from after you've fetched your documents - now what?
How do you pass them to the language model in a format it can understand?
There are a few different methods, or chains, for doing so. LangChain supports three of the more common ones - and
There are a few different methods, or chains, for doing so. LangChain supports four of the more common ones - and
we are actively looking to include more, so if you have any ideas please reach out! Note that there is not
one best method - the decision of which one to use is often very context specific. In order from simplest to
most complex:
@ -39,3 +39,13 @@ asking the LLM to refine the output based on the new document.
**Pros:** Can pull in more relevant context, and may be less lossy than `MapReduceDocumentsChain`.
**Cons:** Requires many more calls to the LLM than `StuffDocumentsChain`. The calls are also NOT independent, meaning they cannot be paralleled like `MapReduceDocumentsChain`. There is also some potential dependencies on the ordering of the documents.
## Map-Rerank
This method involves running an initial prompt on each chunk of data, that not only tries to complete a
task but also gives a score for how certain it is in its answer. The responses are then
ranked according to this score, and the highest score is returned.
**Pros:** Similar pros as `MapReduceDocumentsChain`. Compared to `MapReduceDocumentsChain`, it requires fewer calls.
**Cons:** Cannot combine information between documents. This means it is most useful when you expect there to be a single simple answer in a single document.

@ -7,7 +7,7 @@
"source": [
"# Question Answering with Sources\n",
"\n",
"This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers three different chain types: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
"This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers four different chain types: `stuff`, `map_reduce`, `refine`,`map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
]
},
{
@ -259,7 +259,7 @@
"source": [
"**Intermediate Steps**\n",
"\n",
"We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_refine_steps` variable."
"We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_intermediate_steps` variable."
]
},
{
@ -297,10 +297,87 @@
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "markdown",
"id": "07ff756e",
"metadata": {},
"source": [
"## The `map-rerank` Chain\n",
"\n",
"This sections shows results of using the `map-rerank` Chain to do question answering with sources."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "46b52ef9",
"metadata": {},
"outputs": [],
"source": [
"chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", metadata_keys=['source'], return_intermediate_steps=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7ce2da04",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"result = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cbdcd3c5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result[\"output_text\"]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6f0b3d03",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'answer': ' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.',\n",
" 'score': '100'},\n",
" {'answer': ' This document does not answer the question', 'score': '0'},\n",
" {'answer': ' This document does not answer the question', 'score': '0'},\n",
" {'answer': ' This document does not answer the question', 'score': '0'}]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result[\"intermediate_steps\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa2b8db9",
"id": "e66b8160",
"metadata": {},
"outputs": [],
"source": []

@ -7,7 +7,7 @@
"source": [
"# Question Answering\n",
"\n",
"This notebook walks through how to use LangChain for question answering over a list of documents. It covers three different types of chaings: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
"This notebook walks through how to use LangChain for question answering over a list of documents. It covers four different types of chaings: `stuff`, `map_reduce`, `refine`, `map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
]
},
{
@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"id": "17fcbc0f",
"metadata": {},
"outputs": [],
@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 12,
"id": "291f0117",
"metadata": {},
"outputs": [],
@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 13,
"id": "fd9666a9",
"metadata": {},
"outputs": [],
@ -59,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 14,
"id": "d1eaf6e6",
"metadata": {},
"outputs": [],
@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 15,
"id": "a16e3453",
"metadata": {},
"outputs": [],
@ -294,6 +294,94 @@
"source": [
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "markdown",
"id": "521a77cb",
"metadata": {},
"source": [
"## The `map-rerank` Chain\n",
"\n",
"This sections shows results of using the `map-rerank` Chain to do question answering with sources."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e2bfe203",
"metadata": {},
"outputs": [],
"source": [
"chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", return_intermediate_steps=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "5c28880c",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"results = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "80ac2db3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. '"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results[\"output_text\"]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b428fcb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'answer': ' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. ',\n",
" 'score': '100'},\n",
" {'answer': \" The president said that Justice Breyer is a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that since she's been nominated, she's received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans, and that she is a consensus builder.\",\n",
" 'score': '100'},\n",
" {'answer': ' The president did not mention Justice Breyer in this context.',\n",
" 'score': '0'},\n",
" {'answer': ' The president did not mention Justice Breyer in the given context. ',\n",
" 'score': '0'}]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results[\"intermediate_steps\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4f86521",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

@ -0,0 +1,113 @@
"""Combining documents by mapping a chain over them first, then reranking results."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, cast
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import RegexParser
class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Combining documents by mapping a chain over them, then reranking results."""
llm_chain: LLMChain
"""Chain to apply to each document individually."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
rank_key: str
"""Key in output of llm_chain to rank on."""
answer_key: str
"""Key in output of llm_chain to return as answer."""
metadata_keys: Optional[List[str]] = None
return_intermediate_steps: bool = False
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def output_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
_output_keys = super().output_keys
if self.return_intermediate_steps:
_output_keys = _output_keys + ["intermediate_steps"]
if self.metadata_keys is not None:
_output_keys += self.metadata_keys
return _output_keys
@root_validator()
def validate_llm_output(cls, values: Dict) -> Dict:
"""Validate that the combine chain outputs a dictionary."""
output_parser = values["llm_chain"].prompt.output_parser
if not isinstance(output_parser, RegexParser):
raise ValueError(
"Output parser of llm_chain should be a RegexParser,"
f" got {output_parser}"
)
output_keys = output_parser.output_keys
if values["rank_key"] not in output_keys:
raise ValueError(
f"Got {values['rank_key']} as key to rank on, but did not find "
f"it in the llm_chain output keys ({output_keys})"
)
if values["answer_key"] not in output_keys:
raise ValueError(
f"Got {values['answer_key']} as key to return, but did not find "
f"it in the llm_chain output keys ({output_keys})"
)
return values
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
if "document_variable_name" not in values:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain input_variables"
)
else:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
"""
results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
typed_results = cast(List[dict], results)
sorted_res = sorted(
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
)
output, document = sorted_res[0]
extra_info = {}
if self.metadata_keys is not None:
for key in self.metadata_keys:
extra_info[key] = document.metadata[key]
if self.return_intermediate_steps:
extra_info["intermediate_steps"] = results
return output[self.answer_key], extra_info

@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
@ -11,6 +12,7 @@ from langchain.chains.qa_with_sources import (
refine_prompts,
stuff_prompt,
)
from langchain.chains.question_answering import map_rerank_prompt
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -22,6 +24,25 @@ class LoadingCallable(Protocol):
"""Callable to load the combine documents chain."""
def _load_map_rerank_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
rank_key: str = "score",
answer_key: str = "answer",
**kwargs: Any,
) -> MapRerankDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
return MapRerankDocumentsChain(
llm_chain=llm_chain,
rank_key=rank_key,
answer_key=answer_key,
document_variable_name=document_variable_name,
**kwargs,
)
def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
@ -137,6 +158,7 @@ def load_qa_with_sources_chain(
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
"map_rerank": _load_map_rerank_chain,
}
if chain_type not in loader_mapping:
raise ValueError(

@ -3,11 +3,13 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import (
map_reduce_prompt,
map_rerank_prompt,
refine_prompts,
stuff_prompt,
)
@ -22,6 +24,25 @@ class LoadingCallable(Protocol):
"""Callable to load the combine documents chain."""
def _load_map_rerank_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
rank_key: str = "score",
answer_key: str = "answer",
**kwargs: Any,
) -> MapRerankDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
return MapRerankDocumentsChain(
llm_chain=llm_chain,
rank_key=rank_key,
answer_key=answer_key,
document_variable_name=document_variable_name,
**kwargs,
)
def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
@ -132,6 +153,7 @@ def load_qa_chain(
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
"map_rerank": _load_map_rerank_chain,
}
if chain_type not in loader_mapping:
raise ValueError(

@ -0,0 +1,66 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.base import RegexParser
output_parser = RegexParser(
regex=r"(.*?)\nScore: (.*)",
output_keys=["answer", "score"],
)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
In addition to giving an answer, also return a score of how fully it answered the user's question. This should be in the following format:
Question: [question here]
Helpful Answer: [answer here]
Score: [score between 0 and 100]
How to determine the score:
- Higher is a better answer
- Better responds fully to the asked question, with sufficient level of detail
- If you do not know the answer based on the context, that should be a score of 0
- Don't be overconfident!
Example #1
Context:
---------
Apples are red
---------
Question: what color are apples?
Helpful Answer: red
Score: 100
Example #2
Context:
---------
it was night and the witness forgot his glasses. he was not sure if it was a sports car or an suv
---------
Question: what type was the car?
Helpful Answer: a sports car or an suv
Score: 60
Example #3
Context:
---------
Pears are either red or orange
---------
Question: what color are apples?
Helpful Answer: This document does not answer the question
Score: 0
Begin!
Context:
---------
{context}
---------
Question: {question}
Helpful Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"],
output_parser=output_parser,
)
Loading…
Cancel
Save