diff --git a/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb b/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb index 14e0b3ab78..98e4fa39fc 100644 --- a/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb +++ b/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "fc0db1bc", "metadata": {}, @@ -25,7 +26,7 @@ "from langchain.vectorstores import Chroma\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.embeddings import OpenAIEmbeddings\n", - "from langchain.document_transformers import EmbeddingsRedundantFilter\n", + "from langchain.document_transformers import EmbeddingsRedundantFilter,EmbeddingsClusteringFilter\n", "from langchain.retrievers.document_compressors import DocumentCompressorPipeline\n", "from langchain.retrievers import ContextualCompressionRetriever\n", "\n", @@ -70,6 +71,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "c152339d", "metadata": {}, @@ -92,6 +94,46 @@ " base_compressor=pipeline, base_retriever=lotr\n", ")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c10022fa", + "metadata": {}, + "source": [ + "## Pick a representative sample of documents from the merged retrievers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3885482", + "metadata": {}, + "outputs": [], + "source": [ + "# This filter will divide the documents vectors into clusters or \"centers\" of meaning.\n", + "# Then it will pick the closest document to that center for the final results.\n", + "# By default the result document will be ordered/grouped by clusters.\n", + "filter_ordered_cluster = EmbeddingsClusteringFilter(\n", + " embeddings=filter_embeddings,\n", + " num_clusters=10,\n", + " num_closest=1,\n", + " )\n", + "\n", + "# If you want the final document to be ordered by the original retriever scores\n", + "# you need to add the \"sorted\" parameter.\n", + "filter_ordered_by_retriever = EmbeddingsClusteringFilter(\n", + " embeddings=filter_embeddings,\n", + " num_clusters=10,\n", + " num_closest=1,\n", + " sorted = True,\n", + " )\n", + "\n", + "pipeline = DocumentCompressorPipeline(transformers=[filter_ordered_by_retriever])\n", + "compression_retriever = ContextualCompressionRetriever(\n", + " base_compressor=pipeline, base_retriever=lotr\n", + ")\n" + ] } ], "metadata": { diff --git a/langchain/document_transformers.py b/langchain/document_transformers.py index f1290e01a7..3bed8e9b55 100644 --- a/langchain/document_transformers.py +++ b/langchain/document_transformers.py @@ -71,6 +71,56 @@ def _get_embeddings_from_stateful_docs( return embedded_documents +def _filter_cluster_embeddings( + embedded_documents: List[List[float]], + num_clusters: int, + num_closest: int, + random_state: int, + remove_duplicates: bool, +) -> List[int]: + """Filter documents based on proximity of their embeddings to clusters.""" + + try: + from sklearn.cluster import KMeans + except ImportError: + raise ValueError( + "sklearn package not found, please install it with " + "`pip install scikit-learn`" + ) + + kmeans = KMeans(n_clusters=num_clusters, random_state=random_state).fit( + embedded_documents + ) + closest_indices = [] + + # Loop through the number of clusters you have + for i in range(num_clusters): + # Get the list of distances from that particular cluster center + distances = np.linalg.norm( + embedded_documents - kmeans.cluster_centers_[i], axis=1 + ) + + # Find the indices of the two unique closest ones + # (using argsort to find the smallest 2 distances) + if remove_duplicates: + # Only add not duplicated vectors. + closest_indices_sorted = [ + x + for x in np.argsort(distances)[:num_closest] + if x not in closest_indices + ] + else: + # Skip duplicates and add the next closest vector. + closest_indices_sorted = [ + x for x in np.argsort(distances) if x not in closest_indices + ][:num_closest] + + # Append that position closest indices list + closest_indices.extend(closest_indices_sorted) + + return closest_indices + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): """Filter that drops redundant documents by comparing their embeddings.""" @@ -106,3 +156,63 @@ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: raise NotImplementedError + + +class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel): + """Perform K-means clustering on document vectors. + Returns an arbitrary number of documents closest to center.""" + + embeddings: Embeddings + """Embeddings to use for embedding document contents.""" + + num_clusters: int = 5 + """Number of clusters. Groups of documents with similar meaning.""" + + num_closest: int = 1 + """The number of closest vectors to return for each cluster center.""" + + random_state: int = 42 + """Controls the random number generator used to initialize the cluster centroids. + If you set the random_state parameter to None, the KMeans algorithm will use a + random number generator that is seeded with the current time. This means + that the results of the KMeans algorithm will be different each time you + run it.""" + + sorted: bool = False + """By default results are re-ordered "grouping" them by cluster, if sorted is true + result will be ordered by the original position from the retriever""" + + remove_duplicates = False + """ By default duplicated results are skipped and replaced by the next closest + vector in the cluster. If remove_duplicates is true no replacement will be done: + This could dramatically reduce results when there is a lot of overlap beetween + clusters. + """ + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Filter down documents.""" + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_cluster_embeddings( + embedded_documents, + self.num_clusters, + self.num_closest, + self.random_state, + self.remove_duplicates, + ) + results = sorted(included_idxs) if self.sorted else included_idxs + return [stateful_documents[i] for i in results] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError diff --git a/tests/integration_tests/test_document_transformers.py b/tests/integration_tests/test_document_transformers.py index d5a23dba38..b92accd648 100644 --- a/tests/integration_tests/test_document_transformers.py +++ b/tests/integration_tests/test_document_transformers.py @@ -1,5 +1,6 @@ """Integration test for embedding-based redundant doc filtering.""" from langchain.document_transformers import ( + EmbeddingsClusteringFilter, EmbeddingsRedundantFilter, _DocumentWithState, ) @@ -29,3 +30,42 @@ def test_embeddings_redundant_filter_with_state() -> None: redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) actual = redundant_filter.transform_documents(docs) assert len(actual) == 1 + + +def test_embeddings_clustering_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "A cookie is a small, baked sweet treat and you can find it in the cookie", + "monsters' jar.", + "Cookies are good.", + "I have nightmares about the cookie monster.", + "The most popular pizza styles are: Neapolitan, New York-style and", + "Chicago-style. You can find them on iconic restaurants in major cities.", + "Neapolitan pizza: This is the original pizza style,hailing from Naples,", + "Italy.", + "I wish there were better Italian Pizza restaurants in my neighborhood.", + "New York-style pizza: This is characterized by its large, thin crust, and", + "generous toppings.", + "The first movie to feature a robot was 'A Trip to the Moon' (1902).", + "The first movie to feature a robot that could pass for a human was", + "'Blade Runner' (1982)", + "The first movie to feature a robot that could fall in love with a human", + "was 'Her' (2013)", + "A robot is a machine capable of carrying out complex actions automatically.", + "There are certainly hundreds, if not thousands movies about robots like:", + "'Blade Runner', 'Her' and 'A Trip to the Moon'", + ] + + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsClusteringFilter( + embeddings=embeddings, + num_clusters=3, + num_closest=1, + sorted=True, + ) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 3 + assert texts[1] in [d.page_content for d in actual] + assert texts[4] in [d.page_content for d in actual] + assert texts[11] in [d.page_content for d in actual]