From 4d697d3f2492c99fbf367810574d157f42a37b0d Mon Sep 17 00:00:00 2001 From: Oleg Zabluda Date: Fri, 7 Jul 2023 22:47:53 -0700 Subject: [PATCH] Allow passing custom prompts to GraphIndexCreator (#7381) --------- Co-authored-by: Bagatur --- langchain/indexes/graph.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/langchain/indexes/graph.py b/langchain/indexes/graph.py index ddc4e00400..64536e003c 100644 --- a/langchain/indexes/graph.py +++ b/langchain/indexes/graph.py @@ -3,6 +3,7 @@ from typing import Optional, Type from pydantic import BaseModel +from langchain import BasePromptTemplate from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.indexes.prompts.knowledge_triplet_extraction import ( @@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel): llm: Optional[BaseLanguageModel] = None graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph - def from_text(self, text: str) -> NetworkxEntityGraph: + def from_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: """Create graph index from text.""" if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) + chain = LLMChain(llm=self.llm, prompt=prompt) output = chain.predict(text=text) knowledge = parse_triples(output) for triple in knowledge: graph.add_triple(triple) return graph - async def afrom_text(self, text: str) -> NetworkxEntityGraph: + async def afrom_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: """Create graph index from text asynchronously.""" if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) + chain = LLMChain(llm=self.llm, prompt=prompt) output = await chain.apredict(text=text) knowledge = parse_triples(output) for triple in knowledge: