diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index eff9fac354..fe439d86c5 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -62,7 +62,7 @@ class BaseModel(Base): _classes: Any = None -def _get_embedding_collection_store() -> Any: +def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: global _classes if _classes is not None: return _classes @@ -125,7 +125,7 @@ def _get_embedding_collection_store() -> Any: ) collection = relationship(CollectionStore, back_populates="embeddings") - embedding: Vector = sqlalchemy.Column(Vector(None)) + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) document = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True) @@ -151,6 +151,10 @@ class PGVector(VectorStore): connection_string: Postgres connection string. embedding_function: Any embedding function implementing `langchain.embeddings.base.Embeddings` interface. + embedding_length: The length of the embedding vector. (default: None) + NOTE: This is not mandatory. Defining it will prevent vectors of + any other size to be added to the embeddings table but, without it, + the embeddings can't be indexed. collection_name: The name of the collection to use. (default: langchain) NOTE: This is not the name of the table, but the name of the collection. The tables will be created when initializing the store (if not exists) @@ -183,6 +187,7 @@ class PGVector(VectorStore): self, connection_string: str, embedding_function: Embeddings, + embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, @@ -195,6 +200,7 @@ class PGVector(VectorStore): ) -> None: self.connection_string = connection_string self.embedding_function = embedding_function + self._embedding_length = embedding_length self.collection_name = collection_name self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy @@ -211,7 +217,9 @@ class PGVector(VectorStore): """Initialize the store.""" self.create_vector_extension() - EmbeddingStore, CollectionStore = _get_embedding_collection_store() + EmbeddingStore, CollectionStore = _get_embedding_collection_store( + self._embedding_length + ) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore self.create_tables_if_not_exists()