diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index 71b4bc13..9b7947c4 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -1,6 +1,7 @@ """DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" from __future__ import annotations +import asyncio from typing import Any, Callable, Dict, Optional, Sequence from langchain import LLMChain, PromptTemplate @@ -62,7 +63,21 @@ class LLMChainExtractor(BaseDocumentCompressor): async def acompress_documents( self, documents: Sequence[Document], query: str ) -> Sequence[Document]: - raise NotImplementedError + """Compress page content of raw documents asynchronously.""" + outputs = await asyncio.gather( + *[ + self.llm_chain.apredict_and_parse(**self.get_input(query, doc)) + for doc in documents + ] + ) + compressed_docs = [] + for i, doc in enumerate(documents): + if len(outputs[i]) == 0: + continue + compressed_docs.append( + Document(page_content=outputs[i], metadata=doc.metadata) + ) + return compressed_docs @classmethod def from_llm(