From 4d7fdb8957fdaefced80efef75f6ea98ac14ca9c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 14 Mar 2023 20:00:22 -0700 Subject: [PATCH] Harrison/gml save (#1676) Co-authored-by: Satoru Sakamoto <51464932+satoru814@users.noreply.github.com> --- .../indexes/chain_examples/graph_qa.ipynb | 78 +++++++++++++++++-- langchain/graphs/networkx_graph.py | 28 ++++++- 2 files changed, 97 insertions(+), 9 deletions(-) diff --git a/docs/modules/indexes/chain_examples/graph_qa.ipynb b/docs/modules/indexes/chain_examples/graph_qa.ipynb index c33af3db..59447024 100644 --- a/docs/modules/indexes/chain_examples/graph_qa.ipynb +++ b/docs/modules/indexes/chain_examples/graph_qa.ipynb @@ -178,16 +178,16 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new GraphQAChain chain...\u001B[0m\n", + "\u001b[1m> Entering new GraphQAChain chain...\u001b[0m\n", "Entities Extracted:\n", - "\u001B[32;1m\u001B[1;3m Intel\u001B[0m\n", + "\u001b[32;1m\u001b[1;3m Intel\u001b[0m\n", "Full Context:\n", - "\u001B[32;1m\u001B[1;3mIntel is going to build $20 billion semiconductor \"mega site\"\n", + "\u001b[32;1m\u001b[1;3mIntel is going to build $20 billion semiconductor \"mega site\"\n", "Intel is building state-of-the-art factories\n", "Intel is creating 10,000 new good-paying jobs\n", - "Intel is helping build Silicon Valley\u001B[0m\n", + "Intel is helping build Silicon Valley\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -205,10 +205,76 @@ "chain.run(\"what is Intel going to build?\")" ] }, + { + "cell_type": "markdown", + "id": "410aafa0", + "metadata": {}, + "source": [ + "## Save the graph\n", + "We can also save and load the graph." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bc72cca0", + "metadata": {}, + "outputs": [], + "source": [ + "graph.write_to_gml(\"graph.gml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "652760ad", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.indexes.graph import NetworkxEntityGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "eae591fe", + "metadata": {}, + "outputs": [], + "source": [ + "loaded_graph = NetworkxEntityGraph.from_gml(\"graph.gml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9439d419", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Intel', '$20 billion semiconductor \"mega site\"', 'is going to build'),\n", + " ('Intel', 'state-of-the-art factories', 'is building'),\n", + " ('Intel', '10,000 new good-paying jobs', 'is creating'),\n", + " ('Intel', 'Silicon Valley', 'is helping build'),\n", + " ('Field of dreams',\n", + " \"America's future will be built\",\n", + " 'is the ground on which')]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded_graph.get_triples()" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "f70b9ada", + "id": "045796cf", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/graphs/networkx_graph.py b/langchain/graphs/networkx_graph.py index 66dd5bb2..c49a1ad9 100644 --- a/langchain/graphs/networkx_graph.py +++ b/langchain/graphs/networkx_graph.py @@ -1,6 +1,7 @@ """Networkx wrapper for graph operations.""" +from __future__ import annotations -from typing import List, NamedTuple, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple KG_TRIPLE_DELIMITER = "<|>" @@ -48,7 +49,7 @@ def get_entities(entity_str: str) -> List[str]: class NetworkxEntityGraph: """Networkx wrapper for entity graph operations.""" - def __init__(self) -> None: + def __init__(self, graph: Optional[Any] = None) -> None: """Create a new graph.""" try: import networkx as nx @@ -57,8 +58,24 @@ class NetworkxEntityGraph: "Could not import networkx python package. " "Please it install it with `pip install networkx`." ) + if graph is not None: + if not isinstance(graph, nx.DiGraph): + raise ValueError("Passed in graph is not of correct shape") + self._graph = graph + else: + self._graph = nx.DiGraph() - self._graph = nx.DiGraph() + @classmethod + def from_gml(cls, gml_path: str) -> NetworkxEntityGraph: + try: + import networkx as nx + except ImportError: + raise ValueError( + "Could not import networkx python package. " + "Please it install it with `pip install networkx`." + ) + graph = nx.read_gml(gml_path) + return cls(graph) def add_triple(self, knowledge_triple: KnowledgeTriple) -> None: """Add a triple to the graph.""" @@ -97,6 +114,11 @@ class NetworkxEntityGraph: results.append(f"{src} {relation} {sink}") return results + def write_to_gml(self, path: str) -> None: + import networkx as nx + + nx.write_gml(self._graph, path) + def clear(self) -> None: """Clear the graph.""" self._graph.clear()