From 10e4b32ecbde0a939146ee48130361c5c23dbc37 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Wed, 19 Apr 2023 16:05:05 -0700 Subject: [PATCH] Add document transformer abstraction (#3182) Add DocumentTransformer abstraction so that in #2915 we don't have to wrap TextSplitter and RedundantEmbeddingFilter (neither of which uses the query) in the contextual doc compression abstractions. with this change, doc filter (doc extractor, whatever we call it) would look something like ```python class BaseDocumentFilter(BaseDocumentTransformer[_RetrievedDocument], ABC): @abstractmethod def filter(self, documents: List[_RetrievedDocument], query: str) -> List[_RetrievedDocument]: ... def transform_documents(self, documents: List[_RetrievedDocument], query: Optional[str] = None, **kwargs: Any) -> List[_RetrievedDocument]: if query is None: raise ValueError("Must pass in non-null query to DocumentFilter") return self.filter(documents, query) ``` --- langchain/schema.py | 15 +++++++++++++++ langchain/text_splitter.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/langchain/schema.py b/langchain/schema.py index a2b709f138..65f530948c 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -392,3 +392,18 @@ class OutputParserException(Exception): """ pass + + +D = TypeVar("D", bound=Document) + + +class BaseDocumentTransformer(ABC, Generic[D]): + """Base interface for transforming documents.""" + + @abstractmethod + def transform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + """Transform a list of documents.""" + + @abstractmethod + async def atransform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + """Asynchronously transform a list of documents.""" diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index f35d3c1eb9..b9c3c24b59 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -17,11 +17,12 @@ from typing import ( ) from langchain.docstore.document import Document +from langchain.schema import BaseDocumentTransformer logger = logging.getLogger(__name__) -class TextSplitter(ABC): +class TextSplitter(BaseDocumentTransformer[Document], ABC): """Interface for splitting text into chunks.""" def __init__( @@ -171,6 +172,18 @@ class TextSplitter(ABC): return cls(length_function=_tiktoken_encoder, **kwargs) + def transform_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[Document]: + """Transform list of documents by splitting them.""" + return self.split_documents(documents) + + async def atransform_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[Document]: + """Asynchronously transform a list of documents by splitting them.""" + raise NotImplementedError + class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters."""