From ed8207b2fb898b3f7a65f800be8e381a6e0d9911 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 14 May 2023 18:25:50 -0700 Subject: [PATCH] Harrison/typing of return (#4685) Co-authored-by: OlajideOgun <37077640+OlajideOgun@users.noreply.github.com> --- langchain/vectorstores/base.py | 2 +- langchain/vectorstores/redis.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 11e20de5..11a75885 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -329,7 +329,7 @@ class VectorStore(ABC): """Return VectorStore initialized from texts and embeddings.""" raise NotImplementedError - def as_retriever(self, **kwargs: Any) -> BaseRetriever: + def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: return VectorStoreRetriever(vectorstore=self, **kwargs) diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index ba10fb55..0adec6a4 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -23,9 +23,8 @@ from pydantic import BaseModel, root_validator from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.base import VectorStore, VectorStoreRetriever logger = logging.getLogger(__name__) @@ -544,11 +543,11 @@ class Redis(VectorStore): **kwargs, ) - def as_retriever(self, **kwargs: Any) -> BaseRetriever: + def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever: return RedisVectorStoreRetriever(vectorstore=self, **kwargs) -class RedisVectorStoreRetriever(BaseRetriever, BaseModel): +class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel): vectorstore: Redis search_type: str = "similarity" k: int = 4