Add relik transformer config (#25019)

This commit is contained in:
Tomaz Bratanic 2024-08-03 14:41:45 +02:00 committed by GitHub
parent 1dcee68cb8
commit f9a11a9197
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
from typing import List, Sequence
import logging
from typing import Any, Dict, List, Sequence
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
@ -22,23 +23,33 @@ class RelikGraphTransformer:
model (str): The name of the pretrained Relik model to use.
Default is "relik-ie/relik-relation-extraction-small-wikipedia".
relationship_confidence_threshold (float): The confidence threshold for
filtering relationships. Default is 0.0.
filtering relationships. Default is 0.1.
model_config (Dict[str, any]): Additional configuration options for the
Relik model. Default is an empty dictionary.
ignore_self_loops (bool): Whether to ignore relationships where the
source and target nodes are the same. Default is True.
"""
def __init__(
self,
model: str = "relik-ie/relik-relation-extraction-small-wikipedia",
relationship_confidence_threshold: float = 0.0,
model: str = "relik-ie/relik-relation-extraction-small",
relationship_confidence_threshold: float = 0.1,
model_config: Dict[str, Any] = {},
ignore_self_loops: bool = True,
) -> None:
try:
import relik # type: ignore
# Remove default INFO logging
logging.getLogger("relik").setLevel(logging.WARNING)
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install relik`."
)
self.relik_model = relik.Relik.from_pretrained(model)
self.relik_model = relik.Relik.from_pretrained(model, **model_config)
self.relationship_confidence_threshold = relationship_confidence_threshold
self.ignore_self_loops = ignore_self_loops
def process_document(self, document: Document) -> GraphDocument:
relik_out = self.relik_model(document.page_content)
@ -60,6 +71,9 @@ class RelikGraphTransformer:
# Ignore relationship if below confidence threshold
if triple.confidence < self.relationship_confidence_threshold:
continue
# Ignore self loops
if self.ignore_self_loops and triple.subject.text == triple.object.text:
continue
source_node = Node(
id=triple.subject.text,
type=DEFAULT_NODE_TYPE