diff --git a/langchain/schema.py b/langchain/schema.py index a2b709f1..65f53094 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -392,3 +392,18 @@ class OutputParserException(Exception): """ pass + + +D = TypeVar("D", bound=Document) + + +class BaseDocumentTransformer(ABC, Generic[D]): + """Base interface for transforming documents.""" + + @abstractmethod + def transform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + """Transform a list of documents.""" + + @abstractmethod + async def atransform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + """Asynchronously transform a list of documents.""" diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index f35d3c1e..b9c3c24b 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -17,11 +17,12 @@ from typing import ( ) from langchain.docstore.document import Document +from langchain.schema import BaseDocumentTransformer logger = logging.getLogger(__name__) -class TextSplitter(ABC): +class TextSplitter(BaseDocumentTransformer[Document], ABC): """Interface for splitting text into chunks.""" def __init__( @@ -171,6 +172,18 @@ class TextSplitter(ABC): return cls(length_function=_tiktoken_encoder, **kwargs) + def transform_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[Document]: + """Transform list of documents by splitting them.""" + return self.split_documents(documents) + + async def atransform_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[Document]: + """Asynchronously transform a list of documents by splitting them.""" + raise NotImplementedError + class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters."""