mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[patch]: mmr search for Rockset vectorstore integration (#16908)
- **Description:** Adding support for mmr search in the Rockset vectorstore integration. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** `@_morgan_adams_` --------- Co-authored-by: Rockset API Bot <admin@rockset.io> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
f51e6a35ba
commit
074ad5095f
@ -5,11 +5,14 @@ from copy import deepcopy
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -254,7 +257,12 @@ class Rockset(VectorStore):
|
|||||||
"""Accepts a query_embedding (vector), and returns documents with
|
"""Accepts a query_embedding (vector), and returns documents with
|
||||||
similar embeddings along with their relevance scores."""
|
similar embeddings along with their relevance scores."""
|
||||||
|
|
||||||
q_str = self._build_query_sql(embedding, distance_func, k, where_str)
|
exclude_embeddings = True
|
||||||
|
if "exclude_embeddings" in kwargs:
|
||||||
|
exclude_embeddings = kwargs["exclude_embeddings"]
|
||||||
|
q_str = self._build_query_sql(
|
||||||
|
embedding, distance_func, k, where_str, exclude_embeddings
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
query_response = self._client.Queries.query(sql={"query": q_str})
|
query_response = self._client.Queries.query(sql={"query": q_str})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -290,6 +298,60 @@ class Rockset(VectorStore):
|
|||||||
)
|
)
|
||||||
return finalResult
|
return finalResult
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
*,
|
||||||
|
where_str: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
distance_func (DistanceFunction): how to compute distance between two
|
||||||
|
vectors in Rockset.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
where_str: where clause for the sql query
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
query_embedding = self._embeddings.embed_query(query)
|
||||||
|
initial_docs = self.similarity_search_by_vector(
|
||||||
|
query_embedding,
|
||||||
|
k=fetch_k,
|
||||||
|
where_str=where_str,
|
||||||
|
exclude_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
|
||||||
|
|
||||||
|
selected_indices = maximal_marginal_relevance(
|
||||||
|
np.array(query_embedding),
|
||||||
|
embeddings,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove embeddings key before returning for cleanup to be consistent with
|
||||||
|
# other search functions
|
||||||
|
for i in selected_indices:
|
||||||
|
del initial_docs[i].metadata[self._embedding_key]
|
||||||
|
|
||||||
|
return [initial_docs[i] for i in selected_indices]
|
||||||
|
|
||||||
# Helper functions
|
# Helper functions
|
||||||
|
|
||||||
def _build_query_sql(
|
def _build_query_sql(
|
||||||
@ -298,6 +360,7 @@ class Rockset(VectorStore):
|
|||||||
distance_func: DistanceFunction,
|
distance_func: DistanceFunction,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
where_str: Optional[str] = None,
|
where_str: Optional[str] = None,
|
||||||
|
exclude_embeddings: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
||||||
|
|
||||||
@ -305,8 +368,11 @@ class Rockset(VectorStore):
|
|||||||
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
||||||
[{q_embedding_str}]) as dist"""
|
[{q_embedding_str}]) as dist"""
|
||||||
where_str = f"WHERE {where_str}\n" if where_str else ""
|
where_str = f"WHERE {where_str}\n" if where_str else ""
|
||||||
|
select_embedding = (
|
||||||
|
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
|
||||||
|
)
|
||||||
return f"""\
|
return f"""\
|
||||||
SELECT * EXCEPT({self._embedding_key}), {distance_str}
|
SELECT *{select_embedding} {distance_str}
|
||||||
FROM {self._workspace}.{self._collection_name}
|
FROM {self._workspace}.{self._collection_name}
|
||||||
{where_str}\
|
{where_str}\
|
||||||
ORDER BY dist {distance_func.order_by()}
|
ORDER BY dist {distance_func.order_by()}
|
||||||
|
@ -96,16 +96,18 @@ class TestRockset:
|
|||||||
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
|
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_rockset_insert_and_search(self) -> None:
|
|
||||||
"""Test end to end vector search in Rockset"""
|
|
||||||
|
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
metadatas = [{"metadata_index": i} for i in range(len(texts))]
|
metadatas = [{"metadata_index": i} for i in range(len(texts))]
|
||||||
ids = self.rockset_vectorstore.add_texts(
|
ids = cls.rockset_vectorstore.add_texts(
|
||||||
texts=texts,
|
texts=texts,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(ids) == len(texts)
|
assert len(ids) == len(texts)
|
||||||
|
|
||||||
|
def test_rockset_search(self) -> None:
|
||||||
|
"""Test end-to-end vector search in Rockset"""
|
||||||
|
|
||||||
# Test that `foo` is closest to `foo`
|
# Test that `foo` is closest to `foo`
|
||||||
output = self.rockset_vectorstore.similarity_search(
|
output = self.rockset_vectorstore.similarity_search(
|
||||||
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
|
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
|
||||||
@ -121,6 +123,26 @@ class TestRockset:
|
|||||||
)
|
)
|
||||||
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||||
|
|
||||||
|
def test_rockset_mmr_search(self) -> None:
|
||||||
|
"""Test end-to-end mmr search in Rockset"""
|
||||||
|
output = self.rockset_vectorstore.max_marginal_relevance_search(
|
||||||
|
query="foo",
|
||||||
|
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
||||||
|
fetch_k=1,
|
||||||
|
k=1,
|
||||||
|
)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
|
||||||
|
|
||||||
|
# Find closest vector to `foo` which is not `foo`
|
||||||
|
output = self.rockset_vectorstore.max_marginal_relevance_search(
|
||||||
|
query="foo",
|
||||||
|
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
||||||
|
fetch_k=3,
|
||||||
|
k=1,
|
||||||
|
where_str="metadata_index != 0",
|
||||||
|
)
|
||||||
|
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||||
|
|
||||||
def test_add_documents_and_delete(self) -> None:
|
def test_add_documents_and_delete(self) -> None:
|
||||||
""" "add_documents" and "delete" are requirements to support use
|
""" "add_documents" and "delete" are requirements to support use
|
||||||
with RecordManager"""
|
with RecordManager"""
|
||||||
@ -184,5 +206,21 @@ FROM {WORKSPACE}.{COLLECTION_NAME}
|
|||||||
WHERE age >= 10
|
WHERE age >= 10
|
||||||
ORDER BY dist DESC
|
ORDER BY dist DESC
|
||||||
LIMIT 4
|
LIMIT 4
|
||||||
|
"""
|
||||||
|
assert q_str == expected
|
||||||
|
|
||||||
|
def test_build_query_sql_with_select_embeddings(self) -> None:
|
||||||
|
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
|
q_str = self.rockset_vectorstore._build_query_sql(
|
||||||
|
vector, Rockset.DistanceFunction.COSINE_SIM, 4, "age >= 10", False
|
||||||
|
)
|
||||||
|
vector_str = ",".join(map(str, vector))
|
||||||
|
expected = f"""\
|
||||||
|
SELECT *, \
|
||||||
|
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
|
||||||
|
FROM {WORKSPACE}.{COLLECTION_NAME}
|
||||||
|
WHERE age >= 10
|
||||||
|
ORDER BY dist DESC
|
||||||
|
LIMIT 4
|
||||||
"""
|
"""
|
||||||
assert q_str == expected
|
assert q_str == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user