mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Allow passing custom prompts to GraphIndexCreator (#7381)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
612a74eb7e
commit
4d697d3f24
@ -3,6 +3,7 @@ from typing import Optional, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain import BasePromptTemplate
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
|
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
|
||||||
from langchain.indexes.prompts.knowledge_triplet_extraction import (
|
from langchain.indexes.prompts.knowledge_triplet_extraction import (
|
||||||
@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel):
|
|||||||
llm: Optional[BaseLanguageModel] = None
|
llm: Optional[BaseLanguageModel] = None
|
||||||
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
|
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
|
||||||
|
|
||||||
def from_text(self, text: str) -> NetworkxEntityGraph:
|
def from_text(
|
||||||
|
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||||
|
) -> NetworkxEntityGraph:
|
||||||
"""Create graph index from text."""
|
"""Create graph index from text."""
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
raise ValueError("llm should not be None")
|
raise ValueError("llm should not be None")
|
||||||
graph = self.graph_type()
|
graph = self.graph_type()
|
||||||
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
|
chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||||
output = chain.predict(text=text)
|
output = chain.predict(text=text)
|
||||||
knowledge = parse_triples(output)
|
knowledge = parse_triples(output)
|
||||||
for triple in knowledge:
|
for triple in knowledge:
|
||||||
graph.add_triple(triple)
|
graph.add_triple(triple)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
async def afrom_text(self, text: str) -> NetworkxEntityGraph:
|
async def afrom_text(
|
||||||
|
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||||
|
) -> NetworkxEntityGraph:
|
||||||
"""Create graph index from text asynchronously."""
|
"""Create graph index from text asynchronously."""
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
raise ValueError("llm should not be None")
|
raise ValueError("llm should not be None")
|
||||||
graph = self.graph_type()
|
graph = self.graph_type()
|
||||||
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
|
chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||||
output = await chain.apredict(text=text)
|
output = await chain.apredict(text=text)
|
||||||
knowledge = parse_triples(output)
|
knowledge = parse_triples(output)
|
||||||
for triple in knowledge:
|
for triple in knowledge:
|
||||||
|
Loading…
Reference in New Issue
Block a user