mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
experimental[patch]: Skip pydantic validation for llm graph transformer and fix JSON response where possible (#19915)
LLMs might sometimes return invalid response for LLM graph transformer. Instead of failing due to pydantic validation, we skip it and manually check and optionally fix error where we can, so that more information gets extracted
This commit is contained in:
parent
20f5cd7c95
commit
a1b105ac00
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, List, Optional, Sequence, Type, cast
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_core.documents import Document
|
||||
@ -146,16 +147,133 @@ def create_simple_model(
|
||||
|
||||
def map_to_base_node(node: Any) -> Node:
|
||||
"""Map the SimpleNode to the base Node."""
|
||||
return Node(id=node.id.title(), type=node.type.capitalize())
|
||||
return Node(id=node.id, type=node.type)
|
||||
|
||||
|
||||
def map_to_base_relationship(rel: Any) -> Relationship:
|
||||
"""Map the SimpleRelationship to the base Relationship."""
|
||||
source = Node(id=rel.source_node_id.title(), type=rel.source_node_type.capitalize())
|
||||
target = Node(id=rel.target_node_id.title(), type=rel.target_node_type.capitalize())
|
||||
return Relationship(
|
||||
source=source, target=target, type=rel.type.replace(" ", "_").upper()
|
||||
source = Node(id=rel.source_node_id, type=rel.source_node_type)
|
||||
target = Node(id=rel.target_node_id, type=rel.target_node_type)
|
||||
return Relationship(source=source, target=target, type=rel.type)
|
||||
|
||||
|
||||
def _parse_and_clean_json(
|
||||
argument_json: Dict[str, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
nodes = []
|
||||
for node in argument_json["nodes"]:
|
||||
if not node.get("id"): # Id is mandatory, skip this node
|
||||
continue
|
||||
nodes.append(
|
||||
Node(
|
||||
id=node["id"],
|
||||
type=node.get("type"),
|
||||
)
|
||||
)
|
||||
relationships = []
|
||||
for rel in argument_json["relationships"]:
|
||||
# Mandatory props
|
||||
if (
|
||||
not rel.get("source_node_id")
|
||||
or not rel.get("target_node_id")
|
||||
or not rel.get("type")
|
||||
):
|
||||
continue
|
||||
|
||||
# Node type copying if needed from node list
|
||||
if not rel.get("source_node_type"):
|
||||
try:
|
||||
rel["source_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["source_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["source_node_type"] = None
|
||||
if not rel.get("target_node_type"):
|
||||
try:
|
||||
rel["target_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["target_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["target_node_type"] = None
|
||||
|
||||
source_node = Node(
|
||||
id=rel["source_node_id"],
|
||||
type=rel["source_node_type"],
|
||||
)
|
||||
target_node = Node(
|
||||
id=rel["target_node_id"],
|
||||
type=rel["target_node_type"],
|
||||
)
|
||||
relationships.append(
|
||||
Relationship(
|
||||
source=source_node,
|
||||
target=target_node,
|
||||
type=rel["type"],
|
||||
)
|
||||
)
|
||||
return nodes, relationships
|
||||
|
||||
|
||||
def _format_nodes(nodes: List[Node]) -> List[Node]:
|
||||
return [
|
||||
Node(
|
||||
id=el.id.title() if isinstance(el.id, str) else el.id,
|
||||
type=el.type.capitalize(),
|
||||
)
|
||||
for el in nodes
|
||||
]
|
||||
|
||||
|
||||
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
|
||||
return [
|
||||
Relationship(
|
||||
source=_format_nodes([el.source])[0],
|
||||
target=_format_nodes([el.target])[0],
|
||||
type=el.type.replace(" ", "_").upper(),
|
||||
)
|
||||
for el in rels
|
||||
]
|
||||
|
||||
|
||||
def _convert_to_graph_document(
|
||||
raw_schema: Dict[Any, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
# If there are validation errors
|
||||
if not raw_schema["parsed"]:
|
||||
try:
|
||||
try: # OpenAI type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
)
|
||||
except Exception: # Google type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["function_call"]["arguments"]
|
||||
)
|
||||
|
||||
nodes, relationships = _parse_and_clean_json(argument_json)
|
||||
except Exception: # If we can't parse JSON
|
||||
return ([], [])
|
||||
else: # If there are no validation errors use parsed pydantic object
|
||||
parsed_schema: _Graph = raw_schema["parsed"]
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in parsed_schema.nodes]
|
||||
if parsed_schema.nodes
|
||||
else []
|
||||
)
|
||||
|
||||
relationships = (
|
||||
[map_to_base_relationship(rel) for rel in parsed_schema.relationships]
|
||||
if parsed_schema.relationships
|
||||
else []
|
||||
)
|
||||
# Title / Capitalize
|
||||
return _format_nodes(nodes), _format_relationships(relationships)
|
||||
|
||||
|
||||
class LLMGraphTransformer:
|
||||
@ -213,7 +331,7 @@ class LLMGraphTransformer:
|
||||
|
||||
# Define chain
|
||||
schema = create_simple_model(allowed_nodes, allowed_relationships)
|
||||
structured_llm = llm.with_structured_output(schema)
|
||||
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
||||
self.chain = prompt | structured_llm
|
||||
|
||||
def process_response(self, document: Document) -> GraphDocument:
|
||||
@ -222,33 +340,29 @@ class LLMGraphTransformer:
|
||||
an LLM based on the model's schema and constraints.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in raw_schema.nodes]
|
||||
if raw_schema.nodes
|
||||
else []
|
||||
)
|
||||
relationships = (
|
||||
[map_to_base_relationship(rel) for rel in raw_schema.relationships]
|
||||
if raw_schema.relationships
|
||||
else []
|
||||
)
|
||||
raw_schema = self.chain.invoke({"input": text})
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
|
||||
# Strict mode filtering
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
nodes = [node for node in nodes if node.type in self.allowed_nodes]
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type in self.allowed_nodes
|
||||
and rel.target.type in self.allowed_nodes
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type in self.allowed_relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
@ -273,33 +387,28 @@ class LLMGraphTransformer:
|
||||
graph document.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
|
||||
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in raw_schema.nodes]
|
||||
if raw_schema.nodes
|
||||
else []
|
||||
)
|
||||
relationships = (
|
||||
[map_to_base_relationship(rel) for rel in raw_schema.relationships]
|
||||
if raw_schema.relationships
|
||||
else []
|
||||
)
|
||||
raw_schema = await self.chain.ainvoke({"input": text})
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
nodes = [node for node in nodes if node.type in self.allowed_nodes]
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type in self.allowed_nodes
|
||||
and rel.target.type in self.allowed_nodes
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type in self.allowed_relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
|
Loading…
Reference in New Issue
Block a user