mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Fix RRF and lucene escape characters for neo4j vector store (#14646)
* Remove Lucene special characters (fixes https://github.com/langchain-ai/langchain/issues/14232) * Fixes RRF normalization for hybrid search
This commit is contained in:
parent
7e6ca3c2b9
commit
ea2616ae23
@ -48,14 +48,19 @@ def _get_search_index_query(search_type: SearchType) -> str:
|
|||||||
"CALL { "
|
"CALL { "
|
||||||
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
|
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
|
||||||
"YIELD node, score "
|
"YIELD node, score "
|
||||||
"RETURN node, score UNION "
|
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
|
||||||
|
"UNWIND nodes AS n "
|
||||||
|
# We use 0 as min
|
||||||
|
"RETURN n.node AS node, (n.score / max) AS score UNION "
|
||||||
"CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
|
"CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
|
||||||
"YIELD node, score "
|
"YIELD node, score "
|
||||||
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
|
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
|
||||||
"UNWIND nodes AS n "
|
"UNWIND nodes AS n "
|
||||||
"RETURN n.node AS node, (n.score / max) AS score " # We use 0 as min
|
# We use 0 as min
|
||||||
|
"RETURN n.node AS node, (n.score / max) AS score "
|
||||||
"} "
|
"} "
|
||||||
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " # dedup
|
# dedup
|
||||||
|
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return type_to_query_map[search_type]
|
return type_to_query_map[search_type]
|
||||||
@ -75,6 +80,34 @@ def sort_by_index_name(
|
|||||||
return sorted(lst, key=lambda x: x.get("index_name") != index_name)
|
return sorted(lst, key=lambda x: x.get("index_name") != index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_lucene_chars(text: str) -> str:
|
||||||
|
"""Remove Lucene special characters"""
|
||||||
|
special_chars = [
|
||||||
|
"+",
|
||||||
|
"-",
|
||||||
|
"&",
|
||||||
|
"|",
|
||||||
|
"!",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
"^",
|
||||||
|
'"',
|
||||||
|
"~",
|
||||||
|
"*",
|
||||||
|
"?",
|
||||||
|
":",
|
||||||
|
"\\",
|
||||||
|
]
|
||||||
|
for char in special_chars:
|
||||||
|
if char in text:
|
||||||
|
text = text.replace(char, " ")
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
class Neo4jVector(VectorStore):
|
class Neo4jVector(VectorStore):
|
||||||
"""`Neo4j` vector index.
|
"""`Neo4j` vector index.
|
||||||
|
|
||||||
@ -589,7 +622,7 @@ class Neo4jVector(VectorStore):
|
|||||||
"k": k,
|
"k": k,
|
||||||
"embedding": embedding,
|
"embedding": embedding,
|
||||||
"keyword_index": self.keyword_index_name,
|
"keyword_index": self.keyword_index_name,
|
||||||
"query": kwargs["query"],
|
"query": remove_lucene_chars(kwargs["query"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
results = self.query(read_query, params=parameters)
|
results = self.query(read_query, params=parameters)
|
||||||
|
@ -4,7 +4,11 @@ from typing import List
|
|||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain_community.vectorstores.neo4j_vector import Neo4jVector, SearchType
|
from langchain_community.vectorstores.neo4j_vector import (
|
||||||
|
Neo4jVector,
|
||||||
|
SearchType,
|
||||||
|
_get_search_index_query,
|
||||||
|
)
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
@ -14,7 +18,7 @@ password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
|
|||||||
|
|
||||||
OS_TOKEN_COUNT = 1536
|
OS_TOKEN_COUNT = 1536
|
||||||
|
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cd tests/integration_tests/vectorstores/docker-compose
|
cd tests/integration_tests/vectorstores/docker-compose
|
||||||
@ -615,3 +619,62 @@ def test_neo4jvector_from_existing_graph_multiple_properties_hybrid() -> None:
|
|||||||
assert output == [Document(page_content="\nname: Foo\nname2: Fooz")]
|
assert output == [Document(page_content="\nname: Foo\nname2: Fooz")]
|
||||||
|
|
||||||
drop_vector_indexes(existing)
|
drop_vector_indexes(existing)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4jvector_special_character() -> None:
|
||||||
|
"""Test removing lucene."""
|
||||||
|
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
|
||||||
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||||
|
docsearch = Neo4jVector.from_embeddings(
|
||||||
|
text_embeddings=text_embedding_pairs,
|
||||||
|
embedding=FakeEmbeddingsWithOsDimension(),
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
search_type=SearchType.HYBRID,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"It is the end of the world. Take shelter!", k=1
|
||||||
|
)
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="It is the end of the world. Take shelter!", metadata={})
|
||||||
|
]
|
||||||
|
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_score_normalization() -> None:
|
||||||
|
"""Test if we can get two 1.0 documents with RRF"""
|
||||||
|
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
|
||||||
|
text_embedding_pairs = list(zip(["foo"], text_embeddings))
|
||||||
|
docsearch = Neo4jVector.from_embeddings(
|
||||||
|
text_embeddings=text_embedding_pairs,
|
||||||
|
embedding=FakeEmbeddingsWithOsDimension(),
|
||||||
|
url=url,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
search_type=SearchType.HYBRID,
|
||||||
|
)
|
||||||
|
# Remove deduplication part of the query
|
||||||
|
rrf_query = (
|
||||||
|
_get_search_index_query(SearchType.HYBRID)
|
||||||
|
.rstrip("WITH node, max(score) AS score ORDER BY score DESC LIMIT $k")
|
||||||
|
.replace("UNION", "UNION ALL")
|
||||||
|
+ "RETURN node.text AS text, score LIMIT 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = docsearch.query(
|
||||||
|
rrf_query,
|
||||||
|
params={
|
||||||
|
"index": "vector",
|
||||||
|
"k": 1,
|
||||||
|
"embedding": FakeEmbeddingsWithOsDimension().embed_query("foo"),
|
||||||
|
"query": "foo",
|
||||||
|
"keyword_index": "keyword",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Both FT and Vector must return 1.0 score
|
||||||
|
assert output == [{"text": "foo", "score": 1.0}, {"text": "foo", "score": 1.0}]
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
45
libs/community/tests/unit_tests/vectorstores/test_neo4j.py
Normal file
45
libs/community/tests/unit_tests/vectorstores/test_neo4j.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""Test Neo4j functionality."""
|
||||||
|
|
||||||
|
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
|
||||||
|
|
||||||
|
|
||||||
|
def test_escaping_lucene() -> None:
|
||||||
|
"""Test escaping lucene characters"""
|
||||||
|
assert remove_lucene_chars("Hello+World") == "Hello World"
|
||||||
|
assert remove_lucene_chars("Hello World\\") == "Hello World"
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter!")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter&&")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("Bill&&Melinda Gates Foundation")
|
||||||
|
== "Bill Melinda Gates Foundation"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter(&&)")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter??")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter^")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter+")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter-")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remove_lucene_chars("It is the end of the world. Take shelter~")
|
||||||
|
== "It is the end of the world. Take shelter"
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user