diff --git a/libs/langchain/langchain/chains/combine_documents/__init__.py b/libs/langchain/langchain/chains/combine_documents/__init__.py index 9c66d93432..6b038ec101 100644 --- a/libs/langchain/langchain/chains/combine_documents/__init__.py +++ b/libs/langchain/langchain/chains/combine_documents/__init__.py @@ -5,5 +5,11 @@ from langchain.chains.combine_documents.reduce import ( collapse_docs, split_list_of_docs, ) +from langchain.chains.combine_documents.stuff import create_stuff_documents_chain -__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"] +__all__ = [ + "acollapse_docs", + "collapse_docs", + "split_list_of_docs", + "create_stuff_documents_chain", +] diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 9dd964db71..019d9a6e2b 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Type from langchain_core.documents import Document +from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.runnables.config import RunnableConfig @@ -14,6 +15,18 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter +DEFAULT_DOCUMENT_SEPARATOR = "\n\n" +DOCUMENTS_KEY = "context" +DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}") + + +def _validate_prompt(prompt: BasePromptTemplate) -> None: + if DOCUMENTS_KEY not in prompt.input_variables: + raise ValueError( + f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt " + f"with input variables: {prompt.input_variables}" + ) + class BaseCombineDocumentsChain(Chain, ABC): """Base interface for chains combining documents. diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index a30d4a0e90..ccd228b070 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -1,21 +1,93 @@ """Chain that combines documents by stuffing into context.""" - from typing import Any, Dict, List, Optional, Tuple from langchain_core.documents import Document +from langchain_core.language_models import LanguageModelLike +from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate, format_document -from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.runnables import Runnable, RunnablePassthrough from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( + DEFAULT_DOCUMENT_PROMPT, + DEFAULT_DOCUMENT_SEPARATOR, + DOCUMENTS_KEY, BaseCombineDocumentsChain, + _validate_prompt, ) from langchain.chains.llm import LLMChain -def _get_default_document_prompt() -> PromptTemplate: - return PromptTemplate(input_variables=["page_content"], template="{page_content}") +def create_stuff_documents_chain( + llm: LanguageModelLike, + prompt: BasePromptTemplate, + *, + output_parser: Optional[BaseOutputParser] = None, + document_prompt: Optional[BasePromptTemplate] = None, + document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, +) -> Runnable[Dict[str, Any], Any]: + """Create a chain for passing a list of Documents to a model. + + Args: + llm: Language model. + prompt: Prompt template. Must contain input variable "context", which will be + used for passing in the formatted documents. + output_parser: Output parser. Defaults to StrOutputParser. + document_prompt: Prompt used for formatting each document into a string. Input + variables can be "page_content" or any metadata keys that are in all + documents. "page_content" will automatically retrieve the + `Document.page_content`, and all other inputs variables will be + automatically retrieved from the `Document.metadata` dictionary. Default to + a prompt that only contains `Document.page_content`. + document_separator: String separator to use between formatted document strings. + + Returns: + An LCEL Runnable. The input is a dictionary that must have a "context" key that + maps to a List[Document], and any other input variables expected in the prompt. + The Runnable return type depends on output_parser used. + + Example: + .. code-block:: python + + # pip install -U langchain langchain-community + + from langchain_community.chat_models import ChatOpenAI + from langchain_core.documents import Document + from langchain_core.prompts import ChatPromptTemplate + from langchain.chains.combine_documents import create_stuff_documents_chain + + prompt = ChatPromptTemplate.from_messages( + [("system", "What are everyone's favorite colors:\n\n{context}")] + ) + llm = ChatOpenAI(model_name="gpt-3.5-turbo") + chain = create_stuff_documents_chain(llm, prompt) + + docs = [ + Document(page_content="Jesse loves red but not yellow"), + Document(page_content = "Jamal loves green but not as much as he loves orange") + ] + + chain.invoke({"context": docs}) + """ # noqa: E501 + + _validate_prompt(prompt) + _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT + _output_parser = output_parser or StrOutputParser() + + def format_docs(inputs: dict) -> str: + return document_separator.join( + format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY] + ) + + return ( + RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config( + run_name="format_inputs" + ) + | prompt + | llm + | _output_parser + ).with_config(run_name="stuff_documents_chain") class StuffDocumentsChain(BaseCombineDocumentsChain): @@ -60,7 +132,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): """LLM chain which is called with the formatted document string, along with any other inputs.""" document_prompt: BasePromptTemplate = Field( - default_factory=_get_default_document_prompt + default_factory=lambda: DEFAULT_DOCUMENT_PROMPT ) """Prompt to use to format each document, gets passed to `format_document`.""" document_variable_name: str