Add document transformer abstraction (#3182)

Add DocumentTransformer abstraction so that in #2915 we don't have to
wrap TextSplitter and RedundantEmbeddingFilter (neither of which uses
the query) in the contextual doc compression abstractions. with this
change, doc filter (doc extractor, whatever we call it) would look
something like
```python
class BaseDocumentFilter(BaseDocumentTransformer[_RetrievedDocument], ABC):
  
  @abstractmethod
  def filter(self, documents: List[_RetrievedDocument], query: str) -> List[_RetrievedDocument]:
    ...
  
  def transform_documents(self, documents: List[_RetrievedDocument], query: Optional[str] = None, **kwargs: Any) -> List[_RetrievedDocument]:
    if query is None:
      raise ValueError("Must pass in non-null query to DocumentFilter")
    return self.filter(documents, query)
```
This commit is contained in:
Davis Chase 2023-04-19 16:05:05 -07:00 committed by GitHub
parent 74342ab209
commit 10e4b32ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 1 deletions

View File

@ -392,3 +392,18 @@ class OutputParserException(Exception):
""" """
pass pass
D = TypeVar("D", bound=Document)
class BaseDocumentTransformer(ABC, Generic[D]):
"""Base interface for transforming documents."""
@abstractmethod
def transform_documents(self, documents: List[D], **kwargs: Any) -> List[D]:
"""Transform a list of documents."""
@abstractmethod
async def atransform_documents(self, documents: List[D], **kwargs: Any) -> List[D]:
"""Asynchronously transform a list of documents."""

View File

@ -17,11 +17,12 @@ from typing import (
) )
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextSplitter(ABC): class TextSplitter(BaseDocumentTransformer[Document], ABC):
"""Interface for splitting text into chunks.""" """Interface for splitting text into chunks."""
def __init__( def __init__(
@ -171,6 +172,18 @@ class TextSplitter(ABC):
return cls(length_function=_tiktoken_encoder, **kwargs) return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: List[Document], **kwargs: Any
) -> List[Document]:
"""Transform list of documents by splitting them."""
return self.split_documents(documents)
async def atransform_documents(
self, documents: List[Document], **kwargs: Any
) -> List[Document]:
"""Asynchronously transform a list of documents by splitting them."""
raise NotImplementedError
class CharacterTextSplitter(TextSplitter): class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.""" """Implementation of splitting text that looks at characters."""