From 0806951c07179936cdeda5dfdfbdeb1f5c47d1b4 Mon Sep 17 00:00:00 2001 From: vowelparrot <130414180+vowelparrot@users.noreply.github.com> Date: Tue, 11 Apr 2023 14:14:49 -0700 Subject: [PATCH] Update VectorStore Class Method Typing (#2731) Avoid using placeholder methods that only perform a `cast()` operation because the typing would otherwise be inferred to be the parent `VectorStore` class. This is unnecessary with TypeVar's. --- langchain/vectorstores/atlas.py | 6 ++-- langchain/vectorstores/base.py | 20 ++++++----- langchain/vectorstores/chroma.py | 6 ++-- langchain/vectorstores/pgvector.py | 13 ++++--- langchain/vectorstores/qdrant.py | 54 +++--------------------------- langchain/vectorstores/redis.py | 4 +-- langchain/vectorstores/weaviate.py | 6 ++-- 7 files changed, 35 insertions(+), 74 deletions(-) diff --git a/langchain/vectorstores/atlas.py b/langchain/vectorstores/atlas.py index ce2410b5..af7f5557 100644 --- a/langchain/vectorstores/atlas.py +++ b/langchain/vectorstores/atlas.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import uuid -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List, Optional, Type import numpy as np @@ -210,7 +210,7 @@ class AtlasDB(VectorStore): @classmethod def from_texts( - cls, + cls: Type[AtlasDB], texts: List[str], embedding: Optional[Embeddings] = None, metadatas: Optional[List[dict]] = None, @@ -270,7 +270,7 @@ class AtlasDB(VectorStore): @classmethod def from_documents( - cls, + cls: Type[AtlasDB], documents: List[Document], embedding: Optional[Embeddings] = None, ids: Optional[List[str]] = None, diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 5fdbfad6..92f3a601 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar from pydantic import BaseModel, Field, root_validator @@ -10,6 +10,8 @@ from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever +VST = TypeVar("VST", bound="VectorStore") + class VectorStore(ABC): """Interface for vector stores.""" @@ -153,11 +155,11 @@ class VectorStore(ABC): @classmethod def from_documents( - cls, + cls: Type[VST], documents: List[Document], embedding: Embeddings, **kwargs: Any, - ) -> VectorStore: + ) -> VST: """Return VectorStore initialized from documents and embeddings.""" texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -165,11 +167,11 @@ class VectorStore(ABC): @classmethod async def afrom_documents( - cls, + cls: Type[VST], documents: List[Document], embedding: Embeddings, **kwargs: Any, - ) -> VectorStore: + ) -> VST: """Return VectorStore initialized from documents and embeddings.""" texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -178,22 +180,22 @@ class VectorStore(ABC): @classmethod @abstractmethod def from_texts( - cls, + cls: Type[VST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, - ) -> VectorStore: + ) -> VST: """Return VectorStore initialized from texts and embeddings.""" @classmethod async def afrom_texts( - cls, + cls: Type[VST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, - ) -> VectorStore: + ) -> VST: """Return VectorStore initialized from texts and embeddings.""" raise NotImplementedError diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 903be694..60ecf3de 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type import numpy as np @@ -269,7 +269,7 @@ class Chroma(VectorStore): @classmethod def from_texts( - cls, + cls: Type[Chroma], texts: List[str], embedding: Optional[Embeddings] = None, metadatas: Optional[List[dict]] = None, @@ -307,7 +307,7 @@ class Chroma(VectorStore): @classmethod def from_documents( - cls, + cls: Type[Chroma], documents: List[Document], embedding: Optional[Embeddings] = None, ids: Optional[List[str]] = None, diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 941a9378..27008eb5 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -1,7 +1,10 @@ +"""VectorStore wrapper around a Postgres/PGVector database.""" +from __future__ import annotations + import enum import logging import uuid -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type import sqlalchemy from pgvector.sqlalchemy import Vector @@ -346,7 +349,7 @@ class PGVector(VectorStore): @classmethod def from_texts( - cls, + cls: Type[PGVector], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, @@ -355,7 +358,7 @@ class PGVector(VectorStore): ids: Optional[List[str]] = None, pre_delete_collection: bool = False, **kwargs: Any, - ) -> "PGVector": + ) -> PGVector: """ Return VectorStore initialized from texts and embeddings. Postgres connection string is required @@ -395,7 +398,7 @@ class PGVector(VectorStore): @classmethod def from_documents( - cls, + cls: Type[PGVector], documents: List[Document], embedding: Embeddings, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, @@ -403,7 +406,7 @@ class PGVector(VectorStore): ids: Optional[List[str]] = None, pre_delete_collection: bool = False, **kwargs: Any, - ) -> "PGVector": + ) -> PGVector: """ Return VectorStore initialized from documents and embeddings. Postgres connection string is required diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 5b9880ad..cc5fe355 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -1,7 +1,9 @@ """Wrapper around Qdrant vector database.""" +from __future__ import annotations + import uuid from operator import itemgetter -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -176,55 +178,9 @@ class Qdrant(VectorStore): for i in mmr_selected ] - @classmethod - def from_documents( - cls, - documents: List[Document], - embedding: Embeddings, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - **kwargs: Any, - ) -> "Qdrant": - return cast( - Qdrant, - super().from_documents( - documents, - embedding, - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - collection_name=collection_name, - distance_func=distance_func, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - **kwargs, - ), - ) - @classmethod def from_texts( - cls, + cls: Type[Qdrant], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, @@ -244,7 +200,7 @@ class Qdrant(VectorStore): content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, **kwargs: Any, - ) -> "Qdrant": + ) -> Qdrant: """Construct Qdrant wrapper from raw documents. Args: diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 4b2f7452..02a5d19d 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -4,7 +4,7 @@ from __future__ import annotations import json import logging import uuid -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type import numpy as np from pydantic import BaseModel, root_validator @@ -227,7 +227,7 @@ class Redis(VectorStore): @classmethod def from_texts( - cls, + cls: Type[Redis], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index 22a9a037..0a105801 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -1,7 +1,7 @@ """Wrapper around weaviate vector database.""" from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Type from uuid import uuid4 from langchain.docstore.document import Document @@ -104,11 +104,11 @@ class Weaviate(VectorStore): @classmethod def from_texts( - cls, + cls: Type[Weaviate], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, - ) -> VectorStore: + ) -> Weaviate: """Not implemented for Weaviate yet.""" raise NotImplementedError("weaviate does not currently support `from_texts`.")