community[patch]: Implement vector length definition at init time in PGVector for indexing (#16133)

Replace this entire comment with:
- **Description:** allow user to define tVector length in PGVector when
creating the embedding store, this allows for later indexing
  - **Issue:** #16132
  - **Dependencies:** None
This commit is contained in:
Frank995 2024-01-22 23:32:44 +01:00 committed by GitHub
parent a950fa0487
commit 5694728816
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,7 +62,7 @@ class BaseModel(Base):
_classes: Any = None _classes: Any = None
def _get_embedding_collection_store() -> Any: def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any:
global _classes global _classes
if _classes is not None: if _classes is not None:
return _classes return _classes
@ -125,7 +125,7 @@ def _get_embedding_collection_store() -> Any:
) )
collection = relationship(CollectionStore, back_populates="embeddings") collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(None)) embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True) document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)
@ -151,6 +151,10 @@ class PGVector(VectorStore):
connection_string: Postgres connection string. connection_string: Postgres connection string.
embedding_function: Any embedding function implementing embedding_function: Any embedding function implementing
`langchain.embeddings.base.Embeddings` interface. `langchain.embeddings.base.Embeddings` interface.
embedding_length: The length of the embedding vector. (default: None)
NOTE: This is not mandatory. Defining it will prevent vectors of
any other size to be added to the embeddings table but, without it,
the embeddings can't be indexed.
collection_name: The name of the collection to use. (default: langchain) collection_name: The name of the collection to use. (default: langchain)
NOTE: This is not the name of the table, but the name of the collection. NOTE: This is not the name of the table, but the name of the collection.
The tables will be created when initializing the store (if not exists) The tables will be created when initializing the store (if not exists)
@ -183,6 +187,7 @@ class PGVector(VectorStore):
self, self,
connection_string: str, connection_string: str,
embedding_function: Embeddings, embedding_function: Embeddings,
embedding_length: Optional[int] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None, collection_metadata: Optional[dict] = None,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
@ -195,6 +200,7 @@ class PGVector(VectorStore):
) -> None: ) -> None:
self.connection_string = connection_string self.connection_string = connection_string
self.embedding_function = embedding_function self.embedding_function = embedding_function
self._embedding_length = embedding_length
self.collection_name = collection_name self.collection_name = collection_name
self.collection_metadata = collection_metadata self.collection_metadata = collection_metadata
self._distance_strategy = distance_strategy self._distance_strategy = distance_strategy
@ -211,7 +217,9 @@ class PGVector(VectorStore):
"""Initialize the store.""" """Initialize the store."""
self.create_vector_extension() self.create_vector_extension()
EmbeddingStore, CollectionStore = _get_embedding_collection_store() EmbeddingStore, CollectionStore = _get_embedding_collection_store(
self._embedding_length
)
self.CollectionStore = CollectionStore self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore self.EmbeddingStore = EmbeddingStore
self.create_tables_if_not_exists() self.create_tables_if_not_exists()