mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
feat: implement max marginal relevance for momento vector index (#13619)
**Description** Implements `max_marginal_relevance_search` and `max_marginal_relevance_search_by_vector` for the Momento Vector Index vectorstore. Additionally bumps the `momento` dependency in the lock file and adds logging to the implementation. **Dependencies** ✅ updates `momento` dependency in lock file **Tag maintainer** @baskaryan **Twitter handle** Please tag @momentohq for Momento Vector Index and @mloml for the contribution 🙇 <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
ee9abb6722
commit
e26906c1dc
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -11,15 +12,17 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from langchain.utils import get_from_env
|
from langchain.utils import get_from_env
|
||||||
from langchain.vectorstores.utils import DistanceStrategy
|
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||||
|
|
||||||
VST = TypeVar("VST", bound="VectorStore")
|
VST = TypeVar("VST", bound="VectorStore")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from momento import PreviewVectorIndexClient
|
from momento import PreviewVectorIndexClient
|
||||||
@ -75,9 +78,8 @@ class MomentoVectorIndex(VectorStore):
|
|||||||
index_name (str, optional): The name of the index to store the documents in.
|
index_name (str, optional): The name of the index to store the documents in.
|
||||||
Defaults to "default".
|
Defaults to "default".
|
||||||
distance_strategy (DistanceStrategy, optional): The distance strategy to
|
distance_strategy (DistanceStrategy, optional): The distance strategy to
|
||||||
use. Defaults to DistanceStrategy.COSINE. If you select
|
use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
|
||||||
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
|
the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
|
||||||
Euclidean distance.
|
|
||||||
text_field (str, optional): The name of the metadata field to store the
|
text_field (str, optional): The name of the metadata field to store the
|
||||||
original text in. Defaults to "text".
|
original text in. Defaults to "text".
|
||||||
ensure_index_exists (bool, optional): Whether to ensure that the index
|
ensure_index_exists (bool, optional): Whether to ensure that the index
|
||||||
@ -125,6 +127,7 @@ class MomentoVectorIndex(VectorStore):
|
|||||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||||
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
|
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
|
||||||
else:
|
else:
|
||||||
|
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Distance strategy {self.distance_strategy} not implemented."
|
f"Distance strategy {self.distance_strategy} not implemented."
|
||||||
)
|
)
|
||||||
@ -137,8 +140,10 @@ class MomentoVectorIndex(VectorStore):
|
|||||||
elif isinstance(response, CreateIndex.IndexAlreadyExists):
|
elif isinstance(response, CreateIndex.IndexAlreadyExists):
|
||||||
return False
|
return False
|
||||||
elif isinstance(response, CreateIndex.Error):
|
elif isinstance(response, CreateIndex.Error):
|
||||||
|
logger.error(f"Error creating index: {response.inner_exception}")
|
||||||
raise response.inner_exception
|
raise response.inner_exception
|
||||||
else:
|
else:
|
||||||
|
logger.error(f"Unexpected response: {response}")
|
||||||
raise Exception(f"Unexpected response: {response}")
|
raise Exception(f"Unexpected response: {response}")
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
@ -331,6 +336,87 @@ class MomentoVectorIndex(VectorStore):
|
|||||||
)
|
)
|
||||||
return [doc for doc, _ in results]
|
return [doc for doc, _ in results]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
from momento.requests.vector_index import ALL_METADATA
|
||||||
|
from momento.responses.vector_index import SearchAndFetchVectors
|
||||||
|
|
||||||
|
response = self._client.search_and_fetch_vectors(
|
||||||
|
self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, SearchAndFetchVectors.Success):
|
||||||
|
pass
|
||||||
|
elif isinstance(response, SearchAndFetchVectors.Error):
|
||||||
|
logger.error(f"Error searching and fetching vectors: {response}")
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected response: {response}")
|
||||||
|
raise Exception(f"Unexpected response: {response}")
|
||||||
|
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
query_embedding=np.array([embedding], dtype=np.float32),
|
||||||
|
embedding_list=[hit.vector for hit in response.hits],
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
selected = [response.hits[i].metadata for i in mmr_selected]
|
||||||
|
return [
|
||||||
|
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
|
||||||
|
for metadata in selected
|
||||||
|
]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding.embed_query(query)
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding, k, fetch_k, lambda_mult, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls: Type[VST],
|
cls: Type[VST],
|
||||||
|
15
libs/langchain/poetry.lock
generated
15
libs/langchain/poetry.lock
generated
@ -3936,7 +3936,6 @@ optional = false
|
|||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||||
files = [
|
files = [
|
||||||
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
|
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
|
||||||
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4958,29 +4957,29 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "momento"
|
name = "momento"
|
||||||
version = "1.13.0"
|
version = "1.14.1"
|
||||||
description = "SDK for Momento"
|
description = "SDK for Momento"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7,<4.0"
|
python-versions = ">=3.7,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "momento-1.13.0-py3-none-any.whl", hash = "sha256:dd5ace5b8d679e882afcefaa16bc413973c270b0a7a1c6c45f3eb60b0b9526de"},
|
{file = "momento-1.14.1-py3-none-any.whl", hash = "sha256:241e46669e39c19627396f2b2b027a912861f1b8097fc9f97b05b76b3d90d199"},
|
||||||
{file = "momento-1.13.0.tar.gz", hash = "sha256:39419627542b8f5997a777ff91aa3aaf6406b7d76fb83cd84284a0f7d1f9e356"},
|
{file = "momento-1.14.1.tar.gz", hash = "sha256:d200a5e7463f7746a8a611474af1c245183d7ddf9346d9592760b78b6e801560"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
grpcio = ">=1.46.0,<2.0.0"
|
grpcio = ">=1.46.0,<2.0.0"
|
||||||
momento-wire-types = ">=0.91.1,<0.92.0"
|
momento-wire-types = ">=0.96.0,<0.97.0"
|
||||||
pyjwt = ">=2.4.0,<3.0.0"
|
pyjwt = ">=2.4.0,<3.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "momento-wire-types"
|
name = "momento-wire-types"
|
||||||
version = "0.91.4"
|
version = "0.96.0"
|
||||||
description = "Momento Client Proto Generated Files"
|
description = "Momento Client Proto Generated Files"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7,<4.0"
|
python-versions = ">=3.7,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "momento_wire_types-0.91.4-py3-none-any.whl", hash = "sha256:f296249693de2f6c383a397e7616b84dd83dfd466743d34b035b90865000a2a8"},
|
{file = "momento_wire_types-0.96.0-py3-none-any.whl", hash = "sha256:93dc0e3c31bbe1f664ce33974f235bc20e63b5e35ea8e118f0c5e5ed3cda7709"},
|
||||||
{file = "momento_wire_types-0.91.4.tar.gz", hash = "sha256:de8cd14a12835d95997eb9b753ea47e1a5d2916658ec9320e416da8bd835fdff"},
|
{file = "momento_wire_types-0.96.0.tar.gz", hash = "sha256:9c6c839c698741c54b9fc3a4fe0f82094ea5102418b02bb271ed6e64ea6d7d9e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -125,7 +125,7 @@ def test_from_texts_with_metadatas(
|
|||||||
|
|
||||||
|
|
||||||
def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None:
|
def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None:
|
||||||
# """Test end to end construction and search with scores and IDs."""
|
"""Test end to end construction and search with scores and IDs."""
|
||||||
texts = ["apple", "orange", "hammer"]
|
texts = ["apple", "orange", "hammer"]
|
||||||
metadatas = [{"page": f"{i}"} for i in range(len(texts))]
|
metadatas = [{"page": f"{i}"} for i in range(len(texts))]
|
||||||
|
|
||||||
@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None:
|
|||||||
)
|
)
|
||||||
assert isinstance(response, Search.Success)
|
assert isinstance(response, Search.Success)
|
||||||
assert [hit.id for hit in response.hits] == ids
|
assert [hit.id for hit in response.hits] == ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None:
|
||||||
|
"""Test max marginal relevance search."""
|
||||||
|
pepperoni_pizza = "pepperoni pizza"
|
||||||
|
cheese_pizza = "cheese pizza"
|
||||||
|
hot_dog = "hot dog"
|
||||||
|
|
||||||
|
vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog])
|
||||||
|
wait()
|
||||||
|
search_results = vector_store.similarity_search("pizza", k=2)
|
||||||
|
|
||||||
|
assert search_results == [
|
||||||
|
Document(page_content=pepperoni_pizza, metadata={}),
|
||||||
|
Document(page_content=cheese_pizza, metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2)
|
||||||
|
assert search_results == [
|
||||||
|
Document(page_content=pepperoni_pizza, metadata={}),
|
||||||
|
Document(page_content=hot_dog, metadata={}),
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user