@ -1,6 +1,8 @@
from __future__ import annotations
import os
import uuid
import warnings
from typing import Any , Iterable , List , Optional
from langchain_core . documents import Document
@ -8,6 +10,17 @@ from langchain_core.embeddings import Embeddings
from langchain_core . vectorstores import VectorStore
def import_lancedb ( ) - > Any :
try :
import lancedb
except ImportError as e :
raise ImportError (
" Could not import pinecone lancedb package. "
" Please install it with `pip install lancedb`. "
) from e
return lancedb
class LanceDB ( VectorStore ) :
""" `LanceDB` vector store.
@ -22,15 +35,15 @@ class LanceDB(VectorStore):
id_key : Key to use for the id in the database . Defaults to ` ` id ` ` .
text_key : Key to use for the text in the database . Defaults to ` ` text ` ` .
table_name : Name of the table to use . Defaults to ` ` vectorstore ` ` .
api_key : API key to use for LanceDB cloud database .
region : Region to use for LanceDB cloud database .
mode : Mode to use for adding data to the table . Defaults to ` ` overwrite ` ` .
Example :
. . code - block : : python
db = lancedb . connect ( ' ./lancedb ' )
table = db . open_table ( ' my_table ' )
vectorstore = LanceDB ( table , embedding_function )
vectorstore = LanceDB ( uri = ' /lancedb ' , embedding_function )
vectorstore . add_texts ( [ ' text1 ' , ' text2 ' ] )
result = vectorstore . similarity_search ( ' text1 ' )
"""
@ -39,38 +52,55 @@ class LanceDB(VectorStore):
self ,
connection : Optional [ Any ] = None ,
embedding : Optional [ Embeddings ] = None ,
uri : Optional [ str ] = " /tmp/lancedb " ,
vector_key : Optional [ str ] = " vector " ,
id_key : Optional [ str ] = " id " ,
text_key : Optional [ str ] = " text " ,
table_name : Optional [ str ] = " vectorstore " ,
api_key : Optional [ str ] = None ,
region : Optional [ str ] = None ,
mode : Optional [ str ] = " overwrite " ,
) :
""" Initialize with Lance DB vectorstore """
try :
import lancedb
except ImportError :
raise ImportError (
" Could not import lancedb python package. "
" Please install it with `pip install lancedb`. "
)
self . lancedb = lancedb
lancedb = import_lancedb ( )
self . _embedding = embedding
self . _vector_key = vector_key
self . _id_key = id_key
self . _text_key = text_key
self . _table_name = table_name
self . api_key = api_key or os . getenv ( " LANCE_API_KEY " ) if api_key != " " else None
self . region = region
self . mode = mode
if isinstance ( uri , str ) and self . api_key is None :
if uri . startswith ( " db:// " ) :
raise ValueError ( " API key is required for LanceDB cloud. " )
if self . _embedding is None :
raise ValueError ( " embedding should be provided " )
raise ValueError ( " embedding object should be provided" )
if connection is not None :
if not isinstance ( connection , lancedb . db . LanceTable ) :
raise ValueError (
" connection should be an instance of lancedb.db.LanceTable, " ,
f " got { type ( connection ) } " ,
)
if isinstance ( connection , lancedb . db . LanceDBConnection ) :
self . _connection = connection
elif isinstance ( connection , ( str , lancedb . db . LanceTable ) ) :
raise ValueError (
" `connection` has to be a lancedb.db.LanceDBConnection object. \
` lancedb . db . LanceTable ` is deprecated . "
)
else :
self . _connection = self . _init_table ( )
if self . api_key is None :
self . _connection = lancedb . connect ( uri )
else :
if isinstance ( uri , str ) :
if uri . startswith ( " db:// " ) :
self . _connection = lancedb . connect (
uri , api_key = self . api_key , region = self . region
)
else :
self . _connection = lancedb . connect ( uri )
warnings . warn (
" api key provided with local uri. \
The data will be stored locally "
)
@property
def embeddings ( self ) - > Optional [ Embeddings ] :
@ -88,7 +118,7 @@ class LanceDB(VectorStore):
Args :
texts : Iterable of strings to add to the vectorstore .
metadatas : Optional list of metadatas associated with the texts .
ids : Optional list of ids to associate w ith the texts .
ids : Optional list of ids to associate w ith the texts .
Returns :
List of ids of the added texts .
@ -99,20 +129,70 @@ class LanceDB(VectorStore):
embeddings = self . _embedding . embed_documents ( list ( texts ) ) # type: ignore
for idx , text in enumerate ( texts ) :
embedding = embeddings [ idx ]
metadata = metadatas [ idx ] if metadatas else { }
metadata = metadatas [ idx ] if metadatas else { " id " : ids [ idx ] }
docs . append (
{
self . _vector_key : embedding ,
self . _id_key : ids [ idx ] ,
self . _text_key : text ,
* * metadata ,
" metadata " : metadata ,
}
)
self . _connection . add ( docs )
if self . _table_name in self . _connection . table_names ( ) :
tbl = self . _connection . open_table ( self . _table_name )
if self . api_key is None :
tbl . add ( docs , mode = self . mode )
else :
tbl . add ( docs )
else :
self . _connection . create_table ( self . _table_name , data = docs )
return ids
def get_table ( self , name : Optional [ str ] = None ) - > Any :
if name is not None :
try :
self . _connection . open_table ( name )
except Exception :
raise ValueError ( f " Table { name } not found in the database " )
else :
return self . _connection . open_table ( self . _table_name )
def create_index (
self ,
col_name : Optional [ str ] = None ,
vector_col : Optional [ str ] = None ,
num_partitions : Optional [ int ] = 256 ,
num_sub_vectors : Optional [ int ] = 96 ,
index_cache_size : Optional [ int ] = None ,
) - > None :
"""
Create a scalar ( for non - vector cols ) or a vector index on a table .
Make sure your vector column has enough data before creating an index on it .
Args :
vector_col : Provide if you want to create index on a vector column .
col_name : Provide if you want to create index on a non - vector column .
metric : Provide the metric to use for vector index . Defaults to ' L2 '
choice of metrics : ' L2 ' , ' dot ' , ' cosine '
Returns :
None
"""
if vector_col :
self . _connection . create_index (
vector_column_name = vector_col ,
num_partitions = num_partitions ,
num_sub_vectors = num_sub_vectors ,
index_cache_size = index_cache_size ,
)
elif col_name :
self . _connection . create_scalar_index ( col_name )
else :
raise ValueError ( " Provide either vector_col or col_name " )
def similarity_search (
self , query : str , k : int = 4 , * * kwargs : Any
self , query : str , k : int = 4 , name : Optional [ str ] = None , * * kwargs : Any
) - > List [ Document ] :
""" Return documents most similar to the query
@ -124,8 +204,9 @@ class LanceDB(VectorStore):
List of documents most similar to the query .
"""
embedding = self . _embedding . embed_query ( query ) # type: ignore
tbl = self . get_table ( name )
docs = (
self . _connection . search ( embedding , vector_column_name = self . _vector_key )
tbl . search ( embedding , vector_column_name = self . _vector_key )
. limit ( k )
. to_arrow ( )
)
@ -155,32 +236,47 @@ class LanceDB(VectorStore):
* * kwargs : Any ,
) - > LanceDB :
instance = LanceDB (
connection ,
embedding ,
vector_key ,
id_key ,
text_key ,
connection = connection ,
embedding = embedding ,
vector_key = vector_key ,
id_key = id_key ,
text_key = text_key ,
)
instance . add_texts ( texts , metadatas = metadatas , * * kwargs )
return instance
def _init_table ( self ) - > Any :
import pyarrow as pa
schema = pa . schema (
[
pa . field (
self . _vector_key ,
pa . list_ (
pa . float32 ( ) ,
len ( self . embeddings . embed_query ( " test " ) ) , # type: ignore
) ,
) ,
pa . field ( self . _id_key , pa . string ( ) ) ,
pa . field ( self . _text_key , pa . string ( ) ) ,
]
)
db = self . lancedb . connect ( " /tmp/lancedb " )
tbl = db . create_table ( self . _table_name , schema = schema , mode = " overwrite " )
return tbl
def delete (
self ,
ids : Optional [ List [ str ] ] = None ,
delete_all : Optional [ bool ] = None ,
filter : Optional [ str ] = None ,
drop_columns : Optional [ List [ str ] ] = None ,
name : Optional [ str ] = None ,
* * kwargs : Any ,
) - > None :
"""
Allows deleting rows by filtering , by ids or drop columns from the table .
Args :
filter : Provide a string SQL expression - " {col} {operation} {value} " .
ids : Provide list of ids to delete from the table .
drop_columns : Provide list of columns to drop from the table .
delete_all : If True , delete all rows from the table .
"""
tbl = self . get_table ( name )
if filter :
tbl . delete ( filter )
elif ids :
tbl . delete ( " id in ( ' {} ' ) " . format ( " , " . join ( ids ) ) )
elif drop_columns :
if self . api_key is not None :
raise NotImplementedError (
" Column operations currently not supported in LanceDB Cloud. "
)
else :
tbl . drop_columns ( drop_columns )
elif delete_all :
tbl . delete ( " true " )
else :
raise ValueError ( " Provide either filter, ids, drop_columns or delete_all " )