|
|
|
@ -5,11 +5,14 @@ from copy import deepcopy
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Any, Iterable, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
|
from langchain_core.runnables import run_in_executor
|
|
|
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
|
|
|
|
|
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -254,7 +257,12 @@ class Rockset(VectorStore):
|
|
|
|
|
"""Accepts a query_embedding (vector), and returns documents with
|
|
|
|
|
similar embeddings along with their relevance scores."""
|
|
|
|
|
|
|
|
|
|
q_str = self._build_query_sql(embedding, distance_func, k, where_str)
|
|
|
|
|
exclude_embeddings = True
|
|
|
|
|
if "exclude_embeddings" in kwargs:
|
|
|
|
|
exclude_embeddings = kwargs["exclude_embeddings"]
|
|
|
|
|
q_str = self._build_query_sql(
|
|
|
|
|
embedding, distance_func, k, where_str, exclude_embeddings
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
query_response = self._client.Queries.query(sql={"query": q_str})
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -290,6 +298,60 @@ class Rockset(VectorStore):
|
|
|
|
|
)
|
|
|
|
|
return finalResult
|
|
|
|
|
|
|
|
|
|
def max_marginal_relevance_search(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
fetch_k: int = 20,
|
|
|
|
|
lambda_mult: float = 0.5,
|
|
|
|
|
*,
|
|
|
|
|
where_str: Optional[str] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
|
|
|
|
|
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
|
|
|
among selected documents.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query: Text to look up documents similar to.
|
|
|
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
|
|
|
distance_func (DistanceFunction): how to compute distance between two
|
|
|
|
|
vectors in Rockset.
|
|
|
|
|
lambda_mult: Number between 0 and 1 that determines the degree
|
|
|
|
|
of diversity among the results with 0 corresponding
|
|
|
|
|
to maximum diversity and 1 to minimum diversity.
|
|
|
|
|
Defaults to 0.5.
|
|
|
|
|
where_str: where clause for the sql query
|
|
|
|
|
Returns:
|
|
|
|
|
List of Documents selected by maximal marginal relevance.
|
|
|
|
|
"""
|
|
|
|
|
query_embedding = self._embeddings.embed_query(query)
|
|
|
|
|
initial_docs = self.similarity_search_by_vector(
|
|
|
|
|
query_embedding,
|
|
|
|
|
k=fetch_k,
|
|
|
|
|
where_str=where_str,
|
|
|
|
|
exclude_embeddings=False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
|
|
|
|
|
|
|
|
|
|
selected_indices = maximal_marginal_relevance(
|
|
|
|
|
np.array(query_embedding),
|
|
|
|
|
embeddings,
|
|
|
|
|
lambda_mult=lambda_mult,
|
|
|
|
|
k=k,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# remove embeddings key before returning for cleanup to be consistent with
|
|
|
|
|
# other search functions
|
|
|
|
|
for i in selected_indices:
|
|
|
|
|
del initial_docs[i].metadata[self._embedding_key]
|
|
|
|
|
|
|
|
|
|
return [initial_docs[i] for i in selected_indices]
|
|
|
|
|
|
|
|
|
|
# Helper functions
|
|
|
|
|
|
|
|
|
|
def _build_query_sql(
|
|
|
|
@ -298,6 +360,7 @@ class Rockset(VectorStore):
|
|
|
|
|
distance_func: DistanceFunction,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
where_str: Optional[str] = None,
|
|
|
|
|
exclude_embeddings: bool = True,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
|
|
|
|
|
|
|
|
@ -305,8 +368,11 @@ class Rockset(VectorStore):
|
|
|
|
|
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
|
|
|
|
[{q_embedding_str}]) as dist"""
|
|
|
|
|
where_str = f"WHERE {where_str}\n" if where_str else ""
|
|
|
|
|
select_embedding = (
|
|
|
|
|
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
|
|
|
|
|
)
|
|
|
|
|
return f"""\
|
|
|
|
|
SELECT * EXCEPT({self._embedding_key}), {distance_str}
|
|
|
|
|
SELECT *{select_embedding} {distance_str}
|
|
|
|
|
FROM {self._workspace}.{self._collection_name}
|
|
|
|
|
{where_str}\
|
|
|
|
|
ORDER BY dist {distance_func.order_by()}
|
|
|
|
|