Fixed multi input prompt for MapReduceChain (#4979)

# Fixed multi input prompt for MapReduceChain

Added `kwargs` support for inner chains of `MapReduceChain` via
`from_params` method
Currently the `from_method` method of intialising `MapReduceChain` chain
doesn't work if prompt has multiple inputs. It happens because it uses
`StuffDocumentsChain` and `MapReduceDocumentsChain` underneath, both of
them require specifying `document_variable_name` if `prompt` of their
`llm_chain` has more than one `input`.

With this PR, I have added support for passing their respective `kwargs`
via the `from_params` method.

## Fixes https://github.com/hwchase17/langchain/issues/4752

## Who can review? 
@dev2049 @hwchase17 @agola11

---------

Co-authored-by: imeckr <chandanroutray2012@gmail.com>
searx_updates
Chandan Routray 12 months ago committed by GitHub
parent a97e4252e3
commit bc875a9df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "e9db25f3",
"metadata": {},
"outputs": [],
@ -318,6 +318,141 @@
"chain({\"input_documents\": docs}, return_only_outputs=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b882e209",
"metadata": {},
"source": [
"## The custom `MapReduceChain`\n",
"\n",
"**Multi input prompt**\n",
"\n",
"You can also use prompt with multi input. In this example, we will use a MapReduce chain to answer specifc question about our code."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f7ad9ee2",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"\n",
"map_template_string = \"\"\"Give the following python code information, generate a description that explains what the code does and also mention the time complexity.\n",
"Code:\n",
"{code}\n",
"\n",
"Return the the description in the following format:\n",
"name of the function: description of the function\n",
"\"\"\"\n",
"\n",
"\n",
"reduce_template_string = \"\"\"Give the following following python fuctions name and their descritpion, answer the following question\n",
"{code_description}\n",
"Question: {question}\n",
"Answer:\n",
"\"\"\"\n",
"\n",
"MAP_PROMPT = PromptTemplate(input_variables=[\"code\"], template=map_template_string)\n",
"REDUCE_PROMPT = PromptTemplate(input_variables=[\"code_description\", \"question\"], template=reduce_template_string)\n",
"\n",
"llm = OpenAI()\n",
"\n",
"map_llm_chain = LLMChain(llm=llm, prompt=MAP_PROMPT)\n",
"reduce_llm_chain = LLMChain(llm=llm, prompt=REDUCE_PROMPT)\n",
"\n",
"generative_result_reduce_chain = StuffDocumentsChain(\n",
" llm_chain=reduce_llm_chain,\n",
" document_variable_name=\"code_description\",\n",
")\n",
"\n",
"combine_documents = MapReduceDocumentsChain(\n",
" llm_chain=map_llm_chain,\n",
" combine_document_chain=generative_result_reduce_chain,\n",
" document_variable_name=\"code\",\n",
")\n",
"\n",
"map_reduce = MapReduceChain(\n",
" combine_documents_chain=combine_documents,\n",
" text_splitter=CharacterTextSplitter(separator=\"\\n##\\n\", chunk_size=100, chunk_overlap=0),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0d4caccb",
"metadata": {},
"outputs": [],
"source": [
"code = \"\"\"\n",
"def bubblesort(list):\n",
" for iter_num in range(len(list)-1,0,-1):\n",
" for idx in range(iter_num):\n",
" if list[idx]>list[idx+1]:\n",
" temp = list[idx]\n",
" list[idx] = list[idx+1]\n",
" list[idx+1] = temp\n",
" return list\n",
"##\n",
"def insertion_sort(InputList):\n",
" for i in range(1, len(InputList)):\n",
" j = i-1\n",
" nxt_element = InputList[i]\n",
" while (InputList[j] > nxt_element) and (j >= 0):\n",
" InputList[j+1] = InputList[j]\n",
" j=j-1\n",
" InputList[j+1] = nxt_element\n",
" return InputList\n",
"##\n",
"def shellSort(input_list):\n",
" gap = len(input_list) // 2\n",
" while gap > 0:\n",
" for i in range(gap, len(input_list)):\n",
" temp = input_list[i]\n",
" j = i\n",
" while j >= gap and input_list[j - gap] > temp:\n",
" input_list[j] = input_list[j - gap]\n",
" j = j-gap\n",
" input_list[j] = temp\n",
" gap = gap//2\n",
" return input_list\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d5a9a35b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Created a chunk of size 247, which is longer than the specified 100\n",
"Created a chunk of size 267, which is longer than the specified 100\n"
]
},
{
"data": {
"text/plain": [
"'shellSort has a better time complexity than both bubblesort and insertion_sort, as it has a time complexity of O(n^2), while the other two have a time complexity of O(n^2).'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"map_reduce.run(input_text=code, question=\"Which function has a better time complexity?\")"
]
},
{
"cell_type": "markdown",
"id": "f61350f9",
@ -470,7 +605,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.16"
},
"vscode": {
"interpreter": {

@ -5,7 +5,7 @@ then combines the results with another one.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra
@ -38,15 +38,22 @@ class MapReduceChain(Chain):
prompt: BasePromptTemplate,
text_splitter: TextSplitter,
callbacks: Callbacks = None,
combine_chain_kwargs: Optional[Mapping[str, Any]] = None,
reduce_chain_kwargs: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(
llm_chain=llm_chain,
callbacks=callbacks,
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
combine_document_chain=reduce_chain,
callbacks=callbacks,
**(combine_chain_kwargs if combine_chain_kwargs else {}),
)
return cls(
combine_documents_chain=combine_documents_chain,
@ -84,9 +91,14 @@ class MapReduceChain(Chain):
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
doc_text = inputs.pop(self.input_key)
texts = self.text_splitter.split_text(doc_text)
docs = [Document(page_content=text) for text in texts]
_inputs: Dict[str, Any] = {
**inputs,
self.combine_documents_chain.input_key: docs,
}
outputs = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child()
_inputs, callbacks=_run_manager.get_child()
)
return {self.output_key: outputs}

Loading…
Cancel
Save