Fix RRF and lucene escape characters for neo4j vector store (#14646)

* Remove Lucene special characters (fixes
https://github.com/langchain-ai/langchain/issues/14232)
* Fixes RRF normalization for hybrid search
This commit is contained in:
Tomaz Bratanic 2023-12-13 18:09:50 +01:00 committed by GitHub
parent 7e6ca3c2b9
commit ea2616ae23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 6 deletions

View File

@ -48,14 +48,19 @@ def _get_search_index_query(search_type: SearchType) -> str:
"CALL { " "CALL { "
"CALL db.index.vector.queryNodes($index, $k, $embedding) " "CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score " "YIELD node, score "
"RETURN node, score UNION " "WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
# We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) " "CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
"YIELD node, score " "YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max " "WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n " "UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score " # We use 0 as min # We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score "
"} " "} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " # dedup # dedup
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
), ),
} }
return type_to_query_map[search_type] return type_to_query_map[search_type]
@ -75,6 +80,34 @@ def sort_by_index_name(
return sorted(lst, key=lambda x: x.get("index_name") != index_name) return sorted(lst, key=lambda x: x.get("index_name") != index_name)
def remove_lucene_chars(text: str) -> str:
"""Remove Lucene special characters"""
special_chars = [
"+",
"-",
"&",
"|",
"!",
"(",
")",
"{",
"}",
"[",
"]",
"^",
'"',
"~",
"*",
"?",
":",
"\\",
]
for char in special_chars:
if char in text:
text = text.replace(char, " ")
return text.strip()
class Neo4jVector(VectorStore): class Neo4jVector(VectorStore):
"""`Neo4j` vector index. """`Neo4j` vector index.
@ -589,7 +622,7 @@ class Neo4jVector(VectorStore):
"k": k, "k": k,
"embedding": embedding, "embedding": embedding,
"keyword_index": self.keyword_index_name, "keyword_index": self.keyword_index_name,
"query": kwargs["query"], "query": remove_lucene_chars(kwargs["query"]),
} }
results = self.query(read_query, params=parameters) results = self.query(read_query, params=parameters)

View File

@ -4,7 +4,11 @@ from typing import List
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_community.vectorstores.neo4j_vector import Neo4jVector, SearchType from langchain_community.vectorstores.neo4j_vector import (
Neo4jVector,
SearchType,
_get_search_index_query,
)
from langchain_community.vectorstores.utils import DistanceStrategy from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@ -14,7 +18,7 @@ password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
OS_TOKEN_COUNT = 1536 OS_TOKEN_COUNT = 1536
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"]
""" """
cd tests/integration_tests/vectorstores/docker-compose cd tests/integration_tests/vectorstores/docker-compose
@ -615,3 +619,62 @@ def test_neo4jvector_from_existing_graph_multiple_properties_hybrid() -> None:
assert output == [Document(page_content="\nname: Foo\nname2: Fooz")] assert output == [Document(page_content="\nname: Foo\nname2: Fooz")]
drop_vector_indexes(existing) drop_vector_indexes(existing)
def test_neo4jvector_special_character() -> None:
"""Test removing lucene."""
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings))
docsearch = Neo4jVector.from_embeddings(
text_embeddings=text_embedding_pairs,
embedding=FakeEmbeddingsWithOsDimension(),
url=url,
username=username,
password=password,
pre_delete_collection=True,
search_type=SearchType.HYBRID,
)
output = docsearch.similarity_search(
"It is the end of the world. Take shelter!", k=1
)
assert output == [
Document(page_content="It is the end of the world. Take shelter!", metadata={})
]
drop_vector_indexes(docsearch)
def test_hybrid_score_normalization() -> None:
"""Test if we can get two 1.0 documents with RRF"""
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
text_embedding_pairs = list(zip(["foo"], text_embeddings))
docsearch = Neo4jVector.from_embeddings(
text_embeddings=text_embedding_pairs,
embedding=FakeEmbeddingsWithOsDimension(),
url=url,
username=username,
password=password,
pre_delete_collection=True,
search_type=SearchType.HYBRID,
)
# Remove deduplication part of the query
rrf_query = (
_get_search_index_query(SearchType.HYBRID)
.rstrip("WITH node, max(score) AS score ORDER BY score DESC LIMIT $k")
.replace("UNION", "UNION ALL")
+ "RETURN node.text AS text, score LIMIT 2"
)
output = docsearch.query(
rrf_query,
params={
"index": "vector",
"k": 1,
"embedding": FakeEmbeddingsWithOsDimension().embed_query("foo"),
"query": "foo",
"keyword_index": "keyword",
},
)
# Both FT and Vector must return 1.0 score
assert output == [{"text": "foo", "score": 1.0}, {"text": "foo", "score": 1.0}]
drop_vector_indexes(docsearch)

View File

@ -0,0 +1,45 @@
"""Test Neo4j functionality."""
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
def test_escaping_lucene() -> None:
"""Test escaping lucene characters"""
assert remove_lucene_chars("Hello+World") == "Hello World"
assert remove_lucene_chars("Hello World\\") == "Hello World"
assert (
remove_lucene_chars("It is the end of the world. Take shelter!")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter&&")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("Bill&&Melinda Gates Foundation")
== "Bill Melinda Gates Foundation"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter(&&)")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter??")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter^")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter+")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter-")
== "It is the end of the world. Take shelter"
)
assert (
remove_lucene_chars("It is the end of the world. Take shelter~")
== "It is the end of the world. Take shelter"
)