|
|
|
@ -2,7 +2,7 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
|
|
|
|
|
@ -11,6 +11,13 @@ from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CombineDocsProtocol(Protocol):
|
|
|
|
|
"""Interface for the combine_docs method."""
|
|
|
|
|
|
|
|
|
|
def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
|
|
|
|
"""Interface for the combine_docs method."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_list_of_docs(
|
|
|
|
|
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
|
|
|
|
) -> List[List[Document]]:
|
|
|
|
@ -38,10 +45,10 @@ def _split_list_of_docs(
|
|
|
|
|
|
|
|
|
|
def _collapse_docs(
|
|
|
|
|
docs: List[Document],
|
|
|
|
|
combine_document_func: Callable,
|
|
|
|
|
combine_document_func: CombineDocsProtocol,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Document:
|
|
|
|
|
result = combine_document_func(docs, **kwargs)
|
|
|
|
|
result, _ = combine_document_func(docs, **kwargs)
|
|
|
|
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
|
|
|
|
for doc in docs[1:]:
|
|
|
|
|
for k, v in doc.metadata.items():
|
|
|
|
|