Allow passing custom prompts to GraphIndexCreator (#7381)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Oleg Zabluda 2023-07-07 22:47:53 -07:00 committed by GitHub
parent 612a74eb7e
commit 4d697d3f24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ from typing import Optional, Type
from pydantic import BaseModel from pydantic import BaseModel
from langchain import BasePromptTemplate
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
from langchain.indexes.prompts.knowledge_triplet_extraction import ( from langchain.indexes.prompts.knowledge_triplet_extraction import (
@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel):
llm: Optional[BaseLanguageModel] = None llm: Optional[BaseLanguageModel] = None
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
def from_text(self, text: str) -> NetworkxEntityGraph: def from_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text.""" """Create graph index from text."""
if self.llm is None: if self.llm is None:
raise ValueError("llm should not be None") raise ValueError("llm should not be None")
graph = self.graph_type() graph = self.graph_type()
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) chain = LLMChain(llm=self.llm, prompt=prompt)
output = chain.predict(text=text) output = chain.predict(text=text)
knowledge = parse_triples(output) knowledge = parse_triples(output)
for triple in knowledge: for triple in knowledge:
graph.add_triple(triple) graph.add_triple(triple)
return graph return graph
async def afrom_text(self, text: str) -> NetworkxEntityGraph: async def afrom_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text asynchronously.""" """Create graph index from text asynchronously."""
if self.llm is None: if self.llm is None:
raise ValueError("llm should not be None") raise ValueError("llm should not be None")
graph = self.graph_type() graph = self.graph_type()
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT) chain = LLMChain(llm=self.llm, prompt=prompt)
output = await chain.apredict(text=text) output = await chain.apredict(text=text)
knowledge = parse_triples(output) knowledge = parse_triples(output)
for triple in knowledge: for triple in knowledge: