|
|
|
@ -6,7 +6,10 @@ from typing import Any, Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
from pydantic import Extra, Field, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
|
|
|
from langchain.chains.combine_documents.base import (
|
|
|
|
|
BaseCombineDocumentsChain,
|
|
|
|
|
format_document,
|
|
|
|
|
)
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
@ -116,14 +119,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|
|
|
|
return res, extra_return_dict
|
|
|
|
|
|
|
|
|
|
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
|
|
|
|
base_info = {"page_content": doc.page_content}
|
|
|
|
|
base_info.update(doc.metadata)
|
|
|
|
|
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
|
|
|
|
base_inputs = {
|
|
|
|
|
self.document_variable_name: self.document_prompt.format(**document_info),
|
|
|
|
|
return {
|
|
|
|
|
self.document_variable_name: format_document(doc, self.document_prompt),
|
|
|
|
|
self.initial_response_name: res,
|
|
|
|
|
}
|
|
|
|
|
return base_inputs
|
|
|
|
|
|
|
|
|
|
def _construct_initial_inputs(
|
|
|
|
|
self, docs: List[Document], **kwargs: Any
|
|
|
|
|