diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 11e20de5..11a75885 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -329,7 +329,7 @@ class VectorStore(ABC): """Return VectorStore initialized from texts and embeddings.""" raise NotImplementedError - def as_retriever(self, **kwargs: Any) -> BaseRetriever: + def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: return VectorStoreRetriever(vectorstore=self, **kwargs) diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index ba10fb55..0adec6a4 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -23,9 +23,8 @@ from pydantic import BaseModel, root_validator from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.base import VectorStore, VectorStoreRetriever logger = logging.getLogger(__name__) @@ -544,11 +543,11 @@ class Redis(VectorStore): **kwargs, ) - def as_retriever(self, **kwargs: Any) -> BaseRetriever: + def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever: return RedisVectorStoreRetriever(vectorstore=self, **kwargs) -class RedisVectorStoreRetriever(BaseRetriever, BaseModel): +class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel): vectorstore: Redis search_type: str = "similarity" k: int = 4