feat: implement max marginal relevance for momento vector index (#13619)

**Description**

Implements `max_marginal_relevance_search` and
`max_marginal_relevance_search_by_vector` for the Momento Vector Index
vectorstore.

Additionally bumps the `momento` dependency in the lock file and adds
logging to the implementation.

**Dependencies**

 updates `momento` dependency in lock file

**Tag maintainer**

@baskaryan 

**Twitter handle**

Please tag @momentohq for Momento Vector Index and @mloml for the
contribution 🙇

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Michael Landis 2023-12-04 16:50:23 -08:00 committed by GitHub
parent ee9abb6722
commit e26906c1dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 13 deletions

View File

@ -1,3 +1,4 @@
import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -11,15 +12,17 @@ from typing import (
) )
from uuid import uuid4 from uuid import uuid4
import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain.utils import get_from_env from langchain.utils import get_from_env
from langchain.vectorstores.utils import DistanceStrategy from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
VST = TypeVar("VST", bound="VectorStore") VST = TypeVar("VST", bound="VectorStore")
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from momento import PreviewVectorIndexClient from momento import PreviewVectorIndexClient
@ -75,9 +78,8 @@ class MomentoVectorIndex(VectorStore):
index_name (str, optional): The name of the index to store the documents in. index_name (str, optional): The name of the index to store the documents in.
Defaults to "default". Defaults to "default".
distance_strategy (DistanceStrategy, optional): The distance strategy to distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
Euclidean distance.
text_field (str, optional): The name of the metadata field to store the text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text". original text in. Defaults to "text".
ensure_index_exists (bool, optional): Whether to ensure that the index ensure_index_exists (bool, optional): Whether to ensure that the index
@ -125,6 +127,7 @@ class MomentoVectorIndex(VectorStore):
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else: else:
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
raise ValueError( raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented." f"Distance strategy {self.distance_strategy} not implemented."
) )
@ -137,8 +140,10 @@ class MomentoVectorIndex(VectorStore):
elif isinstance(response, CreateIndex.IndexAlreadyExists): elif isinstance(response, CreateIndex.IndexAlreadyExists):
return False return False
elif isinstance(response, CreateIndex.Error): elif isinstance(response, CreateIndex.Error):
logger.error(f"Error creating index: {response.inner_exception}")
raise response.inner_exception raise response.inner_exception
else: else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}") raise Exception(f"Unexpected response: {response}")
def add_texts( def add_texts(
@ -331,6 +336,87 @@ class MomentoVectorIndex(VectorStore):
) )
return [doc for doc, _ in results] return [doc for doc, _ in results]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import SearchAndFetchVectors
response = self._client.search_and_fetch_vectors(
self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA
)
if isinstance(response, SearchAndFetchVectors.Success):
pass
elif isinstance(response, SearchAndFetchVectors.Error):
logger.error(f"Error searching and fetching vectors: {response}")
return []
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")
mmr_selected = maximal_marginal_relevance(
query_embedding=np.array([embedding], dtype=np.float32),
embedding_list=[hit.vector for hit in response.hits],
lambda_mult=lambda_mult,
k=k,
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
for metadata in selected
]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)
@classmethod @classmethod
def from_texts( def from_texts(
cls: Type[VST], cls: Type[VST],

View File

@ -3936,7 +3936,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [ files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
] ]
[[package]] [[package]]
@ -4958,29 +4957,29 @@ files = [
[[package]] [[package]]
name = "momento" name = "momento"
version = "1.13.0" version = "1.14.1"
description = "SDK for Momento" description = "SDK for Momento"
optional = true optional = true
python-versions = ">=3.7,<4.0" python-versions = ">=3.7,<4.0"
files = [ files = [
{file = "momento-1.13.0-py3-none-any.whl", hash = "sha256:dd5ace5b8d679e882afcefaa16bc413973c270b0a7a1c6c45f3eb60b0b9526de"}, {file = "momento-1.14.1-py3-none-any.whl", hash = "sha256:241e46669e39c19627396f2b2b027a912861f1b8097fc9f97b05b76b3d90d199"},
{file = "momento-1.13.0.tar.gz", hash = "sha256:39419627542b8f5997a777ff91aa3aaf6406b7d76fb83cd84284a0f7d1f9e356"}, {file = "momento-1.14.1.tar.gz", hash = "sha256:d200a5e7463f7746a8a611474af1c245183d7ddf9346d9592760b78b6e801560"},
] ]
[package.dependencies] [package.dependencies]
grpcio = ">=1.46.0,<2.0.0" grpcio = ">=1.46.0,<2.0.0"
momento-wire-types = ">=0.91.1,<0.92.0" momento-wire-types = ">=0.96.0,<0.97.0"
pyjwt = ">=2.4.0,<3.0.0" pyjwt = ">=2.4.0,<3.0.0"
[[package]] [[package]]
name = "momento-wire-types" name = "momento-wire-types"
version = "0.91.4" version = "0.96.0"
description = "Momento Client Proto Generated Files" description = "Momento Client Proto Generated Files"
optional = true optional = true
python-versions = ">=3.7,<4.0" python-versions = ">=3.7,<4.0"
files = [ files = [
{file = "momento_wire_types-0.91.4-py3-none-any.whl", hash = "sha256:f296249693de2f6c383a397e7616b84dd83dfd466743d34b035b90865000a2a8"}, {file = "momento_wire_types-0.96.0-py3-none-any.whl", hash = "sha256:93dc0e3c31bbe1f664ce33974f235bc20e63b5e35ea8e118f0c5e5ed3cda7709"},
{file = "momento_wire_types-0.91.4.tar.gz", hash = "sha256:de8cd14a12835d95997eb9b753ea47e1a5d2916658ec9320e416da8bd835fdff"}, {file = "momento_wire_types-0.96.0.tar.gz", hash = "sha256:9c6c839c698741c54b9fc3a4fe0f82094ea5102418b02bb271ed6e64ea6d7d9e"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -125,7 +125,7 @@ def test_from_texts_with_metadatas(
def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None: def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None:
# """Test end to end construction and search with scores and IDs.""" """Test end to end construction and search with scores and IDs."""
texts = ["apple", "orange", "hammer"] texts = ["apple", "orange", "hammer"]
metadatas = [{"page": f"{i}"} for i in range(len(texts))] metadatas = [{"page": f"{i}"} for i in range(len(texts))]
@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None:
) )
assert isinstance(response, Search.Success) assert isinstance(response, Search.Success)
assert [hit.id for hit in response.hits] == ids assert [hit.id for hit in response.hits] == ids
def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None:
"""Test max marginal relevance search."""
pepperoni_pizza = "pepperoni pizza"
cheese_pizza = "cheese pizza"
hot_dog = "hot dog"
vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog])
wait()
search_results = vector_store.similarity_search("pizza", k=2)
assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=cheese_pizza, metadata={}),
]
search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2)
assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=hot_dog, metadata={}),
]