mirror of https://github.com/hwchase17/langchain
Harrison/nebula graph (#5865)
Co-authored-by: Wey Gu <weyl.gu@gmail.com> Co-authored-by: chenweisomebody <chenweisomebody@gmail.com>pull/5866/head
parent
658f8bdee7
commit
35cfd25db3
@ -0,0 +1,270 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "c94240f5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# NebulaGraphQAChain\n",
|
||||
"\n",
|
||||
"This notebook shows how to use LLMs to provide a natural language interface to NebulaGraph database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc0ee68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You will need to have a running NebulaGraph cluster, for which you can run a containerized cluster by running the following script:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"curl -fsSL nebula-up.siwei.io/install.sh | bash\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Other options are:\n",
|
||||
"- Install as a [Docker Desktop Extension](https://www.docker.com/blog/distributed-cloud-native-graph-database-nebulagraph-docker-extension/). See [here](https://docs.nebula-graph.io/3.5.0/2.quick-start/1.quick-start-workflow/)\n",
|
||||
"- NebulaGraph Cloud Service. See [here](https://www.nebula-graph.io/cloud)\n",
|
||||
"- Deploy from package, source code, or via Kubernetes. See [here](https://docs.nebula-graph.io/)\n",
|
||||
"\n",
|
||||
"Once the cluster is running, we could create the SPACE and SCHEMA for the database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c82f4141",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install ipython-ngql\n",
|
||||
"%load_ext ngql\n",
|
||||
"\n",
|
||||
"# connect ngql jupyter extension to nebulagraph\n",
|
||||
"%ngql --address 127.0.0.1 --port 9669 --user root --password nebula\n",
|
||||
"# create a new space\n",
|
||||
"%ngql CREATE SPACE IF NOT EXISTS langchain(partition_num=1, replica_factor=1, vid_type=fixed_string(128));\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda0809a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Wait for a few seconds for the space to be created.\n",
|
||||
"%ngql USE langchain;"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "119fe35c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create the schema, for full dataset, refer [here](https://www.siwei.io/en/nebulagraph-etl-dbt/)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5aa796ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%ngql\n",
|
||||
"CREATE TAG IF NOT EXISTS movie(name string);\n",
|
||||
"CREATE TAG IF NOT EXISTS person(name string, birthdate string);\n",
|
||||
"CREATE EDGE IF NOT EXISTS acted_in();\n",
|
||||
"CREATE TAG INDEX IF NOT EXISTS person_index ON person(name(128));\n",
|
||||
"CREATE TAG INDEX IF NOT EXISTS movie_index ON movie(name(128));"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "66e4799a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Wait for schema creation to complete, then we can insert some data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d8eea530",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"UsageError: Cell magic `%%ngql` not found.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%ngql\n",
|
||||
"INSERT VERTEX person(name, birthdate) VALUES \"Al Pacino\":(\"Al Pacino\", \"1940-04-25\");\n",
|
||||
"INSERT VERTEX movie(name) VALUES \"The Godfather II\":(\"The Godfather II\");\n",
|
||||
"INSERT VERTEX movie(name) VALUES \"The Godfather Coda: The Death of Michael Corleone\":(\"The Godfather Coda: The Death of Michael Corleone\");\n",
|
||||
"INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather II\":();\n",
|
||||
"INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather Coda: The Death of Michael Corleone\":();"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "62812aad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.chains import NebulaGraphQAChain\n",
|
||||
"from langchain.graphs import NebulaGraph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0928915d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph = NebulaGraph(\n",
|
||||
" space=\"langchain\",\n",
|
||||
" username=\"root\",\n",
|
||||
" password=\"nebula\",\n",
|
||||
" address=\"127.0.0.1\",\n",
|
||||
" port=9669,\n",
|
||||
" session_pool_size=30,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "58c1a8ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Refresh graph schema information\n",
|
||||
"\n",
|
||||
"If the schema of database changes, you can refresh the schema information needed to generate nGQL statements."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e3de44f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# graph.refresh_schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "1fe76ccd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Node properties: [{'tag': 'movie', 'properties': [('name', 'string')]}, {'tag': 'person', 'properties': [('name', 'string'), ('birthdate', 'string')]}]\n",
|
||||
"Edge properties: [{'edge': 'acted_in', 'properties': []}]\n",
|
||||
"Relationships: ['(:person)-[:acted_in]->(:movie)']\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(graph.get_schema)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "68a3c677",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Querying the graph\n",
|
||||
"\n",
|
||||
"We can now use the graph cypher QA chain to ask question of the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "7476ce98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = NebulaGraphQAChain.from_llm(\n",
|
||||
" ChatOpenAI(temperature=0), graph=graph, verbose=True\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "ef8ee27b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new NebulaGraphQAChain chain...\u001b[0m\n",
|
||||
"Generated nGQL:\n",
|
||||
"\u001b[32;1m\u001b[1;3mMATCH (p:`person`)-[:acted_in]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'\n",
|
||||
"RETURN p.`person`.`name`\u001b[0m\n",
|
||||
"Full Context:\n",
|
||||
"\u001b[32;1m\u001b[1;3m{'p.person.name': ['Al Pacino']}\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Al Pacino played in The Godfather II.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"Who played in The Godfather II?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs.nebula_graph import NebulaGraph
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class NebulaGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating nGQL statements."""
|
||||
|
||||
graph: NebulaGraph = Field(exclude=True)
|
||||
ngql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> NebulaGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
ngql_generation_chain=ngql_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate nGQL statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_ngql = self.ngql_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_ngql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_ngql)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
@ -1,5 +1,6 @@
|
||||
"""Graph implementations."""
|
||||
from langchain.graphs.nebula_graph import NebulaGraph
|
||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||
|
||||
__all__ = ["NetworkxEntityGraph", "Neo4jGraph"]
|
||||
__all__ = ["NetworkxEntityGraph", "Neo4jGraph", "NebulaGraph"]
|
||||
|
@ -0,0 +1,201 @@
|
||||
import logging
|
||||
from string import Template
|
||||
from typing import Any, Dict
|
||||
|
||||
rel_query = Template(
|
||||
"""
|
||||
MATCH ()-[e:`$edge_type`]->()
|
||||
WITH e limit 1
|
||||
MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
|
||||
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
|
||||
"""
|
||||
)
|
||||
|
||||
RETRY_TIMES = 3
|
||||
|
||||
|
||||
class NebulaGraph:
|
||||
"""NebulaGraph wrapper for graph operations
|
||||
NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space: str,
|
||||
username: str = "root",
|
||||
password: str = "nebula",
|
||||
address: str = "127.0.0.1",
|
||||
port: int = 9669,
|
||||
session_pool_size: int = 30,
|
||||
) -> None:
|
||||
"""Create a new NebulaGraph wrapper instance."""
|
||||
try:
|
||||
import nebula3 # noqa: F401
|
||||
import pandas # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Please install NebulaGraph Python client and pandas first: "
|
||||
"`pip install nebula3-python pandas`"
|
||||
)
|
||||
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.space = space
|
||||
self.session_pool_size = session_pool_size
|
||||
|
||||
self.session_pool = self._get_session_pool()
|
||||
self.schema = ""
|
||||
# Set schema
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||
|
||||
def _get_session_pool(self) -> Any:
|
||||
assert all(
|
||||
[self.username, self.password, self.address, self.port, self.space]
|
||||
), (
|
||||
"Please provide all of the following parameters: "
|
||||
"username, password, address, port, space"
|
||||
)
|
||||
|
||||
from nebula3.Config import SessionPoolConfig
|
||||
from nebula3.Exception import AuthFailedException, InValidHostname
|
||||
from nebula3.gclient.net.SessionPool import SessionPool
|
||||
|
||||
config = SessionPoolConfig()
|
||||
config.max_size = self.session_pool_size
|
||||
|
||||
try:
|
||||
session_pool = SessionPool(
|
||||
self.username,
|
||||
self.password,
|
||||
self.space,
|
||||
[(self.address, self.port)],
|
||||
)
|
||||
except InValidHostname:
|
||||
raise ValueError(
|
||||
"Could not connect to NebulaGraph database. "
|
||||
"Please ensure that the address and port are correct"
|
||||
)
|
||||
|
||||
try:
|
||||
session_pool.init(config)
|
||||
except AuthFailedException:
|
||||
raise ValueError(
|
||||
"Could not connect to NebulaGraph database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise ValueError(f"Error initializing session pool. Error: {e}")
|
||||
|
||||
return session_pool
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.session_pool.close()
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not close session pool. Error: {e}")
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the NebulaGraph database"""
|
||||
return self.schema
|
||||
|
||||
def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any:
|
||||
"""Query NebulaGraph database."""
|
||||
from nebula3.Exception import IOErrorException, NoValidSessionException
|
||||
from nebula3.fbthrift.transport.TTransport import TTransportException
|
||||
|
||||
try:
|
||||
result = self.session_pool.execute_parameter(query, params)
|
||||
if not result.is_succeeded():
|
||||
logging.warning(
|
||||
f"Error executing query to NebulaGraph. "
|
||||
f"Error: {result.error_msg()}\n"
|
||||
f"Query: {query} \n"
|
||||
)
|
||||
return result
|
||||
|
||||
except NoValidSessionException:
|
||||
logging.warning(
|
||||
f"No valid session found in session pool. "
|
||||
f"Please consider increasing the session pool size. "
|
||||
f"Current size: {self.session_pool_size}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"No valid session found in session pool. "
|
||||
f"Please consider increasing the session pool size. "
|
||||
f"Current size: {self.session_pool_size}"
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
if retry < RETRY_TIMES:
|
||||
retry += 1
|
||||
logging.warning(
|
||||
f"Error executing query to NebulaGraph. "
|
||||
f"Retrying ({retry}/{RETRY_TIMES})...\n"
|
||||
f"query: {query} \n"
|
||||
f"Error: {e}"
|
||||
)
|
||||
return self.execute(query, params, retry)
|
||||
else:
|
||||
raise ValueError(f"Error executing query to NebulaGraph. Error: {e}")
|
||||
|
||||
except (TTransportException, IOErrorException):
|
||||
# connection issue, try to recreate session pool
|
||||
if retry < RETRY_TIMES:
|
||||
retry += 1
|
||||
logging.warning(
|
||||
f"Connection issue with NebulaGraph. "
|
||||
f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool"
|
||||
)
|
||||
self.session_pool = self._get_session_pool()
|
||||
return self.execute(query, params, retry)
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the NebulaGraph schema information.
|
||||
"""
|
||||
tags_schema, edge_types_schema, relationships = [], [], []
|
||||
for tag in self.execute("SHOW TAGS").column_values("Name"):
|
||||
tag_name = tag.cast()
|
||||
tag_schema = {"tag": tag_name, "properties": []}
|
||||
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
|
||||
props, types = r.column_values("Field"), r.column_values("Type")
|
||||
for i in range(r.row_size()):
|
||||
tag_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||
tags_schema.append(tag_schema)
|
||||
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
|
||||
edge_type_name = edge_type.cast()
|
||||
edge_schema = {"edge": edge_type_name, "properties": []}
|
||||
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
|
||||
props, types = r.column_values("Field"), r.column_values("Type")
|
||||
for i in range(r.row_size()):
|
||||
edge_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||
edge_types_schema.append(edge_schema)
|
||||
|
||||
# build relationships types
|
||||
r = self.execute(
|
||||
rel_query.substitute(edge_type=edge_type_name)
|
||||
).column_values("rels")
|
||||
if len(r) > 0:
|
||||
relationships.append(r[0].cast())
|
||||
|
||||
self.schema = (
|
||||
f"Node properties: {tags_schema}\n"
|
||||
f"Edge properties: {edge_types_schema}\n"
|
||||
f"Relationships: {relationships}\n"
|
||||
)
|
||||
|
||||
def query(self, query: str, retry: int = 0) -> Dict[str, Any]:
|
||||
result = self.execute(query, retry=retry)
|
||||
columns = result.keys()
|
||||
d: Dict[str, list] = {}
|
||||
for col_num in range(result.col_size()):
|
||||
col_name = columns[col_num]
|
||||
col_list = result.column_values(col_name)
|
||||
d[col_name] = [x.cast() for x in col_list]
|
||||
return d
|
@ -0,0 +1,90 @@
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain.graphs import NebulaGraph
|
||||
|
||||
|
||||
class TestNebulaGraph(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.space = "test_space"
|
||||
self.username = "test_user"
|
||||
self.password = "test_password"
|
||||
self.address = "test_address"
|
||||
self.port = 1234
|
||||
self.session_pool_size = 10
|
||||
|
||||
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||
def test_init(self, mock_session_pool: Any) -> None:
|
||||
mock_session_pool.return_value = MagicMock()
|
||||
nebula_graph = NebulaGraph(
|
||||
self.space,
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.session_pool_size,
|
||||
)
|
||||
self.assertEqual(nebula_graph.space, self.space)
|
||||
self.assertEqual(nebula_graph.username, self.username)
|
||||
self.assertEqual(nebula_graph.password, self.password)
|
||||
self.assertEqual(nebula_graph.address, self.address)
|
||||
self.assertEqual(nebula_graph.port, self.port)
|
||||
self.assertEqual(nebula_graph.session_pool_size, self.session_pool_size)
|
||||
|
||||
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||
def test_get_session_pool(self, mock_session_pool: Any) -> None:
|
||||
mock_session_pool.return_value = MagicMock()
|
||||
nebula_graph = NebulaGraph(
|
||||
self.space,
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.session_pool_size,
|
||||
)
|
||||
session_pool = nebula_graph._get_session_pool()
|
||||
self.assertIsInstance(session_pool, MagicMock)
|
||||
|
||||
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||
def test_del(self, mock_session_pool: Any) -> None:
|
||||
mock_session_pool.return_value = MagicMock()
|
||||
nebula_graph = NebulaGraph(
|
||||
self.space,
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.session_pool_size,
|
||||
)
|
||||
nebula_graph.__del__()
|
||||
mock_session_pool.return_value.close.assert_called_once()
|
||||
|
||||
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||
def test_execute(self, mock_session_pool: Any) -> None:
|
||||
mock_session_pool.return_value = MagicMock()
|
||||
nebula_graph = NebulaGraph(
|
||||
self.space,
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.session_pool_size,
|
||||
)
|
||||
query = "SELECT * FROM test_table"
|
||||
result = nebula_graph.execute(query)
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
@patch("nebula3.gclient.net.SessionPool.SessionPool")
|
||||
def test_refresh_schema(self, mock_session_pool: Any) -> None:
|
||||
mock_session_pool.return_value = MagicMock()
|
||||
nebula_graph = NebulaGraph(
|
||||
self.space,
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.session_pool_size,
|
||||
)
|
||||
nebula_graph.refresh_schema()
|
||||
self.assertNotEqual(nebula_graph.get_schema, "")
|
Loading…
Reference in New Issue