From 71fd6428c52a496043392dc11f3f95bf271b35c2 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 5 Oct 2023 07:56:03 -0700 Subject: [PATCH] Remove overridden async not implemented method on embeddings filters and add default async implementation for document compressors (#11415) @nfcampos @eyurtsev @baskaryan --------- Co-authored-by: Nuno Campos --- .../embeddings_redundant_filter.py | 10 ---------- .../langchain/retrievers/document_compressors/base.py | 5 ++++- .../retrievers/document_compressors/chain_filter.py | 9 --------- .../retrievers/document_compressors/cohere_rerank.py | 8 -------- .../document_compressors/embeddings_filter.py | 9 --------- 5 files changed, 4 insertions(+), 37 deletions(-) diff --git a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py index c492ce6521..1ef8175f4c 100644 --- a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py +++ b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py @@ -152,11 +152,6 @@ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): ) return [stateful_documents[i] for i in sorted(included_idxs)] - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - raise NotImplementedError - class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel): """Perform K-means clustering on document vectors. @@ -211,8 +206,3 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel): ) results = sorted(included_idxs) if self.sorted else included_idxs return [stateful_documents[i] for i in results] - - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index abb5b02c6c..a468f09701 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from inspect import signature from typing import List, Optional, Sequence, Union @@ -19,7 +20,6 @@ class BaseDocumentCompressor(BaseModel, ABC): ) -> Sequence[Document]: """Compress retrieved documents given the query context.""" - @abstractmethod async def acompress_documents( self, documents: Sequence[Document], @@ -27,6 +27,9 @@ class BaseDocumentCompressor(BaseModel, ABC): callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Compress retrieved documents given the query context.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.compress_documents, documents, query, callbacks + ) class DocumentCompressorPipeline(BaseDocumentCompressor): diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 716909fab8..7d507a6e9f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -53,15 +53,6 @@ class LLMChainFilter(BaseDocumentCompressor): filtered_docs.append(doc) return filtered_docs - async def acompress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - """Filter down documents.""" - raise NotImplementedError() - @classmethod def from_llm( cls, diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index 35de84432e..8199fa46dd 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -82,11 +82,3 @@ class CohereRerank(BaseDocumentCompressor): doc.metadata["relevance_score"] = r.relevance_score final_results.append(doc) return final_results - - async def acompress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - raise NotImplementedError() diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 3547b7d643..001b9494a0 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -68,12 +68,3 @@ class EmbeddingsFilter(BaseDocumentCompressor): ) included_idxs = included_idxs[similar_enough] return [stateful_documents[i] for i in included_idxs] - - async def acompress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - """Filter down documents.""" - raise NotImplementedError()