mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add document transformer abstraction (#3182)
Add DocumentTransformer abstraction so that in #2915 we don't have to wrap TextSplitter and RedundantEmbeddingFilter (neither of which uses the query) in the contextual doc compression abstractions. with this change, doc filter (doc extractor, whatever we call it) would look something like ```python class BaseDocumentFilter(BaseDocumentTransformer[_RetrievedDocument], ABC): @abstractmethod def filter(self, documents: List[_RetrievedDocument], query: str) -> List[_RetrievedDocument]: ... def transform_documents(self, documents: List[_RetrievedDocument], query: Optional[str] = None, **kwargs: Any) -> List[_RetrievedDocument]: if query is None: raise ValueError("Must pass in non-null query to DocumentFilter") return self.filter(documents, query) ```
This commit is contained in:
parent
74342ab209
commit
10e4b32ecb
@ -392,3 +392,18 @@ class OutputParserException(Exception):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
D = TypeVar("D", bound=Document)
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC, Generic[D]):
|
||||
"""Base interface for transforming documents."""
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(self, documents: List[D], **kwargs: Any) -> List[D]:
|
||||
"""Transform a list of documents."""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(self, documents: List[D], **kwargs: Any) -> List[D]:
|
||||
"""Asynchronously transform a list of documents."""
|
||||
|
@ -17,11 +17,12 @@ from typing import (
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseDocumentTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextSplitter(ABC):
|
||||
class TextSplitter(BaseDocumentTransformer[Document], ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(
|
||||
@ -171,6 +172,18 @@ class TextSplitter(ABC):
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
|
||||
def transform_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Transform list of documents by splitting them."""
|
||||
return self.split_documents(documents)
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously transform a list of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters."""
|
||||
|
Loading…
Reference in New Issue
Block a user