LocalDocs version 2 with text embeddings.

localdocs_fix
Adam Treat 7 months ago committed by AT
parent d4ce9f4a7c
commit 371e2a5cbc

@ -490,6 +490,9 @@ struct bert_ctx * bert_load_from_file(const char *fname)
#endif
bert_ctx * new_bert = new bert_ctx;
new_bert->buf_compute.force_cpu = true;
new_bert->work_buf.force_cpu = true;
bert_model & model = new_bert->model;
bert_vocab & vocab = new_bert->vocab;

@ -10,13 +10,14 @@ struct llm_buffer {
uint8_t * addr = NULL;
size_t size = 0;
ggml_vk_memory memory;
bool force_cpu = false;
llm_buffer() = default;
void resize(size_t size) {
free();
if (!ggml_vk_has_device()) {
if (!ggml_vk_has_device() || force_cpu) {
this->addr = new uint8_t[size];
this->size = size;
} else {

@ -75,7 +75,9 @@ qt_add_executable(chat
chatmodel.h chatlistmodel.h chatlistmodel.cpp
chatgpt.h chatgpt.cpp
database.h database.cpp
embeddings.h embeddings.cpp
download.h download.cpp
embllm.cpp embllm.h
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp
llm.h llm.cpp
modellist.h modellist.cpp
@ -90,6 +92,7 @@ qt_add_executable(chat
qt_add_qml_module(chat
URI gpt4all
VERSION 1.0
NO_CACHEGEN
QML_FILES
main.qml
qml/ChatDrawer.qml
@ -170,7 +173,7 @@ else()
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
endif()
target_link_libraries(chat
PRIVATE llmodel)
PRIVATE llmodel bert-default)
set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)

@ -18,6 +18,7 @@ Chat::Chat(QObject *parent)
, m_shouldDeleteLater(false)
, m_isModelLoaded(false)
, m_shouldLoadModelWhenInstalled(false)
, m_collectionModel(new LocalDocsCollectionsModel(this))
{
connectLLM();
}
@ -35,6 +36,7 @@ Chat::Chat(bool isServer, QObject *parent)
, m_shouldDeleteLater(false)
, m_isModelLoaded(false)
, m_shouldLoadModelWhenInstalled(false)
, m_collectionModel(new LocalDocsCollectionsModel(this))
{
connectLLM();
}
@ -71,6 +73,7 @@ void Chat::connectLLM()
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged,
this, &Chat::handleModelInstalled, Qt::QueuedConnection);
}

@ -27,6 +27,7 @@ class Chat : public QObject
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
@ -83,6 +84,7 @@ public:
bool isServer() const { return m_isServer; }
QList<QString> collectionList() const;
LocalDocsCollectionsModel *collectionModel() const { return m_collectionModel; }
Q_INVOKABLE bool hasCollection(const QString &collection) const;
Q_INVOKABLE void addCollection(const QString &collection);
@ -123,6 +125,7 @@ Q_SIGNALS:
void tokenSpeedChanged();
void deviceChanged();
void fallbackReasonChanged();
void collectionModelChanged();
private Q_SLOTS:
void handleResponseChanged(const QString &response);
@ -161,6 +164,7 @@ private:
bool m_shouldDeleteLater;
bool m_isModelLoaded;
bool m_shouldLoadModelWhenInstalled;
LocalDocsCollectionsModel *m_collectionModel;
};
#endif // CHAT_H

@ -1,5 +1,7 @@
#include "database.h"
#include "mysettings.h"
#include "embllm.h"
#include "embeddings.h"
#include <QTimer>
#include <QPdfDocument>
@ -7,18 +9,18 @@
//#define DEBUG
//#define DEBUG_EXAMPLE
#define LOCALDOCS_VERSION 0
#define LOCALDOCS_VERSION 1
const auto INSERT_CHUNK_SQL = QLatin1String(R"(
insert into chunks(document_id, chunk_id, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to,
embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
insert into chunks(document_id, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to)
values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
)");
const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"(
insert into chunks_fts(document_id, chunk_id, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to,
embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
file, title, author, subject, keywords, page, line_from, line_to)
values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
)");
const auto DELETE_CHUNKS_SQL = QLatin1String(R"(
@ -30,20 +32,33 @@ const auto DELETE_CHUNKS_FTS_SQL = QLatin1String(R"(
)");
const auto CHUNKS_SQL = QLatin1String(R"(
create table chunks(document_id integer, chunk_id integer, chunk_text varchar,
create table chunks(document_id integer, chunk_id integer primary key autoincrement, chunk_text varchar,
file varchar, title varchar, author varchar, subject varchar, keywords varchar,
page integer, line_from integer, line_to integer,
embedding_id integer, embedding_path varchar);
page integer, line_from integer, line_to integer);
)");
const auto FTS_CHUNKS_SQL = QLatin1String(R"(
create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to,
embedding_id unindexed, embedding_path unindexed, tokenize="trigram");
file, title, author, subject, keywords, page, line_from, line_to, tokenize="trigram");
)");
const auto SELECT_SQL = QLatin1String(R"(
select chunks_fts.rowid, documents.document_time,
const auto SELECT_CHUNKS_BY_DOCUMENT_SQL = QLatin1String(R"(
select chunk_id from chunks WHERE document_id = ?;
)");
const auto SELECT_CHUNKS_SQL = QLatin1String(R"(
select chunks.chunk_id, documents.document_time,
chunks.chunk_text, chunks.file, chunks.title, chunks.author, chunks.page,
chunks.line_from, chunks.line_to
from chunks
join documents ON chunks.document_id = documents.id
join folders ON documents.folder_id = folders.id
join collections ON folders.id = collections.folder_id
where chunks.chunk_id in (%1) and collections.collection_name in (%2);
)");
const auto SELECT_NGRAM_SQL = QLatin1String(R"(
select chunks_fts.chunk_id, documents.document_time,
chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page,
chunks_fts.line_from, chunks_fts.line_to
from chunks_fts
@ -55,16 +70,14 @@ const auto SELECT_SQL = QLatin1String(R"(
limit %2;
)");
bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text,
bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text,
const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords,
int page, int from, int to,
int embedding_id, const QString &embedding_path)
int page, int from, int to, int *chunk_id)
{
{
if (!q.prepare(INSERT_CHUNK_SQL))
return false;
q.addBindValue(document_id);
q.addBindValue(chunk_id);
q.addBindValue(chunk_text);
q.addBindValue(file);
q.addBindValue(title);
@ -74,16 +87,19 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_
q.addBindValue(page);
q.addBindValue(from);
q.addBindValue(to);
q.addBindValue(embedding_id);
q.addBindValue(embedding_path);
if (!q.exec())
return false;
}
if (!q.exec("select last_insert_rowid();"))
return false;
if (!q.next())
return false;
*chunk_id = q.value(0).toInt();
{
if (!q.prepare(INSERT_CHUNK_FTS_SQL))
return false;
q.addBindValue(document_id);
q.addBindValue(chunk_id);
q.addBindValue(*chunk_id);
q.addBindValue(chunk_text);
q.addBindValue(file);
q.addBindValue(title);
@ -93,8 +109,6 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_
q.addBindValue(page);
q.addBindValue(from);
q.addBindValue(to);
q.addBindValue(embedding_id);
q.addBindValue(embedding_path);
if (!q.exec())
return false;
}
@ -146,6 +160,18 @@ QStringList generateGrams(const QString &input, int N)
return ngrams;
}
bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const std::vector<qint64> &chunk_ids, int retrievalSize)
{
QString chunk_ids_str = QString::number(chunk_ids[0]);
for (size_t i = 1; i < chunk_ids.size(); ++i)
chunk_ids_str += "," + QString::number(chunk_ids[i]);
const QString collection_names_str = collection_names.join("', '");
const QString formatted_query = SELECT_CHUNKS_SQL.arg(chunk_ids_str).arg("'" + collection_names_str + "'");
if (!q.prepare(formatted_query))
return false;
return q.exec();
}
bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const QString &chunk_text, int retrievalSize)
{
static QRegularExpression spaces("\\s+");
@ -155,7 +181,7 @@ bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const QSt
QList<QString> text = generateGrams(chunk_text, N);
QString orText = text.join(" OR ");
const QString collection_names_str = collection_names.join("', '");
const QString formatted_query = SELECT_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize));
const QString formatted_query = SELECT_NGRAM_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize));
if (!q.prepare(formatted_query))
return false;
q.addBindValue(orText);
@ -248,7 +274,8 @@ bool selectAllFromCollections(QSqlQuery &q, QList<CollectionItem> *collections)
CollectionItem i;
i.collection = q.value(0).toString();
i.folder_path = q.value(1).toString();
i.folder_id = q.value(0).toInt();
i.folder_id = q.value(2).toInt();
i.indexing = false;
i.installed = true;
collections->append(i);
}
@ -459,6 +486,12 @@ QSqlError initDb()
return q.lastError();
}
CollectionItem i;
i.collection = collection_name;
i.folder_path = folder_path;
i.folder_id = folder_id;
emit addCollectionItem(i);
// Add a document
int document_time = 123456789;
int document_id;
@ -504,6 +537,8 @@ Database::Database(int chunkSize)
: QObject(nullptr)
, m_watcher(new QFileSystemWatcher(this))
, m_chunkSize(chunkSize)
, m_embLLM(new EmbeddingLLM)
, m_embeddings(new Embeddings(this))
{
moveToThread(&m_dbThread);
connect(&m_dbThread, &QThread::started, this, &Database::start);
@ -511,22 +546,39 @@ Database::Database(int chunkSize)
m_dbThread.start();
}
void Database::handleDocumentErrorAndScheduleNext(const QString &errorMessage,
int document_id, const QString &document_path, const QSqlError &error)
Database::~Database()
{
qWarning() << errorMessage << document_id << document_path << error.text();
m_dbThread.quit();
m_dbThread.wait();
}
void Database::scheduleNext(int folder_id, size_t countForFolder)
{
emit updateCurrentDocsToIndex(folder_id, countForFolder);
if (!countForFolder) {
emit updateIndexing(folder_id, false);
emit updateInstalled(folder_id, true);
m_embeddings->save();
}
if (!m_docsToScan.isEmpty())
QTimer::singleShot(0, this, &Database::scanQueue);
}
void Database::chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page)
void Database::handleDocumentError(const QString &errorMessage,
int document_id, const QString &document_path, const QSqlError &error)
{
qWarning() << errorMessage << document_id << document_path << error.text();
}
size_t Database::chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page,
int maxChunks)
{
int chunk_id = 0;
int charCount = 0;
int line_from = -1;
int line_to = -1;
QList<QString> words;
int chunks = 0;
while (!stream.atEnd()) {
QString word;
@ -536,9 +588,9 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString &
if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) {
const QString chunk = words.join(" ");
QSqlQuery q;
int chunk_id = 0;
if (!addChunk(q,
document_id,
++chunk_id,
chunk,
file,
title,
@ -548,15 +600,111 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString &
page,
line_from,
line_to,
0 /*embedding_id*/,
QString() /*embedding_path*/
&chunk_id
)) {
qWarning() << "ERROR: Could not insert chunk into db" << q.lastError();
}
const std::vector<float> result = m_embLLM->generateEmbeddings(chunk);
if (!m_embeddings->add(result, chunk_id))
qWarning() << "ERROR: Cannot add point to embeddings index";
++chunks;
words.clear();
charCount = 0;
if (maxChunks > 0 && chunks == maxChunks)
return stream.pos();
}
}
return stream.pos();
}
void Database::removeEmbeddingsByDocumentId(int document_id)
{
QSqlQuery q;
if (!q.prepare(SELECT_CHUNKS_BY_DOCUMENT_SQL)) {
qWarning() << "ERROR: Cannot prepare sql for select chunks by document" << q.lastError();
return;
}
q.addBindValue(document_id);
if (!q.exec()) {
qWarning() << "ERROR: Cannot exec sql for select chunks by document" << q.lastError();
return;
}
while (q.next()) {
const int chunk_id = q.value(0).toInt();
m_embeddings->remove(chunk_id);
}
m_embeddings->save();
}
size_t Database::countOfDocuments(int folder_id) const
{
if (!m_docsToScan.contains(folder_id))
return 0;
return m_docsToScan.value(folder_id).size();
}
size_t Database::countOfBytes(int folder_id) const
{
if (!m_docsToScan.contains(folder_id))
return 0;
size_t totalBytes = 0;
const QQueue<DocumentInfo> &docs = m_docsToScan.value(folder_id);
for (const DocumentInfo &f : docs)
totalBytes += f.doc.size();
return totalBytes;
}
DocumentInfo Database::dequeueDocument()
{
Q_ASSERT(!m_docsToScan.isEmpty());
const int firstKey = m_docsToScan.firstKey();
QQueue<DocumentInfo> &queue = m_docsToScan[firstKey];
Q_ASSERT(!queue.isEmpty());
DocumentInfo result = queue.dequeue();
if (queue.isEmpty())
m_docsToScan.remove(firstKey);
return result;
}
void Database::removeFolderFromDocumentQueue(int folder_id)
{
if (!m_docsToScan.contains(folder_id))
return;
m_docsToScan.remove(folder_id);
emit removeFolderById(folder_id);
emit docsToScanChanged();
}
void Database::enqueueDocumentInternal(const DocumentInfo &info, bool prepend)
{
const int key = info.folder;
if (!m_docsToScan.contains(key))
m_docsToScan[key] = QQueue<DocumentInfo>();
if (prepend)
m_docsToScan[key].prepend(info);
else
m_docsToScan[key].enqueue(info);
}
void Database::enqueueDocuments(int folder_id, const QVector<DocumentInfo> &infos)
{
for (int i = 0; i < infos.size(); ++i)
enqueueDocumentInternal(infos[i]);
const size_t count = countOfDocuments(folder_id);
emit updateCurrentDocsToIndex(folder_id, count);
emit updateTotalDocsToIndex(folder_id, count);
const size_t bytes = countOfBytes(folder_id);
emit updateCurrentBytesToIndex(folder_id, bytes);
emit updateTotalBytesToIndex(folder_id, bytes);
emit docsToScanChanged();
}
void Database::scanQueue()
@ -564,7 +712,9 @@ void Database::scanQueue()
if (m_docsToScan.isEmpty())
return;
DocumentInfo info = m_docsToScan.dequeue();
DocumentInfo info = dequeueDocument();
const size_t countForFolder = countOfDocuments(info.folder);
const int folder_id = info.folder;
// Update info
info.doc.stat();
@ -572,99 +722,127 @@ void Database::scanQueue()
// If the doc has since been deleted or no longer readable, then we schedule more work and return
// leaving the cleanup for the cleanup handler
if (!info.doc.exists() || !info.doc.isReadable()) {
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue);
return;
return scheduleNext(folder_id, countForFolder);
}
const int folder_id = info.folder;
const qint64 document_time = info.doc.fileTime(QFile::FileModificationTime).toMSecsSinceEpoch();
const QString document_path = info.doc.canonicalFilePath();
#if defined(DEBUG)
qDebug() << "scanning document" << document_path;
#endif
const bool currentlyProcessing = info.currentlyProcessing;
// Check and see if we already have this document
QSqlQuery q;
int existing_id = -1;
qint64 existing_time = -1;
if (!selectDocument(q, document_path, &existing_id, &existing_time)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot select document",
handleDocumentError("ERROR: Cannot select document",
existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
// If we have the document, we need to compare the last modification time and if it is newer
// we must rescan the document, otherwise return
if (existing_id != -1) {
if (existing_id != -1 && !currentlyProcessing) {
Q_ASSERT(existing_time != -1);
if (document_time == existing_time) {
// No need to rescan, but we do have to schedule next
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue);
return;
return scheduleNext(folder_id, countForFolder);
} else {
removeEmbeddingsByDocumentId(existing_id);
if (!removeChunksByDocumentId(q, existing_id)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot remove chunks of document",
handleDocumentError("ERROR: Cannot remove chunks of document",
existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
}
}
// Update the document_time for an existing document, or add it for the first time now
int document_id = existing_id;
if (document_id != -1) {
if (!updateDocument(q, document_id, document_time)) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not update document_time",
document_id, document_path, q.lastError());
}
} else {
if (!addDocument(q, folder_id, document_time, document_path, &document_id)) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not add document",
document_id, document_path, q.lastError());
if (!currentlyProcessing) {
if (document_id != -1) {
if (!updateDocument(q, document_id, document_time)) {
handleDocumentError("ERROR: Could not update document_time",
document_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
} else {
if (!addDocument(q, folder_id, document_time, document_path, &document_id)) {
handleDocumentError("ERROR: Could not add document",
document_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
}
}
QElapsedTimer timer;
timer.start();
QSqlDatabase::database().transaction();
Q_ASSERT(document_id != -1);
if (info.doc.suffix() == QLatin1String("pdf")) {
if (info.isPdf()) {
QPdfDocument doc;
if (QPdfDocument::Error::None != doc.load(info.doc.canonicalFilePath())) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not load pdf",
handleDocumentError("ERROR: Could not load pdf",
document_id, document_path, q.lastError());
return;
return scheduleNext(folder_id, countForFolder);
}
for (int i = 0; i < doc.pageCount(); ++i) {
const QPdfSelection selection = doc.getAllText(i);
QString text = selection.text();
QTextStream stream(&text);
chunkStream(stream, document_id, info.doc.fileName(),
doc.metaData(QPdfDocument::MetaDataField::Title).toString(),
doc.metaData(QPdfDocument::MetaDataField::Author).toString(),
doc.metaData(QPdfDocument::MetaDataField::Subject).toString(),
doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(),
i + 1
);
const size_t bytes = info.doc.size();
const size_t bytesPerPage = std::floor(bytes / doc.pageCount());
const int pageIndex = info.currentPage;
#if defined(DEBUG)
qDebug() << "scanning page" << pageIndex << "of" << doc.pageCount() << document_path;
#endif
const QPdfSelection selection = doc.getAllText(pageIndex);
QString text = selection.text();
QTextStream stream(&text);
chunkStream(stream, document_id, info.doc.fileName(),
doc.metaData(QPdfDocument::MetaDataField::Title).toString(),
doc.metaData(QPdfDocument::MetaDataField::Author).toString(),
doc.metaData(QPdfDocument::MetaDataField::Subject).toString(),
doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(),
pageIndex + 1
);
m_embeddings->save();
emit subtractCurrentBytesToIndex(info.folder, bytesPerPage);
if (info.currentPage < doc.pageCount()) {
info.currentPage += 1;
info.currentlyProcessing = true;
enqueueDocumentInternal(info, true /*prepend*/);
return scheduleNext(folder_id, countForFolder + 1);
} else {
emit subtractCurrentBytesToIndex(info.folder, bytes - (bytesPerPage * doc.pageCount()));
}
} else {
QFile file(document_path);
if (!file.open( QIODevice::ReadOnly)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot open file for scanning",
existing_id, document_path, q.lastError());
if (!file.open(QIODevice::ReadOnly)) {
handleDocumentError("ERROR: Cannot open file for scanning",
existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
const size_t bytes = info.doc.size();
QTextStream stream(&file);
chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/,
QString() /*subject*/, QString() /*keywords*/, -1 /*page*/);
const size_t byteIndex = info.currentPosition;
if (!stream.seek(byteIndex)) {
handleDocumentError("ERROR: Cannot seek to pos for scanning",
existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
#if defined(DEBUG)
qDebug() << "scanning byteIndex" << byteIndex << "of" << bytes << document_path;
#endif
int pos = chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/,
QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 5 /*maxChunks*/);
m_embeddings->save();
file.close();
const size_t bytesChunked = pos - byteIndex;
emit subtractCurrentBytesToIndex(info.folder, bytesChunked);
if (info.currentPosition < bytes) {
info.currentPosition = pos;
info.currentlyProcessing = true;
enqueueDocumentInternal(info, true /*prepend*/);
return scheduleNext(folder_id, countForFolder + 1);
}
}
QSqlDatabase::database().commit();
#if defined(DEBUG)
qDebug() << "chunking" << document_path << "took" << timer.elapsed() << "ms";
#endif
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue);
return scheduleNext(folder_id, countForFolder);
}
void Database::scanDocuments(int folder_id, const QString &folder_path)
@ -687,6 +865,7 @@ void Database::scanDocuments(int folder_id, const QString &folder_path)
Q_ASSERT(dir.exists());
Q_ASSERT(dir.isReadable());
QDirIterator it(folder_path, QDir::Readable | QDir::Files, QDirIterator::Subdirectories);
QVector<DocumentInfo> infos;
while (it.hasNext()) {
it.next();
QFileInfo fileInfo = it.fileInfo();
@ -701,9 +880,13 @@ void Database::scanDocuments(int folder_id, const QString &folder_path)
DocumentInfo info;
info.folder = folder_id;
info.doc = fileInfo;
m_docsToScan.enqueue(info);
infos.append(info);
}
if (!infos.isEmpty()) {
emit updateIndexing(folder_id, true);
enqueueDocuments(folder_id, infos);
}
emit docsToScanChanged();
}
void Database::start()
@ -717,6 +900,10 @@ void Database::start()
if (err.type() != QSqlError::NoError)
qWarning() << "ERROR: initializing db" << err.text();
}
if (m_embeddings->fileExists() && !m_embeddings->load())
qWarning() << "ERROR: Could not load embeddings";
addCurrentFolders();
}
@ -733,25 +920,12 @@ void Database::addCurrentFolders()
return;
}
emit collectionListUpdated(collections);
for (const auto &i : collections)
addFolder(i.collection, i.folder_path);
}
void Database::updateCollectionList()
{
#if defined(DEBUG)
qDebug() << "updateCollectionList";
#endif
QSqlQuery q;
QList<CollectionItem> collections;
if (!selectAllFromCollections(q, &collections)) {
qWarning() << "ERROR: Cannot select collections" << q.lastError();
return;
}
emit collectionListUpdated(collections);
}
void Database::addFolder(const QString &collection, const QString &path)
{
QFileInfo info(path);
@ -784,14 +958,21 @@ void Database::addFolder(const QString &collection, const QString &path)
return;
}
if (!folders.contains(folder_id) && !addCollection(q, collection, folder_id)) {
qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError();
return;
if (!folders.contains(folder_id)) {
if (!addCollection(q, collection, folder_id)) {
qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError();
return;
}
CollectionItem i;
i.collection = collection;
i.folder_path = path;
i.folder_id = folder_id;
emit addCollectionItem(i);
}
addFolderToWatch(path);
scanDocuments(folder_id, path);
updateCollectionList();
}
void Database::removeFolder(const QString &collection, const QString &path)
@ -840,15 +1021,8 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
if (collections.count() > 1)
return;
// First remove all upcoming jobs associated with this folder by performing an opt-in filter
QQueue<DocumentInfo> docsToScan;
for (const DocumentInfo &info : m_docsToScan) {
if (info.folder == folder_id)
continue;
docsToScan.append(info);
}
m_docsToScan = docsToScan;
emit docsToScanChanged();
// First remove all upcoming jobs associated with this folder
removeFolderFromDocumentQueue(folder_id);
// Get a list of all documents associated with folder
QList<int> documentIds;
@ -859,6 +1033,7 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
// Remove all chunks and documents associated with this folder
for (int document_id : documentIds) {
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(q, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << q.lastError();
return;
@ -875,8 +1050,9 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
return;
}
emit removeFolderById(folder_id);
removeFolderFromWatch(path);
updateCollectionList();
}
bool Database::addFolderToWatch(const QString &path)
@ -903,9 +1079,18 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
#endif
QSqlQuery q;
if (!selectChunk(q, collections, text, retrievalSize)) {
qDebug() << "ERROR: selecting chunks:" << q.lastError().text();
return;
if (m_embeddings->isLoaded()) {
std::vector<float> result = m_embLLM->generateEmbeddings(text);
std::vector<qint64> embeddings = m_embeddings->search(result, retrievalSize);
if (!selectChunk(q, collections, embeddings, retrievalSize)) {
qDebug() << "ERROR: selecting chunks:" << q.lastError().text();
return;
}
} else {
if (!selectChunk(q, collections, text, retrievalSize)) {
qDebug() << "ERROR: selecting chunks:" << q.lastError().text();
return;
}
}
while (q.next()) {
@ -986,6 +1171,7 @@ void Database::cleanDB()
// Remove all chunks and documents that either don't exist or have become unreadable
QSqlQuery query;
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(query, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError();
}
@ -994,7 +1180,6 @@ void Database::cleanDB()
qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError();
}
}
updateCollectionList();
}
void Database::changeChunkSize(int chunkSize)
@ -1024,6 +1209,7 @@ void Database::changeChunkSize(int chunkSize)
int document_id = q.value(0).toInt();
// Remove all chunks and documents to change the chunk size
QSqlQuery query;
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(query, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError();
}

@ -8,10 +8,18 @@
#include <QThread>
#include <QFileSystemWatcher>
class Embeddings;
class EmbeddingLLM;
struct DocumentInfo
{
int folder;
QFileInfo doc;
int currentPage = 0;
size_t currentPosition = 0;
bool currentlyProcessing = false;
bool isPdf() const {
return doc.suffix() == QLatin1String("pdf");
}
};
struct ResultInfo {
@ -30,6 +38,11 @@ struct CollectionItem {
QString folder_path;
int folder_id = -1;
bool installed = false;
bool indexing = false;
int currentDocsToIndex = 0;
int totalDocsToIndex = 0;
size_t currentBytesToIndex = 0;
size_t totalBytesToIndex = 0;
};
Q_DECLARE_METATYPE(CollectionItem)
@ -38,6 +51,7 @@ class Database : public QObject
Q_OBJECT
public:
Database(int chunkSize);
virtual ~Database();
public Q_SLOTS:
void scanQueue();
@ -50,6 +64,16 @@ public Q_SLOTS:
Q_SIGNALS:
void docsToScanChanged();
void updateInstalled(int folder_id, bool b);
void updateIndexing(int folder_id, bool b);
void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex);
void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex);
void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes);
void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex);
void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex);
void addCollectionItem(const CollectionItem &item);
void removeFolderById(int folder_id);
void removeCollectionItem(const QString &collectionName);
void collectionListUpdated(const QList<CollectionItem> &collectionList);
private Q_SLOTS:
@ -58,21 +82,31 @@ private Q_SLOTS:
bool addFolderToWatch(const QString &path);
bool removeFolderFromWatch(const QString &path);
void addCurrentFolders();
void updateCollectionList();
private:
void removeFolderInternal(const QString &collection, int folder_id, const QString &path);
void chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page);
void handleDocumentErrorAndScheduleNext(const QString &errorMessage,
size_t chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page,
int maxChunks = -1);
void removeEmbeddingsByDocumentId(int document_id);
void scheduleNext(int folder_id, size_t countForFolder);
void handleDocumentError(const QString &errorMessage,
int document_id, const QString &document_path, const QSqlError &error);
size_t countOfDocuments(int folder_id) const;
size_t countOfBytes(int folder_id) const;
DocumentInfo dequeueDocument();
void removeFolderFromDocumentQueue(int folder_id);
void enqueueDocumentInternal(const DocumentInfo &info, bool prepend = false);
void enqueueDocuments(int folder_id, const QVector<DocumentInfo> &infos);
private:
int m_chunkSize;
QQueue<DocumentInfo> m_docsToScan;
QMap<int, QQueue<DocumentInfo>> m_docsToScan;
QList<ResultInfo> m_retrieve;
QThread m_dbThread;
QFileSystemWatcher *m_watcher;
EmbeddingLLM *m_embLLM;
Embeddings *m_embeddings;
};
#endif // DATABASE_H

@ -0,0 +1,190 @@
#include "embeddings.h"
#include <QFile>
#include <QFileInfo>
#include <QDebug>
#include "mysettings.h"
#include "hnswlib/hnswlib.h"
#define EMBEDDINGS_VERSION 0
const int s_dim = 384; // Dimension of the elements
const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff
const int s_M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
Embeddings::Embeddings(QObject *parent)
: QObject(parent)
, m_space(nullptr)
, m_hnsw(nullptr)
{
m_filePath = MySettings::globalInstance()->modelPath()
+ QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION);
}
Embeddings::~Embeddings()
{
delete m_hnsw;
m_hnsw = nullptr;
delete m_space;
m_space = nullptr;
}
bool Embeddings::load()
{
QFileInfo info(m_filePath);
if (!info.exists()) {
qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath;
return false;
}
if (!info.isReadable()) {
qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath;
return false;
}
if (!info.isWritable()) {
qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath;
return false;
}
try {
m_space = new hnswlib::InnerProductSpace(s_dim);
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, m_filePath.toStdString(), s_M, s_ef_construction);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not load hnswlib index:" << e.what();
return false;
}
return isLoaded();
}
bool Embeddings::load(qint64 maxElements)
{
try {
m_space = new hnswlib::InnerProductSpace(s_dim);
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, maxElements, s_M, s_ef_construction);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not create hnswlib index:" << e.what();
return false;
}
return isLoaded();
}
bool Embeddings::save()
{
if (!isLoaded())
return false;
try {
m_hnsw->saveIndex(m_filePath.toStdString());
} catch (const std::exception &e) {
qWarning() << "ERROR: could not save hnswlib index:" << e.what();
return false;
}
return true;
}
bool Embeddings::isLoaded() const
{
return m_hnsw != nullptr;
}
bool Embeddings::fileExists() const
{
QFileInfo info(m_filePath);
return info.exists();
}
bool Embeddings::resize(qint64 size)
{
if (!isLoaded()) {
qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!";
return false;
}
Q_ASSERT(m_hnsw);
try {
m_hnsw->resizeIndex(size);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not resize hnswlib index:" << e.what();
return false;
}
return true;
}
bool Embeddings::add(const std::vector<float> &embedding, qint64 label)
{
if (!isLoaded()) {
bool success = load(500);
if (!success) {
qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!";
return false;
}
}
Q_ASSERT(m_hnsw);
if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) {
if (!resize(m_hnsw->max_elements_ + 500)) {
return false;
}
}
try {
m_hnsw->addPoint(embedding.data(), label, false);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what();
return false;
}
return true;
}
void Embeddings::remove(qint64 label)
{
if (!isLoaded()) {
qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!";
return;
}
Q_ASSERT(m_hnsw);
try {
m_hnsw->markDelete(label);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what();
}
}
void Embeddings::clear()
{
delete m_hnsw;
m_hnsw = nullptr;
delete m_space;
m_space = nullptr;
}
std::vector<qint64> Embeddings::search(const std::vector<float> &embedding, int K)
{
if (!isLoaded())
return std::vector<qint64>();
Q_ASSERT(m_hnsw);
std::priority_queue<std::pair<float, hnswlib::labeltype>> result;
try {
result = m_hnsw->searchKnn(embedding.data(), K);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not search hnswlib index:" << e.what();
return std::vector<qint64>();
}
std::vector<qint64> neighbors;
neighbors.reserve(K);
while(!result.empty()) {
neighbors.push_back(result.top().second);
result.pop();
}
// Reverse the neighbors, as the top of the priority queue is the farthest neighbor.
std::reverse(neighbors.begin(), neighbors.end());
return neighbors;
}

@ -0,0 +1,45 @@
#ifndef EMBEDDINGS_H
#define EMBEDDINGS_H
#include <QObject>
namespace hnswlib {
template <typename T>
class HierarchicalNSW;
class InnerProductSpace;
}
class Embeddings : public QObject
{
Q_OBJECT
public:
Embeddings(QObject *parent);
virtual ~Embeddings();
bool load();
bool load(qint64 maxElements);
bool save();
bool isLoaded() const;
bool fileExists() const;
bool resize(qint64 size);
// Adds the embedding and returns the label used
bool add(const std::vector<float> &embedding, qint64 label);
// Removes the embedding at label by marking it as unused
void remove(qint64 label);
// Clears the embeddings
void clear();
// Performs a nearest neighbor search of the embeddings and returns a vector of labels
// for the K nearest neighbors of the given embedding
std::vector<qint64> search(const std::vector<float> &embedding, int K);
private:
QString m_filePath;
hnswlib::InnerProductSpace *m_space;
hnswlib::HierarchicalNSW<float> *m_hnsw;
};
#endif // EMBEDDINGS_H

@ -0,0 +1,64 @@
#include "embllm.h"
#include "modellist.h"
EmbeddingLLM::EmbeddingLLM()
: QObject{nullptr}
, m_model{nullptr}
{
}
EmbeddingLLM::~EmbeddingLLM()
{
delete m_model;
m_model = nullptr;
}
bool EmbeddingLLM::loadModel()
{
const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels();
if (!embeddingModels->count())
return false;
const ModelInfo defaultModel = embeddingModels->defaultModelInfo();
QString filePath = defaultModel.dirpath + defaultModel.filename();
QFileInfo fileInfo(filePath);
if (!fileInfo.exists()) {
qWarning() << "WARNING: Could not load sbert because file does not exist";
m_model = nullptr;
return false;
}
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
bool success = m_model->loadModel(filePath.toStdString());
if (!success) {
qWarning() << "WARNING: Could not load sbert";
delete m_model;
m_model = nullptr;
return false;
}
if (m_model->implementation().modelType()[0] != 'B') {
qWarning() << "WARNING: Model type is not sbert";
delete m_model;
m_model = nullptr;
return false;
}
return true;
}
bool EmbeddingLLM::hasModel() const
{
return m_model;
}
std::vector<float> EmbeddingLLM::generateEmbeddings(const QString &text)
{
if (!hasModel() && !loadModel()) {
qWarning() << "WARNING: Could not load sbert model for embeddings";
return std::vector<float>();
}
Q_ASSERT(hasModel());
return m_model->embedding(text.toStdString());
}

@ -0,0 +1,27 @@
#ifndef EMBLLM_H
#define EMBLLM_H
#include <QObject>
#include <QThread>
#include "../gpt4all-backend/llmodel.h"
class EmbeddingLLM : public QObject
{
Q_OBJECT
public:
EmbeddingLLM();
virtual ~EmbeddingLLM();
bool hasModel() const;
public Q_SLOTS:
std::vector<float> generateEmbeddings(const QString &text);
private:
bool loadModel();
private:
LLModel *m_model = nullptr;
};
#endif // EMBLLM_H

@ -0,0 +1,167 @@
#pragma once
#include <unordered_map>
#include <fstream>
#include <mutex>
#include <algorithm>
#include <assert.h>
namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
char *data_;
size_t maxelements_;
size_t cur_element_count;
size_t size_per_element_;
size_t data_size_;
DISTFUNC <dist_t> fstdistfunc_;
void *dist_func_param_;
std::mutex index_lock;
std::unordered_map<labeltype, size_t > dict_external_to_internal;
BruteforceSearch(SpaceInterface <dist_t> *s)
: data_(nullptr),
maxelements_(0),
cur_element_count(0),
size_per_element_(0),
data_size_(0),
dist_func_param_(nullptr) {
}
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
: data_(nullptr),
maxelements_(0),
cur_element_count(0),
size_per_element_(0),
data_size_(0),
dist_func_param_(nullptr) {
loadIndex(location, s);
}
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
maxelements_ = maxElements;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxElements * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
}
~BruteforceSearch() {
free(data_);
}
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
int idx;
{
std::unique_lock<std::mutex> lock(index_lock);
auto search = dict_external_to_internal.find(label);
if (search != dict_external_to_internal.end()) {
idx = search->second;
} else {
if (cur_element_count >= maxelements_) {
throw std::runtime_error("The number of elements exceeds the specified limit\n");
}
idx = cur_element_count;
dict_external_to_internal[label] = idx;
cur_element_count++;
}
}
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
}
void removePoint(labeltype cur_external) {
size_t cur_c = dict_external_to_internal[cur_external];
dict_external_to_internal.erase(cur_external);
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count-1),
data_size_+sizeof(labeltype));
cur_element_count--;
}
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
if (!topResults.empty()) {
lastdist = topResults.top().first;
}
}
}
return topResults;
}
void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;
writeBinaryPOD(output, maxelements_);
writeBinaryPOD(output, size_per_element_);
writeBinaryPOD(output, cur_element_count);
output.write(data_, maxelements_ * size_per_element_);
output.close();
}
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
std::ifstream input(location, std::ios::binary);
std::streampos position;
readBinaryPOD(input, maxelements_);
readBinaryPOD(input, size_per_element_);
readBinaryPOD(input, cur_element_count);
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
input.read(data_, maxelements_ * size_per_element_);
input.close();
}
};
} // namespace hnswlib

File diff suppressed because it is too large Load Diff

@ -0,0 +1,199 @@
#pragma once
#ifndef NO_MANUAL_VECTORIZATION
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
#define USE_SSE
#ifdef __AVX__
#define USE_AVX
#ifdef __AVX512F__
#define USE_AVX512
#endif
#endif
#endif
#endif
#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64 xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <x86intrin.h>
#include <cpuid.h>
#include <stdint.h>
static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
static uint64_t xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif
#if defined(USE_AVX512)
#include <immintrin.h>
#endif
#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PORTABLE_ALIGN64 __attribute__((aligned(64)))
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif
// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0
static bool AVXCapable() {
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}
static bool AVX512Capable() {
if (!AVXCapable()) return false;
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif
#include <queue>
#include <vector>
#include <iostream>
#include <string.h>
namespace hnswlib {
typedef size_t labeltype;
// This can be extended to store state for filtering (e.g. from a std::set)
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
};
template <typename T>
class pairGreater {
public:
bool operator()(const T& p1, const T& p2) {
return p1.first > p2.first;
}
};
template<typename T>
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
out.write((char *) &podRef, sizeof(T));
}
template<typename T>
static void readBinaryPOD(std::istream &in, T &podRef) {
in.read((char *) &podRef, sizeof(T));
}
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
template<typename MTYPE>
class SpaceInterface {
public:
// virtual void search(void *);
virtual size_t get_data_size() = 0;
virtual DISTFUNC<MTYPE> get_dist_func() = 0;
virtual void *get_dist_func_param() = 0;
virtual ~SpaceInterface() {}
};
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};
template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;
// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}
return result;
}
} // namespace hnswlib
#include "space_l2.h"
#include "space_ip.h"
#include "bruteforce.h"
#include "hnswalg.h"

@ -0,0 +1,375 @@
#pragma once
#include "hnswlib.h"
namespace hnswlib {
static float
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
float res = 0;
for (unsigned i = 0; i < qty; i++) {
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
}
return res;
}
static float
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
}
#if defined(USE_AVX)
// Favor using AVX if available.
static float
InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
size_t qty4 = qty / 4;
const float *pEnd1 = pVect1 + 16 * qty16;
const float *pEnd2 = pVect1 + 4 * qty4;
__m256 sum256 = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256 v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
__m256 v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
}
__m128 v1, v2;
__m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
while (pVect1 < pEnd2) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE)
static float
InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
size_t qty4 = qty / 4;
const float *pEnd1 = pVect1 + 16 * qty16;
const float *pEnd2 = pVect1 + 4 * qty4;
__m128 v1, v2;
__m128 sum_prod = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
while (pVect1 < pEnd2) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_AVX512)
static float
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN64 TmpRes[16];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m512 sum512 = _mm512_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m512 v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
__m512 v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
}
_mm512_store_ps(TmpRes, sum512);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
return sum;
}
static float
InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_AVX)
static float
InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m256 sum256 = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256 v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
__m256 v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
}
_mm256_store_ps(TmpRes, sum256);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
return sum;
}
static float
InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE)
static float
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m128 v1, v2;
__m128 sum_prod = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
static DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
static DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
static DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
static float
InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;
size_t qty_left = qty - qty16;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}
static float
InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;
float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;
float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}
#endif
class InnerProductSpace : public SpaceInterface<float> {
DISTFUNC<float> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProductDistance;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
} else if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#elif defined(USE_AVX)
if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#endif
#if defined(USE_AVX)
if (AVXCapable()) {
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
}
#endif
if (dim % 16 == 0)
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
else if (dim % 4 == 0)
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<float> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~InnerProductSpace() {}
};
} // namespace hnswlib

@ -0,0 +1,324 @@
#pragma once
#include "hnswlib.h"
namespace hnswlib {
static float
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float res = 0;
for (size_t i = 0; i < qty; i++) {
float t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
}
return (res);
}
#if defined(USE_AVX512)
// Favor using AVX512 if available.
static float
L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN64 TmpRes[16];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m512 diff, v1, v2;
__m512 sum = _mm512_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
diff = _mm512_sub_ps(v1, v2);
// sum = _mm512_fmadd_ps(diff, diff, sum);
sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
}
_mm512_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
TmpRes[13] + TmpRes[14] + TmpRes[15];
return (res);
}
#endif
#if defined(USE_AVX)
// Favor using AVX if available.
static float
L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m256 diff, v1, v2;
__m256 sum = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
}
_mm256_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}
#endif
#if defined(USE_SSE)
static float
L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
#endif
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;
size_t qty_left = qty - qty16;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
#if defined(USE_SSE)
static float
L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2;
const float *pEnd1 = pVect1 + (qty4 << 2);
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
static float
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;
float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;
float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
class L2Space : public SpaceInterface<float> {
DISTFUNC<float> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
else if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#elif defined(USE_AVX)
if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<float> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~L2Space() {}
};
static int
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
int res = 0;
unsigned char *a = (unsigned char *) pVect1;
unsigned char *b = (unsigned char *) pVect2;
qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
}
return (res);
}
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
size_t qty = *((size_t*)qty_ptr);
int res = 0;
unsigned char* a = (unsigned char*)pVect1;
unsigned char* b = (unsigned char*)pVect2;
for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
}
return (res);
}
class L2SpaceI : public SpaceInterface<int> {
DISTFUNC<int> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
L2SpaceI(size_t dim) {
if (dim % 4 == 0) {
fstdistfunc_ = L2SqrI4x;
} else {
fstdistfunc_ = L2SqrI;
}
dim_ = dim;
data_size_ = dim * sizeof(unsigned char);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<int> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~L2SpaceI() {}
};
} // namespace hnswlib

@ -0,0 +1,78 @@
#pragma once
#include <mutex>
#include <string.h>
#include <deque>
namespace hnswlib {
typedef unsigned short int vl_type;
class VisitedList {
public:
vl_type curV;
vl_type *mass;
unsigned int numelements;
VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
}
void reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
curV++;
}
}
~VisitedList() { delete[] mass; }
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////
class VisitedListPool {
std::deque<VisitedList *> pool;
std::mutex poolguard;
int numelements;
public:
VisitedListPool(int initmaxpools, int numelements1) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(new VisitedList(numelements));
}
VisitedList *getFreeVisitedList() {
VisitedList *rez;
{
std::unique_lock <std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
} else {
rez = new VisitedList(numelements);
}
}
rez->reset();
return rez;
}
void releaseVisitedList(VisitedList *vl) {
std::unique_lock <std::mutex> lock(poolguard);
pool.push_front(vl);
}
~VisitedListPool() {
while (pool.size()) {
VisitedList *rez = pool.front();
pool.pop_front();
delete rez;
}
}
};
} // namespace hnswlib

@ -24,24 +24,50 @@ LocalDocs::LocalDocs()
&Database::removeFolder, Qt::QueuedConnection);
connect(this, &LocalDocs::requestChunkSizeChange, m_database,
&Database::changeChunkSize, Qt::QueuedConnection);
// Connections for modifying the model and keeping it updated with the database
connect(m_database, &Database::updateInstalled,
m_localDocsModel, &LocalDocsModel::updateInstalled, Qt::QueuedConnection);
connect(m_database, &Database::updateIndexing,
m_localDocsModel, &LocalDocsModel::updateIndexing, Qt::QueuedConnection);
connect(m_database, &Database::updateCurrentDocsToIndex,
m_localDocsModel, &LocalDocsModel::updateCurrentDocsToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateTotalDocsToIndex,
m_localDocsModel, &LocalDocsModel::updateTotalDocsToIndex, Qt::QueuedConnection);
connect(m_database, &Database::subtractCurrentBytesToIndex,
m_localDocsModel, &LocalDocsModel::subtractCurrentBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateCurrentBytesToIndex,
m_localDocsModel, &LocalDocsModel::updateCurrentBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateTotalBytesToIndex,
m_localDocsModel, &LocalDocsModel::updateTotalBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::addCollectionItem,
m_localDocsModel, &LocalDocsModel::addCollectionItem, Qt::QueuedConnection);
connect(m_database, &Database::removeFolderById,
m_localDocsModel, &LocalDocsModel::removeFolderById, Qt::QueuedConnection);
connect(m_database, &Database::removeCollectionItem,
m_localDocsModel, &LocalDocsModel::removeCollectionItem, Qt::QueuedConnection);
connect(m_database, &Database::collectionListUpdated,
m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection);
m_localDocsModel, &LocalDocsModel::collectionListUpdated, Qt::QueuedConnection);
connect(qApp, &QCoreApplication::aboutToQuit, this, &LocalDocs::aboutToQuit);
}
void LocalDocs::aboutToQuit()
{
delete m_database;
m_database = nullptr;
}
void LocalDocs::addFolder(const QString &collection, const QString &path)
{
const QUrl url(path);
const QString localPath = url.isLocalFile() ? url.toLocalFile() : path;
// Add a placeholder collection that is not installed yet
CollectionItem i;
i.collection = collection;
i.folder_path = localPath;
m_localDocsModel->addCollectionItem(i);
emit requestAddFolder(collection, localPath);
}
void LocalDocs::removeFolder(const QString &collection, const QString &path)
{
m_localDocsModel->removeCollectionPath(collection, path);
emit requestRemoveFolder(collection, path);
}

@ -23,6 +23,7 @@ public:
public Q_SLOTS:
void handleChunkSizeChanged();
void aboutToQuit();
Q_SIGNALS:
void requestAddFolder(const QString &collection, const QString &path);
@ -36,7 +37,6 @@ private:
private:
explicit LocalDocs();
~LocalDocs() {}
friend class MyLocalDocs;
};

@ -1,5 +1,27 @@
#include "localdocsmodel.h"
#include "localdocs.h"
LocalDocsCollectionsModel::LocalDocsCollectionsModel(QObject *parent)
: QSortFilterProxyModel(parent)
{
setSourceModel(LocalDocs::globalInstance()->localDocsModel());
}
bool LocalDocsCollectionsModel::filterAcceptsRow(int sourceRow,
const QModelIndex &sourceParent) const
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
const QString collection = sourceModel()->data(index, LocalDocsModel::CollectionRole).toString();
return m_collections.contains(collection);
}
void LocalDocsCollectionsModel::setCollections(const QList<QString> &collections)
{
m_collections = collections;
invalidateFilter();
}
LocalDocsModel::LocalDocsModel(QObject *parent)
: QAbstractListModel(parent)
{
@ -24,6 +46,16 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const
return item.folder_path;
case InstalledRole:
return item.installed;
case IndexingRole:
return item.indexing;
case CurrentDocsToIndexRole:
return item.currentDocsToIndex;
case TotalDocsToIndexRole:
return item.totalDocsToIndex;
case CurrentBytesToIndexRole:
return quint64(item.currentBytesToIndex);
case TotalBytesToIndexRole:
return quint64(item.totalBytesToIndex);
}
return QVariant();
@ -35,9 +67,98 @@ QHash<int, QByteArray> LocalDocsModel::roleNames() const
roles[CollectionRole] = "collection";
roles[FolderPathRole] = "folder_path";
roles[InstalledRole] = "installed";
roles[IndexingRole] = "indexing";
roles[CurrentDocsToIndexRole] = "currentDocsToIndex";
roles[TotalDocsToIndexRole] = "totalDocsToIndex";
roles[CurrentBytesToIndexRole] = "currentBytesToIndex";
roles[TotalBytesToIndexRole] = "totalBytesToIndex";
return roles;
}
void LocalDocsModel::updateInstalled(int folder_id, bool b)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].installed = b;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {InstalledRole});
}
}
void LocalDocsModel::updateIndexing(int folder_id, bool b)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].indexing = b;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {IndexingRole});
}
}
void LocalDocsModel::updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentDocsToIndex = currentDocsToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentDocsToIndexRole});
}
}
void LocalDocsModel::updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].totalDocsToIndex = totalDocsToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {TotalDocsToIndexRole});
}
}
void LocalDocsModel::subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentBytesToIndex -= subtractedBytes;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole});
}
}
void LocalDocsModel::updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentBytesToIndex = currentBytesToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole});
}
}
void LocalDocsModel::updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].totalBytesToIndex = totalBytesToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {TotalBytesToIndexRole});
}
}
void LocalDocsModel::addCollectionItem(const CollectionItem &item)
{
beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size());
@ -45,7 +166,46 @@ void LocalDocsModel::addCollectionItem(const CollectionItem &item)
endInsertRows();
}
void LocalDocsModel::handleCollectionListUpdated(const QList<CollectionItem> &collectionList)
void LocalDocsModel::removeFolderById(int folder_id)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).folder_id == folder_id) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::removeCollectionPath(const QString &name, const QString &path)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).collection == name && m_collectionList.at(i).folder_path == path) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::removeCollectionItem(const QString &collectionName)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).collection == collectionName) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::collectionListUpdated(const QList<CollectionItem> &collectionList)
{
beginResetModel();
m_collectionList = collectionList;

@ -4,6 +4,22 @@
#include <QAbstractListModel>
#include "database.h"
class LocalDocsCollectionsModel : public QSortFilterProxyModel
{
Q_OBJECT
public:
explicit LocalDocsCollectionsModel(QObject *parent);
public Q_SLOTS:
void setCollections(const QList<QString> &collections);
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
private:
QList<QString> m_collections;
};
class LocalDocsModel : public QAbstractListModel
{
Q_OBJECT
@ -12,7 +28,13 @@ public:
enum Roles {
CollectionRole = Qt::UserRole + 1,
FolderPathRole,
InstalledRole
InstalledRole,
IndexingRole,
EmbeddingRole,
CurrentDocsToIndexRole,
TotalDocsToIndexRole,
CurrentBytesToIndexRole,
TotalBytesToIndexRole
};
explicit LocalDocsModel(QObject *parent = nullptr);
@ -20,9 +42,25 @@ public:
QVariant data(const QModelIndex &index, int role) const override;
QHash<int, QByteArray> roleNames() const override;
Q_SIGNALS:
void collectionItemUpdated(int index, const CollectionItem& item);
public Q_SLOTS:
void updateInstalled(int folder_id, bool b);
void updateIndexing(int folder_id, bool b);
void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex);
void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex);
void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes);
void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex);
void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex);
void addCollectionItem(const CollectionItem &item);
void handleCollectionListUpdated(const QList<CollectionItem> &collectionList);
void removeFolderById(int folder_id);
void removeCollectionPath(const QString &name, const QString &path);
void removeCollectionItem(const QString &collectionName);
void collectionListUpdated(const QList<CollectionItem> &collectionList);
private:
void updateItem(int index, const CollectionItem& item);
private:
QList<CollectionItem> m_collectionList;

@ -325,6 +325,10 @@ Window {
anchors.centerIn: parent
width: Math.min(1280, window.width - (window.width * .1))
height: window.height - (window.height * .1)
onDownloadClicked: {
downloadNewModels.showEmbeddingModels = true
downloadNewModels.open()
}
}
Button {
@ -652,6 +656,7 @@ Window {
width: Math.min(600, 0.3 * window.width)
height: window.height - y
onDownloadClicked: {
downloadNewModels.showEmbeddingModels = false
downloadNewModels.open()
}
onAboutClicked: {
@ -818,11 +823,11 @@ Window {
color: theme.textAccent
text: {
switch (currentChat.responseState) {
case Chat.ResponseStopped: return "response stopped ...";
case Chat.LocalDocsRetrieval: return "retrieving " + currentChat.collectionList.join(", ") + " ...";
case Chat.LocalDocsProcessing: return "processing " + currentChat.collectionList.join(", ") + " ...";
case Chat.PromptProcessing: return "processing ..."
case Chat.ResponseGeneration: return "generating response ...";
case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: ") + currentChat.collectionList.join(", ") + " ...";
case Chat.LocalDocsProcessing: return qsTr("searching localdocs: ") + currentChat.collectionList.join(", ") + " ...";
case Chat.PromptProcessing: return qsTr("processing ...")
case Chat.ResponseGeneration: return qsTr("generating response ...");
default: return ""; // handle unexpected values
}
}

@ -138,7 +138,7 @@
"type": "Replit",
"systemPrompt": " ",
"promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use</ul>",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-q4_0.gguf"
},
{
@ -155,7 +155,7 @@
"type": "Starcoder",
"systemPrompt": " ",
"promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based</ul>",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/starcoder-q4_0.gguf"
},
{
@ -172,7 +172,7 @@
"type": "LLaMA",
"systemPrompt": " ",
"promptTemplate": "%1",
"description": "Code completion model",
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf"
},
{
@ -184,11 +184,11 @@
"filesize": "45887744",
"requires": "2.5.0",
"ramrequired": "1",
"parameters": "1 million",
"parameters": "40 million",
"quant": "f16",
"type": "Bert",
"systemPrompt": " ",
"description": "<strong>Sbert</strong><br><ul><li>For embeddings",
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>Necessary for LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
},
{

@ -139,6 +139,50 @@ void ModelInfo::setSystemPrompt(const QString &p)
m_systemPrompt = p;
}
EmbeddingModels::EmbeddingModels(QObject *parent)
: QSortFilterProxyModel(parent)
{
connect(this, &EmbeddingModels::rowsInserted, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::rowsRemoved, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::modelReset, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::layoutChanged, this, &EmbeddingModels::countChanged);
}
bool EmbeddingModels::filterAcceptsRow(int sourceRow,
const QModelIndex &sourceParent) const
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == "all-MiniLM-L6-v2-f16.gguf";
return isInstalled && isEmbedding;
}
int EmbeddingModels::count() const
{
return rowCount();
}
ModelInfo EmbeddingModels::defaultModelInfo() const
{
if (!sourceModel())
return ModelInfo();
const ModelList *sourceListModel = qobject_cast<const ModelList*>(sourceModel());
if (!sourceListModel)
return ModelInfo();
const int rows = sourceListModel->rowCount();
for (int i = 0; i < rows; ++i) {
QModelIndex sourceIndex = sourceListModel->index(i, 0);
if (filterAcceptsRow(i, sourceIndex.parent())) {
const QString id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString();
return sourceListModel->modelInfo(id);
}
}
return ModelInfo();
}
InstalledModels::InstalledModels(QObject *parent)
: QSortFilterProxyModel(parent)
{
@ -153,7 +197,8 @@ bool InstalledModels::filterAcceptsRow(int sourceRow,
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
return isInstalled;
bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool();
return isInstalled && showInGUI;
}
int InstalledModels::count() const
@ -178,8 +223,7 @@ bool DownloadableModels::filterAcceptsRow(int sourceRow,
bool withinLimit = sourceRow < (m_expanded ? sourceModel()->rowCount() : m_limit);
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isDownloadable = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty();
bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool();
return withinLimit && isDownloadable && showInGUI;
return withinLimit && isDownloadable;
}
int DownloadableModels::count() const
@ -210,10 +254,12 @@ ModelList *ModelList::globalInstance()
ModelList::ModelList()
: QAbstractListModel(nullptr)
, m_embeddingModels(new EmbeddingModels(this))
, m_installedModels(new InstalledModels(this))
, m_downloadableModels(new DownloadableModels(this))
, m_asyncModelRequestOngoing(false)
{
m_embeddingModels->setSourceModel(this);
m_installedModels->setSourceModel(this);
m_downloadableModels->setSourceModel(this);
m_watcher = new QFileSystemWatcher(this);
@ -280,6 +326,17 @@ const QList<QString> ModelList::userDefaultModelList() const
return models;
}
int ModelList::defaultEmbeddingModelIndex() const
{
QMutexLocker locker(&m_mutex);
for (int i = 0; i < m_models.size(); ++i) {
const ModelInfo *info = m_models.at(i);
const bool isEmbedding = info->filename() == "all-MiniLM-L6-v2-f16.gguf";
if (isEmbedding) return i;
}
return -1;
}
ModelInfo ModelList::defaultModelInfo() const
{
QMutexLocker locker(&m_mutex);

@ -120,6 +120,24 @@ private:
};
Q_DECLARE_METATYPE(ModelInfo)
class EmbeddingModels : public QSortFilterProxyModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit EmbeddingModels(QObject *parent);
int count() const;
ModelInfo defaultModelInfo() const;
Q_SIGNALS:
void countChanged();
void defaultModelIndexChanged();
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
};
class InstalledModels : public QSortFilterProxyModel
{
Q_OBJECT
@ -165,6 +183,8 @@ class ModelList : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex NOTIFY defaultEmbeddingModelIndexChanged)
Q_PROPERTY(EmbeddingModels* embeddingModels READ embeddingModels NOTIFY embeddingModelsChanged)
Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged)
Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged)
Q_PROPERTY(QList<QString> userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged)
@ -273,6 +293,7 @@ public:
Q_INVOKABLE QString clone(const ModelInfo &model);
Q_INVOKABLE void remove(const ModelInfo &model);
ModelInfo defaultModelInfo() const;
int defaultEmbeddingModelIndex() const;
void addModel(const QString &id);
void changeId(const QString &oldId, const QString &newId);
@ -280,6 +301,7 @@ public:
const QList<ModelInfo> exportModelList() const;
const QList<QString> userDefaultModelList() const;
EmbeddingModels *embeddingModels() const { return m_embeddingModels; }
InstalledModels *installedModels() const { return m_installedModels; }
DownloadableModels *downloadableModels() const { return m_downloadableModels; }
@ -300,10 +322,12 @@ public:
Q_SIGNALS:
void countChanged();
void embeddingModelsChanged();
void installedModelsChanged();
void downloadableModelsChanged();
void userDefaultModelListChanged();
void asyncModelRequestOngoingChanged();
void defaultEmbeddingModelIndexChanged();
private Q_SLOTS:
void updateModelsFromJson();
@ -326,6 +350,7 @@ private:
private:
mutable QMutex m_mutex;
QNetworkAccessManager m_networkManager;
EmbeddingModels *m_embeddingModels;
InstalledModels *m_installedModels;
DownloadableModels *m_downloadableModels;
QList<ModelInfo*> m_models;

@ -21,7 +21,7 @@ MyDialog {
id: listLabel
anchors.top: parent.top
anchors.left: parent.left
text: "Available LocalDocs Collections:"
text: qsTr("Local Documents:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
}
@ -63,17 +63,60 @@ MyDialog {
currentChat.removeCollection(collection)
}
}
ToolTip.text: qsTr("Warning: searching collections while indexing can return incomplete results")
ToolTip.visible: hovered && model.indexing
}
Text {
id: collectionId
anchors.verticalCenter: parent.verticalCenter
anchors.left: checkBox.right
anchors.margins: 20
anchors.leftMargin: 10
text: collection
font.pixelSize: theme.fontSizeLarge
elide: Text.ElideRight
color: theme.textColor
}
ProgressBar {
id: itemProgressBar
anchors.verticalCenter: parent.verticalCenter
anchors.left: collectionId.right
anchors.right: parent.right
anchors.margins: 20
anchors.leftMargin: 40
visible: model.indexing
value: (model.totalBytesToIndex - model.currentBytesToIndex) / model.totalBytesToIndex
background: Rectangle {
implicitHeight: 45
color: theme.backgroundDarkest
radius: 3
}
contentItem: Item {
implicitHeight: 40
Rectangle {
width: itemProgressBar.visualPosition * parent.width
height: parent.height
radius: 2
color: theme.assistantColor
}
}
Accessible.role: Accessible.ProgressBar
Accessible.name: qsTr("Indexing progressBar")
Accessible.description: qsTr("Shows the progress made in the indexing")
}
Label {
id: speedLabel
color: theme.textColor
visible: model.indexing
anchors.verticalCenter: itemProgressBar.verticalCenter
anchors.left: itemProgressBar.left
anchors.right: itemProgressBar.right
horizontalAlignment: Text.AlignHCenter
text: qsTr("indexing...")
elide: Text.ElideRight
font.pixelSize: theme.fontSizeLarge
}
}
}
}

@ -5,6 +5,7 @@ import QtQuick.Controls.Basic
import QtQuick.Layouts
import QtQuick.Dialogs
import localdocs
import modellist
import mysettings
import network
@ -13,7 +14,11 @@ MySettingsTab {
MySettings.restoreLocalDocsDefaults();
}
title: qsTr("LocalDocs Plugin (BETA)")
property bool hasEmbeddingModel: ModelList.embeddingModels.count !== 0
showAdvancedSettingsButton: hasEmbeddingModel
showRestoreDefaultsButton: hasEmbeddingModel
title: qsTr("LocalDocs")
contentItem: ColumnLayout {
id: root
spacing: 10
@ -21,7 +26,30 @@ MySettingsTab {
property alias collection: collection.text
property alias folder_path: folderEdit.text
Label {
id: downloadLabel
Layout.fillWidth: true
Layout.maximumWidth: parent.width
wrapMode: Text.Wrap
visible: !hasEmbeddingModel
Layout.alignment: Qt.AlignLeft
text: qsTr("This feature requires the download of a text embedding model in order to index documents for later search. Please download the <b>SBert</a> text embedding model from the download dialog to proceed.")
font.pixelSize: theme.fontSizeLarger
}
MyButton {
visible: !hasEmbeddingModel
Layout.topMargin: 20
Layout.alignment: Qt.AlignLeft
text: qsTr("Download")
font.pixelSize: theme.fontSizeLarger
onClicked: {
downloadClicked()
}
}
Item {
visible: hasEmbeddingModel
Layout.fillWidth: true
height: row.height
RowLayout {
@ -106,6 +134,7 @@ MySettingsTab {
}
ColumnLayout {
visible: hasEmbeddingModel
spacing: 0
Repeater {
model: LocalDocs.localDocsModel
@ -145,29 +174,25 @@ MySettingsTab {
anchors.right: parent.right
anchors.verticalCenter: parent.verticalCenter
anchors.margins: 20
width: Math.max(removeButton.width, busyIndicator.width)
height: Math.max(removeButton.height, busyIndicator.height)
width: removeButton.width
height:removeButton.height
MyButton {
id: removeButton
anchors.centerIn: parent
text: qsTr("Remove")
visible: !item.removing && installed
visible: !item.removing
onClicked: {
item.removing = true
LocalDocs.removeFolder(collection, folder_path)
}
}
MyBusyIndicator {
id: busyIndicator
anchors.centerIn: parent
visible: item.removing || !installed
}
}
}
}
}
RowLayout {
visible: hasEmbeddingModel
Label {
id: showReferencesLabel
text: qsTr("Show references:")
@ -186,6 +211,7 @@ MySettingsTab {
}
Rectangle {
visible: hasEmbeddingModel
Layout.fillWidth: true
height: 1
color: theme.tabBorder
@ -196,6 +222,7 @@ MySettingsTab {
columns: 3
rowSpacing: 10
columnSpacing: 10
visible: hasEmbeddingModel
Rectangle {
Layout.row: 3

@ -16,9 +16,17 @@ MyDialog {
modal: true
closePolicy: Popup.CloseOnEscape | Popup.CloseOnPressOutside
padding: 10
property bool showEmbeddingModels: false
onOpened: {
Network.sendModelDownloaderDialog();
if (showEmbeddingModels) {
ModelList.downloadableModels.expanded = true
var targetModelIndex = ModelList.defaultEmbeddingModelIndex
console.log("targetModelIndex " + targetModelIndex)
modelListView.positionViewAtIndex(targetModelIndex, ListView.Contain);
}
}
PopupDialog {

@ -9,8 +9,11 @@ Item {
property string title: ""
property Item contentItem: null
property Item advancedSettings: null
property bool showAdvancedSettingsButton: true
property bool showRestoreDefaultsButton: true
property var openFolderDialog
signal restoreDefaultsClicked
signal downloadClicked
onContentItemChanged: function() {
if (contentItem) {
@ -64,6 +67,7 @@ Item {
MyButton {
id: restoreDefaultsButton
anchors.left: parent.left
visible: showRestoreDefaultsButton
width: implicitWidth
text: qsTr("Restore Defaults")
font.pixelSize: theme.fontSizeLarge
@ -77,7 +81,7 @@ Item {
MyButton {
id: advancedSettingsButton
anchors.right: parent.right
visible: root.advancedSettings
visible: root.advancedSettings && showAdvancedSettingsButton
width: implicitWidth
text: !advancedInner.visible ? qsTr("Advanced Settings") : qsTr("Hide Advanced Settings")
font.pixelSize: theme.fontSizeLarge

@ -19,6 +19,8 @@ MyDialog {
Network.sendSettingsDialog();
}
signal downloadClicked
Item {
Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Settings")
@ -28,13 +30,13 @@ MyDialog {
ListModel {
id: stacksModel
ListElement {
title: "Models"
title: qsTr("Models")
}
ListElement {
title: "Application"
title: qsTr("Application")
}
ListElement {
title: "Plugins"
title: qsTr("LocalDocs")
}
}
@ -107,9 +109,16 @@ MyDialog {
}
MySettingsStack {
title: qsTr("LocalDocs Plugin (BETA) Settings")
title: qsTr("Local Document Collections")
tabs: [
Component { LocalDocsSettings { } }
Component {
LocalDocsSettings {
id: localDocsSettings
Component.onCompleted: {
localDocsSettings.downloadClicked.connect(settingsDialog.downloadClicked);
}
}
}
]
}
}

Loading…
Cancel
Save