|
|
@ -1,6 +1,7 @@
|
|
|
|
"""Networkx wrapper for graph operations."""
|
|
|
|
"""Networkx wrapper for graph operations."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, NamedTuple, Tuple
|
|
|
|
from typing import Any, List, NamedTuple, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
KG_TRIPLE_DELIMITER = "<|>"
|
|
|
|
KG_TRIPLE_DELIMITER = "<|>"
|
|
|
|
|
|
|
|
|
|
|
@ -48,7 +49,7 @@ def get_entities(entity_str: str) -> List[str]:
|
|
|
|
class NetworkxEntityGraph:
|
|
|
|
class NetworkxEntityGraph:
|
|
|
|
"""Networkx wrapper for entity graph operations."""
|
|
|
|
"""Networkx wrapper for entity graph operations."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
def __init__(self, graph: Optional[Any] = None) -> None:
|
|
|
|
"""Create a new graph."""
|
|
|
|
"""Create a new graph."""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
import networkx as nx
|
|
|
|
import networkx as nx
|
|
|
@ -57,8 +58,24 @@ class NetworkxEntityGraph:
|
|
|
|
"Could not import networkx python package. "
|
|
|
|
"Could not import networkx python package. "
|
|
|
|
"Please it install it with `pip install networkx`."
|
|
|
|
"Please it install it with `pip install networkx`."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if graph is not None:
|
|
|
|
|
|
|
|
if not isinstance(graph, nx.DiGraph):
|
|
|
|
|
|
|
|
raise ValueError("Passed in graph is not of correct shape")
|
|
|
|
|
|
|
|
self._graph = graph
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self._graph = nx.DiGraph()
|
|
|
|
|
|
|
|
|
|
|
|
self._graph = nx.DiGraph()
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
import networkx as nx
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Could not import networkx python package. "
|
|
|
|
|
|
|
|
"Please it install it with `pip install networkx`."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
graph = nx.read_gml(gml_path)
|
|
|
|
|
|
|
|
return cls(graph)
|
|
|
|
|
|
|
|
|
|
|
|
def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
|
|
|
|
def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
|
|
|
|
"""Add a triple to the graph."""
|
|
|
|
"""Add a triple to the graph."""
|
|
|
@ -97,6 +114,11 @@ class NetworkxEntityGraph:
|
|
|
|
results.append(f"{src} {relation} {sink}")
|
|
|
|
results.append(f"{src} {relation} {sink}")
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_to_gml(self, path: str) -> None:
|
|
|
|
|
|
|
|
import networkx as nx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nx.write_gml(self._graph, path)
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
def clear(self) -> None:
|
|
|
|
"""Clear the graph."""
|
|
|
|
"""Clear the graph."""
|
|
|
|
self._graph.clear()
|
|
|
|
self._graph.clear()
|
|
|
|