From 1e80113ac98b286e581c0210d3bbdcdc57f4e87c Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Wed, 17 Jan 2024 22:22:19 +0100 Subject: [PATCH] community[patch]: Add neo4j timeout and value sanitization option (#16138) The timeout function comes in handy when you want to kill longrunning queries. The value sanitization removes all lists that are larger than 128 elements. The idea here is to remove embedding properties from results. --- .../langchain_community/graphs/neo4j_graph.py | 56 ++++++++++++++++++- .../integration_tests/graphs/test_neo4j.py | 19 +++++++ .../unit_tests/graphs/test_neo4j_graph.py | 32 +++++++++++ 3 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 libs/community/tests/unit_tests/graphs/test_neo4j_graph.py diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index af416b29bc..92efd27f08 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -31,8 +31,50 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output """ +def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitizes the input dictionary by removing embedding-like values, + lists with more than 128 elements, that are mostly irrelevant for + generating answers in a LLM context. These properties, if left in + results, can occupy significant context space and detract from + the LLM's performance by introducing unnecessary noise and cost. + """ + LIST_LIMIT = 128 + # Create a new dictionary to avoid changing size during iteration + new_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + # Recurse to handle nested dictionaries + new_dict[key] = value_sanitize(value) + elif isinstance(value, list): + # check if it has less than LIST_LIMIT values + if len(value) < LIST_LIMIT: + # if value is a list, check if it contains dictionaries to clean + cleaned_list = [] + for item in value: + if isinstance(item, dict): + cleaned_list.append(value_sanitize(item)) + else: + cleaned_list.append(item) + new_dict[key] = cleaned_list + else: + new_dict[key] = value + return new_dict + + class Neo4jGraph(GraphStore): - """Neo4j wrapper for graph operations. + """Provides a connection to a Neo4j database for various graph operations. + Parameters: + url (Optional[str]): The URL of the Neo4j database server. + username (Optional[str]): The username for database authentication. + password (Optional[str]): The password for database authentication. + database (str): The name of the database to connect to. Default is 'neo4j'. + timeout (Optional[float]): The timeout for transactions in seconds. + Useful for terminating long-running queries. + By default, there is no timeout set. + sanitize (bool): A flag to indicate whether to remove lists with + more than 128 elements from results. Useful for removing + embedding-like properties from database responses. Default is False. *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. @@ -52,6 +94,8 @@ class Neo4jGraph(GraphStore): username: Optional[str] = None, password: Optional[str] = None, database: str = "neo4j", + timeout: Optional[float] = None, + sanitize: bool = False, ) -> None: """Create a new Neo4j graph wrapper instance.""" try: @@ -69,6 +113,8 @@ class Neo4jGraph(GraphStore): self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database + self.timeout = timeout + self.sanitize = sanitize self.schema: str = "" self.structured_schema: Dict[str, Any] = {} # Verify connection @@ -106,12 +152,16 @@ class Neo4jGraph(GraphStore): def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Neo4j database.""" + from neo4j import Query from neo4j.exceptions import CypherSyntaxError with self._driver.session(database=self._database) as session: try: - data = session.run(query, params) - return [r.data() for r in data] + data = session.run(Query(text=query, timeout=self.timeout), params) + json_data = [r.data() for r in data] + if self.sanitize: + json_data = value_sanitize(json_data) + return json_data except CypherSyntaxError as e: raise ValueError(f"Generated Cypher Statement is not valid\n{e}") diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index fd9e3bb36f..948fe4c719 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -69,3 +69,22 @@ def test_cypher_return_correct_schema() -> None: sorted(relationships, key=lambda x: x["output"]["end"]) == expected_relationships ) + + +def test_neo4j_timeout() -> None: + """Test that neo4j uses the timeout correctly.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, timeout=0.1) + try: + graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})") + except Exception as e: + assert ( + e.code + == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration" + ) diff --git a/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py new file mode 100644 index 0000000000..b352529ba6 --- /dev/null +++ b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py @@ -0,0 +1,32 @@ +from langchain_community.graphs.neo4j_graph import value_sanitize + + +def test_value_sanitize_with_small_list(): + small_list = list(range(15)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "small_list": small_list} + expected_output = {"key1": "value1", "small_list": small_list} + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_oversized_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": oversized_list} + expected_output = { + "key1": "value1" + # oversized_list should not be included + } + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_nested_oversized_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} + expected_output = {"key1": "value1", "oversized_list": {}} + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_dict_in_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} + expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} + assert value_sanitize(input_dict) == expected_output