From bc875a9df16d17db531f9e363c18ed8b5ebbc047 Mon Sep 17 00:00:00 2001 From: Chandan Routray Date: Sun, 4 Jun 2023 03:11:03 +0530 Subject: [PATCH] Fixed multi input prompt for MapReduceChain (#4979) # Fixed multi input prompt for MapReduceChain Added `kwargs` support for inner chains of `MapReduceChain` via `from_params` method Currently the `from_method` method of intialising `MapReduceChain` chain doesn't work if prompt has multiple inputs. It happens because it uses `StuffDocumentsChain` and `MapReduceDocumentsChain` underneath, both of them require specifying `document_variable_name` if `prompt` of their `llm_chain` has more than one `input`. With this PR, I have added support for passing their respective `kwargs` via the `from_params` method. ## Fixes https://github.com/hwchase17/langchain/issues/4752 ## Who can review? @dev2049 @hwchase17 @agola11 --------- Co-authored-by: imeckr --- .../chains/index_examples/summarize.ipynb | 139 +++++++++++++++++- langchain/chains/mapreduce.py | 20 ++- 2 files changed, 153 insertions(+), 6 deletions(-) diff --git a/docs/modules/chains/index_examples/summarize.ipynb b/docs/modules/chains/index_examples/summarize.ipynb index 6b5357c4..429a9fbb 100644 --- a/docs/modules/chains/index_examples/summarize.ipynb +++ b/docs/modules/chains/index_examples/summarize.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "e9db25f3", "metadata": {}, "outputs": [], @@ -318,6 +318,141 @@ "chain({\"input_documents\": docs}, return_only_outputs=True)" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b882e209", + "metadata": {}, + "source": [ + "## The custom `MapReduceChain`\n", + "\n", + "**Multi input prompt**\n", + "\n", + "You can also use prompt with multi input. In this example, we will use a MapReduce chain to answer specifc question about our code." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7ad9ee2", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain\n", + "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", + "\n", + "map_template_string = \"\"\"Give the following python code information, generate a description that explains what the code does and also mention the time complexity.\n", + "Code:\n", + "{code}\n", + "\n", + "Return the the description in the following format:\n", + "name of the function: description of the function\n", + "\"\"\"\n", + "\n", + "\n", + "reduce_template_string = \"\"\"Give the following following python fuctions name and their descritpion, answer the following question\n", + "{code_description}\n", + "Question: {question}\n", + "Answer:\n", + "\"\"\"\n", + "\n", + "MAP_PROMPT = PromptTemplate(input_variables=[\"code\"], template=map_template_string)\n", + "REDUCE_PROMPT = PromptTemplate(input_variables=[\"code_description\", \"question\"], template=reduce_template_string)\n", + "\n", + "llm = OpenAI()\n", + "\n", + "map_llm_chain = LLMChain(llm=llm, prompt=MAP_PROMPT)\n", + "reduce_llm_chain = LLMChain(llm=llm, prompt=REDUCE_PROMPT)\n", + "\n", + "generative_result_reduce_chain = StuffDocumentsChain(\n", + " llm_chain=reduce_llm_chain,\n", + " document_variable_name=\"code_description\",\n", + ")\n", + "\n", + "combine_documents = MapReduceDocumentsChain(\n", + " llm_chain=map_llm_chain,\n", + " combine_document_chain=generative_result_reduce_chain,\n", + " document_variable_name=\"code\",\n", + ")\n", + "\n", + "map_reduce = MapReduceChain(\n", + " combine_documents_chain=combine_documents,\n", + " text_splitter=CharacterTextSplitter(separator=\"\\n##\\n\", chunk_size=100, chunk_overlap=0),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0d4caccb", + "metadata": {}, + "outputs": [], + "source": [ + "code = \"\"\"\n", + "def bubblesort(list):\n", + " for iter_num in range(len(list)-1,0,-1):\n", + " for idx in range(iter_num):\n", + " if list[idx]>list[idx+1]:\n", + " temp = list[idx]\n", + " list[idx] = list[idx+1]\n", + " list[idx+1] = temp\n", + " return list\n", + "##\n", + "def insertion_sort(InputList):\n", + " for i in range(1, len(InputList)):\n", + " j = i-1\n", + " nxt_element = InputList[i]\n", + " while (InputList[j] > nxt_element) and (j >= 0):\n", + " InputList[j+1] = InputList[j]\n", + " j=j-1\n", + " InputList[j+1] = nxt_element\n", + " return InputList\n", + "##\n", + "def shellSort(input_list):\n", + " gap = len(input_list) // 2\n", + " while gap > 0:\n", + " for i in range(gap, len(input_list)):\n", + " temp = input_list[i]\n", + " j = i\n", + " while j >= gap and input_list[j - gap] > temp:\n", + " input_list[j] = input_list[j - gap]\n", + " j = j-gap\n", + " input_list[j] = temp\n", + " gap = gap//2\n", + " return input_list\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d5a9a35b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Created a chunk of size 247, which is longer than the specified 100\n", + "Created a chunk of size 267, which is longer than the specified 100\n" + ] + }, + { + "data": { + "text/plain": [ + "'shellSort has a better time complexity than both bubblesort and insertion_sort, as it has a time complexity of O(n^2), while the other two have a time complexity of O(n^2).'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "map_reduce.run(input_text=code, question=\"Which function has a better time complexity?\")" + ] + }, { "cell_type": "markdown", "id": "f61350f9", @@ -470,7 +605,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.8.16" }, "vscode": { "interpreter": { diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 768342d1..5474fc79 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -5,7 +5,7 @@ then combines the results with another one. """ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra @@ -38,15 +38,22 @@ class MapReduceChain(Chain): prompt: BasePromptTemplate, text_splitter: TextSplitter, callbacks: Callbacks = None, + combine_chain_kwargs: Optional[Mapping[str, Any]] = None, + reduce_chain_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> MapReduceChain: """Construct a map-reduce chain that uses the chain for map and reduce.""" llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) - reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks) + reduce_chain = StuffDocumentsChain( + llm_chain=llm_chain, + callbacks=callbacks, + **(reduce_chain_kwargs if reduce_chain_kwargs else {}), + ) combine_documents_chain = MapReduceDocumentsChain( llm_chain=llm_chain, combine_document_chain=reduce_chain, callbacks=callbacks, + **(combine_chain_kwargs if combine_chain_kwargs else {}), ) return cls( combine_documents_chain=combine_documents_chain, @@ -84,9 +91,14 @@ class MapReduceChain(Chain): ) -> Dict[str, str]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() # Split the larger text into smaller chunks. - texts = self.text_splitter.split_text(inputs[self.input_key]) + doc_text = inputs.pop(self.input_key) + texts = self.text_splitter.split_text(doc_text) docs = [Document(page_content=text) for text in texts] + _inputs: Dict[str, Any] = { + **inputs, + self.combine_documents_chain.input_key: docs, + } outputs = self.combine_documents_chain.run( - input_documents=docs, callbacks=_run_manager.get_child() + _inputs, callbacks=_run_manager.get_child() ) return {self.output_key: outputs}