Remove overridden async not implemented method on embeddings filters and add default async implementation for document compressors (#11415)

@nfcampos @eyurtsev @baskaryan

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/5583/head
Jacob Lee 1 year ago committed by GitHub
parent 2f490be09b
commit 71fd6428c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -152,11 +152,6 @@ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
"""Perform K-means clustering on document vectors.
@ -211,8 +206,3 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
)
results = sorted(included_idxs) if self.sorted else included_idxs
return [stateful_documents[i] for i in results]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError

@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod
from inspect import signature
from typing import List, Optional, Sequence, Union
@ -19,7 +20,6 @@ class BaseDocumentCompressor(BaseModel, ABC):
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
@abstractmethod
async def acompress_documents(
self,
documents: Sequence[Document],
@ -27,6 +27,9 @@ class BaseDocumentCompressor(BaseModel, ABC):
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
return await asyncio.get_running_loop().run_in_executor(
None, self.compress_documents, documents, query, callbacks
)
class DocumentCompressorPipeline(BaseDocumentCompressor):

@ -53,15 +53,6 @@ class LLMChainFilter(BaseDocumentCompressor):
filtered_docs.append(doc)
return filtered_docs
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError()
@classmethod
def from_llm(
cls,

@ -82,11 +82,3 @@ class CohereRerank(BaseDocumentCompressor):
doc.metadata["relevance_score"] = r.relevance_score
final_results.append(doc)
return final_results
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
raise NotImplementedError()

@ -68,12 +68,3 @@ class EmbeddingsFilter(BaseDocumentCompressor):
)
included_idxs = included_idxs[similar_enough]
return [stateful_documents[i] for i in included_idxs]
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError()

Loading…
Cancel
Save