forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
1.0 KiB
Python
31 lines
1.0 KiB
Python
"""Graph Index Creator."""
|
|
from typing import Optional, Type
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
|
|
from langchain.indexes.prompts.knowledge_triplet_extraction import (
|
|
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
|
)
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
|
class GraphIndexCreator(BaseModel):
|
|
"""Functionality to create graph index."""
|
|
|
|
llm: Optional[BaseLLM] = None
|
|
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
|
|
|
|
def from_text(self, text: str) -> NetworkxEntityGraph:
|
|
"""Create graph index from text."""
|
|
if self.llm is None:
|
|
raise ValueError("llm should not be None")
|
|
graph = self.graph_type()
|
|
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
|
|
output = chain.predict(text=text)
|
|
knowledge = parse_triples(output)
|
|
for triple in knowledge:
|
|
graph.add_triple(triple)
|
|
return graph
|