|
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.graphs import NetworkxEntityGraph
|
|
|
|
|
from langchain.graphs.networkx_graph import get_entities, parse_triples
|
|
|
|
|
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
|
|
|
from langchain.memory.prompt import (
|
|
|
|
@ -78,9 +78,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
|
return list(outputs.keys())[0]
|
|
|
|
|
return self.output_key
|
|
|
|
|
|
|
|
|
|
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
|
|
|
|
"""Get the current entities in the conversation."""
|
|
|
|
|
prompt_input_key = self._get_prompt_input_key(inputs)
|
|
|
|
|
def get_current_entities(self, input_string: str) -> List[str]:
|
|
|
|
|
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
|
|
|
|
buffer_string = get_buffer_string(
|
|
|
|
|
self.chat_memory.messages[-self.k * 2 :],
|
|
|
|
@ -89,14 +87,17 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
output = chain.predict(
|
|
|
|
|
history=buffer_string,
|
|
|
|
|
input=inputs[prompt_input_key],
|
|
|
|
|
input=input_string,
|
|
|
|
|
)
|
|
|
|
|
return get_entities(output)
|
|
|
|
|
|
|
|
|
|
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
|
|
|
|
"""Get and update knowledge graph from the conversation history."""
|
|
|
|
|
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
|
|
|
|
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
|
|
|
|
"""Get the current entities in the conversation."""
|
|
|
|
|
prompt_input_key = self._get_prompt_input_key(inputs)
|
|
|
|
|
return self.get_current_entities(inputs[prompt_input_key])
|
|
|
|
|
|
|
|
|
|
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
|
|
|
|
|
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
|
|
|
|
buffer_string = get_buffer_string(
|
|
|
|
|
self.chat_memory.messages[-self.k * 2 :],
|
|
|
|
|
human_prefix=self.human_prefix,
|
|
|
|
@ -104,10 +105,16 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
output = chain.predict(
|
|
|
|
|
history=buffer_string,
|
|
|
|
|
input=inputs[prompt_input_key],
|
|
|
|
|
input=input_string,
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
|
|
|
|
knowledge = parse_triples(output)
|
|
|
|
|
return knowledge
|
|
|
|
|
|
|
|
|
|
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
|
|
|
|
"""Get and update knowledge graph from the conversation history."""
|
|
|
|
|
prompt_input_key = self._get_prompt_input_key(inputs)
|
|
|
|
|
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
|
|
|
|
|
for triple in knowledge:
|
|
|
|
|
self.kg.add_triple(triple)
|
|
|
|
|
|
|
|
|
|