Remove overridden async not implemented method on embeddings filters and add default async implementation for document compressors (#11415)

@nfcampos @eyurtsev @baskaryan

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/5583/head
Jacob Lee 1 year ago committed by GitHub
parent 2f490be09b
commit 71fd6428c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -152,11 +152,6 @@ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
) )
return [stateful_documents[i] for i in sorted(included_idxs)] return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel): class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
"""Perform K-means clustering on document vectors. """Perform K-means clustering on document vectors.
@ -211,8 +206,3 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
) )
results = sorted(included_idxs) if self.sorted else included_idxs results = sorted(included_idxs) if self.sorted else included_idxs
return [stateful_documents[i] for i in results] return [stateful_documents[i] for i in results]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError

@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
@ -19,7 +20,6 @@ class BaseDocumentCompressor(BaseModel, ABC):
) -> Sequence[Document]: ) -> Sequence[Document]:
"""Compress retrieved documents given the query context.""" """Compress retrieved documents given the query context."""
@abstractmethod
async def acompress_documents( async def acompress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
@ -27,6 +27,9 @@ class BaseDocumentCompressor(BaseModel, ABC):
callbacks: Optional[Callbacks] = None, callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]: ) -> Sequence[Document]:
"""Compress retrieved documents given the query context.""" """Compress retrieved documents given the query context."""
return await asyncio.get_running_loop().run_in_executor(
None, self.compress_documents, documents, query, callbacks
)
class DocumentCompressorPipeline(BaseDocumentCompressor): class DocumentCompressorPipeline(BaseDocumentCompressor):

@ -53,15 +53,6 @@ class LLMChainFilter(BaseDocumentCompressor):
filtered_docs.append(doc) filtered_docs.append(doc)
return filtered_docs return filtered_docs
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError()
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,

@ -82,11 +82,3 @@ class CohereRerank(BaseDocumentCompressor):
doc.metadata["relevance_score"] = r.relevance_score doc.metadata["relevance_score"] = r.relevance_score
final_results.append(doc) final_results.append(doc)
return final_results return final_results
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
raise NotImplementedError()

@ -68,12 +68,3 @@ class EmbeddingsFilter(BaseDocumentCompressor):
) )
included_idxs = included_idxs[similar_enough] included_idxs = included_idxs[similar_enough]
return [stateful_documents[i] for i in included_idxs] return [stateful_documents[i] for i in included_idxs]
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError()

Loading…
Cancel
Save