diff --git a/docs/examples/chains/qa_with_sources.ipynb b/docs/examples/chains/qa_with_sources.ipynb index 3fe6f2d709..1135e86c1f 100644 --- a/docs/examples/chains/qa_with_sources.ipynb +++ b/docs/examples/chains/qa_with_sources.ipynb @@ -197,17 +197,17 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "f60875c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament, which focused on building a coalition of freedom-loving nations to confront Putin, unifying European allies, countering Russia's lies with truth, and enforcing powerful economic sanctions. Source: 2\"}" + "{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament. He discussed the situation in Ukraine, the NATO Alliance, and the United States' response to Putin's attack on Ukraine. He spoke about the extensive preparation and coalition building that was done in advance of the attack, and the unified response from the European Union, Canada, Japan, Korea, Australia, New Zealand, and many other countries. He also discussed the economic sanctions that have been imposed on Russia, and the effects they have had on Putin's war fund. Source: 1, 2\"}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/examples/chains/question_answering.ipynb b/docs/examples/chains/question_answering.ipynb index 8da082076b..68ac550f4c 100644 --- a/docs/examples/chains/question_answering.ipynb +++ b/docs/examples/chains/question_answering.ipynb @@ -195,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "d8b5286e", "metadata": {}, "outputs": [ @@ -205,7 +205,7 @@ "{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament about building a coalition of freedom-loving nations to confront Putin, unifying European allies, countering Russia's lies with truth, and enforcing powerful economic sanctions.\"}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/examples/chains/summarize.ipynb b/docs/examples/chains/summarize.ipynb index e9aae85920..cd2fd71649 100644 --- a/docs/examples/chains/summarize.ipynb +++ b/docs/examples/chains/summarize.ipynb @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "27989fc4", "metadata": {}, "outputs": [], @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "da4d9801", "metadata": {}, "outputs": [ @@ -110,7 +110,7 @@ "' In his speech, President Biden addressed the ongoing conflict between Russia and Ukraine, and the need for the United States and its allies to stand with Ukraine. He also discussed the American Rescue Plan, the Bipartisan Infrastructure Law, and the Bipartisan Innovation Act, which will help to create jobs, modernize infrastructure, and level the playing field with China. He also emphasized the importance of buying American products to support American jobs.'" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -131,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "ef28e1d4", "metadata": {}, "outputs": [], @@ -141,17 +141,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "f82c5f9f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "' In response to Russian aggression in Ukraine, the US and its allies have imposed economic sanctions, cut off access to technology, seized assets of Russian oligarchs, and closed American airspace to Russian flights. The US is also providing military, economic, and humanitarian assistance to Ukraine, mobilizing ground forces, air squadrons, and ship deployments, and releasing 30 million barrels of oil from its Strategic Petroleum Reserve. President Biden has also passed the American Rescue Plan, Bipartisan Infrastructure Law, and Bipartisan Innovation Act to provide economic relief and rebuild America.'" + "' In response to Russian aggression in Ukraine, the US and its allies have imposed economic sanctions, cut off access to technology, seized assets of Russian oligarchs, and closed American airspace to Russian flights. The US is also providing military, economic, and humanitarian assistance to Ukraine, mobilizing ground forces, air squadrons, and ship deployments, and releasing 30 million barrels of oil from its Strategic Petroleum Reserve. President Biden has also passed the American Rescue Plan, Bipartisan Infrastructure Law, and Bipartisan Innovation Act to provide economic relief and create jobs.'" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -172,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "3bcbe31e", "metadata": {}, "outputs": [], @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "c8cad866", "metadata": {}, "outputs": [ @@ -192,7 +192,7 @@ "\"\\nIn this speech, the speaker addresses the American people and their allies, discussing the recent aggression of Russia's Vladimir Putin in Ukraine. The speaker outlines the actions taken by the United States and its allies to hold Putin accountable, including economic sanctions, cutting off access to technology, and seizing the assets of Russian oligarchs. The speaker also announces the closing of American airspace to Russian flights, further isolating Russia and adding an additional squeeze on their economy. The Russian stock market has lost 40% of its value and trading remains suspended. Together with our allies, the United States is providing military, economic, and humanitarian assistance to Ukraine, and has mobilized forces to protect NATO countries. The speaker also announces the release of 60 million barrels of oil from reserves around the world, with the United States releasing 30 million barrels from its own Strategic Petroleum Reserve. The speaker emphasizes that the United States and its allies will defend every inch of NATO territory and that Putin will pay a high price for his aggression. The speaker also acknowledges the hardships faced by the American people due to the pandemic and the American Rescue Plan, which has provided immediate economic relief for tens of millions of Americans, helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. The speaker\"" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 5651e60306..159abd7610 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -1,4 +1,6 @@ """Load question answering with sources chains.""" +from typing import Any, Mapping, Protocol + from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain @@ -6,50 +8,82 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.qa_with_sources import ( map_reduce_prompt, - refine_prompt, + refine_prompts, stuff_prompt, ) from langchain.llms.base import LLM +from langchain.prompts.base import BasePromptTemplate + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + + def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" -def _load_stuff_chain(llm: LLM) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=stuff_prompt.PROMPT) +def _load_stuff_chain( + llm: LLM, + prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_variable_name: str = "summaries", + **kwargs: Any, +) -> StuffDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) return StuffDocumentsChain( llm_chain=llm_chain, - document_variable_name="summaries", + document_variable_name=document_variable_name, document_prompt=stuff_prompt.EXAMPLE_PROMPT, + **kwargs, ) -def _load_map_reduce_chain(llm: LLM) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.QUESTION_PROMPT) - reduce_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.COMBINE_PROMPT) +def _load_map_reduce_chain( + llm: LLM, + question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, + combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, + document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, + combine_document_variable_name: str = "summaries", + map_reduce_document_variable_name: str = "context", + **kwargs: Any, +) -> MapReduceDocumentsChain: + map_chain = LLMChain(llm=llm, prompt=question_prompt) + reduce_chain = LLMChain(llm=llm, prompt=combine_prompt) combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, - document_variable_name="summaries", - document_prompt=map_reduce_prompt.EXAMPLE_PROMPT, + document_variable_name=combine_document_variable_name, + document_prompt=document_prompt, ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, - document_variable_name="context", + document_variable_name=map_reduce_document_variable_name, + **kwargs, ) -def _load_refine_chain(llm: LLM) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=refine_prompt.DEFAULT_TEXT_QA_PROMPT) - refine_chain = LLMChain(llm=llm, prompt=refine_prompt.DEFAULT_REFINE_PROMPT) +def _load_refine_chain( + llm: LLM, + question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, + refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, + document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, + document_variable_name: str = "context_str", + initial_response_name: str = "existing_answer", + **kwargs: Any, +) -> RefineDocumentsChain: + initial_chain = LLMChain(llm=llm, prompt=question_prompt) + refine_chain = LLMChain(llm=llm, prompt=refine_prompt) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, - document_variable_name="context_str", - initial_response_name="existing_answer", - document_prompt=refine_prompt.EXAMPLE_PROMPT, + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + document_prompt=document_prompt, + **kwargs, ) def load_qa_with_sources_chain( - llm: LLM, chain_type: str = "stuff" + llm: LLM, chain_type: str = "stuff", **kwargs: Any ) -> BaseCombineDocumentsChain: """Load question answering with sources chain. @@ -61,7 +95,7 @@ def load_qa_with_sources_chain( Returns: A chain to use for question answering with sources. """ - loader_mapping = { + loader_mapping: Mapping[str, LoadingCallable] = { "stuff": _load_stuff_chain, "map_reduce": _load_map_reduce_chain, "refine": _load_refine_chain, @@ -71,4 +105,5 @@ def load_qa_with_sources_chain( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm) + _func: LoadingCallable = loader_mapping[chain_type] + return _func(llm, **kwargs) diff --git a/langchain/chains/qa_with_sources/refine_prompt.py b/langchain/chains/qa_with_sources/refine_prompts.py similarity index 100% rename from langchain/chains/qa_with_sources/refine_prompt.py rename to langchain/chains/qa_with_sources/refine_prompts.py diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 89f698bb42..1591054caa 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -1,4 +1,6 @@ """Load question answering chains.""" +from typing import Any, Mapping, Protocol + from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain @@ -6,44 +8,77 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import ( map_reduce_prompt, - refine_prompt, + refine_prompts, stuff_prompt, ) from langchain.llms.base import LLM +from langchain.prompts.base import BasePromptTemplate + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" -def _load_stuff_chain(llm: LLM) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=stuff_prompt.PROMPT) + +def _load_stuff_chain( + llm: LLM, + prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_variable_name: str = "context", + **kwargs: Any, +) -> StuffDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) # TODO: document prompt - return StuffDocumentsChain(llm_chain=llm_chain, document_variable_name="context") + return StuffDocumentsChain( + llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs + ) -def _load_map_reduce_chain(llm: LLM) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.QUESTION_PROMPT) - reduce_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.COMBINE_PROMPT) +def _load_map_reduce_chain( + llm: LLM, + question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, + combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, + combine_document_variable_name: str = "summaries", + map_reduce_document_variable_name: str = "context", + **kwargs: Any, +) -> MapReduceDocumentsChain: + map_chain = LLMChain(llm=llm, prompt=question_prompt) + reduce_chain = LLMChain(llm=llm, prompt=combine_prompt) # TODO: document prompt combine_document_chain = StuffDocumentsChain( - llm_chain=reduce_chain, document_variable_name="summaries" + llm_chain=reduce_chain, document_variable_name=combine_document_variable_name ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, - document_variable_name="context", + document_variable_name=map_reduce_document_variable_name, + **kwargs, ) -def _load_refine_chain(llm: LLM) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=refine_prompt.DEFAULT_TEXT_QA_PROMPT) - refine_chain = LLMChain(llm=llm, prompt=refine_prompt.DEFAULT_REFINE_PROMPT) +def _load_refine_chain( + llm: LLM, + question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, + refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, + document_variable_name: str = "context_str", + initial_response_name: str = "existing_answer", + **kwargs: Any, +) -> RefineDocumentsChain: + initial_chain = LLMChain(llm=llm, prompt=question_prompt) + refine_chain = LLMChain(llm=llm, prompt=refine_prompt) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, - document_variable_name="context_str", - initial_response_name="existing_answer", + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + **kwargs, ) -def load_qa_chain(llm: LLM, chain_type: str = "stuff") -> BaseCombineDocumentsChain: +def load_qa_chain( + llm: LLM, chain_type: str = "stuff", **kwargs: Any +) -> BaseCombineDocumentsChain: """Load question answering chain. Args: @@ -54,7 +89,7 @@ def load_qa_chain(llm: LLM, chain_type: str = "stuff") -> BaseCombineDocumentsCh Returns: A chain to use for question answering. """ - loader_mapping = { + loader_mapping: Mapping[str, LoadingCallable] = { "stuff": _load_stuff_chain, "map_reduce": _load_map_reduce_chain, "refine": _load_refine_chain, @@ -64,4 +99,4 @@ def load_qa_chain(llm: LLM, chain_type: str = "stuff") -> BaseCombineDocumentsCh f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm) + return loader_mapping[chain_type](llm, **kwargs) diff --git a/langchain/chains/question_answering/refine_prompt.py b/langchain/chains/question_answering/refine_prompts.py similarity index 100% rename from langchain/chains/question_answering/refine_prompt.py rename to langchain/chains/question_answering/refine_prompts.py diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index a1743fc0e3..ab90c70b57 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -1,46 +1,79 @@ """Load summarizing chains.""" +from typing import Any, Mapping, Protocol + from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain -from langchain.chains.summarize import map_reduce_prompt, refine_prompt, stuff_prompt +from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt from langchain.llms.base import LLM +from langchain.prompts.base import BasePromptTemplate + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" -def _load_stuff_chain(llm: LLM) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=stuff_prompt.PROMPT) + +def _load_stuff_chain( + llm: LLM, + prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_variable_name: str = "text", + **kwargs: Any, +) -> StuffDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) # TODO: document prompt - return StuffDocumentsChain(llm_chain=llm_chain, document_variable_name="text") + return StuffDocumentsChain( + llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs + ) -def _load_map_reduce_chain(llm: LLM) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.PROMPT) - reduce_chain = LLMChain(llm=llm, prompt=map_reduce_prompt.PROMPT) +def _load_map_reduce_chain( + llm: LLM, + map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, + combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, + combine_document_variable_name: str = "text", + map_reduce_document_variable_name: str = "text", + **kwargs: Any, +) -> MapReduceDocumentsChain: + map_chain = LLMChain(llm=llm, prompt=map_prompt) + reduce_chain = LLMChain(llm=llm, prompt=combine_prompt) # TODO: document prompt combine_document_chain = StuffDocumentsChain( - llm_chain=reduce_chain, document_variable_name="text" + llm_chain=reduce_chain, document_variable_name=combine_document_variable_name ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, - document_variable_name="text", + document_variable_name=map_reduce_document_variable_name, + **kwargs, ) -def _load_refine_chain(llm: LLM) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=refine_prompt.PROMPT) - refine_chain = LLMChain(llm=llm, prompt=refine_prompt.REFINE_PROMPT) +def _load_refine_chain( + llm: LLM, + question_prompt: BasePromptTemplate = refine_prompts.PROMPT, + refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, + document_variable_name: str = "text", + initial_response_name: str = "existing_answer", + **kwargs: Any, +) -> RefineDocumentsChain: + initial_chain = LLMChain(llm=llm, prompt=question_prompt) + refine_chain = LLMChain(llm=llm, prompt=refine_prompt) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, - document_variable_name="text", - initial_response_name="existing_answer", + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + **kwargs, ) def load_summarize_chain( - llm: LLM, chain_type: str = "stuff" + llm: LLM, chain_type: str = "stuff", **kwargs: Any ) -> BaseCombineDocumentsChain: """Load summarizing chain. @@ -52,7 +85,7 @@ def load_summarize_chain( Returns: A chain to use for summarizing. """ - loader_mapping = { + loader_mapping: Mapping[str, LoadingCallable] = { "stuff": _load_stuff_chain, "map_reduce": _load_map_reduce_chain, "refine": _load_refine_chain, @@ -62,4 +95,4 @@ def load_summarize_chain( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm) + return loader_mapping[chain_type](llm, **kwargs) diff --git a/langchain/chains/summarize/refine_prompt.py b/langchain/chains/summarize/refine_prompts.py similarity index 100% rename from langchain/chains/summarize/refine_prompt.py rename to langchain/chains/summarize/refine_prompts.py