mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Fix RRF and lucene escape characters for neo4j vector store (#14646)
* Remove Lucene special characters (fixes https://github.com/langchain-ai/langchain/issues/14232) * Fixes RRF normalization for hybrid search
This commit is contained in:
parent
7e6ca3c2b9
commit
ea2616ae23
@ -48,14 +48,19 @@ def _get_search_index_query(search_type: SearchType) -> str:
|
||||
"CALL { "
|
||||
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
|
||||
"YIELD node, score "
|
||||
"RETURN node, score UNION "
|
||||
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
|
||||
"UNWIND nodes AS n "
|
||||
# We use 0 as min
|
||||
"RETURN n.node AS node, (n.score / max) AS score UNION "
|
||||
"CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) "
|
||||
"YIELD node, score "
|
||||
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
|
||||
"UNWIND nodes AS n "
|
||||
"RETURN n.node AS node, (n.score / max) AS score " # We use 0 as min
|
||||
# We use 0 as min
|
||||
"RETURN n.node AS node, (n.score / max) AS score "
|
||||
"} "
|
||||
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " # dedup
|
||||
# dedup
|
||||
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
|
||||
),
|
||||
}
|
||||
return type_to_query_map[search_type]
|
||||
@ -75,6 +80,34 @@ def sort_by_index_name(
|
||||
return sorted(lst, key=lambda x: x.get("index_name") != index_name)
|
||||
|
||||
|
||||
def remove_lucene_chars(text: str) -> str:
|
||||
"""Remove Lucene special characters"""
|
||||
special_chars = [
|
||||
"+",
|
||||
"-",
|
||||
"&",
|
||||
"|",
|
||||
"!",
|
||||
"(",
|
||||
")",
|
||||
"{",
|
||||
"}",
|
||||
"[",
|
||||
"]",
|
||||
"^",
|
||||
'"',
|
||||
"~",
|
||||
"*",
|
||||
"?",
|
||||
":",
|
||||
"\\",
|
||||
]
|
||||
for char in special_chars:
|
||||
if char in text:
|
||||
text = text.replace(char, " ")
|
||||
return text.strip()
|
||||
|
||||
|
||||
class Neo4jVector(VectorStore):
|
||||
"""`Neo4j` vector index.
|
||||
|
||||
@ -589,7 +622,7 @@ class Neo4jVector(VectorStore):
|
||||
"k": k,
|
||||
"embedding": embedding,
|
||||
"keyword_index": self.keyword_index_name,
|
||||
"query": kwargs["query"],
|
||||
"query": remove_lucene_chars(kwargs["query"]),
|
||||
}
|
||||
|
||||
results = self.query(read_query, params=parameters)
|
||||
|
@ -4,7 +4,11 @@ from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores.neo4j_vector import Neo4jVector, SearchType
|
||||
from langchain_community.vectorstores.neo4j_vector import (
|
||||
Neo4jVector,
|
||||
SearchType,
|
||||
_get_search_index_query,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
@ -14,7 +18,7 @@ password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
|
||||
|
||||
OS_TOKEN_COUNT = 1536
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"]
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/vectorstores/docker-compose
|
||||
@ -615,3 +619,62 @@ def test_neo4jvector_from_existing_graph_multiple_properties_hybrid() -> None:
|
||||
assert output == [Document(page_content="\nname: Foo\nname2: Fooz")]
|
||||
|
||||
drop_vector_indexes(existing)
|
||||
|
||||
|
||||
def test_neo4jvector_special_character() -> None:
|
||||
"""Test removing lucene."""
|
||||
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
|
||||
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||
docsearch = Neo4jVector.from_embeddings(
|
||||
text_embeddings=text_embedding_pairs,
|
||||
embedding=FakeEmbeddingsWithOsDimension(),
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
pre_delete_collection=True,
|
||||
search_type=SearchType.HYBRID,
|
||||
)
|
||||
output = docsearch.similarity_search(
|
||||
"It is the end of the world. Take shelter!", k=1
|
||||
)
|
||||
assert output == [
|
||||
Document(page_content="It is the end of the world. Take shelter!", metadata={})
|
||||
]
|
||||
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_hybrid_score_normalization() -> None:
|
||||
"""Test if we can get two 1.0 documents with RRF"""
|
||||
text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts)
|
||||
text_embedding_pairs = list(zip(["foo"], text_embeddings))
|
||||
docsearch = Neo4jVector.from_embeddings(
|
||||
text_embeddings=text_embedding_pairs,
|
||||
embedding=FakeEmbeddingsWithOsDimension(),
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
pre_delete_collection=True,
|
||||
search_type=SearchType.HYBRID,
|
||||
)
|
||||
# Remove deduplication part of the query
|
||||
rrf_query = (
|
||||
_get_search_index_query(SearchType.HYBRID)
|
||||
.rstrip("WITH node, max(score) AS score ORDER BY score DESC LIMIT $k")
|
||||
.replace("UNION", "UNION ALL")
|
||||
+ "RETURN node.text AS text, score LIMIT 2"
|
||||
)
|
||||
|
||||
output = docsearch.query(
|
||||
rrf_query,
|
||||
params={
|
||||
"index": "vector",
|
||||
"k": 1,
|
||||
"embedding": FakeEmbeddingsWithOsDimension().embed_query("foo"),
|
||||
"query": "foo",
|
||||
"keyword_index": "keyword",
|
||||
},
|
||||
)
|
||||
# Both FT and Vector must return 1.0 score
|
||||
assert output == [{"text": "foo", "score": 1.0}, {"text": "foo", "score": 1.0}]
|
||||
drop_vector_indexes(docsearch)
|
||||
|
45
libs/community/tests/unit_tests/vectorstores/test_neo4j.py
Normal file
45
libs/community/tests/unit_tests/vectorstores/test_neo4j.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Test Neo4j functionality."""
|
||||
|
||||
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
|
||||
|
||||
|
||||
def test_escaping_lucene() -> None:
|
||||
"""Test escaping lucene characters"""
|
||||
assert remove_lucene_chars("Hello+World") == "Hello World"
|
||||
assert remove_lucene_chars("Hello World\\") == "Hello World"
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter!")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter&&")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("Bill&&Melinda Gates Foundation")
|
||||
== "Bill Melinda Gates Foundation"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter(&&)")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter??")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter^")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter+")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter-")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
||||
assert (
|
||||
remove_lucene_chars("It is the end of the world. Take shelter~")
|
||||
== "It is the end of the world. Take shelter"
|
||||
)
|
Loading…
Reference in New Issue
Block a user