experimental[patch]: Skip pydantic validation for llm graph transformer and fix JSON response where possible (#19915)

LLMs might sometimes return invalid response for LLM graph transformer.
Instead of failing due to pydantic validation, we skip it and manually
check and optionally fix error where we can, so that more information
gets extracted
This commit is contained in:
Tomaz Bratanic 2024-04-12 20:29:25 +02:00 committed by GitHub
parent 20f5cd7c95
commit a1b105ac00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import asyncio
from typing import Any, List, Optional, Sequence, Type, cast
import json
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
@ -146,16 +147,133 @@ def create_simple_model(
def map_to_base_node(node: Any) -> Node:
"""Map the SimpleNode to the base Node."""
return Node(id=node.id.title(), type=node.type.capitalize())
return Node(id=node.id, type=node.type)
def map_to_base_relationship(rel: Any) -> Relationship:
"""Map the SimpleRelationship to the base Relationship."""
source = Node(id=rel.source_node_id.title(), type=rel.source_node_type.capitalize())
target = Node(id=rel.target_node_id.title(), type=rel.target_node_type.capitalize())
return Relationship(
source=source, target=target, type=rel.type.replace(" ", "_").upper()
)
source = Node(id=rel.source_node_id, type=rel.source_node_type)
target = Node(id=rel.target_node_id, type=rel.target_node_type)
return Relationship(source=source, target=target, type=rel.type)
def _parse_and_clean_json(
argument_json: Dict[str, Any],
) -> Tuple[List[Node], List[Relationship]]:
nodes = []
for node in argument_json["nodes"]:
if not node.get("id"): # Id is mandatory, skip this node
continue
nodes.append(
Node(
id=node["id"],
type=node.get("type"),
)
)
relationships = []
for rel in argument_json["relationships"]:
# Mandatory props
if (
not rel.get("source_node_id")
or not rel.get("target_node_id")
or not rel.get("type")
):
continue
# Node type copying if needed from node list
if not rel.get("source_node_type"):
try:
rel["source_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["source_node_id"]
][0]
except IndexError:
rel["source_node_type"] = None
if not rel.get("target_node_type"):
try:
rel["target_node_type"] = [
el.get("type")
for el in argument_json["nodes"]
if el["id"] == rel["target_node_id"]
][0]
except IndexError:
rel["target_node_type"] = None
source_node = Node(
id=rel["source_node_id"],
type=rel["source_node_type"],
)
target_node = Node(
id=rel["target_node_id"],
type=rel["target_node_type"],
)
relationships.append(
Relationship(
source=source_node,
target=target_node,
type=rel["type"],
)
)
return nodes, relationships
def _format_nodes(nodes: List[Node]) -> List[Node]:
return [
Node(
id=el.id.title() if isinstance(el.id, str) else el.id,
type=el.type.capitalize(),
)
for el in nodes
]
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
return [
Relationship(
source=_format_nodes([el.source])[0],
target=_format_nodes([el.target])[0],
type=el.type.replace(" ", "_").upper(),
)
for el in rels
]
def _convert_to_graph_document(
raw_schema: Dict[Any, Any],
) -> Tuple[List[Node], List[Relationship]]:
# If there are validation errors
if not raw_schema["parsed"]:
try:
try: # OpenAI type response
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
)
except Exception: # Google type response
argument_json = json.loads(
raw_schema["raw"].additional_kwargs["function_call"]["arguments"]
)
nodes, relationships = _parse_and_clean_json(argument_json)
except Exception: # If we can't parse JSON
return ([], [])
else: # If there are no validation errors use parsed pydantic object
parsed_schema: _Graph = raw_schema["parsed"]
nodes = (
[map_to_base_node(node) for node in parsed_schema.nodes]
if parsed_schema.nodes
else []
)
relationships = (
[map_to_base_relationship(rel) for rel in parsed_schema.relationships]
if parsed_schema.relationships
else []
)
# Title / Capitalize
return _format_nodes(nodes), _format_relationships(relationships)
class LLMGraphTransformer:
@ -213,7 +331,7 @@ class LLMGraphTransformer:
# Define chain
schema = create_simple_model(allowed_nodes, allowed_relationships)
structured_llm = llm.with_structured_output(schema)
structured_llm = llm.with_structured_output(schema, include_raw=True)
self.chain = prompt | structured_llm
def process_response(self, document: Document) -> GraphDocument:
@ -222,33 +340,29 @@ class LLMGraphTransformer:
an LLM based on the model's schema and constraints.
"""
text = document.page_content
raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]
if raw_schema.nodes
else []
)
relationships = (
[map_to_base_relationship(rel) for rel in raw_schema.relationships]
if raw_schema.relationships
else []
)
raw_schema = self.chain.invoke({"input": text})
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
# Strict mode filtering
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
nodes = [node for node in nodes if node.type in self.allowed_nodes]
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type in self.allowed_nodes
and rel.target.type in self.allowed_nodes
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type in self.allowed_relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
@ -273,33 +387,28 @@ class LLMGraphTransformer:
graph document.
"""
text = document.page_content
raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]
if raw_schema.nodes
else []
)
relationships = (
[map_to_base_relationship(rel) for rel in raw_schema.relationships]
if raw_schema.relationships
else []
)
raw_schema = await self.chain.ainvoke({"input": text})
raw_schema = cast(Dict[Any, Any], raw_schema)
nodes, relationships = _convert_to_graph_document(raw_schema)
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
if self.allowed_nodes:
nodes = [node for node in nodes if node.type in self.allowed_nodes]
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
nodes = [
node for node in nodes if node.type.lower() in lower_allowed_nodes
]
relationships = [
rel
for rel in relationships
if rel.source.type in self.allowed_nodes
and rel.target.type in self.allowed_nodes
if rel.source.type.lower() in lower_allowed_nodes
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type in self.allowed_relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
return GraphDocument(nodes=nodes, relationships=relationships, source=document)