diff --git a/docs/extras/integrations/vectorstores/sklearn.ipynb b/docs/extras/integrations/vectorstores/sklearn.ipynb index b93c734a74..ea86d68da1 100644 --- a/docs/extras/integrations/vectorstores/sklearn.ipynb +++ b/docs/extras/integrations/vectorstores/sklearn.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -65,7 +65,7 @@ "from langchain.vectorstores import SKLearnVectorStore\n", "from langchain.document_loaders import TextLoader\n", "\n", - "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", + "loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "docs = text_splitter.split_documents(documents)\n", @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -100,6 +100,7 @@ ], "source": [ "import tempfile\n", + "import os\n", "\n", "persist_path = os.path.join(tempfile.gettempdir(), \"union.parquet\")\n", "\n", @@ -184,6 +185,32 @@ "print(docs[0].page_content)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "_filter = {\"id\": \"c53e6eac-0070-403c-8435-a9e528539610\"}\n", + "docs = vector_store.similarity_search(query, filter=_filter)\n", + "print(len(docs))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -217,7 +244,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.1" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/vectorstores/sklearn.py b/libs/langchain/langchain/vectorstores/sklearn.py index dcc6237c25..69845a2c70 100644 --- a/libs/langchain/langchain/vectorstores/sklearn.py +++ b/libs/langchain/langchain/vectorstores/sklearn.py @@ -233,33 +233,66 @@ class SKLearnVectorStore(VectorStore): return list(zip(neigh_idxs[0], neigh_dists[0])) def similarity_search_with_score( - self, query: str, *, k: int = DEFAULT_K, **kwargs: Any + self, + query: str, + *, + k: int = DEFAULT_K, + fetch_k: int = DEFAULT_FETCH_K, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: query_embedding = self._embedding_function.embed_query(query) indices_dists = self._similarity_index_search_with_score( - query_embedding, k=k, **kwargs + query_embedding, k=fetch_k, **kwargs ) - return [ - ( + + docs: List[Tuple[Document, float]] = [] + for idx, dist in indices_dists: + doc = ( Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ), dist, ) - for idx, dist in indices_dists - ] + + if filter is None: + docs.append(doc) + else: + filter = { + key: [value] if not isinstance(value, list) else value + for key, value in filter.items() + } + if all( + doc[0].metadata.get(key) in value for key, value in filter.items() + ): + docs.append(doc) + return docs[:k] def similarity_search( - self, query: str, k: int = DEFAULT_K, **kwargs: Any + self, + query: str, + k: int = DEFAULT_K, + fetch_k: int = DEFAULT_FETCH_K, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[Document]: - docs_scores = self.similarity_search_with_score(query, k=k, **kwargs) + docs_scores = self.similarity_search_with_score( + query, k=k, fetch_k=fetch_k, filter=filter, **kwargs + ) return [doc for doc, _ in docs_scores] def _similarity_search_with_relevance_scores( - self, query: str, k: int = DEFAULT_K, **kwargs: Any + self, + query: str, + k: int = DEFAULT_K, + fetch_k: int = DEFAULT_FETCH_K, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: - docs_dists = self.similarity_search_with_score(query, k=k, **kwargs) + docs_dists = self.similarity_search_with_score( + query, k=k, fetch_k=fetch_k, filter=filter, **kwargs + ) docs, dists = zip(*docs_dists) scores = [1 / math.exp(dist) for dist in dists] return list(zip(list(docs), scores)) @@ -270,6 +303,7 @@ class SKLearnVectorStore(VectorStore): k: int = DEFAULT_K, fetch_k: int = DEFAULT_FETCH_K, lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -283,6 +317,7 @@ class SKLearnVectorStore(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List of Documents selected by maximal marginal relevance. """ @@ -294,17 +329,28 @@ class SKLearnVectorStore(VectorStore): mmr_selected = maximal_marginal_relevance( self._np.array(embedding, dtype=self._np.float32), result_embeddings, - k=k, + k=fetch_k, lambda_mult=lambda_mult, ) mmr_indices = [indices[i] for i in mmr_selected] - return [ - Document( + + docs = [] + for idx in mmr_indices: + doc = Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ) - for idx in mmr_indices - ] + if filter is None: + docs.append(doc) + else: + filter = { + key: [value] if not isinstance(value, list) else value + for key, value in filter.items() + } + if all(doc.metadata.get(key) in value for key, value in filter.items()): + docs.append(doc) + + return docs[:k] def max_marginal_relevance_search( self, @@ -312,6 +358,7 @@ class SKLearnVectorStore(VectorStore): k: int = DEFAULT_K, fetch_k: int = DEFAULT_FETCH_K, lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -325,6 +372,7 @@ class SKLearnVectorStore(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List of Documents selected by maximal marginal relevance. """ @@ -335,7 +383,7 @@ class SKLearnVectorStore(VectorStore): embedding = self._embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mul=lambda_mult + embedding, k, fetch_k, lambda_mul=lambda_mult, filter=filter, **kwargs ) return docs diff --git a/libs/langchain/tests/unit_tests/vectorstores/test_sklearn.py b/libs/langchain/tests/unit_tests/vectorstores/test_sklearn.py index 36bfca1e02..b14d3b2db6 100644 --- a/libs/langchain/tests/unit_tests/vectorstores/test_sklearn.py +++ b/libs/langchain/tests/unit_tests/vectorstores/test_sklearn.py @@ -12,7 +12,7 @@ def test_sklearn() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings()) - output = docsearch.similarity_search("foo", k=1) + output = docsearch.similarity_search("foo", k=1, fetch_k=3) assert len(output) == 1 assert output[0].page_content == "foo" @@ -27,10 +27,24 @@ def test_sklearn_with_metadatas() -> None: FakeEmbeddings(), metadatas=metadatas, ) - output = docsearch.similarity_search("foo", k=1) + output = docsearch.similarity_search("foo", k=1, fetch_k=3) assert output[0].metadata["page"] == "0" +@pytest.mark.requires("numpy", "sklearn") +def test_sklearn_with_metadatas_and_filter() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = SKLearnVectorStore.from_texts( + texts, + FakeEmbeddings(), + metadatas=metadatas, + ) + output = docsearch.similarity_search("foo", k=1, fetch_k=4, filter={"page": "1"}) + assert output[0].metadata["page"] == "1" + + @pytest.mark.requires("numpy", "sklearn") def test_sklearn_with_metadatas_with_scores() -> None: """Test end to end construction and scored search.""" @@ -41,7 +55,7 @@ def test_sklearn_with_metadatas_with_scores() -> None: FakeEmbeddings(), metadatas=metadatas, ) - output = docsearch.similarity_search_with_relevance_scores("foo", k=1) + output = docsearch.similarity_search_with_relevance_scores("foo", k=1, fetch_k=3) assert len(output) == 1 doc, score = output[0] assert doc.page_content == "foo" @@ -61,7 +75,7 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None: serializer="json", ) - output = docsearch.similarity_search("foo", k=1) + output = docsearch.similarity_search("foo", k=1, fetch_k=3) assert len(output) == 1 assert output[0].page_content == "foo" @@ -71,7 +85,7 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None: docsearch = SKLearnVectorStore( FakeEmbeddings(), persist_path=str(persist_path), serializer="json" ) - output = docsearch.similarity_search("foo", k=1) + output = docsearch.similarity_search("foo", k=1, fetch_k=3) assert len(output) == 1 assert output[0].page_content == "foo" @@ -98,3 +112,19 @@ def test_sklearn_mmr_by_vector() -> None: ) assert len(output) == 1 assert output[0].page_content == "foo" + + +@pytest.mark.requires("numpy", "sklearn") +def test_sklearn_mmr_with_metadata_and_filter() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = SKLearnVectorStore.from_texts( + texts, FakeEmbeddings(), metadatas=metadatas + ) + output = docsearch.max_marginal_relevance_search( + "foo", k=1, fetch_k=4, filter={"page": "1"} + ) + assert len(output) == 1 + assert output[0].page_content == "foo" + assert output[0].metadata["page"] == "1"