Add the extract types to diffbot graph transformer (#21315)

Before you could only extract triples (diffbot calls it facts) from
diffbot to avoid isolated nodes. However, sometimes isolated nodes can
still be useful like for prefiltering, so we want to allow users to
extract them if they want. Default behaviour is unchanged.
pull/21311/head^2
Tomaz Bratanic 3 weeks ago committed by GitHub
parent c038991590
commit 5b6d1a907d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import requests
@ -6,6 +7,11 @@ from langchain_community.graphs.graph_document import GraphDocument, Node, Relat
from langchain_core.documents import Document
class TypeOption(str, Enum):
FACTS = "facts"
ENTITIES = "entities"
def format_property_key(s: str) -> str:
"""Formats a string to be used as a property key."""
@ -141,6 +147,7 @@ class DiffbotGraphTransformer:
include_qualifiers: bool = True,
include_evidence: bool = True,
simplified_schema: bool = True,
extract_types: List[TypeOption] = [TypeOption.FACTS],
) -> None:
"""
Initialize the graph transformer with various options.
@ -157,6 +164,11 @@ class DiffbotGraphTransformer:
Whether to include evidence for the relationships.
simplified_schema (bool):
Whether to use a simplified schema for relationships.
extract_types (List[TypeOption]):
A list of data types to extract. Only facts or entities
are supported. By default, the option is set to facts.
A fact represents a combination of source and target
nodes with a relationship type.
"""
self.diffbot_api_key = diffbot_api_key or get_from_env(
"diffbot_api_key", "DIFFBOT_API_KEY"
@ -167,6 +179,13 @@ class DiffbotGraphTransformer:
self.simplified_schema = None
if simplified_schema:
self.simplified_schema = SimplifiedSchema()
if not extract_types:
raise ValueError(
"`extract_types` cannot be an empty array. "
"Allowed values are 'facts', 'entities', or both."
)
self.extract_types = extract_types
def nlp_request(self, text: str) -> Dict[str, Any]:
"""
@ -185,7 +204,7 @@ class DiffbotGraphTransformer:
"lang": "en",
}
FIELDS = "facts"
FIELDS = ",".join(self.extract_types)
HOST = "nl.diffbot.com"
url = (
f"https://{HOST}/v1/?fields={FIELDS}&"
@ -209,77 +228,97 @@ class DiffbotGraphTransformer:
"""
# Return empty result if there are no facts
if "facts" not in payload or not payload["facts"]:
if ("facts" not in payload or not payload["facts"]) and (
"entities" not in payload or not payload["entities"]
):
return GraphDocument(nodes=[], relationships=[], source=document)
# Nodes are a custom class because we need to deduplicate
nodes_list = NodesList()
# Relationships are a list because we don't deduplicate nor anything else
relationships = list()
for record in payload["facts"]:
# Skip if the fact is below the threshold confidence
if record["confidence"] < self.fact_threshold_confidence:
continue
# TODO: It should probably be treated as a node property
if not record["value"]["allTypes"]:
continue
# Define source node
source_id = (
record["entity"]["allUris"][0]
if record["entity"]["allUris"]
else record["entity"]["name"]
)
source_label = record["entity"]["allTypes"][0]["name"].capitalize()
source_name = record["entity"]["name"]
source_node = Node(id=source_id, type=source_label)
nodes_list.add_node_property(
(source_id, source_label), {"name": source_name}
)
# Define target node
target_id = (
record["value"]["allUris"][0]
if record["value"]["allUris"]
else record["value"]["name"]
)
target_label = record["value"]["allTypes"][0]["name"].capitalize()
target_name = record["value"]["name"]
# Some facts are better suited as node properties
if target_label in FACT_TO_PROPERTY_TYPE:
if "entities" in payload and payload["entities"]:
for record in payload["entities"]:
# Ignore if it doesn't have a type
if not record["allTypes"]:
continue
# Define source node
source_id = (
record["allUris"][0] if record["allUris"] else record["name"]
)
source_label = record["allTypes"][0]["name"].capitalize()
source_name = record["name"]
nodes_list.add_node_property(
(source_id, source_label),
{format_property_key(record["property"]["name"]): target_name},
(source_id, source_label), {"name": source_name}
)
else: # Define relationship
# Define target node object
target_node = Node(id=target_id, type=target_label)
relationships = list()
# Relationships are a list because we don't deduplicate nor anything else
if "facts" in payload and payload["facts"]:
for record in payload["facts"]:
# Skip if the fact is below the threshold confidence
if record["confidence"] < self.fact_threshold_confidence:
continue
# TODO: It should probably be treated as a node property
if not record["value"]["allTypes"]:
continue
# Define source node
source_id = (
record["entity"]["allUris"][0]
if record["entity"]["allUris"]
else record["entity"]["name"]
)
source_label = record["entity"]["allTypes"][0]["name"].capitalize()
source_name = record["entity"]["name"]
source_node = Node(id=source_id, type=source_label)
nodes_list.add_node_property(
(target_id, target_label), {"name": target_name}
(source_id, source_label), {"name": source_name}
)
# Define relationship type
rel_type = record["property"]["name"].replace(" ", "_").upper()
if self.simplified_schema:
rel_type = self.simplified_schema.get_type(rel_type)
# Relationship qualifiers/properties
rel_properties = dict()
relationship_evidence = [el["passage"] for el in record["evidence"]][0]
if self.include_evidence:
rel_properties.update({"evidence": relationship_evidence})
if self.include_qualifiers and record.get("qualifiers"):
for property in record["qualifiers"]:
prop_key = format_property_key(property["property"]["name"])
rel_properties[prop_key] = property["value"]["name"]
relationship = Relationship(
source=source_node,
target=target_node,
type=rel_type,
properties=rel_properties,
# Define target node
target_id = (
record["value"]["allUris"][0]
if record["value"]["allUris"]
else record["value"]["name"]
)
relationships.append(relationship)
target_label = record["value"]["allTypes"][0]["name"].capitalize()
target_name = record["value"]["name"]
# Some facts are better suited as node properties
if target_label in FACT_TO_PROPERTY_TYPE:
nodes_list.add_node_property(
(source_id, source_label),
{format_property_key(record["property"]["name"]): target_name},
)
else: # Define relationship
# Define target node object
target_node = Node(id=target_id, type=target_label)
nodes_list.add_node_property(
(target_id, target_label), {"name": target_name}
)
# Define relationship type
rel_type = record["property"]["name"].replace(" ", "_").upper()
if self.simplified_schema:
rel_type = self.simplified_schema.get_type(rel_type)
# Relationship qualifiers/properties
rel_properties = dict()
relationship_evidence = [
el["passage"] for el in record["evidence"]
][0]
if self.include_evidence:
rel_properties.update({"evidence": relationship_evidence})
if self.include_qualifiers and record.get("qualifiers"):
for property in record["qualifiers"]:
prop_key = format_property_key(property["property"]["name"])
rel_properties[prop_key] = property["value"]["name"]
relationship = Relationship(
source=source_node,
target=target_node,
type=rel_type,
properties=rel_properties,
)
relationships.append(relationship)
return GraphDocument(
nodes=nodes_list.return_node_list(),

Loading…
Cancel
Save