From 19c85aa9907765c0a2dbe7c46e9d5dd2d6df0f30 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 17 Apr 2023 20:28:01 -0700 Subject: [PATCH] Factor out doc formatting and add validation (#3026) @cnhhoang850 slightly more generic fix for #2944, works for whatever the expected metadata keys are not just `source` --- langchain/chains/combine_documents/base.py | 19 +++++++++++++++ langchain/chains/combine_documents/refine.py | 13 +++++------ langchain/chains/combine_documents/stuff.py | 16 ++++--------- .../chains/test_combine_documents.py | 23 +++++++++++++++++++ 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index dbf03a438e..ad6c0b6628 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -7,9 +7,28 @@ from pydantic import Field from langchain.chains.base import Chain from langchain.docstore.document import Document +from langchain.prompts.base import BasePromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter +def format_document(doc: Document, prompt: BasePromptTemplate) -> str: + """Format a document into a string based on a prompt template.""" + base_info = {"page_content": doc.page_content} + base_info.update(doc.metadata) + missing_metadata = set(prompt.input_variables).difference(base_info) + if len(missing_metadata) > 0: + required_metadata = [ + iv for iv in prompt.input_variables if iv != "page_content" + ] + raise ValueError( + f"Document prompt requires documents to have metadata variables: " + f"{required_metadata}. Received document with missing metadata: " + f"{list(missing_metadata)}." + ) + document_info = {k: base_info[k] for k in prompt.input_variables} + return prompt.format(**document_info) + + class BaseCombineDocumentsChain(Chain, ABC): """Base interface for chains combining documents.""" diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 6aba3bc792..7d1ae7ff03 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -6,7 +6,10 @@ from typing import Any, Dict, List, Tuple from pydantic import Extra, Field, root_validator -from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.base import ( + BaseCombineDocumentsChain, + format_document, +) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate @@ -116,14 +119,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): return res, extra_return_dict def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]: - base_info = {"page_content": doc.page_content} - base_info.update(doc.metadata) - document_info = {k: base_info[k] for k in self.document_prompt.input_variables} - base_inputs = { - self.document_variable_name: self.document_prompt.format(**document_info), + return { + self.document_variable_name: format_document(doc, self.document_prompt), self.initial_response_name: res, } - return base_inputs def _construct_initial_inputs( self, docs: List[Document], **kwargs: Any diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index efb5686290..237ecc2d4f 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -4,7 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple from pydantic import Extra, Field, root_validator -from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.base import ( + BaseCombineDocumentsChain, + format_document, +) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate @@ -56,17 +59,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return values def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: - # Get relevant information from each document. - doc_dicts = [] - for doc in docs: - base_info = {"page_content": doc.page_content} - base_info.update(doc.metadata) - document_info = { - k: base_info[k] for k in self.document_prompt.input_variables - } - doc_dicts.append(document_info) # Format each document according to the prompt - doc_strings = [self.document_prompt.format(**doc) for doc in doc_dicts] + doc_strings = [format_document(doc, self.document_prompt) for doc in docs] # Join the documents together to put them in the prompt. inputs = { k: v diff --git a/tests/unit_tests/chains/test_combine_documents.py b/tests/unit_tests/chains/test_combine_documents.py index 095f216d17..7377250352 100644 --- a/tests/unit_tests/chains/test_combine_documents.py +++ b/tests/unit_tests/chains/test_combine_documents.py @@ -4,6 +4,8 @@ from typing import Any, List import pytest +from langchain import PromptTemplate +from langchain.chains.combine_documents.base import format_document from langchain.chains.combine_documents.map_reduce import ( _collapse_docs, _split_list_of_docs, @@ -116,3 +118,24 @@ def test__collapse_docs_metadata() -> None: } expected_output = Document(page_content="foobar", metadata=expected_metadata) assert output == expected_output + + +def test_format_doc_with_metadata() -> None: + """Test format doc on a valid document.""" + doc = Document(page_content="foo", metadata={"bar": "baz"}) + prompt = PromptTemplate( + input_variables=["page_content", "bar"], template="{page_content}, {bar}" + ) + expected_output = "foo, baz" + output = format_document(doc, prompt) + assert output == expected_output + + +def test_format_doc_missing_metadata() -> None: + """Test format doc on a document with missing metadata.""" + doc = Document(page_content="foo") + prompt = PromptTemplate( + input_variables=["page_content", "bar"], template="{page_content}, {bar}" + ) + with pytest.raises(ValueError): + format_document(doc, prompt)