You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/vectorstores/epsilla.py

376 lines
14 KiB
Python

"""Wrapper around Epsilla vector database."""
from __future__ import annotations
import logging
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from pyepsilla import vectordb
logger = logging.getLogger()
class Epsilla(VectorStore):
"""
Wrapper around Epsilla vector database.
As a prerequisite, you need to install ``pyepsilla`` package
and have a running Epsilla vector database (for example, through our docker image)
See the following documentation for how to run an Epsilla vector database:
https://epsilla-inc.gitbook.io/epsilladb/quick-start
Args:
client (Any): Epsilla client to connect to.
embeddings (Embeddings): Function used to embed the texts.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
Example:
.. code-block:: python
from langchain_community.vectorstores import Epsilla
from pyepsilla import vectordb
client = vectordb.Client()
embeddings = OpenAIEmbeddings()
db_path = "/tmp/vectorstore"
db_name = "langchain_store"
epsilla = Epsilla(client, embeddings, db_path, db_name)
"""
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
def __init__(
self,
client: Any,
embeddings: Embeddings,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
):
"""Initialize with necessary components."""
try:
import pyepsilla
except ImportError as e:
raise ImportError(
"Could not import pyepsilla python package. "
"Please install pyepsilla package with `pip install pyepsilla`."
) from e
if not isinstance(client, pyepsilla.vectordb.Client):
raise TypeError(
f"client should be an instance of pyepsilla.vectordb.Client, "
f"got {type(client)}"
)
self._client: vectordb.Client = client
self._db_name = db_name
self._embeddings = embeddings
self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
self._client.load_db(db_name=db_name, db_path=db_path)
self._client.use_db(db_name=db_name)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
def use_collection(self, collection_name: str) -> None:
"""
Set default collection to use.
Args:
collection_name (str): The name of the collection.
"""
self._collection_name = collection_name
def clear_data(self, collection_name: str = "") -> None:
"""
Clear data in a collection.
Args:
collection_name (Optional[str]): The name of the collection.
If not provided, the default collection will be used.
"""
if not collection_name:
collection_name = self._collection_name
self._client.drop_table(collection_name)
def get(
self, collection_name: str = "", response_fields: Optional[List[str]] = None
) -> List[dict]:
"""Get the collection.
Args:
collection_name (Optional[str]): The name of the collection
to retrieve data from.
If not provided, the default collection will be used.
response_fields (Optional[List[str]]): List of field names in the result.
If not specified, all available fields will be responded.
Returns:
A list of the retrieved data.
"""
if not collection_name:
collection_name = self._collection_name
status_code, response = self._client.get(
table_name=collection_name, response_fields=response_fields
)
if status_code != 200:
logger.error(f"Failed to get records: {response['message']}")
raise Exception("Error: {}.".format(response["message"]))
return response["result"]
def _create_collection(
self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
if not embeddings:
raise ValueError("Embeddings list is empty.")
dim = len(embeddings[0])
fields: List[dict] = [
{"name": "id", "dataType": "INT"},
{"name": "text", "dataType": "STRING"},
{"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
]
if metadatas is not None:
field_names = [field["name"] for field in fields]
for metadata in metadatas:
for key, value in metadata.items():
if key in field_names:
continue
d_type: str
if isinstance(value, str):
d_type = "STRING"
elif isinstance(value, int):
d_type = "INT"
elif isinstance(value, float):
d_type = "FLOAT"
elif isinstance(value, bool):
d_type = "BOOL"
else:
raise ValueError(f"Unsupported data type for {key}.")
fields.append({"name": key, "dataType": d_type})
field_names.append(key)
status_code, response = self._client.create_table(
table_name, table_fields=fields
)
if status_code != 200:
if status_code == 409:
logger.info(f"Continuing with the existing table {table_name}.")
else:
logger.error(
f"Failed to create collection {table_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
collection_name: Optional[str] = "",
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> List[str]:
"""
Embed texts and add them to the database.
Args:
texts (Iterable[str]): The texts to embed.
metadatas (Optional[List[dict]]): Metadata dicts
attached to each of the texts. Defaults to None.
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
List of ids of the added texts.
"""
if not collection_name:
collection_name = self._collection_name
else:
self._collection_name = collection_name
if drop_old:
self._client.drop_db(db_name=collection_name)
texts = list(texts)
try:
embeddings = self._embeddings.embed_documents(texts)
except NotImplementedError:
embeddings = [self._embeddings.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
self._create_collection(
table_name=collection_name, embeddings=embeddings, metadatas=metadatas
)
ids = [hash(uuid.uuid4()) for _ in texts]
records = []
for index, id in enumerate(ids):
record = {
"id": id,
"text": texts[index],
"embeddings": embeddings[index],
}
if metadatas is not None:
metadata = metadatas[index].items()
for key, value in metadata:
record[key] = value
records.append(record)
status_code, response = self._client.insert(
table_name=collection_name, records=records
)
if status_code != 200:
logger.error(
f"Failed to add records to {collection_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
return [str(id) for id in ids]
def similarity_search(
self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
) -> List[Document]:
"""
Return the documents that are semantically most relevant to the query.
Args:
query (str): String to query the vectorstore with.
k (Optional[int]): Number of documents to return. Defaults to 4.
collection_name (Optional[str]): Collection to use.
Defaults to "langchain_store" or the one provided before.
Returns:
List of documents that are semantically most relevant to the query
"""
if not collection_name:
collection_name = self._collection_name
query_vector = self._embeddings.embed_query(query)
status_code, response = self._client.query(
table_name=collection_name,
query_field="embeddings",
query_vector=query_vector,
limit=k,
)
if status_code != 200:
logger.error(f"Search failed: {response['message']}.")
raise Exception("Error: {}.".format(response["message"]))
exclude_keys = ["id", "text", "embeddings"]
return list(
map(
lambda item: Document(
page_content=item["text"],
metadata={
key: item[key] for key in item if key not in exclude_keys
},
),
response["result"],
)
)
@classmethod
def from_texts(
cls: Type[Epsilla],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""Create an Epsilla vectorstore from raw documents.
Args:
texts (List[str]): List of text data to be inserted.
embeddings (Embeddings): Embedding function.
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
metadatas (Optional[List[dict]]): Metadata for each text.
Defaults to None.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
Epsilla: Epsilla vector store.
"""
instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
instance.add_texts(
texts,
metadatas=metadatas,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)
return instance
@classmethod
def from_documents(
cls: Type[Epsilla],
documents: List[Document],
embedding: Embeddings,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""Create an Epsilla vectorstore from a list of documents.
Args:
texts (List[str]): List of text data to be inserted.
embeddings (Embeddings): Embedding function.
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
metadatas (Optional[List[dict]]): Metadata for each text.
Defaults to None.
db_path (Optional[str]): The path where the database will be persisted.
Defaults to "/tmp/langchain-epsilla".
db_name (Optional[str]): Give a name to the loaded database.
Defaults to "langchain_store".
collection_name (Optional[str]): Which collection to use.
Defaults to "langchain_collection".
If provided, default collection name will be set as well.
drop_old (Optional[bool]): Whether to drop the previous collection
and create a new one. Defaults to False.
Returns:
Epsilla: Epsilla vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts,
embedding,
metadatas=metadatas,
client=client,
db_path=db_path,
db_name=db_name,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)