From 82f4c0589c4e658cf59ae80afbf256b33036bcbb Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Fri, 20 Oct 2023 23:43:01 +0200 Subject: [PATCH] Add neo4j graph environment variables (#12080) --- .../langchain/langchain/graphs/neo4j_graph.py | 14 ++++++-- .../chains/test_graph_database.py | 33 +++++++++++++------ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/graphs/neo4j_graph.py b/libs/langchain/langchain/graphs/neo4j_graph.py index 4f65674236..dfbf38bb98 100644 --- a/libs/langchain/langchain/graphs/neo4j_graph.py +++ b/libs/langchain/langchain/graphs/neo4j_graph.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from langchain.graphs.graph_document import GraphDocument from langchain.graphs.graph_store import GraphStore +from langchain.utils import get_from_env node_properties_query = """ CALL apoc.meta.data() @@ -45,7 +46,11 @@ class Neo4jGraph(GraphStore): """ def __init__( - self, url: str, username: str, password: str, database: str = "neo4j" + self, + url: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: str = "neo4j", ) -> None: """Create a new Neo4j graph wrapper instance.""" try: @@ -56,6 +61,11 @@ class Neo4jGraph(GraphStore): "Please install it with `pip install neo4j`." ) + url = get_from_env("url", "NEO4J_URI", url) + username = get_from_env("username", "NEO4J_USERNAME", username) + password = get_from_env("password", "NEO4J_PASSWORD", password) + database = get_from_env("database", "NEO4J_DATABASE", database) + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database self.schema: str = "" diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database.py b/libs/langchain/tests/integration_tests/chains/test_graph_database.py index fd67d60f2d..eb40972461 100644 --- a/libs/langchain/tests/integration_tests/chains/test_graph_database.py +++ b/libs/langchain/tests/integration_tests/chains/test_graph_database.py @@ -14,7 +14,7 @@ from langchain.llms.openai import OpenAI def test_connect_neo4j() -> None: """Test that Neo4j database is correctly instantiated and connected.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -36,9 +36,22 @@ def test_connect_neo4j() -> None: assert output == expected_output +def test_connect_neo4j_env() -> None: + """Test that Neo4j database environment variables.""" + graph = Neo4jGraph() + + output = graph.query( + """ + RETURN "test" AS output + """ + ) + expected_output = [{"output": "test"}] + assert output == expected_output + + def test_cypher_generating_run() -> None: """Test that Cypher statement is correctly generated and executed.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -68,7 +81,7 @@ def test_cypher_generating_run() -> None: def test_cypher_top_k() -> None: """Test top_k parameter correctly limits the number of results in the context.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -102,7 +115,7 @@ def test_cypher_top_k() -> None: def test_cypher_intermediate_steps() -> None: """Test the returning of the intermediate steps.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -146,7 +159,7 @@ def test_cypher_intermediate_steps() -> None: def test_cypher_return_direct() -> None: """Test that chain returns direct results.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -178,7 +191,7 @@ def test_cypher_return_direct() -> None: def test_cypher_return_correct_schema() -> None: """Test that chain returns direct results.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -243,7 +256,7 @@ def test_cypher_save_load() -> None: """Test saving and loading.""" FILE_PATH = "cypher.yaml" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -267,7 +280,7 @@ def test_cypher_save_load() -> None: def test_exclude_types() -> None: """Test exclude types from schema.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -307,7 +320,7 @@ def test_exclude_types() -> None: def test_include_types() -> None: """Test include types from schema.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None @@ -347,7 +360,7 @@ def test_include_types() -> None: def test_include_types2() -> None: """Test include types from schema.""" - url = os.environ.get("NEO4J_URL") + url = os.environ.get("NEO4J_URI") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") assert url is not None