Update VectorStore Class Method Typing (#2731)

Avoid using placeholder methods that only perform a `cast()`
operation because the typing would otherwise be inferred to be the
parent `VectorStore` class. This is unnecessary with TypeVar's.
fix_agent_callbacks
vowelparrot 1 year ago committed by GitHub
parent 446c3d586c
commit 0806951c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import uuid
from typing import Any, Iterable, List, Optional
from typing import Any, Iterable, List, Optional, Type
import numpy as np
@ -210,7 +210,7 @@ class AtlasDB(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[AtlasDB],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
@ -270,7 +270,7 @@ class AtlasDB(VectorStore):
@classmethod
def from_documents(
cls,
cls: Type[AtlasDB],
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
from pydantic import BaseModel, Field, root_validator
@ -10,6 +10,8 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
VST = TypeVar("VST", bound="VectorStore")
class VectorStore(ABC):
"""Interface for vector stores."""
@ -153,11 +155,11 @@ class VectorStore(ABC):
@classmethod
def from_documents(
cls,
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VectorStore:
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
@ -165,11 +167,11 @@ class VectorStore(ABC):
@classmethod
async def afrom_documents(
cls,
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VectorStore:
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
@ -178,22 +180,22 @@ class VectorStore(ABC):
@classmethod
@abstractmethod
def from_texts(
cls,
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VectorStore:
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
@classmethod
async def afrom_texts(
cls,
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VectorStore:
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
import numpy as np
@ -269,7 +269,7 @@ class Chroma(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[Chroma],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
@ -307,7 +307,7 @@ class Chroma(VectorStore):
@classmethod
def from_documents(
cls,
cls: Type[Chroma],
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,

@ -1,7 +1,10 @@
"""VectorStore wrapper around a Postgres/PGVector database."""
from __future__ import annotations
import enum
import logging
import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import sqlalchemy
from pgvector.sqlalchemy import Vector
@ -346,7 +349,7 @@ class PGVector(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[PGVector],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
@ -355,7 +358,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> "PGVector":
) -> PGVector:
"""
Return VectorStore initialized from texts and embeddings.
Postgres connection string is required
@ -395,7 +398,7 @@ class PGVector(VectorStore):
@classmethod
def from_documents(
cls,
cls: Type[PGVector],
documents: List[Document],
embedding: Embeddings,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
@ -403,7 +406,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> "PGVector":
) -> PGVector:
"""
Return VectorStore initialized from documents and embeddings.
Postgres connection string is required

@ -1,7 +1,9 @@
"""Wrapper around Qdrant vector database."""
from __future__ import annotations
import uuid
from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
@ -176,55 +178,9 @@ class Qdrant(VectorStore):
for i in mmr_selected
]
@classmethod
def from_documents(
cls,
documents: List[Document],
embedding: Embeddings,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
**kwargs: Any,
) -> "Qdrant":
return cast(
Qdrant,
super().from_documents(
documents,
embedding,
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
collection_name=collection_name,
distance_func=distance_func,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
**kwargs,
),
)
@classmethod
def from_texts(
cls,
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
@ -244,7 +200,7 @@ class Qdrant(VectorStore):
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
**kwargs: Any,
) -> "Qdrant":
) -> Qdrant:
"""Construct Qdrant wrapper from raw documents.
Args:

@ -4,7 +4,7 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
import numpy as np
from pydantic import BaseModel, root_validator
@ -227,7 +227,7 @@ class Redis(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[Redis],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,

@ -1,7 +1,7 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Type
from uuid import uuid4
from langchain.docstore.document import Document
@ -104,11 +104,11 @@ class Weaviate(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[Weaviate],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VectorStore:
) -> Weaviate:
"""Not implemented for Weaviate yet."""
raise NotImplementedError("weaviate does not currently support `from_texts`.")

Loading…
Cancel
Save