diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index e2c1fddc..06fe1107 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( ENTITY_EXTRACTION_PROMPT, @@ -11,7 +10,7 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_buffer_string, get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseMessage +from langchain.schema import BaseLanguageModel, BaseMessage class ConversationEntityMemory(BaseChatMemory, BaseModel): @@ -19,7 +18,7 @@ class ConversationEntityMemory(BaseChatMemory, BaseModel): human_prefix: str = "Human" ai_prefix: str = "AI" - llm: BaseLLM + llm: BaseLanguageModel entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT store: Dict[str, Optional[str]] = {} diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index f8036de0..a6a23402 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field from langchain.chains.llm import LLMChain from langchain.graphs import NetworkxEntityGraph from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples -from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( ENTITY_EXTRACTION_PROMPT, @@ -13,7 +12,7 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_buffer_string, get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import SystemMessage +from langchain.schema import BaseLanguageModel, SystemMessage class ConversationKGMemory(BaseChatMemory, BaseModel): @@ -29,7 +28,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel): kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph) knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT - llm: BaseLLM + llm: BaseLanguageModel """Number of previous utterances to include in the context.""" memory_key: str = "history" #: :meta private: diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 7056274f..136c4f73 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -3,18 +3,17 @@ from typing import Any, Dict, List from pydantic import BaseModel, root_validator from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.utils import get_buffer_string from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseMessage, SystemMessage +from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage class SummarizerMixin(BaseModel): human_prefix: str = "Human" ai_prefix: str = "AI" - llm: BaseLLM + llm: BaseLanguageModel prompt: BasePromptTemplate = SUMMARY_PROMPT def predict_new_summary(