Harrison/gml save (#1676)

Co-authored-by: Satoru Sakamoto <51464932+satoru814@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-14 20:00:22 -07:00 committed by GitHub
parent 656efe6ef3
commit 4d7fdb8957
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 9 deletions

View File

@ -178,16 +178,16 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"\u001B[1m> Entering new GraphQAChain chain...\u001B[0m\n", "\u001b[1m> Entering new GraphQAChain chain...\u001b[0m\n",
"Entities Extracted:\n", "Entities Extracted:\n",
"\u001B[32;1m\u001B[1;3m Intel\u001B[0m\n", "\u001b[32;1m\u001b[1;3m Intel\u001b[0m\n",
"Full Context:\n", "Full Context:\n",
"\u001B[32;1m\u001B[1;3mIntel is going to build $20 billion semiconductor \"mega site\"\n", "\u001b[32;1m\u001b[1;3mIntel is going to build $20 billion semiconductor \"mega site\"\n",
"Intel is building state-of-the-art factories\n", "Intel is building state-of-the-art factories\n",
"Intel is creating 10,000 new good-paying jobs\n", "Intel is creating 10,000 new good-paying jobs\n",
"Intel is helping build Silicon Valley\u001B[0m\n", "Intel is helping build Silicon Valley\u001b[0m\n",
"\n", "\n",
"\u001B[1m> Finished chain.\u001B[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
}, },
{ {
@ -205,10 +205,76 @@
"chain.run(\"what is Intel going to build?\")" "chain.run(\"what is Intel going to build?\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "410aafa0",
"metadata": {},
"source": [
"## Save the graph\n",
"We can also save and load the graph."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bc72cca0",
"metadata": {},
"outputs": [],
"source": [
"graph.write_to_gml(\"graph.gml\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "652760ad",
"metadata": {},
"outputs": [],
"source": [
"from langchain.indexes.graph import NetworkxEntityGraph"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "eae591fe",
"metadata": {},
"outputs": [],
"source": [
"loaded_graph = NetworkxEntityGraph.from_gml(\"graph.gml\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9439d419",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Intel', '$20 billion semiconductor \"mega site\"', 'is going to build'),\n",
" ('Intel', 'state-of-the-art factories', 'is building'),\n",
" ('Intel', '10,000 new good-paying jobs', 'is creating'),\n",
" ('Intel', 'Silicon Valley', 'is helping build'),\n",
" ('Field of dreams',\n",
" \"America's future will be built\",\n",
" 'is the ground on which')]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_graph.get_triples()"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "f70b9ada", "id": "045796cf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

View File

@ -1,6 +1,7 @@
"""Networkx wrapper for graph operations.""" """Networkx wrapper for graph operations."""
from __future__ import annotations
from typing import List, NamedTuple, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
KG_TRIPLE_DELIMITER = "<|>" KG_TRIPLE_DELIMITER = "<|>"
@ -48,7 +49,7 @@ def get_entities(entity_str: str) -> List[str]:
class NetworkxEntityGraph: class NetworkxEntityGraph:
"""Networkx wrapper for entity graph operations.""" """Networkx wrapper for entity graph operations."""
def __init__(self) -> None: def __init__(self, graph: Optional[Any] = None) -> None:
"""Create a new graph.""" """Create a new graph."""
try: try:
import networkx as nx import networkx as nx
@ -57,8 +58,24 @@ class NetworkxEntityGraph:
"Could not import networkx python package. " "Could not import networkx python package. "
"Please it install it with `pip install networkx`." "Please it install it with `pip install networkx`."
) )
if graph is not None:
if not isinstance(graph, nx.DiGraph):
raise ValueError("Passed in graph is not of correct shape")
self._graph = graph
else:
self._graph = nx.DiGraph()
self._graph = nx.DiGraph() @classmethod
def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
try:
import networkx as nx
except ImportError:
raise ValueError(
"Could not import networkx python package. "
"Please it install it with `pip install networkx`."
)
graph = nx.read_gml(gml_path)
return cls(graph)
def add_triple(self, knowledge_triple: KnowledgeTriple) -> None: def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
"""Add a triple to the graph.""" """Add a triple to the graph."""
@ -97,6 +114,11 @@ class NetworkxEntityGraph:
results.append(f"{src} {relation} {sink}") results.append(f"{src} {relation} {sink}")
return results return results
def write_to_gml(self, path: str) -> None:
import networkx as nx
nx.write_gml(self._graph, path)
def clear(self) -> None: def clear(self) -> None:
"""Clear the graph.""" """Clear the graph."""
self._graph.clear() self._graph.clear()