community[patch]: Fix Hybrid Search for non-Databricks managed embeddings (#25590)

Description: Send both the query and query_embedding to the Databricks
index for hybrid search.

Issue: When using hybrid search with non-Databricks managed embedding we
currently don't pass both the embedding and query_text to the index.
Hybrid search requires both of these. This change fixes this issue for
both `similarity_search` and `similarity_search_by_vector`.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Erik Lindgren 2024-08-23 04:57:13 -04:00 committed by GitHub
parent bcd5842b5d
commit 583b0449eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 99 additions and 6 deletions

View File

@ -341,7 +341,11 @@ class DatabricksVectorSearch(VectorStore):
query_vector = None
else:
assert self.embeddings is not None, "embedding model is required."
query_text = None
# The value for `query_text` needs to be specified only for hybrid search.
if query_type is not None and query_type.upper() == "HYBRID":
query_text = query
else:
query_text = None
query_vector = self.embeddings.embed_query(query)
search_resp = self.index.similarity_search(
columns=self.columns,
@ -487,6 +491,7 @@ class DatabricksVectorSearch(VectorStore):
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
@ -505,6 +510,7 @@ class DatabricksVectorSearch(VectorStore):
k=k,
filter=filter,
query_type=query_type,
query=query,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
@ -516,6 +522,7 @@ class DatabricksVectorSearch(VectorStore):
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector, along with scores.
@ -534,9 +541,25 @@ class DatabricksVectorSearch(VectorStore):
"`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings."
)
if query_type is not None and query_type.upper() == "HYBRID":
if query is None:
raise ValueError(
"A value for `query` must be specified for hybrid search."
)
query_text = query
else:
if query is not None:
raise ValueError(
(
"Cannot specify both `embedding` and "
'`query` unless `query_type="HYBRID"'
)
)
query_text = None
search_resp = self.index.similarity_search(
columns=self.columns,
query_vector=embedding,
query_text=query_text,
filters=filter or _alias_filters(kwargs),
num_results=k,
query_type=query_type,

View File

@ -482,7 +482,7 @@ def test_delete_fail_no_ids() -> None:
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, query_type", itertools.product(ALL_INDEXES, ALL_QUERY_TYPES)
"index_details, query_type", itertools.product(ALL_INDEXES, [None, "ANN"])
)
def test_similarity_search(index_details: dict, query_type: Optional[str]) -> None:
index = mock_index(index_details)
@ -518,6 +518,42 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize("index_details", ALL_INDEXES)
def test_similarity_search_hybrid(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query = "foo"
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type="HYBRID"
)
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type="HYBRID",
)
else:
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_text=query,
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type="HYBRID",
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
def test_similarity_search_both_filter_and_filters_passed() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
@ -657,9 +693,14 @@ def test_standard_params() -> None:
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
"index_details, query_type",
itertools.product(
[DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], [None, "ANN"]
),
)
def test_similarity_search_by_vector(index_details: dict) -> None:
def test_similarity_search_by_vector(
index_details: dict, query_type: Optional[str]
) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
@ -668,14 +709,43 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters
query_embedding, k=limit, filter=filters, query_type=query_type
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type=None,
query_type=query_type,
query_text=None,
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
)
def test_similarity_search_by_vector_hybrid(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type="HYBRID",
query_text="foo",
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)