diff --git a/langchain/indexes/graph.py b/langchain/indexes/graph.py index ddc4e00400..64536e003c 100644 --- a/langchain/indexes/graph.py +++ b/langchain/indexes/graph.py @@ -3,6 +3,7 @@ from typing import Optional, Type from pydantic import BaseModel +from langchain import BasePromptTemplate from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.indexes.prompts.knowledge_triplet_extraction import ( @@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel): llm: Optional[BaseLanguageModel] = None graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph - def from_text(self, text: str) -> NetworkxEntityGraph: + def from_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: """Create graph index from text.""" if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) + chain = LLMChain(llm=self.llm, prompt=prompt) output = chain.predict(text=text) knowledge = parse_triples(output) for triple in knowledge: graph.add_triple(triple) return graph - async def afrom_text(self, text: str) -> NetworkxEntityGraph: + async def afrom_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: """Create graph index from text asynchronously.""" if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) + chain = LLMChain(llm=self.llm, prompt=prompt) output = await chain.apredict(text=text) knowledge = parse_triples(output) for triple in knowledge: