diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index b7037621ce..c8022d1db0 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from langchain_core.utils import get_from_env +from langchain_core.utils import get_from_dict_or_env from langchain_community.graphs.graph_document import GraphDocument from langchain_community.graphs.graph_store import GraphStore @@ -154,7 +154,7 @@ class Neo4jGraph(GraphStore): url: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, - database: str = "neo4j", + database: Optional[str] = None, timeout: Optional[float] = None, sanitize: bool = False, ) -> None: @@ -167,10 +167,16 @@ class Neo4jGraph(GraphStore): "Please install it with `pip install neo4j`." ) - url = get_from_env("url", "NEO4J_URI", url) - username = get_from_env("username", "NEO4J_USERNAME", username) - password = get_from_env("password", "NEO4J_PASSWORD", password) - database = get_from_env("database", "NEO4J_DATABASE", database) + url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") + username = get_from_dict_or_env( + {"username": username}, "username", "NEO4J_USERNAME" + ) + password = get_from_dict_or_env( + {"password": password}, "password", "NEO4J_PASSWORD" + ) + database = get_from_dict_or_env( + {"database": database}, "database", "NEO4J_DATABASE", "neo4j" + ) self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 2ce1366ba6..dfb2f98690 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -17,7 +17,7 @@ from typing import ( from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.utils import get_from_env +from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore from langchain_community.vectorstores.utils import DistanceStrategy @@ -155,7 +155,7 @@ class Neo4jVector(VectorStore): password: Optional[str] = None, url: Optional[str] = None, keyword_index_name: Optional[str] = "keyword", - database: str = "neo4j", + database: Optional[str] = None, index_name: str = "vector", node_label: str = "Chunk", embedding_node_property: str = "embedding", @@ -186,11 +186,19 @@ class Neo4jVector(VectorStore): # Handle if the credentials are environment variables # Support URL for backwards compatibility - url = os.environ.get("NEO4J_URL", url) - url = get_from_env("url", "NEO4J_URI", url) - username = get_from_env("username", "NEO4J_USERNAME", username) - password = get_from_env("password", "NEO4J_PASSWORD", password) - database = get_from_env("database", "NEO4J_DATABASE", database) + if not url: + url = os.environ.get("NEO4J_URL") + + url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") + username = get_from_dict_or_env( + {"username": username}, "username", "NEO4J_USERNAME" + ) + password = get_from_dict_or_env( + {"password": password}, "password", "NEO4J_PASSWORD" + ) + database = get_from_dict_or_env( + {"database": database}, "database", "NEO4J_DATABASE", "neo4j" + ) self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database