mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Factor out doc formatting and add validation (#3026)
@cnhhoang850 slightly more generic fix for #2944, works for whatever the expected metadata keys are not just `source`
This commit is contained in:
parent
3453b7457c
commit
19c85aa990
@ -7,9 +7,28 @@ from pydantic import Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
|
@ -6,7 +6,10 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@ -116,14 +119,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
return res, extra_return_dict
|
||||
|
||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||
base_inputs = {
|
||||
self.document_variable_name: self.document_prompt.format(**document_info),
|
||||
return {
|
||||
self.document_variable_name: format_document(doc, self.document_prompt),
|
||||
self.initial_response_name: res,
|
||||
}
|
||||
return base_inputs
|
||||
|
||||
def _construct_initial_inputs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
|
@ -4,7 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@ -56,17 +59,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
# Get relevant information from each document.
|
||||
doc_dicts = []
|
||||
for doc in docs:
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
document_info = {
|
||||
k: base_info[k] for k in self.document_prompt.input_variables
|
||||
}
|
||||
doc_dicts.append(document_info)
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [self.document_prompt.format(**doc) for doc in doc_dicts]
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
|
@ -4,6 +4,8 @@ from typing import Any, List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chains.combine_documents.base import format_document
|
||||
from langchain.chains.combine_documents.map_reduce import (
|
||||
_collapse_docs,
|
||||
_split_list_of_docs,
|
||||
@ -116,3 +118,24 @@ def test__collapse_docs_metadata() -> None:
|
||||
}
|
||||
expected_output = Document(page_content="foobar", metadata=expected_metadata)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_format_doc_with_metadata() -> None:
|
||||
"""Test format doc on a valid document."""
|
||||
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
||||
)
|
||||
expected_output = "foo, baz"
|
||||
output = format_document(doc, prompt)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_format_doc_missing_metadata() -> None:
|
||||
"""Test format doc on a document with missing metadata."""
|
||||
doc = Document(page_content="foo")
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
format_document(doc, prompt)
|
||||
|
Loading…
Reference in New Issue
Block a user