|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
from typing import Any, Dict, List, Type, Union
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
@ -12,7 +12,12 @@ from langchain.memory.prompt import (
|
|
|
|
|
)
|
|
|
|
|
from langchain.memory.utils import get_prompt_input_key
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
BaseLanguageModel,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
SystemMessage,
|
|
|
|
|
get_buffer_string,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
@ -29,6 +34,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
|
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
|
|
|
|
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
|
|
|
|
llm: BaseLanguageModel
|
|
|
|
|
summary_message_cls: Type[BaseMessage] = SystemMessage
|
|
|
|
|
"""Number of previous utterances to include in the context."""
|
|
|
|
|
memory_key: str = "history" #: :meta private:
|
|
|
|
|
|
|
|
|
@ -40,12 +46,15 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|
|
|
|
knowledge = self.kg.get_entity_knowledge(entity)
|
|
|
|
|
if knowledge:
|
|
|
|
|
summaries[entity] = ". ".join(knowledge) + "."
|
|
|
|
|
context: Union[str, List]
|
|
|
|
|
if summaries:
|
|
|
|
|
summary_strings = [
|
|
|
|
|
f"On {entity}: {summary}" for entity, summary in summaries.items()
|
|
|
|
|
]
|
|
|
|
|
if self.return_messages:
|
|
|
|
|
context: Any = [SystemMessage(content=text) for text in summary_strings]
|
|
|
|
|
context = [
|
|
|
|
|
self.summary_message_cls(content=text) for text in summary_strings
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
context = "\n".join(summary_strings)
|
|
|
|
|
else:
|
|
|
|
|