|
|
|
@ -5,9 +5,67 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
|
|
|
|
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
from langchain_core.messages import SystemMessage
|
|
|
|
|
from langchain_core.output_parsers import JsonOutputParser
|
|
|
|
|
from langchain_core.prompts import (
|
|
|
|
|
ChatPromptTemplate,
|
|
|
|
|
HumanMessagePromptTemplate,
|
|
|
|
|
PromptTemplate,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
examples = [
|
|
|
|
|
{
|
|
|
|
|
"text": (
|
|
|
|
|
"Adam is a software engineer in Microsoft since 2009, "
|
|
|
|
|
"and last year he got an award as the Best Talent"
|
|
|
|
|
),
|
|
|
|
|
"head": "Adam",
|
|
|
|
|
"head_type": "Person",
|
|
|
|
|
"relation": "WORKS_FOR",
|
|
|
|
|
"tail": "Microsoft",
|
|
|
|
|
"tail_type": "Company",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"text": (
|
|
|
|
|
"Adam is a software engineer in Microsoft since 2009, "
|
|
|
|
|
"and last year he got an award as the Best Talent"
|
|
|
|
|
),
|
|
|
|
|
"head": "Adam",
|
|
|
|
|
"head_type": "Person",
|
|
|
|
|
"relation": "HAS_AWARD",
|
|
|
|
|
"tail": "Best Talent",
|
|
|
|
|
"tail_type": "Award",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"text": (
|
|
|
|
|
"Microsoft is a tech company that provide "
|
|
|
|
|
"several products such as Microsoft Word"
|
|
|
|
|
),
|
|
|
|
|
"head": "Microsoft Word",
|
|
|
|
|
"head_type": "Product",
|
|
|
|
|
"relation": "PRODUCED_BY",
|
|
|
|
|
"tail": "Microsoft",
|
|
|
|
|
"tail_type": "Company",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"text": "Microsoft Word is a lightweight app that accessible offline",
|
|
|
|
|
"head": "Microsoft Word",
|
|
|
|
|
"head_type": "Product",
|
|
|
|
|
"relation": "HAS_CHARACTERISTIC",
|
|
|
|
|
"tail": "lightweight app",
|
|
|
|
|
"tail_type": "Characteristic",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"text": "Microsoft Word is a lightweight app that accessible offline",
|
|
|
|
|
"head": "Microsoft Word",
|
|
|
|
|
"head_type": "Product",
|
|
|
|
|
"relation": "HAS_CHARACTERISTIC",
|
|
|
|
|
"tail": "accessible offline",
|
|
|
|
|
"tail_type": "Characteristic",
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
system_prompt = (
|
|
|
|
|
"# Knowledge Graph Instructions for GPT-4\n"
|
|
|
|
|
"## 1. Overview\n"
|
|
|
|
@ -99,6 +157,103 @@ class _Graph(BaseModel):
|
|
|
|
|
relationships: Optional[List]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnstructuredRelation(BaseModel):
|
|
|
|
|
head: str = Field(
|
|
|
|
|
description=(
|
|
|
|
|
"extracted head entity like Microsoft, Apple, John. "
|
|
|
|
|
"Must use human-readable unique identifier."
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
head_type: str = Field(
|
|
|
|
|
description="type of the extracted head entity like Person, Company, etc"
|
|
|
|
|
)
|
|
|
|
|
relation: str = Field(description="relation between the head and the tail entities")
|
|
|
|
|
tail: str = Field(
|
|
|
|
|
description=(
|
|
|
|
|
"extracted tail entity like Microsoft, Apple, John. "
|
|
|
|
|
"Must use human-readable unique identifier."
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
tail_type: str = Field(
|
|
|
|
|
description="type of the extracted tail entity like Person, Company, etc"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_unstructured_prompt(
|
|
|
|
|
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
|
|
|
|
) -> ChatPromptTemplate:
|
|
|
|
|
node_labels_str = str(node_labels) if node_labels else ""
|
|
|
|
|
rel_types_str = str(rel_types) if rel_types else ""
|
|
|
|
|
base_string_parts = [
|
|
|
|
|
"You are a top-tier algorithm designed for extracting information in "
|
|
|
|
|
"structured formats to build a knowledge graph. Your task is to identify "
|
|
|
|
|
"the entities and relations requested with the user prompt from a given "
|
|
|
|
|
"text. You must generate the output in a JSON format containing a list "
|
|
|
|
|
'with JSON objects. Each object should have the keys: "head", '
|
|
|
|
|
'"head_type", "relation", "tail", and "tail_type". The "head" '
|
|
|
|
|
"key must contain the text of the extracted entity with one of the types "
|
|
|
|
|
"from the provided list in the user prompt.",
|
|
|
|
|
f'The "head_type" key must contain the type of the extracted head entity, '
|
|
|
|
|
f"which must be one of the types from {node_labels_str}."
|
|
|
|
|
if node_labels
|
|
|
|
|
else "",
|
|
|
|
|
f'The "relation" key must contain the type of relation between the "head" '
|
|
|
|
|
f'and the "tail", which must be one of the relations from {rel_types_str}.'
|
|
|
|
|
if rel_types
|
|
|
|
|
else "",
|
|
|
|
|
f'The "tail" key must represent the text of an extracted entity which is '
|
|
|
|
|
f'the tail of the relation, and the "tail_type" key must contain the type '
|
|
|
|
|
f"of the tail entity from {node_labels_str}."
|
|
|
|
|
if node_labels
|
|
|
|
|
else "",
|
|
|
|
|
"Attempt to extract as many entities and relations as you can. Maintain "
|
|
|
|
|
"Entity Consistency: When extracting entities, it's vital to ensure "
|
|
|
|
|
'consistency. If an entity, such as "John Doe", is mentioned multiple '
|
|
|
|
|
"times in the text but is referred to by different names or pronouns "
|
|
|
|
|
'(e.g., "Joe", "he"), always use the most complete identifier for '
|
|
|
|
|
"that entity. The knowledge graph should be coherent and easily "
|
|
|
|
|
"understandable, so maintaining consistency in entity references is "
|
|
|
|
|
"crucial.",
|
|
|
|
|
"IMPORTANT NOTES:\n- Don't add any explanation and text.",
|
|
|
|
|
]
|
|
|
|
|
system_prompt = "\n".join(filter(None, base_string_parts))
|
|
|
|
|
|
|
|
|
|
system_message = SystemMessage(content=system_prompt)
|
|
|
|
|
parser = JsonOutputParser(pydantic_object=UnstructuredRelation)
|
|
|
|
|
|
|
|
|
|
human_prompt = PromptTemplate(
|
|
|
|
|
template="""Based on the following example, extract entities and
|
|
|
|
|
relations from the provided text.\n\n
|
|
|
|
|
Use the following entity types, don't use other entity that is not defined below:
|
|
|
|
|
# ENTITY TYPES:
|
|
|
|
|
{node_labels}
|
|
|
|
|
|
|
|
|
|
Use the following relation types, don't use other relation that is not defined below:
|
|
|
|
|
# RELATION TYPES:
|
|
|
|
|
{rel_types}
|
|
|
|
|
|
|
|
|
|
Below are a number of examples of text and their extracted entities and relationships.
|
|
|
|
|
{examples}
|
|
|
|
|
|
|
|
|
|
For the following text, extract entities and relations as in the provided example.
|
|
|
|
|
{format_instructions}\nText: {input}""",
|
|
|
|
|
input_variables=["input"],
|
|
|
|
|
partial_variables={
|
|
|
|
|
"format_instructions": parser.get_format_instructions(),
|
|
|
|
|
"node_labels": node_labels,
|
|
|
|
|
"rel_types": rel_types,
|
|
|
|
|
"examples": examples,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
|
|
|
|
|
|
|
|
|
|
chat_prompt = ChatPromptTemplate.from_messages(
|
|
|
|
|
[system_message, human_message_prompt]
|
|
|
|
|
)
|
|
|
|
|
return chat_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_simple_model(
|
|
|
|
|
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
|
|
|
|
) -> Type[_Graph]:
|
|
|
|
@ -317,22 +472,38 @@ class LLMGraphTransformer:
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
allowed_nodes: List[str] = [],
|
|
|
|
|
allowed_relationships: List[str] = [],
|
|
|
|
|
prompt: ChatPromptTemplate = default_prompt,
|
|
|
|
|
prompt: Optional[ChatPromptTemplate] = None,
|
|
|
|
|
strict_mode: bool = True,
|
|
|
|
|
) -> None:
|
|
|
|
|
if not hasattr(llm, "with_structured_output"):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The specified LLM does not support the 'with_structured_output'. "
|
|
|
|
|
"Please ensure you are using an LLM that supports this feature."
|
|
|
|
|
)
|
|
|
|
|
self.allowed_nodes = allowed_nodes
|
|
|
|
|
self.allowed_relationships = allowed_relationships
|
|
|
|
|
self.strict_mode = strict_mode
|
|
|
|
|
self._function_call = True
|
|
|
|
|
# Check if the LLM really supports structured output
|
|
|
|
|
try:
|
|
|
|
|
llm.with_structured_output(_Graph)
|
|
|
|
|
except NotImplementedError:
|
|
|
|
|
self._function_call = False
|
|
|
|
|
if not self._function_call:
|
|
|
|
|
try:
|
|
|
|
|
import json_repair
|
|
|
|
|
|
|
|
|
|
# Define chain
|
|
|
|
|
schema = create_simple_model(allowed_nodes, allowed_relationships)
|
|
|
|
|
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
|
|
|
|
self.chain = prompt | structured_llm
|
|
|
|
|
self.json_repair = json_repair
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import json_repair python package. "
|
|
|
|
|
"Please install it with `pip install json-repair`."
|
|
|
|
|
)
|
|
|
|
|
prompt = prompt or create_unstructured_prompt(
|
|
|
|
|
allowed_nodes, allowed_relationships
|
|
|
|
|
)
|
|
|
|
|
self.chain = prompt | llm
|
|
|
|
|
else:
|
|
|
|
|
# Define chain
|
|
|
|
|
schema = create_simple_model(allowed_nodes, allowed_relationships)
|
|
|
|
|
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
|
|
|
|
prompt = prompt or default_prompt
|
|
|
|
|
self.chain = prompt | structured_llm
|
|
|
|
|
|
|
|
|
|
def process_response(self, document: Document) -> GraphDocument:
|
|
|
|
|
"""
|
|
|
|
@ -341,8 +512,27 @@ class LLMGraphTransformer:
|
|
|
|
|
"""
|
|
|
|
|
text = document.page_content
|
|
|
|
|
raw_schema = self.chain.invoke({"input": text})
|
|
|
|
|
raw_schema = cast(Dict[Any, Any], raw_schema)
|
|
|
|
|
nodes, relationships = _convert_to_graph_document(raw_schema)
|
|
|
|
|
if self._function_call:
|
|
|
|
|
raw_schema = cast(Dict[Any, Any], raw_schema)
|
|
|
|
|
nodes, relationships = _convert_to_graph_document(raw_schema)
|
|
|
|
|
else:
|
|
|
|
|
nodes_set = set()
|
|
|
|
|
relationships = []
|
|
|
|
|
parsed_json = self.json_repair.loads(raw_schema.content)
|
|
|
|
|
for rel in parsed_json:
|
|
|
|
|
# Nodes need to be deduplicated using a set
|
|
|
|
|
nodes_set.add((rel["head"], rel["head_type"]))
|
|
|
|
|
nodes_set.add((rel["tail"], rel["tail_type"]))
|
|
|
|
|
|
|
|
|
|
source_node = Node(id=rel["head"], type=rel["head_type"])
|
|
|
|
|
target_node = Node(id=rel["tail"], type=rel["tail_type"])
|
|
|
|
|
relationships.append(
|
|
|
|
|
Relationship(
|
|
|
|
|
source=source_node, target=target_node, type=rel["relation"]
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# Create nodes list
|
|
|
|
|
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
|
|
|
|
|
|
|
|
|
|
# Strict mode filtering
|
|
|
|
|
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
|
|
|
|