Factor out doc formatting and add validation (#3026)

@cnhhoang850 slightly more generic fix for #2944, works for whatever the
expected metadata keys are not just `source`
This commit is contained in:
Davis Chase 2023-04-17 20:28:01 -07:00 committed by GitHub
parent 3453b7457c
commit 19c85aa990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 18 deletions

View File

@ -7,9 +7,28 @@ from pydantic import Field
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""

View File

@ -6,7 +6,10 @@ from typing import Any, Dict, List, Tuple
from pydantic import Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
@ -116,14 +119,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
base_inputs = {
self.document_variable_name: self.document_prompt.format(**document_info),
return {
self.document_variable_name: format_document(doc, self.document_prompt),
self.initial_response_name: res,
}
return base_inputs
def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any

View File

@ -4,7 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
@ -56,17 +59,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
# Get relevant information from each document.
doc_dicts = []
for doc in docs:
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
document_info = {
k: base_info[k] for k in self.document_prompt.input_variables
}
doc_dicts.append(document_info)
# Format each document according to the prompt
doc_strings = [self.document_prompt.format(**doc) for doc in doc_dicts]
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
inputs = {
k: v

View File

@ -4,6 +4,8 @@ from typing import Any, List
import pytest
from langchain import PromptTemplate
from langchain.chains.combine_documents.base import format_document
from langchain.chains.combine_documents.map_reduce import (
_collapse_docs,
_split_list_of_docs,
@ -116,3 +118,24 @@ def test__collapse_docs_metadata() -> None:
}
expected_output = Document(page_content="foobar", metadata=expected_metadata)
assert output == expected_output
def test_format_doc_with_metadata() -> None:
"""Test format doc on a valid document."""
doc = Document(page_content="foo", metadata={"bar": "baz"})
prompt = PromptTemplate(
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
)
expected_output = "foo, baz"
output = format_document(doc, prompt)
assert output == expected_output
def test_format_doc_missing_metadata() -> None:
"""Test format doc on a document with missing metadata."""
doc = Document(page_content="foo")
prompt = PromptTemplate(
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
)
with pytest.raises(ValueError):
format_document(doc, prompt)