Add neo4j graph environment variables (#12080)

pull/12099/head
Tomaz Bratanic 9 months ago committed by GitHub
parent d5400f6502
commit 82f4c0589c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,8 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from langchain.graphs.graph_document import GraphDocument
from langchain.graphs.graph_store import GraphStore
from langchain.utils import get_from_env
node_properties_query = """
CALL apoc.meta.data()
@ -45,7 +46,11 @@ class Neo4jGraph(GraphStore):
"""
def __init__(
self, url: str, username: str, password: str, database: str = "neo4j"
self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: str = "neo4j",
) -> None:
"""Create a new Neo4j graph wrapper instance."""
try:
@ -56,6 +61,11 @@ class Neo4jGraph(GraphStore):
"Please install it with `pip install neo4j`."
)
url = get_from_env("url", "NEO4J_URI", url)
username = get_from_env("username", "NEO4J_USERNAME", username)
password = get_from_env("password", "NEO4J_PASSWORD", password)
database = get_from_env("database", "NEO4J_DATABASE", database)
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.schema: str = ""

@ -14,7 +14,7 @@ from langchain.llms.openai import OpenAI
def test_connect_neo4j() -> None:
"""Test that Neo4j database is correctly instantiated and connected."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -36,9 +36,22 @@ def test_connect_neo4j() -> None:
assert output == expected_output
def test_connect_neo4j_env() -> None:
"""Test that Neo4j database environment variables."""
graph = Neo4jGraph()
output = graph.query(
"""
RETURN "test" AS output
"""
)
expected_output = [{"output": "test"}]
assert output == expected_output
def test_cypher_generating_run() -> None:
"""Test that Cypher statement is correctly generated and executed."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -68,7 +81,7 @@ def test_cypher_generating_run() -> None:
def test_cypher_top_k() -> None:
"""Test top_k parameter correctly limits the number of results in the context."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -102,7 +115,7 @@ def test_cypher_top_k() -> None:
def test_cypher_intermediate_steps() -> None:
"""Test the returning of the intermediate steps."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -146,7 +159,7 @@ def test_cypher_intermediate_steps() -> None:
def test_cypher_return_direct() -> None:
"""Test that chain returns direct results."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -178,7 +191,7 @@ def test_cypher_return_direct() -> None:
def test_cypher_return_correct_schema() -> None:
"""Test that chain returns direct results."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -243,7 +256,7 @@ def test_cypher_save_load() -> None:
"""Test saving and loading."""
FILE_PATH = "cypher.yaml"
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -267,7 +280,7 @@ def test_cypher_save_load() -> None:
def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -307,7 +320,7 @@ def test_exclude_types() -> None:
def test_include_types() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
@ -347,7 +360,7 @@ def test_include_types() -> None:
def test_include_types2() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None

Loading…
Cancel
Save