From 23c22fcbc90232a1792b457384eb848c634eaf18 Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Wed, 12 Jun 2024 18:27:26 +0200 Subject: [PATCH] langchain[minor]: Make EmbeddingsFilters async (#22737) Add native async implementation for EmbeddingsFilter --- .../embeddings_redundant_filter.py | 14 ++++++++ .../document_compressors/test_base.py | 21 +++++++++++ .../test_embeddings_filter.py | 34 ++++++++++++++++++ .../retrievers/test_contextual_compression.py | 23 ++++++++++++ .../document_compressors/embeddings_filter.py | 35 +++++++++++++++++++ 5 files changed, 127 insertions(+) diff --git a/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py b/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py index 8e6cc89dc0..9352925e6a 100644 --- a/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py +++ b/libs/community/langchain_community/document_transformers/embeddings_redundant_filter.py @@ -75,6 +75,20 @@ def _get_embeddings_from_stateful_docs( return embedded_documents +async def _aget_embeddings_from_stateful_docs( + embeddings: Embeddings, documents: Sequence[_DocumentWithState] +) -> List[List[float]]: + if len(documents) and "embedded_doc" in documents[0].state: + embedded_documents = [doc.state["embedded_doc"] for doc in documents] + else: + embedded_documents = await embeddings.aembed_documents( + [d.page_content for d in documents] + ) + for doc, embedding in zip(documents, embedded_documents): + doc.state["embedded_doc"] = embedding + return embedded_documents + + def _filter_cluster_embeddings( embedded_documents: List[List[float]], num_clusters: int, diff --git a/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py index 500e32a392..3326c844dd 100644 --- a/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py @@ -27,3 +27,24 @@ def test_document_compressor_pipeline() -> None: actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals") assert len(actual) == 1 assert actual[0].page_content in texts[:2] + + +async def test_adocument_compressor_pipeline() -> None: + embeddings = OpenAIEmbeddings() + splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ") + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8) + pipeline_filter = DocumentCompressorPipeline( + transformers=[splitter, redundant_filter, relevant_filter] + ) + texts = [ + "This sentence is about cows", + "This sentence was about cows", + "foo bar baz", + ] + docs = [Document(page_content=". ".join(texts))] + actual = await pipeline_filter.acompress_documents( + docs, "Tell me about farm animals" + ) + assert len(actual) == 1 + assert actual[0].page_content in texts[:2] diff --git a/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py index d90f09ff31..1dd0da33a1 100644 --- a/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py @@ -23,6 +23,20 @@ def test_embeddings_filter() -> None: assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2 +async def atest_embeddings_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "I wish there were better Italian restaurants in my neighborhood.", + "My favorite color is green", + ] + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + actual = relevant_filter.compress_documents(docs, "What did I say about food?") + assert len(actual) == 2 + assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2 + + def test_embeddings_filter_with_state() -> None: texts = [ "What happened to all of my cookies?", @@ -41,3 +55,23 @@ def test_embeddings_filter_with_state() -> None: actual = relevant_filter.compress_documents(docs, query) assert len(actual) == 1 assert texts[-1] == actual[0].page_content + + +async def test_aembeddings_filter_with_state() -> None: + texts = [ + "What happened to all of my cookies?", + "I wish there were better Italian restaurants in my neighborhood.", + "My favorite color is green", + ] + query = "What did I say about food?" + embeddings = OpenAIEmbeddings() + embedded_query = embeddings.embed_query(query) + state = {"embedded_doc": np.zeros(len(embedded_query))} + docs = [_DocumentWithState(page_content=t, state=state) for t in texts] + docs[-1].state = {"embedded_doc": embedded_query} + relevant_filter = EmbeddingsFilter( # type: ignore[call-arg] + embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True + ) + actual = relevant_filter.compress_documents(docs, query) + assert len(actual) == 1 + assert texts[-1] == actual[0].page_content diff --git a/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py b/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py index 203cd222d4..fb643a81c2 100644 --- a/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py +++ b/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py @@ -1,3 +1,4 @@ +import pytest from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter @@ -24,3 +25,25 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None: actual = retriever.invoke("Tell me about the Celtics") assert len(actual) == 2 assert texts[-1] not in [d.page_content for d in actual] + + +@pytest.mark.asyncio +async def test_acontextual_compression_retriever_get_relevant_docs() -> None: + """Test get_relevant_docs.""" + texts = [ + "This is a document about the Boston Celtics", + "The Boston Celtics won the game by 20 points", + "I simply love going to the movies", + ] + embeddings = OpenAIEmbeddings() + base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever( + search_kwargs={"k": len(texts)} + ) + retriever = ContextualCompressionRetriever( + base_compressor=base_compressor, base_retriever=base_retriever + ) + + actual = retriever.invoke("Tell me about the Celtics") + assert len(actual) == 2 + assert texts[-1] not in [d.page_content for d in actual] diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 270d315def..12ef05e869 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -86,3 +86,38 @@ class EmbeddingsFilter(BaseDocumentCompressor): for i in included_idxs: stateful_documents[i].state["query_similarity_score"] = similarity[i] return [stateful_documents[i] for i in included_idxs] + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """Filter documents based on similarity of their embeddings to the query.""" + try: + from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501 + _aget_embeddings_from_stateful_docs, + get_stateful_documents, + ) + except ImportError: + raise ImportError( + "To use please install langchain-community " + "with `pip install langchain-community`." + ) + stateful_documents = get_stateful_documents(documents) + embedded_documents = await _aget_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + embedded_query = await self.embeddings.aembed_query(query) + similarity = self.similarity_fn([embedded_query], embedded_documents)[0] + included_idxs = np.arange(len(embedded_documents)) + if self.k is not None: + included_idxs = np.argsort(similarity)[::-1][: self.k] + if self.similarity_threshold is not None: + similar_enough = np.where( + similarity[included_idxs] > self.similarity_threshold + ) + included_idxs = included_idxs[similar_enough] + for i in included_idxs: + stateful_documents[i].state["query_similarity_score"] = similarity[i] + return [stateful_documents[i] for i in included_idxs]