diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index a7808e15..27f903b1 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from itertools import islice from typing import Any, Dict, Iterable, List, Optional -from pydantic import Field +from pydantic import BaseModel, Field from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain @@ -19,7 +19,7 @@ from langchain.schema import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) -class BaseEntityStore(ABC): +class BaseEntityStore(BaseModel, ABC): @abstractmethod def get(self, key: str, default: Optional[str] = None) -> Optional[str]: """Get entity value from store."""