mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
376 lines
14 KiB
Python
376 lines
14 KiB
Python
|
"""Wrapper around Epsilla vector database."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import logging
|
||
|
import uuid
|
||
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
|
||
|
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.vectorstores import VectorStore
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from pyepsilla import vectordb
|
||
|
|
||
|
logger = logging.getLogger()
|
||
|
|
||
|
|
||
|
class Epsilla(VectorStore):
|
||
|
"""
|
||
|
Wrapper around Epsilla vector database.
|
||
|
|
||
|
As a prerequisite, you need to install ``pyepsilla`` package
|
||
|
and have a running Epsilla vector database (for example, through our docker image)
|
||
|
See the following documentation for how to run an Epsilla vector database:
|
||
|
https://epsilla-inc.gitbook.io/epsilladb/quick-start
|
||
|
|
||
|
Args:
|
||
|
client (Any): Epsilla client to connect to.
|
||
|
embeddings (Embeddings): Function used to embed the texts.
|
||
|
db_path (Optional[str]): The path where the database will be persisted.
|
||
|
Defaults to "/tmp/langchain-epsilla".
|
||
|
db_name (Optional[str]): Give a name to the loaded database.
|
||
|
Defaults to "langchain_store".
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.vectorstores import Epsilla
|
||
|
from pyepsilla import vectordb
|
||
|
|
||
|
client = vectordb.Client()
|
||
|
embeddings = OpenAIEmbeddings()
|
||
|
db_path = "/tmp/vectorstore"
|
||
|
db_name = "langchain_store"
|
||
|
epsilla = Epsilla(client, embeddings, db_path, db_name)
|
||
|
"""
|
||
|
|
||
|
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
|
||
|
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
|
||
|
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
client: Any,
|
||
|
embeddings: Embeddings,
|
||
|
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||
|
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||
|
):
|
||
|
"""Initialize with necessary components."""
|
||
|
try:
|
||
|
import pyepsilla
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Could not import pyepsilla python package. "
|
||
|
"Please install pyepsilla package with `pip install pyepsilla`."
|
||
|
) from e
|
||
|
|
||
|
if not isinstance(client, pyepsilla.vectordb.Client):
|
||
|
raise TypeError(
|
||
|
f"client should be an instance of pyepsilla.vectordb.Client, "
|
||
|
f"got {type(client)}"
|
||
|
)
|
||
|
|
||
|
self._client: vectordb.Client = client
|
||
|
self._db_name = db_name
|
||
|
self._embeddings = embeddings
|
||
|
self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
|
||
|
self._client.load_db(db_name=db_name, db_path=db_path)
|
||
|
self._client.use_db(db_name=db_name)
|
||
|
|
||
|
@property
|
||
|
def embeddings(self) -> Optional[Embeddings]:
|
||
|
return self._embeddings
|
||
|
|
||
|
def use_collection(self, collection_name: str) -> None:
|
||
|
"""
|
||
|
Set default collection to use.
|
||
|
|
||
|
Args:
|
||
|
collection_name (str): The name of the collection.
|
||
|
"""
|
||
|
self._collection_name = collection_name
|
||
|
|
||
|
def clear_data(self, collection_name: str = "") -> None:
|
||
|
"""
|
||
|
Clear data in a collection.
|
||
|
|
||
|
Args:
|
||
|
collection_name (Optional[str]): The name of the collection.
|
||
|
If not provided, the default collection will be used.
|
||
|
"""
|
||
|
if not collection_name:
|
||
|
collection_name = self._collection_name
|
||
|
self._client.drop_table(collection_name)
|
||
|
|
||
|
def get(
|
||
|
self, collection_name: str = "", response_fields: Optional[List[str]] = None
|
||
|
) -> List[dict]:
|
||
|
"""Get the collection.
|
||
|
|
||
|
Args:
|
||
|
collection_name (Optional[str]): The name of the collection
|
||
|
to retrieve data from.
|
||
|
If not provided, the default collection will be used.
|
||
|
response_fields (Optional[List[str]]): List of field names in the result.
|
||
|
If not specified, all available fields will be responded.
|
||
|
|
||
|
Returns:
|
||
|
A list of the retrieved data.
|
||
|
"""
|
||
|
if not collection_name:
|
||
|
collection_name = self._collection_name
|
||
|
status_code, response = self._client.get(
|
||
|
table_name=collection_name, response_fields=response_fields
|
||
|
)
|
||
|
if status_code != 200:
|
||
|
logger.error(f"Failed to get records: {response['message']}")
|
||
|
raise Exception("Error: {}.".format(response["message"]))
|
||
|
return response["result"]
|
||
|
|
||
|
def _create_collection(
|
||
|
self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
|
||
|
) -> None:
|
||
|
if not embeddings:
|
||
|
raise ValueError("Embeddings list is empty.")
|
||
|
|
||
|
dim = len(embeddings[0])
|
||
|
fields: List[dict] = [
|
||
|
{"name": "id", "dataType": "INT"},
|
||
|
{"name": "text", "dataType": "STRING"},
|
||
|
{"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
|
||
|
]
|
||
|
if metadatas is not None:
|
||
|
field_names = [field["name"] for field in fields]
|
||
|
for metadata in metadatas:
|
||
|
for key, value in metadata.items():
|
||
|
if key in field_names:
|
||
|
continue
|
||
|
d_type: str
|
||
|
if isinstance(value, str):
|
||
|
d_type = "STRING"
|
||
|
elif isinstance(value, int):
|
||
|
d_type = "INT"
|
||
|
elif isinstance(value, float):
|
||
|
d_type = "FLOAT"
|
||
|
elif isinstance(value, bool):
|
||
|
d_type = "BOOL"
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported data type for {key}.")
|
||
|
fields.append({"name": key, "dataType": d_type})
|
||
|
field_names.append(key)
|
||
|
|
||
|
status_code, response = self._client.create_table(
|
||
|
table_name, table_fields=fields
|
||
|
)
|
||
|
if status_code != 200:
|
||
|
if status_code == 409:
|
||
|
logger.info(f"Continuing with the existing table {table_name}.")
|
||
|
else:
|
||
|
logger.error(
|
||
|
f"Failed to create collection {table_name}: {response['message']}"
|
||
|
)
|
||
|
raise Exception("Error: {}.".format(response["message"]))
|
||
|
|
||
|
def add_texts(
|
||
|
self,
|
||
|
texts: Iterable[str],
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
collection_name: Optional[str] = "",
|
||
|
drop_old: Optional[bool] = False,
|
||
|
**kwargs: Any,
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Embed texts and add them to the database.
|
||
|
|
||
|
Args:
|
||
|
texts (Iterable[str]): The texts to embed.
|
||
|
metadatas (Optional[List[dict]]): Metadata dicts
|
||
|
attached to each of the texts. Defaults to None.
|
||
|
collection_name (Optional[str]): Which collection to use.
|
||
|
Defaults to "langchain_collection".
|
||
|
If provided, default collection name will be set as well.
|
||
|
drop_old (Optional[bool]): Whether to drop the previous collection
|
||
|
and create a new one. Defaults to False.
|
||
|
|
||
|
Returns:
|
||
|
List of ids of the added texts.
|
||
|
"""
|
||
|
if not collection_name:
|
||
|
collection_name = self._collection_name
|
||
|
else:
|
||
|
self._collection_name = collection_name
|
||
|
|
||
|
if drop_old:
|
||
|
self._client.drop_db(db_name=collection_name)
|
||
|
|
||
|
texts = list(texts)
|
||
|
try:
|
||
|
embeddings = self._embeddings.embed_documents(texts)
|
||
|
except NotImplementedError:
|
||
|
embeddings = [self._embeddings.embed_query(x) for x in texts]
|
||
|
|
||
|
if len(embeddings) == 0:
|
||
|
logger.debug("Nothing to insert, skipping.")
|
||
|
return []
|
||
|
|
||
|
self._create_collection(
|
||
|
table_name=collection_name, embeddings=embeddings, metadatas=metadatas
|
||
|
)
|
||
|
|
||
|
ids = [hash(uuid.uuid4()) for _ in texts]
|
||
|
records = []
|
||
|
for index, id in enumerate(ids):
|
||
|
record = {
|
||
|
"id": id,
|
||
|
"text": texts[index],
|
||
|
"embeddings": embeddings[index],
|
||
|
}
|
||
|
if metadatas is not None:
|
||
|
metadata = metadatas[index].items()
|
||
|
for key, value in metadata:
|
||
|
record[key] = value
|
||
|
records.append(record)
|
||
|
|
||
|
status_code, response = self._client.insert(
|
||
|
table_name=collection_name, records=records
|
||
|
)
|
||
|
if status_code != 200:
|
||
|
logger.error(
|
||
|
f"Failed to add records to {collection_name}: {response['message']}"
|
||
|
)
|
||
|
raise Exception("Error: {}.".format(response["message"]))
|
||
|
return [str(id) for id in ids]
|
||
|
|
||
|
def similarity_search(
|
||
|
self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
|
||
|
) -> List[Document]:
|
||
|
"""
|
||
|
Return the documents that are semantically most relevant to the query.
|
||
|
|
||
|
Args:
|
||
|
query (str): String to query the vectorstore with.
|
||
|
k (Optional[int]): Number of documents to return. Defaults to 4.
|
||
|
collection_name (Optional[str]): Collection to use.
|
||
|
Defaults to "langchain_store" or the one provided before.
|
||
|
Returns:
|
||
|
List of documents that are semantically most relevant to the query
|
||
|
"""
|
||
|
if not collection_name:
|
||
|
collection_name = self._collection_name
|
||
|
query_vector = self._embeddings.embed_query(query)
|
||
|
status_code, response = self._client.query(
|
||
|
table_name=collection_name,
|
||
|
query_field="embeddings",
|
||
|
query_vector=query_vector,
|
||
|
limit=k,
|
||
|
)
|
||
|
if status_code != 200:
|
||
|
logger.error(f"Search failed: {response['message']}.")
|
||
|
raise Exception("Error: {}.".format(response["message"]))
|
||
|
|
||
|
exclude_keys = ["id", "text", "embeddings"]
|
||
|
return list(
|
||
|
map(
|
||
|
lambda item: Document(
|
||
|
page_content=item["text"],
|
||
|
metadata={
|
||
|
key: item[key] for key in item if key not in exclude_keys
|
||
|
},
|
||
|
),
|
||
|
response["result"],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def from_texts(
|
||
|
cls: Type[Epsilla],
|
||
|
texts: List[str],
|
||
|
embedding: Embeddings,
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
client: Any = None,
|
||
|
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||
|
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||
|
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
|
||
|
drop_old: Optional[bool] = False,
|
||
|
**kwargs: Any,
|
||
|
) -> Epsilla:
|
||
|
"""Create an Epsilla vectorstore from raw documents.
|
||
|
|
||
|
Args:
|
||
|
texts (List[str]): List of text data to be inserted.
|
||
|
embeddings (Embeddings): Embedding function.
|
||
|
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
|
||
|
metadatas (Optional[List[dict]]): Metadata for each text.
|
||
|
Defaults to None.
|
||
|
db_path (Optional[str]): The path where the database will be persisted.
|
||
|
Defaults to "/tmp/langchain-epsilla".
|
||
|
db_name (Optional[str]): Give a name to the loaded database.
|
||
|
Defaults to "langchain_store".
|
||
|
collection_name (Optional[str]): Which collection to use.
|
||
|
Defaults to "langchain_collection".
|
||
|
If provided, default collection name will be set as well.
|
||
|
drop_old (Optional[bool]): Whether to drop the previous collection
|
||
|
and create a new one. Defaults to False.
|
||
|
|
||
|
Returns:
|
||
|
Epsilla: Epsilla vector store.
|
||
|
"""
|
||
|
instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
|
||
|
instance.add_texts(
|
||
|
texts,
|
||
|
metadatas=metadatas,
|
||
|
collection_name=collection_name,
|
||
|
drop_old=drop_old,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
return instance
|
||
|
|
||
|
@classmethod
|
||
|
def from_documents(
|
||
|
cls: Type[Epsilla],
|
||
|
documents: List[Document],
|
||
|
embedding: Embeddings,
|
||
|
client: Any = None,
|
||
|
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
|
||
|
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
|
||
|
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
|
||
|
drop_old: Optional[bool] = False,
|
||
|
**kwargs: Any,
|
||
|
) -> Epsilla:
|
||
|
"""Create an Epsilla vectorstore from a list of documents.
|
||
|
|
||
|
Args:
|
||
|
texts (List[str]): List of text data to be inserted.
|
||
|
embeddings (Embeddings): Embedding function.
|
||
|
client (pyepsilla.vectordb.Client): Epsilla client to connect to.
|
||
|
metadatas (Optional[List[dict]]): Metadata for each text.
|
||
|
Defaults to None.
|
||
|
db_path (Optional[str]): The path where the database will be persisted.
|
||
|
Defaults to "/tmp/langchain-epsilla".
|
||
|
db_name (Optional[str]): Give a name to the loaded database.
|
||
|
Defaults to "langchain_store".
|
||
|
collection_name (Optional[str]): Which collection to use.
|
||
|
Defaults to "langchain_collection".
|
||
|
If provided, default collection name will be set as well.
|
||
|
drop_old (Optional[bool]): Whether to drop the previous collection
|
||
|
and create a new one. Defaults to False.
|
||
|
|
||
|
Returns:
|
||
|
Epsilla: Epsilla vector store.
|
||
|
"""
|
||
|
texts = [doc.page_content for doc in documents]
|
||
|
metadatas = [doc.metadata for doc in documents]
|
||
|
|
||
|
return cls.from_texts(
|
||
|
texts,
|
||
|
embedding,
|
||
|
metadatas=metadatas,
|
||
|
client=client,
|
||
|
db_path=db_path,
|
||
|
db_name=db_name,
|
||
|
collection_name=collection_name,
|
||
|
drop_old=drop_old,
|
||
|
**kwargs,
|
||
|
)
|