mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community[patch]: Fix Hybrid Search for non-Databricks managed embeddings (#25590)
Description: Send both the query and query_embedding to the Databricks index for hybrid search. Issue: When using hybrid search with non-Databricks managed embedding we currently don't pass both the embedding and query_text to the index. Hybrid search requires both of these. This change fixes this issue for both `similarity_search` and `similarity_search_by_vector`. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
bcd5842b5d
commit
583b0449eb
@ -341,7 +341,11 @@ class DatabricksVectorSearch(VectorStore):
|
||||
query_vector = None
|
||||
else:
|
||||
assert self.embeddings is not None, "embedding model is required."
|
||||
query_text = None
|
||||
# The value for `query_text` needs to be specified only for hybrid search.
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
query_text = query
|
||||
else:
|
||||
query_text = None
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self.columns,
|
||||
@ -487,6 +491,7 @@ class DatabricksVectorSearch(VectorStore):
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
@ -505,6 +510,7 @@ class DatabricksVectorSearch(VectorStore):
|
||||
k=k,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
@ -516,6 +522,7 @@ class DatabricksVectorSearch(VectorStore):
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector, along with scores.
|
||||
@ -534,9 +541,25 @@ class DatabricksVectorSearch(VectorStore):
|
||||
"`similarity_search_by_vector` is not supported for index with "
|
||||
"Databricks-managed embeddings."
|
||||
)
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
if query is None:
|
||||
raise ValueError(
|
||||
"A value for `query` must be specified for hybrid search."
|
||||
)
|
||||
query_text = query
|
||||
else:
|
||||
if query is not None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Cannot specify both `embedding` and "
|
||||
'`query` unless `query_type="HYBRID"'
|
||||
)
|
||||
)
|
||||
query_text = None
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self.columns,
|
||||
query_vector=embedding,
|
||||
query_text=query_text,
|
||||
filters=filter or _alias_filters(kwargs),
|
||||
num_results=k,
|
||||
query_type=query_type,
|
||||
|
@ -482,7 +482,7 @@ def test_delete_fail_no_ids() -> None:
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details, query_type", itertools.product(ALL_INDEXES, ALL_QUERY_TYPES)
|
||||
"index_details, query_type", itertools.product(ALL_INDEXES, [None, "ANN"])
|
||||
)
|
||||
def test_similarity_search(index_details: dict, query_type: Optional[str]) -> None:
|
||||
index = mock_index(index_details)
|
||||
@ -518,6 +518,42 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize("index_details", ALL_INDEXES)
|
||||
def test_similarity_search_hybrid(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filter=filters, query_type="HYBRID"
|
||||
)
|
||||
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
else:
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_text=query,
|
||||
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
assert len(search_result) == len(fake_texts)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
def test_similarity_search_both_filter_and_filters_passed() -> None:
|
||||
index = mock_index(DIRECT_ACCESS_INDEX)
|
||||
@ -657,9 +693,14 @@ def test_standard_params() -> None:
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
"index_details, query_type",
|
||||
itertools.product(
|
||||
[DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], [None, "ANN"]
|
||||
),
|
||||
)
|
||||
def test_similarity_search_by_vector(index_details: dict) -> None:
|
||||
def test_similarity_search_by_vector(
|
||||
index_details: dict, query_type: Optional[str]
|
||||
) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
@ -668,14 +709,43 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters
|
||||
query_embedding, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=None,
|
||||
query_type=query_type,
|
||||
query_text=None,
|
||||
)
|
||||
assert len(search_result) == len(fake_texts)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
|
||||
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
@pytest.mark.parametrize(
|
||||
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
|
||||
)
|
||||
def test_similarity_search_by_vector_hybrid(index_details: dict) -> None:
|
||||
index = mock_index(index_details)
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
vectorsearch = default_databricks_vector_search(index)
|
||||
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
|
||||
)
|
||||
index.similarity_search.assert_called_once_with(
|
||||
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
query_text="foo",
|
||||
)
|
||||
assert len(search_result) == len(fake_texts)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
|
||||
|
Loading…
Reference in New Issue
Block a user