langchain[minor]: Add stuff docs runnable (#15178)

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/15181/head
Bagatur 6 months ago committed by GitHub
parent 63916cfe35
commit 56fad2e8ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,5 +5,11 @@ from langchain.chains.combine_documents.reduce import (
collapse_docs,
split_list_of_docs,
)
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"]
__all__ = [
"acollapse_docs",
"collapse_docs",
"split_list_of_docs",
"create_stuff_documents_chain",
]

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables.config import RunnableConfig
@ -14,6 +15,18 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
def _validate_prompt(prompt: BasePromptTemplate) -> None:
if DOCUMENTS_KEY not in prompt.input_variables:
raise ValueError(
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
f"with input variables: {prompt.input_variables}"
)
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents.

@ -1,21 +1,93 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
)
from langchain.chains.llm import LLMChain
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
def create_stuff_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model.
Args:
llm: Language model.
prompt: Prompt template. Must contain input variable "context", which will be
used for passing in the formatted documents.
output_parser: Output parser. Defaults to StrOutputParser.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
document_separator: String separator to use between formatted document strings.
Returns:
An LCEL Runnable. The input is a dictionary that must have a "context" key that
maps to a List[Document], and any other input variables expected in the prompt.
The Runnable return type depends on output_parser used.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\n\n{context}")]
)
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs})
""" # noqa: E501
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
class StuffDocumentsChain(BaseCombineDocumentsChain):
@ -60,7 +132,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"""LLM chain which is called with the formatted document string,
along with any other inputs."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
)
"""Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str

Loading…
Cancel
Save