Update VectorStore Class Method Typing (#2731)

Avoid using placeholder methods that only perform a `cast()`
operation because the typing would otherwise be inferred to be the
parent `VectorStore` class. This is unnecessary with TypeVar's.
fix_agent_callbacks
vowelparrot 1 year ago committed by GitHub
parent 446c3d586c
commit 0806951c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import uuid import uuid
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional, Type
import numpy as np import numpy as np
@ -210,7 +210,7 @@ class AtlasDB(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[AtlasDB],
texts: List[str], texts: List[str],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -270,7 +270,7 @@ class AtlasDB(VectorStore):
@classmethod @classmethod
def from_documents( def from_documents(
cls, cls: Type[AtlasDB],
documents: List[Document], documents: List[Document],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator
@ -10,6 +10,8 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
VST = TypeVar("VST", bound="VectorStore")
class VectorStore(ABC): class VectorStore(ABC):
"""Interface for vector stores.""" """Interface for vector stores."""
@ -153,11 +155,11 @@ class VectorStore(ABC):
@classmethod @classmethod
def from_documents( def from_documents(
cls, cls: Type[VST],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> VST:
"""Return VectorStore initialized from documents and embeddings.""" """Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
@ -165,11 +167,11 @@ class VectorStore(ABC):
@classmethod @classmethod
async def afrom_documents( async def afrom_documents(
cls, cls: Type[VST],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> VST:
"""Return VectorStore initialized from documents and embeddings.""" """Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
@ -178,22 +180,22 @@ class VectorStore(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_texts( def from_texts(
cls, cls: Type[VST],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> VST:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
@classmethod @classmethod
async def afrom_texts( async def afrom_texts(
cls, cls: Type[VST],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> VST:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError raise NotImplementedError

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import uuid import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
import numpy as np import numpy as np
@ -269,7 +269,7 @@ class Chroma(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[Chroma],
texts: List[str], texts: List[str],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -307,7 +307,7 @@ class Chroma(VectorStore):
@classmethod @classmethod
def from_documents( def from_documents(
cls, cls: Type[Chroma],
documents: List[Document], documents: List[Document],
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,

@ -1,7 +1,10 @@
"""VectorStore wrapper around a Postgres/PGVector database."""
from __future__ import annotations
import enum import enum
import logging import logging
import uuid import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import sqlalchemy import sqlalchemy
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
@ -346,7 +349,7 @@ class PGVector(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[PGVector],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -355,7 +358,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
**kwargs: Any, **kwargs: Any,
) -> "PGVector": ) -> PGVector:
""" """
Return VectorStore initialized from texts and embeddings. Return VectorStore initialized from texts and embeddings.
Postgres connection string is required Postgres connection string is required
@ -395,7 +398,7 @@ class PGVector(VectorStore):
@classmethod @classmethod
def from_documents( def from_documents(
cls, cls: Type[PGVector],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
@ -403,7 +406,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
**kwargs: Any, **kwargs: Any,
) -> "PGVector": ) -> PGVector:
""" """
Return VectorStore initialized from documents and embeddings. Return VectorStore initialized from documents and embeddings.
Postgres connection string is required Postgres connection string is required

@ -1,7 +1,9 @@
"""Wrapper around Qdrant vector database.""" """Wrapper around Qdrant vector database."""
from __future__ import annotations
import uuid import uuid
from operator import itemgetter from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -176,55 +178,9 @@ class Qdrant(VectorStore):
for i in mmr_selected for i in mmr_selected
] ]
@classmethod
def from_documents(
cls,
documents: List[Document],
embedding: Embeddings,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
**kwargs: Any,
) -> "Qdrant":
return cast(
Qdrant,
super().from_documents(
documents,
embedding,
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
collection_name=collection_name,
distance_func=distance_func,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
**kwargs,
),
)
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[Qdrant],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -244,7 +200,7 @@ class Qdrant(VectorStore):
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
**kwargs: Any, **kwargs: Any,
) -> "Qdrant": ) -> Qdrant:
"""Construct Qdrant wrapper from raw documents. """Construct Qdrant wrapper from raw documents.
Args: Args:

@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
import logging import logging
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
import numpy as np import numpy as np
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -227,7 +227,7 @@ class Redis(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[Redis],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,

@ -1,7 +1,7 @@
"""Wrapper around weaviate vector database.""" """Wrapper around weaviate vector database."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional, Type
from uuid import uuid4 from uuid import uuid4
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -104,11 +104,11 @@ class Weaviate(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[Weaviate],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> Weaviate:
"""Not implemented for Weaviate yet.""" """Not implemented for Weaviate yet."""
raise NotImplementedError("weaviate does not currently support `from_texts`.") raise NotImplementedError("weaviate does not currently support `from_texts`.")

Loading…
Cancel
Save