@ -1,5 +1,6 @@
import asyncio
from typing import Any , List , Optional , Sequence , Type , cast
import json
from typing import Any , Dict , List , Optional , Sequence , Tuple , Type , cast
from langchain_community . graphs . graph_document import GraphDocument , Node , Relationship
from langchain_core . documents import Document
@ -146,16 +147,133 @@ def create_simple_model(
def map_to_base_node ( node : Any ) - > Node :
""" Map the SimpleNode to the base Node. """
return Node ( id = node . id . title ( ) , type = node . type . capitalize ( ) )
return Node ( id = node . id , type = node . type )
def map_to_base_relationship ( rel : Any ) - > Relationship :
""" Map the SimpleRelationship to the base Relationship. """
source = Node ( id = rel . source_node_id . title ( ) , type = rel . source_node_type . capitalize ( ) )
target = Node ( id = rel . target_node_id . title ( ) , type = rel . target_node_type . capitalize ( ) )
return Relationship (
source = source , target = target , type = rel . type . replace ( " " , " _ " ) . upper ( )
)
source = Node ( id = rel . source_node_id , type = rel . source_node_type )
target = Node ( id = rel . target_node_id , type = rel . target_node_type )
return Relationship ( source = source , target = target , type = rel . type )
def _parse_and_clean_json (
argument_json : Dict [ str , Any ] ,
) - > Tuple [ List [ Node ] , List [ Relationship ] ] :
nodes = [ ]
for node in argument_json [ " nodes " ] :
if not node . get ( " id " ) : # Id is mandatory, skip this node
continue
nodes . append (
Node (
id = node [ " id " ] ,
type = node . get ( " type " ) ,
)
)
relationships = [ ]
for rel in argument_json [ " relationships " ] :
# Mandatory props
if (
not rel . get ( " source_node_id " )
or not rel . get ( " target_node_id " )
or not rel . get ( " type " )
) :
continue
# Node type copying if needed from node list
if not rel . get ( " source_node_type " ) :
try :
rel [ " source_node_type " ] = [
el . get ( " type " )
for el in argument_json [ " nodes " ]
if el [ " id " ] == rel [ " source_node_id " ]
] [ 0 ]
except IndexError :
rel [ " source_node_type " ] = None
if not rel . get ( " target_node_type " ) :
try :
rel [ " target_node_type " ] = [
el . get ( " type " )
for el in argument_json [ " nodes " ]
if el [ " id " ] == rel [ " target_node_id " ]
] [ 0 ]
except IndexError :
rel [ " target_node_type " ] = None
source_node = Node (
id = rel [ " source_node_id " ] ,
type = rel [ " source_node_type " ] ,
)
target_node = Node (
id = rel [ " target_node_id " ] ,
type = rel [ " target_node_type " ] ,
)
relationships . append (
Relationship (
source = source_node ,
target = target_node ,
type = rel [ " type " ] ,
)
)
return nodes , relationships
def _format_nodes ( nodes : List [ Node ] ) - > List [ Node ] :
return [
Node (
id = el . id . title ( ) if isinstance ( el . id , str ) else el . id ,
type = el . type . capitalize ( ) ,
)
for el in nodes
]
def _format_relationships ( rels : List [ Relationship ] ) - > List [ Relationship ] :
return [
Relationship (
source = _format_nodes ( [ el . source ] ) [ 0 ] ,
target = _format_nodes ( [ el . target ] ) [ 0 ] ,
type = el . type . replace ( " " , " _ " ) . upper ( ) ,
)
for el in rels
]
def _convert_to_graph_document (
raw_schema : Dict [ Any , Any ] ,
) - > Tuple [ List [ Node ] , List [ Relationship ] ] :
# If there are validation errors
if not raw_schema [ " parsed " ] :
try :
try : # OpenAI type response
argument_json = json . loads (
raw_schema [ " raw " ] . additional_kwargs [ " tool_calls " ] [ 0 ] [ " function " ] [
" arguments "
]
)
except Exception : # Google type response
argument_json = json . loads (
raw_schema [ " raw " ] . additional_kwargs [ " function_call " ] [ " arguments " ]
)
nodes , relationships = _parse_and_clean_json ( argument_json )
except Exception : # If we can't parse JSON
return ( [ ] , [ ] )
else : # If there are no validation errors use parsed pydantic object
parsed_schema : _Graph = raw_schema [ " parsed " ]
nodes = (
[ map_to_base_node ( node ) for node in parsed_schema . nodes ]
if parsed_schema . nodes
else [ ]
)
relationships = (
[ map_to_base_relationship ( rel ) for rel in parsed_schema . relationships ]
if parsed_schema . relationships
else [ ]
)
# Title / Capitalize
return _format_nodes ( nodes ) , _format_relationships ( relationships )
class LLMGraphTransformer :
@ -213,7 +331,7 @@ class LLMGraphTransformer:
# Define chain
schema = create_simple_model ( allowed_nodes , allowed_relationships )
structured_llm = llm . with_structured_output ( schema )
structured_llm = llm . with_structured_output ( schema , include_raw = True )
self . chain = prompt | structured_llm
def process_response ( self , document : Document ) - > GraphDocument :
@ -222,33 +340,29 @@ class LLMGraphTransformer:
an LLM based on the model ' s schema and constraints.
"""
text = document . page_content
raw_schema = cast ( _Graph , self . chain . invoke ( { " input " : text } ) )
nodes = (
[ map_to_base_node ( node ) for node in raw_schema . nodes ]
if raw_schema . nodes
else [ ]
)
relationships = (
[ map_to_base_relationship ( rel ) for rel in raw_schema . relationships ]
if raw_schema . relationships
else [ ]
)
raw_schema = self . chain . invoke ( { " input " : text } )
raw_schema = cast ( Dict [ Any , Any ] , raw_schema )
nodes , relationships = _convert_to_graph_document ( raw_schema )
# Strict mode filtering
if self . strict_mode and ( self . allowed_nodes or self . allowed_relationships ) :
if self . allowed_nodes :
nodes = [ node for node in nodes if node . type in self . allowed_nodes ]
lower_allowed_nodes = [ el . lower ( ) for el in self . allowed_nodes ]
nodes = [
node for node in nodes if node . type . lower ( ) in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel . source . type in self . allowed_nodes
and rel . target . type in self . allowed_nodes
if rel . source . type . lower ( ) in lower_ allowed_nodes
and rel . target . type . lower ( ) in lower_ allowed_nodes
]
if self . allowed_relationships :
relationships = [
rel
for rel in relationships
if rel . type in self . allowed_relationships
if rel . type . lower ( )
in [ el . lower ( ) for el in self . allowed_relationships ]
]
return GraphDocument ( nodes = nodes , relationships = relationships , source = document )
@ -273,33 +387,28 @@ class LLMGraphTransformer:
graph document .
"""
text = document . page_content
raw_schema = cast ( _Graph , await self . chain . ainvoke ( { " input " : text } ) )
nodes = (
[ map_to_base_node ( node ) for node in raw_schema . nodes ]
if raw_schema . nodes
else [ ]
)
relationships = (
[ map_to_base_relationship ( rel ) for rel in raw_schema . relationships ]
if raw_schema . relationships
else [ ]
)
raw_schema = await self . chain . ainvoke ( { " input " : text } )
raw_schema = cast ( Dict [ Any , Any ] , raw_schema )
nodes , relationships = _convert_to_graph_document ( raw_schema )
if self . strict_mode and ( self . allowed_nodes or self . allowed_relationships ) :
if self . allowed_nodes :
nodes = [ node for node in nodes if node . type in self . allowed_nodes ]
lower_allowed_nodes = [ el . lower ( ) for el in self . allowed_nodes ]
nodes = [
node for node in nodes if node . type . lower ( ) in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel . source . type in self . allowed_nodes
and rel . target . type in self . allowed_nodes
if rel . source . type . lower ( ) in lower_ allowed_nodes
and rel . target . type . lower ( ) in lower_ allowed_nodes
]
if self . allowed_relationships :
relationships = [
rel
for rel in relationships
if rel . type in self . allowed_relationships
if rel . type . lower ( )
in [ el . lower ( ) for el in self . allowed_relationships ]
]
return GraphDocument ( nodes = nodes , relationships = relationships , source = document )