mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
The Fellowship of the Vectors: New Embeddings Filter using clustering. (#7015)
Continuing with Tolkien inspired series of langchain tools. I bring to you: **The Fellowship of the Vectors**, AKA EmbeddingsClusteringFilter. This document filter uses embeddings to group vectors together into clusters, then allows you to pick an arbitrary number of documents vector based on proximity to the cluster centers. That's a representative sample of the cluster. The original idea is from [Greg Kamradt](https://github.com/gkamradt) from this video (Level4): https://www.youtube.com/watch?v=qaPMdcCqtWk&t=365s I added few tricks to make it a bit more versatile, so you can parametrize what to do with duplicate documents in case of cluster overlap: replace the duplicates with the next closest document or remove it. This allow you to use it as an special kind of redundant filter too. Additionally you can choose 2 diff orders: grouped by cluster or respecting the original retriever scores. In my use case I was using the docs grouped by cluster to run refine chains per cluster to generate summarization over a large corpus of documents. Let me know if you want to change anything! @rlancemartin, @eyurtsev, @hwchase17, --------- Co-authored-by: rlm <pexpresss31@gmail.com>
This commit is contained in:
parent
b489466488
commit
3ce4e46c8c
@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "fc0db1bc",
|
"id": "fc0db1bc",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -25,7 +26,7 @@
|
|||||||
"from langchain.vectorstores import Chroma\n",
|
"from langchain.vectorstores import Chroma\n",
|
||||||
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
"from langchain.document_transformers import EmbeddingsRedundantFilter\n",
|
"from langchain.document_transformers import EmbeddingsRedundantFilter,EmbeddingsClusteringFilter\n",
|
||||||
"from langchain.retrievers.document_compressors import DocumentCompressorPipeline\n",
|
"from langchain.retrievers.document_compressors import DocumentCompressorPipeline\n",
|
||||||
"from langchain.retrievers import ContextualCompressionRetriever\n",
|
"from langchain.retrievers import ContextualCompressionRetriever\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -70,6 +71,7 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "c152339d",
|
"id": "c152339d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -92,6 +94,46 @@
|
|||||||
" base_compressor=pipeline, base_retriever=lotr\n",
|
" base_compressor=pipeline, base_retriever=lotr\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c10022fa",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Pick a representative sample of documents from the merged retrievers."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b3885482",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This filter will divide the documents vectors into clusters or \"centers\" of meaning.\n",
|
||||||
|
"# Then it will pick the closest document to that center for the final results.\n",
|
||||||
|
"# By default the result document will be ordered/grouped by clusters.\n",
|
||||||
|
"filter_ordered_cluster = EmbeddingsClusteringFilter(\n",
|
||||||
|
" embeddings=filter_embeddings,\n",
|
||||||
|
" num_clusters=10,\n",
|
||||||
|
" num_closest=1,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"# If you want the final document to be ordered by the original retriever scores\n",
|
||||||
|
"# you need to add the \"sorted\" parameter.\n",
|
||||||
|
"filter_ordered_by_retriever = EmbeddingsClusteringFilter(\n",
|
||||||
|
" embeddings=filter_embeddings,\n",
|
||||||
|
" num_clusters=10,\n",
|
||||||
|
" num_closest=1,\n",
|
||||||
|
" sorted = True,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"pipeline = DocumentCompressorPipeline(transformers=[filter_ordered_by_retriever])\n",
|
||||||
|
"compression_retriever = ContextualCompressionRetriever(\n",
|
||||||
|
" base_compressor=pipeline, base_retriever=lotr\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -71,6 +71,56 @@ def _get_embeddings_from_stateful_docs(
|
|||||||
return embedded_documents
|
return embedded_documents
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_cluster_embeddings(
|
||||||
|
embedded_documents: List[List[float]],
|
||||||
|
num_clusters: int,
|
||||||
|
num_closest: int,
|
||||||
|
random_state: int,
|
||||||
|
remove_duplicates: bool,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Filter documents based on proximity of their embeddings to clusters."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"sklearn package not found, please install it with "
|
||||||
|
"`pip install scikit-learn`"
|
||||||
|
)
|
||||||
|
|
||||||
|
kmeans = KMeans(n_clusters=num_clusters, random_state=random_state).fit(
|
||||||
|
embedded_documents
|
||||||
|
)
|
||||||
|
closest_indices = []
|
||||||
|
|
||||||
|
# Loop through the number of clusters you have
|
||||||
|
for i in range(num_clusters):
|
||||||
|
# Get the list of distances from that particular cluster center
|
||||||
|
distances = np.linalg.norm(
|
||||||
|
embedded_documents - kmeans.cluster_centers_[i], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the indices of the two unique closest ones
|
||||||
|
# (using argsort to find the smallest 2 distances)
|
||||||
|
if remove_duplicates:
|
||||||
|
# Only add not duplicated vectors.
|
||||||
|
closest_indices_sorted = [
|
||||||
|
x
|
||||||
|
for x in np.argsort(distances)[:num_closest]
|
||||||
|
if x not in closest_indices
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Skip duplicates and add the next closest vector.
|
||||||
|
closest_indices_sorted = [
|
||||||
|
x for x in np.argsort(distances) if x not in closest_indices
|
||||||
|
][:num_closest]
|
||||||
|
|
||||||
|
# Append that position closest indices list
|
||||||
|
closest_indices.extend(closest_indices_sorted)
|
||||||
|
|
||||||
|
return closest_indices
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||||
"""Filter that drops redundant documents by comparing their embeddings."""
|
"""Filter that drops redundant documents by comparing their embeddings."""
|
||||||
|
|
||||||
@ -106,3 +156,63 @@ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
|||||||
self, documents: Sequence[Document], **kwargs: Any
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
) -> Sequence[Document]:
|
) -> Sequence[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
|
||||||
|
"""Perform K-means clustering on document vectors.
|
||||||
|
Returns an arbitrary number of documents closest to center."""
|
||||||
|
|
||||||
|
embeddings: Embeddings
|
||||||
|
"""Embeddings to use for embedding document contents."""
|
||||||
|
|
||||||
|
num_clusters: int = 5
|
||||||
|
"""Number of clusters. Groups of documents with similar meaning."""
|
||||||
|
|
||||||
|
num_closest: int = 1
|
||||||
|
"""The number of closest vectors to return for each cluster center."""
|
||||||
|
|
||||||
|
random_state: int = 42
|
||||||
|
"""Controls the random number generator used to initialize the cluster centroids.
|
||||||
|
If you set the random_state parameter to None, the KMeans algorithm will use a
|
||||||
|
random number generator that is seeded with the current time. This means
|
||||||
|
that the results of the KMeans algorithm will be different each time you
|
||||||
|
run it."""
|
||||||
|
|
||||||
|
sorted: bool = False
|
||||||
|
"""By default results are re-ordered "grouping" them by cluster, if sorted is true
|
||||||
|
result will be ordered by the original position from the retriever"""
|
||||||
|
|
||||||
|
remove_duplicates = False
|
||||||
|
""" By default duplicated results are skipped and replaced by the next closest
|
||||||
|
vector in the cluster. If remove_duplicates is true no replacement will be done:
|
||||||
|
This could dramatically reduce results when there is a lot of overlap beetween
|
||||||
|
clusters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Filter down documents."""
|
||||||
|
stateful_documents = get_stateful_documents(documents)
|
||||||
|
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||||
|
self.embeddings, stateful_documents
|
||||||
|
)
|
||||||
|
included_idxs = _filter_cluster_embeddings(
|
||||||
|
embedded_documents,
|
||||||
|
self.num_clusters,
|
||||||
|
self.num_closest,
|
||||||
|
self.random_state,
|
||||||
|
self.remove_duplicates,
|
||||||
|
)
|
||||||
|
results = sorted(included_idxs) if self.sorted else included_idxs
|
||||||
|
return [stateful_documents[i] for i in results]
|
||||||
|
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Integration test for embedding-based redundant doc filtering."""
|
"""Integration test for embedding-based redundant doc filtering."""
|
||||||
from langchain.document_transformers import (
|
from langchain.document_transformers import (
|
||||||
|
EmbeddingsClusteringFilter,
|
||||||
EmbeddingsRedundantFilter,
|
EmbeddingsRedundantFilter,
|
||||||
_DocumentWithState,
|
_DocumentWithState,
|
||||||
)
|
)
|
||||||
@ -29,3 +30,42 @@ def test_embeddings_redundant_filter_with_state() -> None:
|
|||||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||||
actual = redundant_filter.transform_documents(docs)
|
actual = redundant_filter.transform_documents(docs)
|
||||||
assert len(actual) == 1
|
assert len(actual) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_clustering_filter() -> None:
|
||||||
|
texts = [
|
||||||
|
"What happened to all of my cookies?",
|
||||||
|
"A cookie is a small, baked sweet treat and you can find it in the cookie",
|
||||||
|
"monsters' jar.",
|
||||||
|
"Cookies are good.",
|
||||||
|
"I have nightmares about the cookie monster.",
|
||||||
|
"The most popular pizza styles are: Neapolitan, New York-style and",
|
||||||
|
"Chicago-style. You can find them on iconic restaurants in major cities.",
|
||||||
|
"Neapolitan pizza: This is the original pizza style,hailing from Naples,",
|
||||||
|
"Italy.",
|
||||||
|
"I wish there were better Italian Pizza restaurants in my neighborhood.",
|
||||||
|
"New York-style pizza: This is characterized by its large, thin crust, and",
|
||||||
|
"generous toppings.",
|
||||||
|
"The first movie to feature a robot was 'A Trip to the Moon' (1902).",
|
||||||
|
"The first movie to feature a robot that could pass for a human was",
|
||||||
|
"'Blade Runner' (1982)",
|
||||||
|
"The first movie to feature a robot that could fall in love with a human",
|
||||||
|
"was 'Her' (2013)",
|
||||||
|
"A robot is a machine capable of carrying out complex actions automatically.",
|
||||||
|
"There are certainly hundreds, if not thousands movies about robots like:",
|
||||||
|
"'Blade Runner', 'Her' and 'A Trip to the Moon'",
|
||||||
|
]
|
||||||
|
|
||||||
|
docs = [Document(page_content=t) for t in texts]
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
redundant_filter = EmbeddingsClusteringFilter(
|
||||||
|
embeddings=embeddings,
|
||||||
|
num_clusters=3,
|
||||||
|
num_closest=1,
|
||||||
|
sorted=True,
|
||||||
|
)
|
||||||
|
actual = redundant_filter.transform_documents(docs)
|
||||||
|
assert len(actual) == 3
|
||||||
|
assert texts[1] in [d.page_content for d in actual]
|
||||||
|
assert texts[4] in [d.page_content for d in actual]
|
||||||
|
assert texts[11] in [d.page_content for d in actual]
|
||||||
|
Loading…
Reference in New Issue
Block a user