diff --git a/docs/modules/chains/examples/graph_nebula_qa.ipynb b/docs/modules/chains/examples/graph_nebula_qa.ipynb new file mode 100644 index 00000000..f4a77de2 --- /dev/null +++ b/docs/modules/chains/examples/graph_nebula_qa.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "c94240f5", + "metadata": {}, + "source": [ + "# NebulaGraphQAChain\n", + "\n", + "This notebook shows how to use LLMs to provide a natural language interface to NebulaGraph database." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "dbc0ee68", + "metadata": {}, + "source": [ + "You will need to have a running NebulaGraph cluster, for which you can run a containerized cluster by running the following script:\n", + "\n", + "```bash\n", + "curl -fsSL nebula-up.siwei.io/install.sh | bash\n", + "```\n", + "\n", + "Other options are:\n", + "- Install as a [Docker Desktop Extension](https://www.docker.com/blog/distributed-cloud-native-graph-database-nebulagraph-docker-extension/). See [here](https://docs.nebula-graph.io/3.5.0/2.quick-start/1.quick-start-workflow/)\n", + "- NebulaGraph Cloud Service. See [here](https://www.nebula-graph.io/cloud)\n", + "- Deploy from package, source code, or via Kubernetes. See [here](https://docs.nebula-graph.io/)\n", + "\n", + "Once the cluster is running, we could create the SPACE and SCHEMA for the database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c82f4141", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install ipython-ngql\n", + "%load_ext ngql\n", + "\n", + "# connect ngql jupyter extension to nebulagraph\n", + "%ngql --address 127.0.0.1 --port 9669 --user root --password nebula\n", + "# create a new space\n", + "%ngql CREATE SPACE IF NOT EXISTS langchain(partition_num=1, replica_factor=1, vid_type=fixed_string(128));\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eda0809a", + "metadata": {}, + "outputs": [], + "source": [ + "# Wait for a few seconds for the space to be created.\n", + "%ngql USE langchain;" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "119fe35c", + "metadata": {}, + "source": [ + "Create the schema, for full dataset, refer [here](https://www.siwei.io/en/nebulagraph-etl-dbt/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aa796ee", + "metadata": {}, + "outputs": [], + "source": [ + "%%ngql\n", + "CREATE TAG IF NOT EXISTS movie(name string);\n", + "CREATE TAG IF NOT EXISTS person(name string, birthdate string);\n", + "CREATE EDGE IF NOT EXISTS acted_in();\n", + "CREATE TAG INDEX IF NOT EXISTS person_index ON person(name(128));\n", + "CREATE TAG INDEX IF NOT EXISTS movie_index ON movie(name(128));" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "66e4799a", + "metadata": {}, + "source": [ + "Wait for schema creation to complete, then we can insert some data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d8eea530", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "UsageError: Cell magic `%%ngql` not found.\n" + ] + } + ], + "source": [ + "%%ngql\n", + "INSERT VERTEX person(name, birthdate) VALUES \"Al Pacino\":(\"Al Pacino\", \"1940-04-25\");\n", + "INSERT VERTEX movie(name) VALUES \"The Godfather II\":(\"The Godfather II\");\n", + "INSERT VERTEX movie(name) VALUES \"The Godfather Coda: The Death of Michael Corleone\":(\"The Godfather Coda: The Death of Michael Corleone\");\n", + "INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather II\":();\n", + "INSERT EDGE acted_in() VALUES \"Al Pacino\"->\"The Godfather Coda: The Death of Michael Corleone\":();" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "62812aad", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import NebulaGraphQAChain\n", + "from langchain.graphs import NebulaGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0928915d", + "metadata": {}, + "outputs": [], + "source": [ + "graph = NebulaGraph(\n", + " space=\"langchain\",\n", + " username=\"root\",\n", + " password=\"nebula\",\n", + " address=\"127.0.0.1\",\n", + " port=9669,\n", + " session_pool_size=30,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "58c1a8ea", + "metadata": {}, + "source": [ + "## Refresh graph schema information\n", + "\n", + "If the schema of database changes, you can refresh the schema information needed to generate nGQL statements." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e3de44f", + "metadata": {}, + "outputs": [], + "source": [ + "# graph.refresh_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1fe76ccd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Node properties: [{'tag': 'movie', 'properties': [('name', 'string')]}, {'tag': 'person', 'properties': [('name', 'string'), ('birthdate', 'string')]}]\n", + "Edge properties: [{'edge': 'acted_in', 'properties': []}]\n", + "Relationships: ['(:person)-[:acted_in]->(:movie)']\n", + "\n" + ] + } + ], + "source": [ + "print(graph.get_schema)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "68a3c677", + "metadata": {}, + "source": [ + "## Querying the graph\n", + "\n", + "We can now use the graph cypher QA chain to ask question of the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7476ce98", + "metadata": {}, + "outputs": [], + "source": [ + "chain = NebulaGraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ef8ee27b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new NebulaGraphQAChain chain...\u001b[0m\n", + "Generated nGQL:\n", + "\u001b[32;1m\u001b[1;3mMATCH (p:`person`)-[:acted_in]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'\n", + "RETURN p.`person`.`name`\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m{'p.person.name': ['Al Pacino']}\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Al Pacino played in The Godfather II.'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Who played in The Godfather II?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 196803d3..72c46963 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -11,6 +11,7 @@ from langchain.chains.conversational_retrieval.base import ( from langchain.chains.flare.base import FlareChain from langchain.chains.graph_qa.base import GraphQAChain from langchain.chains.graph_qa.cypher import GraphCypherQAChain +from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain @@ -67,4 +68,5 @@ __all__ = [ "ConversationalRetrievalChain", "OpenAPIEndpointChain", "FlareChain", + "NebulaGraphQAChain", ] diff --git a/langchain/chains/graph_qa/nebulagraph.py b/langchain/chains/graph_qa/nebulagraph.py new file mode 100644 index 00000000..ab4048fd --- /dev/null +++ b/langchain/chains/graph_qa/nebulagraph.py @@ -0,0 +1,91 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import Field + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT +from langchain.chains.llm import LLMChain +from langchain.graphs.nebula_graph import NebulaGraph +from langchain.prompts.base import BasePromptTemplate + + +class NebulaGraphQAChain(Chain): + """Chain for question-answering against a graph by generating nGQL statements.""" + + graph: NebulaGraph = Field(exclude=True) + ngql_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT, + **kwargs: Any, + ) -> NebulaGraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt) + + return cls( + qa_chain=qa_chain, + ngql_generation_chain=ngql_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate nGQL statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + generated_ngql = self.ngql_generation_chain.run( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_ngql, color="green", end="\n", verbose=self.verbose + ) + context = self.graph.query(generated_ngql) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index aefb2489..df3f6f13 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -49,6 +49,29 @@ CYPHER_GENERATION_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE ) +NEBULAGRAPH_EXTRA_INSTRUCTIONS = """ +Instructions: + +First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard): +1. it requires explicit label specification when referring to node properties: v.`Foo`.name +2. it uses double equals sign for comparison: `==` rather than `=` +For instance: +```diff +< MATCH (p:person)-[:directed]->(m:movie) WHERE m.name = 'The Godfather II' +< RETURN p.name; +--- +> MATCH (p:`person`)-[:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II' +> RETURN p.`person`.`name`; +```\n""" + +NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Generate Cypher", "Generate NebulaGraph Cypher" +).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS) + +NGQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE +) + CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. The information part contains the provided information that you must use to construct an answer. The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it. diff --git a/langchain/graphs/__init__.py b/langchain/graphs/__init__.py index 72b0976a..138efa79 100644 --- a/langchain/graphs/__init__.py +++ b/langchain/graphs/__init__.py @@ -1,5 +1,6 @@ """Graph implementations.""" +from langchain.graphs.nebula_graph import NebulaGraph from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.graphs.networkx_graph import NetworkxEntityGraph -__all__ = ["NetworkxEntityGraph", "Neo4jGraph"] +__all__ = ["NetworkxEntityGraph", "Neo4jGraph", "NebulaGraph"] diff --git a/langchain/graphs/nebula_graph.py b/langchain/graphs/nebula_graph.py new file mode 100644 index 00000000..cfe07297 --- /dev/null +++ b/langchain/graphs/nebula_graph.py @@ -0,0 +1,201 @@ +import logging +from string import Template +from typing import Any, Dict + +rel_query = Template( + """ +MATCH ()-[e:`$edge_type`]->() + WITH e limit 1 +MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e) +RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels +""" +) + +RETRY_TIMES = 3 + + +class NebulaGraph: + """NebulaGraph wrapper for graph operations + NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space. + """ + + def __init__( + self, + space: str, + username: str = "root", + password: str = "nebula", + address: str = "127.0.0.1", + port: int = 9669, + session_pool_size: int = 30, + ) -> None: + """Create a new NebulaGraph wrapper instance.""" + try: + import nebula3 # noqa: F401 + import pandas # noqa: F401 + except ImportError: + raise ValueError( + "Please install NebulaGraph Python client and pandas first: " + "`pip install nebula3-python pandas`" + ) + + self.username = username + self.password = password + self.address = address + self.port = port + self.space = space + self.session_pool_size = session_pool_size + + self.session_pool = self._get_session_pool() + self.schema = "" + # Set schema + try: + self.refresh_schema() + except Exception as e: + raise ValueError(f"Could not refresh schema. Error: {e}") + + def _get_session_pool(self) -> Any: + assert all( + [self.username, self.password, self.address, self.port, self.space] + ), ( + "Please provide all of the following parameters: " + "username, password, address, port, space" + ) + + from nebula3.Config import SessionPoolConfig + from nebula3.Exception import AuthFailedException, InValidHostname + from nebula3.gclient.net.SessionPool import SessionPool + + config = SessionPoolConfig() + config.max_size = self.session_pool_size + + try: + session_pool = SessionPool( + self.username, + self.password, + self.space, + [(self.address, self.port)], + ) + except InValidHostname: + raise ValueError( + "Could not connect to NebulaGraph database. " + "Please ensure that the address and port are correct" + ) + + try: + session_pool.init(config) + except AuthFailedException: + raise ValueError( + "Could not connect to NebulaGraph database. " + "Please ensure that the username and password are correct" + ) + except RuntimeError as e: + raise ValueError(f"Error initializing session pool. Error: {e}") + + return session_pool + + def __del__(self) -> None: + try: + self.session_pool.close() + except Exception as e: + logging.warning(f"Could not close session pool. Error: {e}") + + @property + def get_schema(self) -> str: + """Returns the schema of the NebulaGraph database""" + return self.schema + + def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any: + """Query NebulaGraph database.""" + from nebula3.Exception import IOErrorException, NoValidSessionException + from nebula3.fbthrift.transport.TTransport import TTransportException + + try: + result = self.session_pool.execute_parameter(query, params) + if not result.is_succeeded(): + logging.warning( + f"Error executing query to NebulaGraph. " + f"Error: {result.error_msg()}\n" + f"Query: {query} \n" + ) + return result + + except NoValidSessionException: + logging.warning( + f"No valid session found in session pool. " + f"Please consider increasing the session pool size. " + f"Current size: {self.session_pool_size}" + ) + raise ValueError( + f"No valid session found in session pool. " + f"Please consider increasing the session pool size. " + f"Current size: {self.session_pool_size}" + ) + + except RuntimeError as e: + if retry < RETRY_TIMES: + retry += 1 + logging.warning( + f"Error executing query to NebulaGraph. " + f"Retrying ({retry}/{RETRY_TIMES})...\n" + f"query: {query} \n" + f"Error: {e}" + ) + return self.execute(query, params, retry) + else: + raise ValueError(f"Error executing query to NebulaGraph. Error: {e}") + + except (TTransportException, IOErrorException): + # connection issue, try to recreate session pool + if retry < RETRY_TIMES: + retry += 1 + logging.warning( + f"Connection issue with NebulaGraph. " + f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool" + ) + self.session_pool = self._get_session_pool() + return self.execute(query, params, retry) + + def refresh_schema(self) -> None: + """ + Refreshes the NebulaGraph schema information. + """ + tags_schema, edge_types_schema, relationships = [], [], [] + for tag in self.execute("SHOW TAGS").column_values("Name"): + tag_name = tag.cast() + tag_schema = {"tag": tag_name, "properties": []} + r = self.execute(f"DESCRIBE TAG `{tag_name}`") + props, types = r.column_values("Field"), r.column_values("Type") + for i in range(r.row_size()): + tag_schema["properties"].append((props[i].cast(), types[i].cast())) + tags_schema.append(tag_schema) + for edge_type in self.execute("SHOW EDGES").column_values("Name"): + edge_type_name = edge_type.cast() + edge_schema = {"edge": edge_type_name, "properties": []} + r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`") + props, types = r.column_values("Field"), r.column_values("Type") + for i in range(r.row_size()): + edge_schema["properties"].append((props[i].cast(), types[i].cast())) + edge_types_schema.append(edge_schema) + + # build relationships types + r = self.execute( + rel_query.substitute(edge_type=edge_type_name) + ).column_values("rels") + if len(r) > 0: + relationships.append(r[0].cast()) + + self.schema = ( + f"Node properties: {tags_schema}\n" + f"Edge properties: {edge_types_schema}\n" + f"Relationships: {relationships}\n" + ) + + def query(self, query: str, retry: int = 0) -> Dict[str, Any]: + result = self.execute(query, retry=retry) + columns = result.keys() + d: Dict[str, list] = {} + for col_num in range(result.col_size()): + col_name = columns[col_num] + col_list = result.column_values(col_name) + d[col_name] = [x.cast() for x in col_list] + return d diff --git a/poetry.lock b/poetry.lock index de89033b..bef0d8c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2355,6 +2355,17 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "future" +version = "0.18.3" +description = "Clean single-source support for Python 3 and 2" +category = "main" +optional = true +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "future-0.18.3.tar.gz", hash = "sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307"}, +] + [[package]] name = "gast" version = "0.4.0" @@ -5184,6 +5195,24 @@ nbformat = "*" sphinx = ">=1.8" traitlets = ">=5" +[[package]] +name = "nebula3-python" +version = "3.4.0" +description = "Python client for NebulaGraph V3.4" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "nebula3-python-3.4.0.tar.gz", hash = "sha256:47bd8b1b4bb2c2f0e5122bc147926cb50578a66841acf6a743cae4d0362c9eaa"}, + {file = "nebula3_python-3.4.0-py3-none-any.whl", hash = "sha256:d9d94c6a41712875e6ec866907de0789057f860e64f547f87d9f199439759dd6"}, +] + +[package.dependencies] +future = ">=0.18.0" +httplib2 = ">=0.20.0" +pytz = ">=2021.1" +six = ">=1.16.0" + [[package]] name = "neo4j" version = "5.9.0" @@ -11311,7 +11340,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "tigrisdb"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "nebula3-python"] azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech"] cohere = ["cohere"] docarray = ["docarray"] @@ -11325,4 +11354,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "faeb3cc6feb059096a66ba8b1fd2271cd91e3a9553cb4f05e5ea493610ac3763" +content-hash = "836aca50cdc2300a684e7c039cfe8100b705f21496014bd49f64d92f4a6baa10" diff --git a/pyproject.toml b/pyproject.toml index 8221eeaa..3704f223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ bibtexparser = {version = "^1.4.0", optional = true} singlestoredb = {version = "^0.6.1", optional = true} pyspark = {version = "^3.4.0", optional = true} tigrisdb = {version = "^1.0.0b6", optional = true} +nebula3-python = {version = "^3.4.0", optional = true} langchainplus-sdk = ">=0.0.6" @@ -283,7 +284,7 @@ all = [ "azure-cognitiveservices-speech", "momento", "singlestoredb", - "tigrisdb" + "nebula3-python", ] # An extra used to be able to add extended testing. diff --git a/tests/integration_tests/test_nebulagraph.py b/tests/integration_tests/test_nebulagraph.py new file mode 100644 index 00000000..bf10f909 --- /dev/null +++ b/tests/integration_tests/test_nebulagraph.py @@ -0,0 +1,90 @@ +import unittest +from typing import Any +from unittest.mock import MagicMock, patch + +from langchain.graphs import NebulaGraph + + +class TestNebulaGraph(unittest.TestCase): + def setUp(self) -> None: + self.space = "test_space" + self.username = "test_user" + self.password = "test_password" + self.address = "test_address" + self.port = 1234 + self.session_pool_size = 10 + + @patch("nebula3.gclient.net.SessionPool.SessionPool") + def test_init(self, mock_session_pool: Any) -> None: + mock_session_pool.return_value = MagicMock() + nebula_graph = NebulaGraph( + self.space, + self.username, + self.password, + self.address, + self.port, + self.session_pool_size, + ) + self.assertEqual(nebula_graph.space, self.space) + self.assertEqual(nebula_graph.username, self.username) + self.assertEqual(nebula_graph.password, self.password) + self.assertEqual(nebula_graph.address, self.address) + self.assertEqual(nebula_graph.port, self.port) + self.assertEqual(nebula_graph.session_pool_size, self.session_pool_size) + + @patch("nebula3.gclient.net.SessionPool.SessionPool") + def test_get_session_pool(self, mock_session_pool: Any) -> None: + mock_session_pool.return_value = MagicMock() + nebula_graph = NebulaGraph( + self.space, + self.username, + self.password, + self.address, + self.port, + self.session_pool_size, + ) + session_pool = nebula_graph._get_session_pool() + self.assertIsInstance(session_pool, MagicMock) + + @patch("nebula3.gclient.net.SessionPool.SessionPool") + def test_del(self, mock_session_pool: Any) -> None: + mock_session_pool.return_value = MagicMock() + nebula_graph = NebulaGraph( + self.space, + self.username, + self.password, + self.address, + self.port, + self.session_pool_size, + ) + nebula_graph.__del__() + mock_session_pool.return_value.close.assert_called_once() + + @patch("nebula3.gclient.net.SessionPool.SessionPool") + def test_execute(self, mock_session_pool: Any) -> None: + mock_session_pool.return_value = MagicMock() + nebula_graph = NebulaGraph( + self.space, + self.username, + self.password, + self.address, + self.port, + self.session_pool_size, + ) + query = "SELECT * FROM test_table" + result = nebula_graph.execute(query) + self.assertIsInstance(result, MagicMock) + + @patch("nebula3.gclient.net.SessionPool.SessionPool") + def test_refresh_schema(self, mock_session_pool: Any) -> None: + mock_session_pool.return_value = MagicMock() + nebula_graph = NebulaGraph( + self.space, + self.username, + self.password, + self.address, + self.port, + self.session_pool_size, + ) + nebula_graph.refresh_schema() + self.assertNotEqual(nebula_graph.get_schema, "")