@ -3,112 +3,26 @@ from __future__ import annotations
import logging
import uuid
from typing import Any , Dict , Iterable , List , Optional , Tuple
from typing import Any , Dict , Iterable , List , Optional , Sequence, Tuple, Type
import sqlalchemy
from sqlalchemy import REAL , Index
from sqlalchemy . dialects. postgresql import ARRAY , JSON , UUID
from sqlalchemy import REAL , Column , String , Table , create_engine , insert , text
from sqlalchemy . dialects . postgresql import ARRAY , JSON , TEXT
from sqlalchemy . engine import Row
try :
from sqlalchemy . orm import declarative_base
except ImportError :
from sqlalchemy . ext . declarative import declarative_base
from sqlalchemy . orm import Session , relationship
from sqlalchemy . sql . expression import func
from langchain . docstore . document import Document
from langchain . embeddings . base import Embeddings
from langchain . utils import get_from_dict_or_env
from langchain . vectorstores . base import VectorStore
Base = declarative_base ( ) # type: Any
ADA_TOKEN_COUNT = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = " langchain "
class BaseModel ( Base ) :
__abstract__ = True
uuid = sqlalchemy . Column ( UUID ( as_uuid = True ) , primary_key = True , default = uuid . uuid4 )
class CollectionStore ( BaseModel ) :
__tablename__ = " langchain_pg_collection "
_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = " langchain_document "
name = sqlalchemy . Column ( sqlalchemy . String )
cmetadata = sqlalchemy . Column ( JSON )
embeddings = relationship (
" EmbeddingStore " ,
back_populates = " collection " ,
passive_deletes = True ,
)
@classmethod
def get_by_name ( cls , session : Session , name : str ) - > Optional [ " CollectionStore " ] :
return session . query ( cls ) . filter ( cls . name == name ) . first ( ) # type: ignore
@classmethod
def get_or_create (
cls ,
session : Session ,
name : str ,
cmetadata : Optional [ dict ] = None ,
) - > Tuple [ " CollectionStore " , bool ] :
"""
Get or create a collection .
Returns [ Collection , bool ] where the bool is True if the collection was created .
"""
created = False
collection = cls . get_by_name ( session , name )
if collection :
return collection , created
collection = cls ( name = name , cmetadata = cmetadata )
session . add ( collection )
session . commit ( )
created = True
return collection , created
class EmbeddingStore ( BaseModel ) :
__tablename__ = " langchain_pg_embedding "
collection_id = sqlalchemy . Column (
UUID ( as_uuid = True ) ,
sqlalchemy . ForeignKey (
f " { CollectionStore . __tablename__ } .uuid " ,
ondelete = " CASCADE " ,
) ,
)
collection = relationship ( CollectionStore , back_populates = " embeddings " )
embedding : sqlalchemy . Column = sqlalchemy . Column ( ARRAY ( REAL ) )
document = sqlalchemy . Column ( sqlalchemy . String , nullable = True )
cmetadata = sqlalchemy . Column ( JSON , nullable = True )
# custom_id : any user defined id
custom_id = sqlalchemy . Column ( sqlalchemy . String , nullable = True )
# The following line creates an index named 'langchain_pg_embedding_vector_idx'
langchain_pg_embedding_vector_idx = Index (
" langchain_pg_embedding_vector_idx " ,
embedding ,
postgresql_using = " ann " ,
postgresql_with = {
" distancemeasure " : " L2 " ,
" dim " : 1536 ,
" pq_segments " : 64 ,
" hnsw_m " : 100 ,
" pq_centers " : 2048 ,
} ,
)
class QueryResult :
EmbeddingStore : EmbeddingStore
distance : float
Base = declarative_base ( ) # type: Any
class AnalyticDB ( VectorStore ) :
@ -132,15 +46,15 @@ class AnalyticDB(VectorStore):
self ,
connection_string : str ,
embedding_function : Embeddings ,
embedding_dimension : int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM ,
collection_name : str = _LANGCHAIN_DEFAULT_COLLECTION_NAME ,
collection_metadata : Optional [ dict ] = None ,
pre_delete_collection : bool = False ,
logger : Optional [ logging . Logger ] = None ,
) - > None :
self . connection_string = connection_string
self . embedding_function = embedding_function
self . embedding_dimension = embedding_dimension
self . collection_name = collection_name
self . collection_metadata = collection_metadata
self . pre_delete_collection = pre_delete_collection
self . logger = logger or logging . getLogger ( __name__ )
self . __post_init__ ( )
@ -151,47 +65,68 @@ class AnalyticDB(VectorStore):
"""
Initialize the store .
"""
self . _conn = self . connect ( )
self . create_tables_if_not_exists ( )
self . engine = create_engine ( self . connection_string )
self . create_collection ( )
def connect ( self ) - > sqlalchemy . engine . Connection :
engine = sqlalchemy . create_engine ( self . connection_string )
conn = engine . connect ( )
return conn
def create_tables_if_not_exists ( self ) - > None :
Base . metadata . create_all ( self . _conn )
def drop_tables ( self ) - > None :
Base . metadata . drop_all ( self . _conn )
def create_table_if_not_exists ( self ) - > None :
# Define the dynamic table
Table (
self . collection_name ,
Base . metadata ,
Column ( " id " , TEXT , primary_key = True , default = uuid . uuid4 ) ,
Column ( " embedding " , ARRAY ( REAL ) ) ,
Column ( " document " , String , nullable = True ) ,
Column ( " metadata " , JSON , nullable = True ) ,
extend_existing = True ,
)
with self . engine . connect ( ) as conn :
# Create the table
Base . metadata . create_all ( conn )
# Check if the index exists
index_name = f " { self . collection_name } _embedding_idx "
index_query = text (
f """
SELECT 1
FROM pg_indexes
WHERE indexname = ' {index_name} ' ;
"""
)
result = conn . execute ( index_query ) . scalar ( )
# Create the index if it doesn't exist
if not result :
index_statement = text (
f """
CREATE INDEX { index_name }
ON { self . collection_name } USING ann ( embedding )
WITH (
" dim " = { self . embedding_dimension } ,
" hnsw_m " = 100
) ;
"""
)
conn . execute ( index_statement )
conn . commit ( )
def create_collection ( self ) - > None :
if self . pre_delete_collection :
self . delete_collection ( )
with Session ( self . _conn ) as session :
CollectionStore . get_or_create (
session , self . collection_name , cmetadata = self . collection_metadata
)
self . create_table_if_not_exists ( )
def delete_collection ( self ) - > None :
self . logger . debug ( " Trying to delete collection " )
with Session ( self . _conn ) as session :
collection = self . get_collection ( session )
if not collection :
self . logger . error ( " Collection not found " )
return
session . delete ( collection )
session . commit ( )
def get_collection ( self , session : Session ) - > Optional [ " CollectionStore " ] :
return CollectionStore . get_by_name ( session , self . collection_name )
drop_statement = text ( f " DROP TABLE IF EXISTS { self . collection_name } ; " )
with self . engine . connect ( ) as conn :
conn . execute ( drop_statement )
conn . commit ( )
def add_texts (
self ,
texts : Iterable [ str ] ,
metadatas : Optional [ List [ dict ] ] = None ,
ids : Optional [ List [ str ] ] = None ,
batch_size : int = 500 ,
* * kwargs : Any ,
) - > List [ str ] :
""" Run more texts through the embeddings and add to the vectorstore.
@ -212,20 +147,43 @@ class AnalyticDB(VectorStore):
if not metadatas :
metadatas = [ { } for _ in texts ]
with Session ( self . _conn ) as session :
collection = self . get_collection ( session )
if not collection :
raise ValueError ( " Collection not found " )
for text , metadata , embedding , id in zip ( texts , metadatas , embeddings , ids ) :
embedding_store = EmbeddingStore (
embedding = embedding ,
document = text ,
cmetadata = metadata ,
custom_id = id ,
# Define the table schema
chunks_table = Table (
self . collection_name ,
Base . metadata ,
Column ( " id " , TEXT , primary_key = True ) ,
Column ( " embedding " , ARRAY ( REAL ) ) ,
Column ( " document " , String , nullable = True ) ,
Column ( " metadata " , JSON , nullable = True ) ,
extend_existing = True ,
)
chunks_table_data = [ ]
with self . engine . connect ( ) as conn :
for document , metadata , chunk_id , embedding in zip (
texts , metadatas , ids , embeddings
) :
chunks_table_data . append (
{
" id " : chunk_id ,
" embedding " : embedding ,
" document " : document ,
" metadata " : metadata ,
}
)
collection . embeddings . append ( embedding_store )
session . add ( embedding_store )
session . commit ( )
# Execute the batch insert when the batch size is reached
if len ( chunks_table_data ) == batch_size :
conn . execute ( insert ( chunks_table ) . values ( chunks_table_data ) )
# Clear the chunks_table_data list for the next batch
chunks_table_data . clear ( )
# Insert any remaining records that didn't make up a full batch
if chunks_table_data :
conn . execute ( insert ( chunks_table ) . values ( chunks_table_data ) )
# Commit the transaction only once after all records have been inserted
conn . commit ( )
return ids
@ -275,52 +233,69 @@ class AnalyticDB(VectorStore):
)
return docs
def _similarity_search_with_relevance_scores (
self ,
query : str ,
k : int = 4 ,
* * kwargs : Any ,
) - > List [ Tuple [ Document , float ] ] :
""" Return docs and relevance scores in the range [0, 1].
0 is dissimilar , 1 is most similar .
Args :
query : input text
k : Number of Documents to return . Defaults to 4.
* * kwargs : kwargs to be passed to similarity search . Should include :
score_threshold : Optional , a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns :
List of Tuples of ( doc , similarity_score )
"""
return self . similarity_search_with_score ( query , k , * * kwargs )
def similarity_search_with_score_by_vector (
self ,
embedding : List [ float ] ,
k : int = 4 ,
filter : Optional [ dict ] = None ,
) - > List [ Tuple [ Document , float ] ] :
with Session ( self . _conn ) as session :
collection = self . get_collection ( session )
if not collection :
raise ValueError ( " Collection not found " )
filter_by = EmbeddingStore . collection_id == collection . uuid
# Add the filter if provided
filter_condition = " "
if filter is not None :
filter_clauses = [ ]
for key , value in filter . items ( ) :
filter_by_metadata = EmbeddingStore . cmetadata [ key ] . astext == str ( value )
filter_clauses . append ( filter_by_metadata )
conditions = [
f " metadata->> { key !r} = { value !r} " for key , value in filter . items ( )
]
filter_condition = f " WHERE { ' AND ' . join ( conditions ) } "
# Define the base query
sql_query = f """
SELECT * , l2_distance ( embedding , : embedding ) as distance
FROM { self . collection_name }
{ filter_condition }
ORDER BY embedding < - > : embedding
LIMIT : k
"""
filter_by = sqlalchemy . and_ ( filter_by , * filter_clauses )
# Set up the query parameters
params = { " embedding " : embedding , " k " : k }
results : List [ QueryResult ] = (
session . query (
EmbeddingStore ,
func . l2_distance ( EmbeddingStore . embedding , embedding ) . label ( " distance " ) ,
)
. filter ( filter_by )
. order_by ( EmbeddingStore . embedding . op ( " <-> " ) ( embedding ) )
. join (
CollectionStore ,
EmbeddingStore . collection_id == CollectionStore . uuid ,
)
. limit ( k )
. all ( )
)
docs = [
# Execute the query and fetch the results
with self . engine . connect ( ) as conn :
results : Sequence [ Row ] = conn . execute ( text ( sql_query ) , params ) . fetchall ( )
documents_with_scores = [
(
Document (
page_content = result . EmbeddingStore. document,
metadata = result . EmbeddingStore. c metadata,
page_content = result . document ,
metadata = result . metadata ,
) ,
result . distance if self . embedding_function is not None else None ,
)
for result in results
]
return doc s
return documents_with_scores
def similarity_search_by_vector (
self ,
@ -346,10 +321,11 @@ class AnalyticDB(VectorStore):
@classmethod
def from_texts (
cls ,
cls : Type [ AnalyticDB ] ,
texts : List [ str ] ,
embedding : Embeddings ,
metadatas : Optional [ List [ dict ] ] = None ,
embedding_dimension : int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM ,
collection_name : str = _LANGCHAIN_DEFAULT_COLLECTION_NAME ,
ids : Optional [ List [ str ] ] = None ,
pre_delete_collection : bool = False ,
@ -368,6 +344,7 @@ class AnalyticDB(VectorStore):
connection_string = connection_string ,
collection_name = collection_name ,
embedding_function = embedding ,
embedding_dimension = embedding_dimension ,
pre_delete_collection = pre_delete_collection ,
)
@ -379,7 +356,7 @@ class AnalyticDB(VectorStore):
connection_string : str = get_from_dict_or_env (
data = kwargs ,
key = " connection_string " ,
env_key = " PG VECTOR _CONNECTION_STRING" ,
env_key = " PG _CONNECTION_STRING" ,
)
if not connection_string :
@ -393,9 +370,10 @@ class AnalyticDB(VectorStore):
@classmethod
def from_documents (
cls ,
cls : Type [ AnalyticDB ] ,
documents : List [ Document ] ,
embedding : Embeddings ,
embedding_dimension : int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM ,
collection_name : str = _LANGCHAIN_DEFAULT_COLLECTION_NAME ,
ids : Optional [ List [ str ] ] = None ,
pre_delete_collection : bool = False ,
@ -418,6 +396,7 @@ class AnalyticDB(VectorStore):
texts = texts ,
pre_delete_collection = pre_delete_collection ,
embedding = embedding ,
embedding_dimension = embedding_dimension ,
metadatas = metadatas ,
ids = ids ,
collection_name = collection_name ,