diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 7ccc2e7d10..bb8f1b9b30 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -48,14 +48,19 @@ def _get_search_index_query(search_type: SearchType) -> str: "CALL { " "CALL db.index.vector.queryNodes($index, $k, $embedding) " "YIELD node, score " - "RETURN node, score UNION " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + # We use 0 as min + "RETURN n.node AS node, (n.score / max) AS score UNION " "CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) " "YIELD node, score " "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " # We use 0 as min + # We use 0 as min + "RETURN n.node AS node, (n.score / max) AS score " "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " # dedup + # dedup + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " ), } return type_to_query_map[search_type] @@ -75,6 +80,34 @@ def sort_by_index_name( return sorted(lst, key=lambda x: x.get("index_name") != index_name) +def remove_lucene_chars(text: str) -> str: + """Remove Lucene special characters""" + special_chars = [ + "+", + "-", + "&", + "|", + "!", + "(", + ")", + "{", + "}", + "[", + "]", + "^", + '"', + "~", + "*", + "?", + ":", + "\\", + ] + for char in special_chars: + if char in text: + text = text.replace(char, " ") + return text.strip() + + class Neo4jVector(VectorStore): """`Neo4j` vector index. @@ -589,7 +622,7 @@ class Neo4jVector(VectorStore): "k": k, "embedding": embedding, "keyword_index": self.keyword_index_name, - "query": kwargs["query"], + "query": remove_lucene_chars(kwargs["query"]), } results = self.query(read_query, params=parameters) diff --git a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py index f7ba418b47..cb0c79a3a0 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py @@ -4,7 +4,11 @@ from typing import List from langchain_core.documents import Document -from langchain_community.vectorstores.neo4j_vector import Neo4jVector, SearchType +from langchain_community.vectorstores.neo4j_vector import ( + Neo4jVector, + SearchType, + _get_search_index_query, +) from langchain_community.vectorstores.utils import DistanceStrategy from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings @@ -14,7 +18,7 @@ password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") OS_TOKEN_COUNT = 1536 -texts = ["foo", "bar", "baz"] +texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"] """ cd tests/integration_tests/vectorstores/docker-compose @@ -615,3 +619,62 @@ def test_neo4jvector_from_existing_graph_multiple_properties_hybrid() -> None: assert output == [Document(page_content="\nname: Foo\nname2: Fooz")] drop_vector_indexes(existing) + + +def test_neo4jvector_special_character() -> None: + """Test removing lucene.""" + text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = Neo4jVector.from_embeddings( + text_embeddings=text_embedding_pairs, + embedding=FakeEmbeddingsWithOsDimension(), + url=url, + username=username, + password=password, + pre_delete_collection=True, + search_type=SearchType.HYBRID, + ) + output = docsearch.similarity_search( + "It is the end of the world. Take shelter!", k=1 + ) + assert output == [ + Document(page_content="It is the end of the world. Take shelter!", metadata={}) + ] + + drop_vector_indexes(docsearch) + + +def test_hybrid_score_normalization() -> None: + """Test if we can get two 1.0 documents with RRF""" + text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) + text_embedding_pairs = list(zip(["foo"], text_embeddings)) + docsearch = Neo4jVector.from_embeddings( + text_embeddings=text_embedding_pairs, + embedding=FakeEmbeddingsWithOsDimension(), + url=url, + username=username, + password=password, + pre_delete_collection=True, + search_type=SearchType.HYBRID, + ) + # Remove deduplication part of the query + rrf_query = ( + _get_search_index_query(SearchType.HYBRID) + .rstrip("WITH node, max(score) AS score ORDER BY score DESC LIMIT $k") + .replace("UNION", "UNION ALL") + + "RETURN node.text AS text, score LIMIT 2" + ) + + output = docsearch.query( + rrf_query, + params={ + "index": "vector", + "k": 1, + "embedding": FakeEmbeddingsWithOsDimension().embed_query("foo"), + "query": "foo", + "keyword_index": "keyword", + }, + ) + # Both FT and Vector must return 1.0 score + assert output == [{"text": "foo", "score": 1.0}, {"text": "foo", "score": 1.0}] + drop_vector_indexes(docsearch) diff --git a/libs/community/tests/unit_tests/vectorstores/test_neo4j.py b/libs/community/tests/unit_tests/vectorstores/test_neo4j.py new file mode 100644 index 0000000000..280334283e --- /dev/null +++ b/libs/community/tests/unit_tests/vectorstores/test_neo4j.py @@ -0,0 +1,45 @@ +"""Test Neo4j functionality.""" + +from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars + + +def test_escaping_lucene() -> None: + """Test escaping lucene characters""" + assert remove_lucene_chars("Hello+World") == "Hello World" + assert remove_lucene_chars("Hello World\\") == "Hello World" + assert ( + remove_lucene_chars("It is the end of the world. Take shelter!") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter&&") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("Bill&&Melinda Gates Foundation") + == "Bill Melinda Gates Foundation" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter(&&)") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter??") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter^") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter+") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter-") + == "It is the end of the world. Take shelter" + ) + assert ( + remove_lucene_chars("It is the end of the world. Take shelter~") + == "It is the end of the world. Take shelter" + )