Harrison/pg vector move (#7580)

This commit is contained in:
Harrison Chase 2023-07-11 23:22:34 -07:00 committed by GitHub
parent 2667ddc686
commit 641fd74baa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 50 deletions

View File

@ -1,9 +1,50 @@
from typing import Optional, Tuple
import sqlalchemy
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session, relationship
from langchain.vectorstores.pgvector import BaseModel, CollectionStore
from langchain.vectorstores.pgvector import BaseModel
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):

View File

@ -4,17 +4,30 @@ from __future__ import annotations
import enum
import logging
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import sqlalchemy
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import Session, declarative_base, relationship
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Session, declarative_base
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
if TYPE_CHECKING:
from langchain.vectorstores._pgvector_data_models import CollectionStore
class DistanceStrategy(str, enum.Enum):
"""Enumerator of the Distance strategies."""
@ -37,45 +50,6 @@ class BaseModel(Base):
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class PGVector(VectorStore):
"""VectorStore implementation using Postgres and pgvector.
@ -141,8 +115,12 @@ class PGVector(VectorStore):
"""
self._conn = self.connect()
# self.create_vector_extension()
from langchain.vectorstores._pgvector_data_models import EmbeddingStore
from langchain.vectorstores._pgvector_data_models import (
CollectionStore,
EmbeddingStore,
)
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore
self.create_tables_if_not_exists()
self.create_collection()
@ -173,7 +151,7 @@ class PGVector(VectorStore):
if self.pre_delete_collection:
self.delete_collection()
with Session(self._conn) as session:
CollectionStore.get_or_create(
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
@ -188,7 +166,7 @@ class PGVector(VectorStore):
session.commit()
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return CollectionStore.get_by_name(session, self.collection_name)
return self.CollectionStore.get_by_name(session, self.collection_name)
@classmethod
def __from(
@ -200,6 +178,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
connection_string: Optional[str] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> PGVector:
@ -208,7 +187,8 @@ class PGVector(VectorStore):
if not metadatas:
metadatas = [{} for _ in texts]
connection_string = cls.get_connection_string(kwargs)
if connection_string is None:
connection_string = cls.get_connection_string(kwargs)
store = cls(
connection_string=connection_string,
@ -389,8 +369,8 @@ class PGVector(VectorStore):
.filter(filter_by)
.order_by(sqlalchemy.asc("distance"))
.join(
CollectionStore,
self.EmbeddingStore.collection_id == CollectionStore.uuid,
self.CollectionStore,
self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
)
.limit(k)
.all()