From 9273b49b62b4f1073affeb7bf63671a049859f0b Mon Sep 17 00:00:00 2001 From: AT Date: Mon, 24 Jun 2024 18:49:23 -0400 Subject: [PATCH] chat: major UI redesign for v3.0.0 (#2396) Signed-off-by: Adam Treat Signed-off-by: Jared Van Bortel Co-authored-by: Jared Van Bortel --- .gitmodules | 3 + gpt4all-backend/dlhandle.cpp | 18 +- gpt4all-backend/gptj.cpp | 33 +- gpt4all-backend/llamamodel.cpp | 72 +- gpt4all-backend/llmodel.cpp | 48 +- gpt4all-backend/llmodel_c.cpp | 12 +- gpt4all-backend/llmodel_shared.cpp | 6 +- gpt4all-backend/llmodel_shared.h | 3 +- gpt4all-backend/utils.cpp | 18 +- gpt4all-backend/utils.h | 3 +- gpt4all-chat/CMakeLists.txt | 74 +- gpt4all-chat/chat.cpp | 47 +- gpt4all-chat/chat.h | 2 +- gpt4all-chat/chatapi.cpp | 11 +- gpt4all-chat/chatlistmodel.cpp | 4 +- gpt4all-chat/chatlistmodel.h | 24 +- gpt4all-chat/chatllm.cpp | 46 +- gpt4all-chat/chatllm.h | 4 +- gpt4all-chat/chatmodel.h | 213 +- gpt4all-chat/database.cpp | 2147 ++++++++++----- gpt4all-chat/database.h | 182 +- gpt4all-chat/download.cpp | 68 +- gpt4all-chat/download.h | 6 + gpt4all-chat/embeddings.cpp | 202 -- gpt4all-chat/embeddings.h | 48 - gpt4all-chat/embllm.cpp | 215 +- gpt4all-chat/embllm.h | 33 +- gpt4all-chat/hnswlib/bruteforce.h | 167 -- gpt4all-chat/hnswlib/hnswalg.h | 1271 --------- gpt4all-chat/hnswlib/hnswlib.h | 199 -- gpt4all-chat/hnswlib/space_ip.h | 375 --- gpt4all-chat/hnswlib/space_l2.h | 324 --- gpt4all-chat/hnswlib/visited_list_pool.h | 78 - gpt4all-chat/icons/alt_logo.svg | 52 + gpt4all-chat/icons/antenna_1.svg | 4 + gpt4all-chat/icons/antenna_2.svg | 5 + gpt4all-chat/icons/antenna_3.svg | 6 + gpt4all-chat/icons/changelog.svg | 3 + gpt4all-chat/icons/chat.svg | 3 + gpt4all-chat/icons/db.svg | 8 +- gpt4all-chat/icons/discord.svg | 3 + gpt4all-chat/icons/edit.svg | 8 +- gpt4all-chat/icons/email.svg | 1 + gpt4all-chat/icons/file-md.svg | 1 + gpt4all-chat/icons/file-pdf.svg | 1 + gpt4all-chat/icons/file-txt.svg | 1 + gpt4all-chat/icons/file.svg | 1 + gpt4all-chat/icons/github.svg | 10 + gpt4all-chat/icons/globe.svg | 3 + gpt4all-chat/icons/home.svg | 3 + gpt4all-chat/icons/info.svg | 3 + gpt4all-chat/icons/local-docs.svg | 3 + gpt4all-chat/icons/models.svg | 3 + gpt4all-chat/icons/nomic_logo.svg | 7 + gpt4all-chat/icons/notes.svg | 1 + gpt4all-chat/icons/search.svg | 6 + gpt4all-chat/icons/settings.svg | 47 +- gpt4all-chat/icons/trash.svg | 8 +- gpt4all-chat/icons/twitter.svg | 3 + gpt4all-chat/icons/you.svg | 41 + gpt4all-chat/llm.cpp | 14 +- gpt4all-chat/localdocs.cpp | 67 +- gpt4all-chat/localdocs.h | 15 +- gpt4all-chat/localdocsmodel.cpp | 202 +- gpt4all-chat/localdocsmodel.h | 53 +- gpt4all-chat/logger.cpp | 4 +- gpt4all-chat/main.cpp | 4 +- gpt4all-chat/main.qml | 576 +++- gpt4all-chat/modellist.cpp | 263 +- gpt4all-chat/modellist.h | 42 +- gpt4all-chat/mysettings.cpp | 962 ++----- gpt4all-chat/mysettings.h | 187 +- gpt4all-chat/network.cpp | 15 +- gpt4all-chat/oscompat.cpp | 70 + gpt4all-chat/oscompat.h | 7 + gpt4all-chat/qml/AboutDialog.qml | 101 - gpt4all-chat/qml/AddCollectionView.qml | 170 ++ gpt4all-chat/qml/AddModelView.qml | 726 +++++ gpt4all-chat/qml/ApplicationSettings.qml | 318 ++- gpt4all-chat/qml/ChatDrawer.qml | 124 +- gpt4all-chat/qml/ChatView.qml | 2421 ++++++++--------- gpt4all-chat/qml/CollectionsDialog.qml | 148 - gpt4all-chat/qml/CollectionsDrawer.qml | 148 + gpt4all-chat/qml/HomeView.qml | 278 ++ gpt4all-chat/qml/LocalDocsSettings.qml | 450 ++- gpt4all-chat/qml/LocalDocsView.qml | 457 ++++ gpt4all-chat/qml/ModelDownloaderDialog.qml | 643 ----- gpt4all-chat/qml/ModelSettings.qml | 88 +- gpt4all-chat/qml/ModelsView.qml | 321 +++ gpt4all-chat/qml/MyBusyIndicator.qml | 25 +- gpt4all-chat/qml/MyButton.qml | 9 +- gpt4all-chat/qml/MyComboBox.qml | 2 +- gpt4all-chat/qml/MyDirectoryField.qml | 2 + gpt4all-chat/qml/MyFancyLink.qml | 44 + gpt4all-chat/qml/MyMiniButton.qml | 5 +- gpt4all-chat/qml/MySettingsButton.qml | 8 +- .../qml/MySettingsDestructiveButton.qml | 10 +- gpt4all-chat/qml/MySettingsLabel.qml | 39 +- gpt4all-chat/qml/MySettingsStack.qml | 41 +- gpt4all-chat/qml/MySettingsTab.qml | 32 +- gpt4all-chat/qml/MySlug.qml | 2 +- gpt4all-chat/qml/MyToolButton.qml | 16 +- gpt4all-chat/qml/MyWelcomeButton.qml | 77 + gpt4all-chat/qml/SettingsDialog.qml | 132 - gpt4all-chat/qml/SettingsView.qml | 158 ++ gpt4all-chat/qml/StartupDialog.qml | 16 +- gpt4all-chat/qml/Theme.qml | 586 +++- gpt4all-chat/responsetext.cpp | 128 +- gpt4all-chat/responsetext.h | 17 - gpt4all-chat/server.cpp | 4 +- gpt4all-chat/usearch | 1 + 111 files changed, 8540 insertions(+), 7879 deletions(-) delete mode 100644 gpt4all-chat/embeddings.cpp delete mode 100644 gpt4all-chat/embeddings.h delete mode 100644 gpt4all-chat/hnswlib/bruteforce.h delete mode 100644 gpt4all-chat/hnswlib/hnswalg.h delete mode 100644 gpt4all-chat/hnswlib/hnswlib.h delete mode 100644 gpt4all-chat/hnswlib/space_ip.h delete mode 100644 gpt4all-chat/hnswlib/space_l2.h delete mode 100644 gpt4all-chat/hnswlib/visited_list_pool.h create mode 100644 gpt4all-chat/icons/alt_logo.svg create mode 100644 gpt4all-chat/icons/antenna_1.svg create mode 100644 gpt4all-chat/icons/antenna_2.svg create mode 100644 gpt4all-chat/icons/antenna_3.svg create mode 100644 gpt4all-chat/icons/changelog.svg create mode 100644 gpt4all-chat/icons/chat.svg create mode 100644 gpt4all-chat/icons/discord.svg create mode 100644 gpt4all-chat/icons/email.svg create mode 100644 gpt4all-chat/icons/file-md.svg create mode 100644 gpt4all-chat/icons/file-pdf.svg create mode 100644 gpt4all-chat/icons/file-txt.svg create mode 100644 gpt4all-chat/icons/file.svg create mode 100644 gpt4all-chat/icons/github.svg create mode 100644 gpt4all-chat/icons/globe.svg create mode 100644 gpt4all-chat/icons/home.svg create mode 100644 gpt4all-chat/icons/info.svg create mode 100644 gpt4all-chat/icons/local-docs.svg create mode 100644 gpt4all-chat/icons/models.svg create mode 100644 gpt4all-chat/icons/nomic_logo.svg create mode 100644 gpt4all-chat/icons/notes.svg create mode 100644 gpt4all-chat/icons/search.svg create mode 100644 gpt4all-chat/icons/twitter.svg create mode 100644 gpt4all-chat/icons/you.svg create mode 100644 gpt4all-chat/oscompat.cpp create mode 100644 gpt4all-chat/oscompat.h delete mode 100644 gpt4all-chat/qml/AboutDialog.qml create mode 100644 gpt4all-chat/qml/AddCollectionView.qml create mode 100644 gpt4all-chat/qml/AddModelView.qml delete mode 100644 gpt4all-chat/qml/CollectionsDialog.qml create mode 100644 gpt4all-chat/qml/CollectionsDrawer.qml create mode 100644 gpt4all-chat/qml/HomeView.qml create mode 100644 gpt4all-chat/qml/LocalDocsView.qml delete mode 100644 gpt4all-chat/qml/ModelDownloaderDialog.qml create mode 100644 gpt4all-chat/qml/ModelsView.qml create mode 100644 gpt4all-chat/qml/MyFancyLink.qml create mode 100644 gpt4all-chat/qml/MyWelcomeButton.qml delete mode 100644 gpt4all-chat/qml/SettingsDialog.qml create mode 100644 gpt4all-chat/qml/SettingsView.qml create mode 160000 gpt4all-chat/usearch diff --git a/.gitmodules b/.gitmodules index 03751865..4a1439e2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = gpt4all-backend/llama.cpp-mainline url = https://github.com/nomic-ai/llama.cpp.git branch = master +[submodule "gpt4all-chat/usearch"] + path = gpt4all-chat/usearch + url = https://github.com/unum-cloud/usearch.git diff --git a/gpt4all-backend/dlhandle.cpp b/gpt4all-backend/dlhandle.cpp index fd9a1f22..d7d46b29 100644 --- a/gpt4all-backend/dlhandle.cpp +++ b/gpt4all-backend/dlhandle.cpp @@ -20,24 +20,28 @@ namespace fs = std::filesystem; #ifndef _WIN32 -Dlhandle::Dlhandle(const fs::path &fpath) { +Dlhandle::Dlhandle(const fs::path &fpath) +{ chandle = dlopen(fpath.c_str(), RTLD_LAZY | RTLD_LOCAL); if (!chandle) { throw Exception("dlopen: "s + dlerror()); } } -Dlhandle::~Dlhandle() { +Dlhandle::~Dlhandle() +{ if (chandle) dlclose(chandle); } -void *Dlhandle::get_internal(const char *symbol) const { +void *Dlhandle::get_internal(const char *symbol) const +{ return dlsym(chandle, symbol); } #else // defined(_WIN32) -Dlhandle::Dlhandle(const fs::path &fpath) { +Dlhandle::Dlhandle(const fs::path &fpath) +{ fs::path afpath = fs::absolute(fpath); // Suppress the "Entry Point Not Found" dialog, caused by outdated nvcuda.dll from the GPU driver @@ -58,11 +62,13 @@ Dlhandle::Dlhandle(const fs::path &fpath) { } } -Dlhandle::~Dlhandle() { +Dlhandle::~Dlhandle() +{ if (chandle) FreeLibrary(HMODULE(chandle)); } -void *Dlhandle::get_internal(const char *symbol) const { +void *Dlhandle::get_internal(const char *symbol) const +{ return GetProcAddress(HMODULE(chandle), symbol); } diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 90ec829a..48a141c7 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -123,7 +123,8 @@ static bool kv_cache_init( } // load the model's weights from a file path -bool gptj_model_load(const std::string &fname, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) { +bool gptj_model_load(const std::string &fname, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) +{ printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); if(mem_req != nullptr) { *mem_req = 0; @@ -667,7 +668,8 @@ GPTJ::GPTJ() d_ptr->modelLoaded = false; } -size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx, int ngl) { +size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx, int ngl) +{ (void)n_ctx; (void)ngl; gptj_model dummy_model; @@ -677,7 +679,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx, int ngl) { return mem_req; } -bool GPTJ::loadModel(const std::string &modelPath, int n_ctx, int ngl) { +bool GPTJ::loadModel(const std::string &modelPath, int n_ctx, int ngl) +{ (void)n_ctx; (void)ngl; d_ptr->modelLoaded = false; @@ -698,7 +701,8 @@ bool GPTJ::loadModel(const std::string &modelPath, int n_ctx, int ngl) { return true; } -void GPTJ::setThreadCount(int32_t n_threads) { +void GPTJ::setThreadCount(int32_t n_threads) +{ d_ptr->n_threads = n_threads; } @@ -780,7 +784,8 @@ const std::vector &GPTJ::endTokens() const return fres; } -const char *get_arch_name(gguf_context *ctx_gguf) { +const char *get_arch_name(gguf_context *ctx_gguf) +{ const int kid = gguf_find_key(ctx_gguf, "general.architecture"); if (kid == -1) throw std::runtime_error("key not found in model: general.architecture"); @@ -799,19 +804,23 @@ const char *get_arch_name(gguf_context *ctx_gguf) { #endif extern "C" { -DLL_EXPORT bool is_g4a_backend_model_implementation() { +DLL_EXPORT bool is_g4a_backend_model_implementation() +{ return true; } -DLL_EXPORT const char *get_model_type() { +DLL_EXPORT const char *get_model_type() +{ return modelType_; } -DLL_EXPORT const char *get_build_variant() { +DLL_EXPORT const char *get_build_variant() +{ return GGML_BUILD_VARIANT; } -DLL_EXPORT char *get_file_arch(const char *fname) { +DLL_EXPORT char *get_file_arch(const char *fname) +{ struct ggml_context * ctx_meta = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -832,11 +841,13 @@ DLL_EXPORT char *get_file_arch(const char *fname) { return arch; } -DLL_EXPORT bool is_arch_supported(const char *arch) { +DLL_EXPORT bool is_arch_supported(const char *arch) +{ return !strcmp(arch, "gptj"); } -DLL_EXPORT LLModel *construct() { +DLL_EXPORT LLModel *construct() +{ return new GPTJ; } } diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 685e8aa5..a55c23bf 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -84,16 +84,19 @@ static const std::vector EMBEDDING_ARCHES { "bert", "nomic-bert", }; -static bool is_embedding_arch(const std::string &arch) { +static bool is_embedding_arch(const std::string &arch) +{ return std::find(EMBEDDING_ARCHES.begin(), EMBEDDING_ARCHES.end(), arch) < EMBEDDING_ARCHES.end(); } -static bool llama_verbose() { +static bool llama_verbose() +{ const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP"); return var && *var; } -static void llama_log_callback(enum ggml_log_level level, const char *text, void *userdata) { +static void llama_log_callback(enum ggml_log_level level, const char *text, void *userdata) +{ (void)userdata; if (llama_verbose() || level <= GGML_LOG_LEVEL_ERROR) { fputs(text, stderr); @@ -147,7 +150,8 @@ static int llama_sample_top_p_top_k( return llama_sample_token(ctx, &candidates_p); } -const char *get_arch_name(gguf_context *ctx_gguf) { +const char *get_arch_name(gguf_context *ctx_gguf) +{ const int kid = gguf_find_key(ctx_gguf, "general.architecture"); if (kid == -1) throw std::runtime_error("key not found in model: general.architecture"); @@ -159,7 +163,8 @@ const char *get_arch_name(gguf_context *ctx_gguf) { return gguf_get_val_str(ctx_gguf, kid); } -static gguf_context *load_gguf(const char *fname) { +static gguf_context *load_gguf(const char *fname) +{ struct gguf_init_params params = { /*.no_alloc = */ true, /*.ctx = */ nullptr, @@ -180,7 +185,8 @@ static gguf_context *load_gguf(const char *fname) { return ctx; } -static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) { +static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) +{ int32_t value = -1; std::string arch; @@ -237,7 +243,8 @@ struct llama_file_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; }; -size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl) { +size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl) +{ // TODO(cebtenzzre): update to GGUF (void)ngl; // FIXME(cetenzzre): use this value auto fin = std::ifstream(modelPath, std::ios::binary); @@ -261,7 +268,8 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl) return filesize + est_kvcache_size; } -bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const { +bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const +{ auto * ctx = load_gguf(modelPath.c_str()); if (!ctx) { std::cerr << __func__ << ": failed to load " << modelPath << "\n"; @@ -297,7 +305,8 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const { return res; } -bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const { +bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const +{ bool result = false; std::string arch; @@ -453,12 +462,14 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) return true; } -void LLamaModel::setThreadCount(int32_t n_threads) { +void LLamaModel::setThreadCount(int32_t n_threads) +{ d_ptr->n_threads = n_threads; llama_set_n_threads(d_ptr->ctx, n_threads, n_threads); } -int32_t LLamaModel::threadCount() const { +int32_t LLamaModel::threadCount() const +{ return d_ptr->n_threads; } @@ -581,7 +592,8 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const } #ifdef GGML_USE_VULKAN -static const char *getVulkanVendorName(uint32_t vendorID) { +static const char *getVulkanVendorName(uint32_t vendorID) +{ switch (vendorID) { case 0x10DE: return "nvidia"; case 0x1002: return "amd"; @@ -738,11 +750,13 @@ bool LLamaModel::usingGPUDevice() const return hasDevice; } -const char *LLamaModel::backendName() const { +const char *LLamaModel::backendName() const +{ return d_ptr->backend_name; } -const char *LLamaModel::gpuDeviceName() const { +const char *LLamaModel::gpuDeviceName() const +{ if (usingGPUDevice()) { #if defined(GGML_USE_KOMPUTE) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CUDA) return d_ptr->deviceName.c_str(); @@ -768,13 +782,15 @@ void llama_batch_add( batch.n_tokens++; } -static void batch_add_seq(llama_batch &batch, const std::vector &tokens, int seq_id) { +static void batch_add_seq(llama_batch &batch, const std::vector &tokens, int seq_id) +{ for (unsigned i = 0; i < tokens.size(); i++) { llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } } -size_t LLamaModel::embeddingSize() const { +size_t LLamaModel::embeddingSize() const +{ return llama_n_embd(d_ptr->model); } @@ -894,12 +910,14 @@ void LLamaModel::embed( // MD5 hash of "nomic empty" static const char EMPTY_PLACEHOLDER[] = "24df574ea1c998de59d5be15e769658e"; -auto product(double a) -> std::function { +auto product(double a) -> std::function +{ return [a](double b) { return a * b; }; } template -double getL2NormScale(T *start, T *end) { +double getL2NormScale(T *start, T *end) +{ double magnitude = std::sqrt(std::inner_product(start, end, start, 0.0)); return 1.0 / std::max(magnitude, 1e-12); } @@ -1107,19 +1125,23 @@ void LLamaModel::embedInternal( #endif extern "C" { -DLL_EXPORT bool is_g4a_backend_model_implementation() { +DLL_EXPORT bool is_g4a_backend_model_implementation() +{ return true; } -DLL_EXPORT const char *get_model_type() { +DLL_EXPORT const char *get_model_type() +{ return modelType_; } -DLL_EXPORT const char *get_build_variant() { +DLL_EXPORT const char *get_build_variant() +{ return GGML_BUILD_VARIANT; } -DLL_EXPORT char *get_file_arch(const char *fname) { +DLL_EXPORT char *get_file_arch(const char *fname) +{ char *arch = nullptr; std::string archStr; @@ -1144,11 +1166,13 @@ cleanup: return arch; } -DLL_EXPORT bool is_arch_supported(const char *arch) { +DLL_EXPORT bool is_arch_supported(const char *arch) +{ return std::find(KNOWN_ARCHES.begin(), KNOWN_ARCHES.end(), std::string(arch)) < KNOWN_ARCHES.end(); } -DLL_EXPORT LLModel *construct() { +DLL_EXPORT LLModel *construct() +{ llama_log_set(llama_log_callback, nullptr); return new LLamaModel; } diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index 402553ba..2f35180e 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -92,17 +92,20 @@ LLModel::Implementation::Implementation(Implementation &&o) o.m_dlhandle = nullptr; } -LLModel::Implementation::~Implementation() { +LLModel::Implementation::~Implementation() +{ delete m_dlhandle; } -static bool isImplementation(const Dlhandle &dl) { +static bool isImplementation(const Dlhandle &dl) +{ return dl.get("is_g4a_backend_model_implementation"); } // Add the CUDA Toolkit to the DLL search path on Windows. // This is necessary for chat.exe to find CUDA when started from Qt Creator. -static void addCudaSearchPath() { +static void addCudaSearchPath() +{ #ifdef _WIN32 if (const auto *cudaPath = _wgetenv(L"CUDA_PATH")) { auto libDir = std::wstring(cudaPath) + L"\\bin"; @@ -114,7 +117,8 @@ static void addCudaSearchPath() { #endif } -const std::vector &LLModel::Implementation::implementationList() { +const std::vector &LLModel::Implementation::implementationList() +{ if (cpu_supports_avx() == 0) { throw std::runtime_error("CPU does not support AVX"); } @@ -169,14 +173,16 @@ const std::vector &LLModel::Implementation::implementat return *libs; } -static std::string applyCPUVariant(const std::string &buildVariant) { +static std::string applyCPUVariant(const std::string &buildVariant) +{ if (buildVariant != "metal" && cpu_supports_avx2() == 0) { return buildVariant + "-avxonly"; } return buildVariant; } -const LLModel::Implementation* LLModel::Implementation::implementation(const char *fname, const std::string& buildVariant) { +const LLModel::Implementation* LLModel::Implementation::implementation(const char *fname, const std::string& buildVariant) +{ bool buildVariantMatched = false; std::optional archName; for (const auto& i : implementationList()) { @@ -200,7 +206,8 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha throw BadArchError(std::move(*archName)); } -LLModel *LLModel::Implementation::construct(const std::string &modelPath, const std::string &backend, int n_ctx) { +LLModel *LLModel::Implementation::construct(const std::string &modelPath, const std::string &backend, int n_ctx) +{ std::vector desiredBackends; if (backend != "auto") { desiredBackends.push_back(backend); @@ -240,7 +247,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, const throw MissingImplementationError("Could not find any implementations for backend: " + backend); } -LLModel *LLModel::Implementation::constructGlobalLlama(const std::optional &backend) { +LLModel *LLModel::Implementation::constructGlobalLlama(const std::optional &backend) +{ static std::unordered_map> implCache; const std::vector *impls; @@ -284,7 +292,8 @@ LLModel *LLModel::Implementation::constructGlobalLlama(const std::optional LLModel::Implementation::availableGPUDevices(size_t memoryRequired) { +std::vector LLModel::Implementation::availableGPUDevices(size_t memoryRequired) +{ std::vector devices; #ifndef __APPLE__ static const std::string backends[] = {"kompute", "cuda"}; @@ -299,33 +308,40 @@ std::vector LLModel::Implementation::availableGPUDevices(siz return devices; } -int32_t LLModel::Implementation::maxContextLength(const std::string &modelPath) { +int32_t LLModel::Implementation::maxContextLength(const std::string &modelPath) +{ auto *llama = constructGlobalLlama(); return llama ? llama->maxContextLength(modelPath) : -1; } -int32_t LLModel::Implementation::layerCount(const std::string &modelPath) { +int32_t LLModel::Implementation::layerCount(const std::string &modelPath) +{ auto *llama = constructGlobalLlama(); return llama ? llama->layerCount(modelPath) : -1; } -bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath) { +bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath) +{ auto *llama = constructGlobalLlama(); return llama && llama->isEmbeddingModel(modelPath); } -void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) { +void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) +{ s_implementations_search_path = path; } -const std::string& LLModel::Implementation::implementationsSearchPath() { +const std::string& LLModel::Implementation::implementationsSearchPath() +{ return s_implementations_search_path; } -bool LLModel::Implementation::hasSupportedCPU() { +bool LLModel::Implementation::hasSupportedCPU() +{ return cpu_supports_avx() != 0; } -int LLModel::Implementation::cpuSupportsAVX2() { +int LLModel::Implementation::cpuSupportsAVX2() +{ return cpu_supports_avx2(); } diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 439bf05a..62663d7f 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -20,7 +20,8 @@ struct LLModelWrapper { ~LLModelWrapper() { delete llModel; } }; -llmodel_model llmodel_model_create(const char *model_path) { +llmodel_model llmodel_model_create(const char *model_path) +{ const char *error; auto fres = llmodel_model_create2(model_path, "auto", &error); if (!fres) { @@ -29,7 +30,8 @@ llmodel_model llmodel_model_create(const char *model_path) { return fres; } -static void llmodel_set_error(const char **errptr, const char *message) { +static void llmodel_set_error(const char **errptr, const char *message) +{ thread_local static std::string last_error_message; if (errptr) { last_error_message = message; @@ -37,7 +39,8 @@ static void llmodel_set_error(const char **errptr, const char *message) { } } -llmodel_model llmodel_model_create2(const char *model_path, const char *backend, const char **error) { +llmodel_model llmodel_model_create2(const char *model_path, const char *backend, const char **error) +{ LLModel *llModel; try { llModel = LLModel::Implementation::construct(model_path, backend); @@ -51,7 +54,8 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *backend, return wrapper; } -void llmodel_model_destroy(llmodel_model model) { +void llmodel_model_destroy(llmodel_model model) +{ delete static_cast(model); } diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 1296dc4e..1f797e8f 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -14,7 +14,8 @@ #include // TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is) -void LLModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) { +void LLModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ int n_keep = shouldAddBOS(); const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase; @@ -43,7 +44,8 @@ stop_generating: recalculate(false); } -static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) { +static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) +{ static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))"); auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex); diff --git a/gpt4all-backend/llmodel_shared.h b/gpt4all-backend/llmodel_shared.h index b19db356..94a267bf 100644 --- a/gpt4all-backend/llmodel_shared.h +++ b/gpt4all-backend/llmodel_shared.h @@ -38,7 +38,8 @@ struct llm_kv_cache { } }; -inline void ggml_graph_compute_g4a(llm_buffer& buf, ggml_cgraph * graph, int n_threads) { +inline void ggml_graph_compute_g4a(llm_buffer& buf, ggml_cgraph * graph, int n_threads) +{ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); if (plan.work_size > 0) { buf.resize(plan.work_size); diff --git a/gpt4all-backend/utils.cpp b/gpt4all-backend/utils.cpp index 2d25686e..5d32c91e 100644 --- a/gpt4all-backend/utils.cpp +++ b/gpt4all-backend/utils.cpp @@ -8,7 +8,8 @@ #include #include -void replace(std::string & str, const std::string & needle, const std::string & replacement) { +void replace(std::string & str, const std::string & needle, const std::string & replacement) +{ size_t pos = 0; while ((pos = str.find(needle, pos)) != std::string::npos) { str.replace(pos, needle.length(), replacement); @@ -16,7 +17,8 @@ void replace(std::string & str, const std::string & needle, const std::string & } } -std::map json_parse(const std::string & fname) { +std::map json_parse(const std::string & fname) +{ std::map result; // read file into string @@ -107,7 +109,8 @@ std::map json_parse(const std::string & fname) { return result; } -std::vector gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) { +std::vector gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) +{ std::vector words; // first split the text into words @@ -162,12 +165,14 @@ std::vector gpt_tokenize_inner(const gpt_vocab & vocab, const std return tokens; } -std::string regex_escape(const std::string &s) { +std::string regex_escape(const std::string &s) +{ static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])"); return std::regex_replace(s, metacharacters, "\\$&"); } -std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { +std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) +{ // Generate the subpattern from the special_tokens vector if it's not empty if (!vocab.special_tokens.empty()) { std::vector out; @@ -203,7 +208,8 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri } -bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { +bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) +{ printf("%s: loading vocab from '%s'\n", __func__, fname.c_str()); vocab.token_to_id = ::json_parse(fname); diff --git a/gpt4all-backend/utils.h b/gpt4all-backend/utils.h index ea99f7a0..9740aabd 100644 --- a/gpt4all-backend/utils.h +++ b/gpt4all-backend/utils.h @@ -14,7 +14,8 @@ // // General purpose inline functions // -constexpr inline unsigned long long operator ""_MiB(unsigned long long bytes) { +constexpr inline unsigned long long operator ""_MiB(unsigned long long bytes) +{ return bytes*1024*1024; } diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index bdd64092..742a2e5d 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -16,15 +16,16 @@ if(APPLE) endif() endif() -set(APP_VERSION_MAJOR 2) -set(APP_VERSION_MINOR 8) -set(APP_VERSION_PATCH 1) -set(APP_VERSION "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") +set(APP_VERSION_MAJOR 3) +set(APP_VERSION_MINOR 0) +set(APP_VERSION_PATCH 0) +set(APP_VERSION_BASE "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") +set(APP_VERSION "${APP_VERSION_BASE}-rc1") # Include the binary directory for the generated header file include_directories("${CMAKE_CURRENT_BINARY_DIR}") -project(gpt4all VERSION ${APP_VERSION} LANGUAGES CXX C) +project(gpt4all VERSION ${APP_VERSION_BASE} LANGUAGES CXX C) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTORCC ON) @@ -91,7 +92,6 @@ qt_add_executable(chat chatmodel.h chatlistmodel.h chatlistmodel.cpp chatapi.h chatapi.cpp database.h database.cpp - embeddings.h embeddings.cpp download.h download.cpp embllm.cpp embllm.h localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp @@ -102,6 +102,7 @@ qt_add_executable(chat server.h server.cpp logger.h logger.cpp responsetext.h responsetext.cpp + oscompat.h oscompat.cpp ${METAL_SHADER_FILE} ${APP_ICON_RESOURCE} ) @@ -112,21 +113,24 @@ qt_add_qml_module(chat NO_CACHEGEN QML_FILES main.qml + qml/AddCollectionView.qml + qml/AddModelView.qml qml/ChatDrawer.qml qml/ChatView.qml - qml/CollectionsDialog.qml - qml/ModelDownloaderDialog.qml + qml/CollectionsDrawer.qml + qml/HomeView.qml + qml/ModelsView.qml qml/NetworkDialog.qml qml/NewVersionDialog.qml qml/ThumbsDownDialog.qml - qml/SettingsDialog.qml + qml/SettingsView.qml qml/StartupDialog.qml qml/PopupDialog.qml - qml/AboutDialog.qml qml/Theme.qml qml/ModelSettings.qml qml/ApplicationSettings.qml qml/LocalDocsSettings.qml + qml/LocalDocsView.qml qml/SwitchModelDialog.qml qml/MySettingsTab.qml qml/MySettingsStack.qml @@ -138,33 +142,58 @@ qt_add_qml_module(chat qml/MyComboBox.qml qml/MyDialog.qml qml/MyDirectoryField.qml + qml/MyFancyLink.qml qml/MyTextArea.qml qml/MyTextField.qml qml/MyCheckBox.qml qml/MyBusyIndicator.qml qml/MyMiniButton.qml qml/MyToolButton.qml + qml/MyWelcomeButton.qml RESOURCES + icons/antenna_1.svg + icons/antenna_2.svg + icons/antenna_3.svg icons/send_message.svg icons/stop_generating.svg icons/regenerate.svg + icons/chat.svg + icons/changelog.svg icons/close.svg icons/copy.svg icons/db.svg + icons/discord.svg icons/download.svg icons/settings.svg icons/eject.svg icons/edit.svg + icons/email.svg + icons/file.svg + icons/file-md.svg + icons/file-pdf.svg + icons/file-txt.svg + icons/github.svg + icons/globe.svg + icons/home.svg icons/image.svg + icons/info.svg + icons/local-docs.svg + icons/models.svg + icons/nomic_logo.svg + icons/notes.svg + icons/search.svg icons/trash.svg icons/network.svg icons/thumbs_up.svg icons/thumbs_down.svg + icons/twitter.svg icons/left_panel_closed.svg icons/left_panel_open.svg icons/logo.svg icons/logo-32.png icons/logo-48.png + icons/you.svg + icons/alt_logo.svg ) set_target_properties(chat PROPERTIES @@ -190,6 +219,13 @@ endif() target_compile_definitions(chat PRIVATE $<$,$>:QT_QML_DEBUG>) + +# usearch uses the identifier 'slots' which conflicts with Qt's 'slots' keyword +target_compile_definitions(chat PRIVATE QT_NO_SIGNALS_SLOTS_KEYWORDS) + +target_include_directories(chat PRIVATE usearch/include + usearch/fp16/include) + if(LINUX) target_link_libraries(chat PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf Qt6::WaylandCompositor) @@ -200,6 +236,20 @@ endif() target_link_libraries(chat PRIVATE llmodel) + +# -- extra resources -- + +set(LOCAL_EMBEDDING_MODEL "nomic-embed-text-v1.5.f16.gguf") +set(LOCAL_EMBEDDING_MODEL_MD5 "a5401e7f7e46ed9fcaed5b60a281d547") +file(DOWNLOAD + "https://gpt4all.io/models/gguf/${LOCAL_EMBEDDING_MODEL}" + "${CMAKE_BINARY_DIR}/resources/${LOCAL_EMBEDDING_MODEL}" + EXPECTED_HASH "MD5=${LOCAL_EMBEDDING_MODEL_MD5}" +) + + +# -- install -- + set(COMPONENT_NAME_MAIN ${PROJECT_NAME}) if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) @@ -264,6 +314,10 @@ if (LLMODEL_CUDA) endif() endif() +install(FILES "${CMAKE_BINARY_DIR}/resources/${LOCAL_EMBEDDING_MODEL}" + DESTINATION resources + COMPONENT ${COMPONENT_NAME_MAIN}) + set(CPACK_GENERATOR "IFW") set(CPACK_VERBATIM_VARIABLES YES) set(CPACK_IFW_VERBOSE ON) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 42aa15bb..b5564657 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include @@ -130,7 +131,7 @@ void Chat::prompt(const QString &prompt) void Chat::regenerateResponse() { const int index = m_chatModel->count() - 1; - m_chatModel->updateReferences(index, QString(), QList()); + m_chatModel->updateSources(index, QList()); emit regenerateResponseRequested(); } @@ -193,43 +194,6 @@ void Chat::responseStopped(qint64 promptResponseMs) { m_tokenSpeed = QString(); emit tokenSpeedChanged(); - - const QString chatResponse = response(); - QList references; - QList referencesContext; - int validReferenceNumber = 1; - for (const ResultInfo &info : databaseResults()) { - if (info.file.isEmpty()) - continue; - if (validReferenceNumber == 1) - references.append((!chatResponse.endsWith("\n") ? "\n" : QString()) + QStringLiteral("\n---")); - QString reference; - { - QTextStream stream(&reference); - stream << (validReferenceNumber++) << ". "; - if (!info.title.isEmpty()) - stream << "\"" << info.title << "\". "; - if (!info.author.isEmpty()) - stream << "By " << info.author << ". "; - if (!info.date.isEmpty()) - stream << "Date: " << info.date << ". "; - stream << "In " << info.file << ". "; - if (info.page != -1) - stream << "Page " << info.page << ". "; - if (info.from != -1) { - stream << "Lines " << info.from; - if (info.to != -1) - stream << "-" << info.to; - stream << ". "; - } - stream << "[Context](context://" << validReferenceNumber - 1 << ")"; - } - references.append(reference); - referencesContext.append(info.text); - } - - const int index = m_chatModel->count() - 1; - m_chatModel->updateReferences(index, references.join("\n"), referencesContext); emit responseChanged(); m_responseInProgress = false; @@ -336,7 +300,7 @@ void Chat::generatedNameChanged(const QString &name) // Only use the first three words maximum and remove newlines and extra spaces m_generatedName = name.simplified(); QStringList words = m_generatedName.split(' ', Qt::SkipEmptyParts); - int wordCount = qMin(3, words.size()); + int wordCount = qMin(7, words.size()); m_name = words.mid(0, wordCount).join(' '); emit nameChanged(); } @@ -378,6 +342,8 @@ void Chat::handleFallbackReasonChanged(const QString &fallbackReason) void Chat::handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; + const int index = m_chatModel->count() - 1; + m_chatModel->updateSources(index, m_databaseResults); } void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) @@ -389,7 +355,8 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) emit modelInfoChanged(); } -void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) { +void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) +{ m_trySwitchContextInProgress = value; emit trySwitchContextInProgressChanged(); } diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 4f1e8c5a..9da04459 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -95,7 +95,7 @@ public: void unloadAndDeleteLater(); void markForDeletion(); - qint64 creationDate() const { return m_creationDate; } + QDateTime creationDate() const { return QDateTime::fromSecsSinceEpoch(m_creationDate); } bool serialize(QDataStream &stream, int version) const; bool deserialize(QDataStream &stream, int version); bool isServer() const { return m_isServer; } diff --git a/gpt4all-chat/chatapi.cpp b/gpt4all-chat/chatapi.cpp index 560eb55f..e9106dfc 100644 --- a/gpt4all-chat/chatapi.cpp +++ b/gpt4all-chat/chatapi.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,8 @@ #include +using namespace Qt::Literals::StringLiterals; + //#define DEBUG ChatAPI::ChatAPI() @@ -194,7 +197,7 @@ void ChatAPIWorker::request(const QString &apiKey, m_ctx = promptCtx; QUrl apiUrl(m_chat->url()); - const QString authorization = QString("Bearer %1").arg(apiKey).trimmed(); + const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed(); QNetworkRequest request(apiUrl); request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); request.setRawHeader("Authorization", authorization.toUtf8()); @@ -241,8 +244,8 @@ void ChatAPIWorker::handleReadyRead() if (!ok || code != 200) { m_chat->callResponse( -1, - QString("ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3") - .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString() + u"ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3"_s + .arg(code).arg(reply->errorString(), reply->readAll()).toStdString() ); emit finished(); return; @@ -263,7 +266,7 @@ void ChatAPIWorker::handleReadyRead() QJsonParseError err; const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); if (err.error != QJsonParseError::NoError) { - m_chat->callResponse(-1, QString("ERROR: ChatAPI responded with invalid json \"%1\"") + m_chat->callResponse(-1, u"ERROR: ChatAPI responded with invalid json \"%1\""_s .arg(err.errorString()).toStdString()); continue; } diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index 03d592cf..d5b7070a 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -13,12 +13,13 @@ #include #include #include +#include #include #include #define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 7 +#define CHAT_FORMAT_VERSION 8 class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) @@ -64,7 +65,6 @@ ChatSaver::ChatSaver() void ChatListModel::saveChats() { - const QString savePath = MySettings::globalInstance()->modelPath(); QVector toSave; for (Chat *chat : m_chats) { if (chat == m_serverChat) diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index 691b5780..43ccbb41 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -56,7 +56,8 @@ public: enum Roles { IdRole = Qt::UserRole + 1, - NameRole + NameRole, + SectionRole }; int rowCount(const QModelIndex &parent = QModelIndex()) const override @@ -76,6 +77,26 @@ public: return item->id(); case NameRole: return item->name(); + case SectionRole: { + if (item == m_serverChat) + return QString(); + const QDate date = QDate::currentDate(); + const QDate itemDate = item->creationDate().date(); + if (date == itemDate) + return tr("TODAY"); + else if (itemDate >= date.addDays(-7)) + return tr("THIS WEEK"); + else if (itemDate >= date.addMonths(-1)) + return tr("THIS MONTH"); + else if (itemDate >= date.addMonths(-6)) + return tr("LAST SIX MONTHS"); + else if (itemDate.year() == date.year()) + return tr("THIS YEAR"); + else if (itemDate.year() == date.year() - 1) + return tr("LAST YEAR"); + else + return QString::number(itemDate.year()); + } } return QVariant(); @@ -86,6 +107,7 @@ public: QHash roles; roles[IdRole] = "id"; roles[NameRole] = "name"; + roles[SectionRole] = "section"; return roles; } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 11c51041..581eaab5 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -28,9 +29,12 @@ #include #include #include +#include #include #include +using namespace Qt::Literals::StringLiterals; + //#define DEBUG //#define DEBUG_MODEL_LOADING @@ -180,7 +184,7 @@ bool ChatLLM::loadDefaultModel() { ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); if (defaultModel.filename().isEmpty()) { - emit modelLoadingError(QString("Could not find any model to load")); + emit modelLoadingError(u"Could not find any model to load"_qs); return false; } return loadModel(defaultModel); @@ -292,7 +296,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) setModelInfo(modelInfo); Q_ASSERT(!m_modelInfo.filename().isEmpty()); if (m_modelInfo.filename().isEmpty()) - emit modelLoadingError(QString("Modelinfo is left null for %1").arg(modelInfo.filename())); + emit modelLoadingError(u"Modelinfo is left null for %1"_s.arg(modelInfo.filename())); else processSystemPrompt(); return true; @@ -377,9 +381,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) static QSet warned; auto fname = modelInfo.filename(); if (!warned.contains(fname)) { - emit modelLoadingWarning(QString( - "%1 is known to be broken. Please get a replacement via the download dialog." - ).arg(fname)); + emit modelLoadingWarning( + u"%1 is known to be broken. Please get a replacement via the download dialog."_s.arg(fname) + ); warned.insert(fname); // don't warn again until restart } } @@ -485,7 +489,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (!m_isServer) LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); - emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); + emit modelLoadingError(u"Could not load model due to invalid model file for %1"_s.arg(modelInfo.filename())); modelLoadProps.insert("error", "loadmodel_failed"); } else { switch (m_llModelInfo.model->implementation().modelType()[0]) { @@ -497,7 +501,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (!m_isServer) LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); - emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename())); + emit modelLoadingError(u"Could not determine model type for %1"_s.arg(modelInfo.filename())); } } @@ -507,7 +511,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (!m_isServer) LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); - emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError)); + emit modelLoadingError(u"Error loading %1: %2"_s.arg(modelInfo.filename(), constructError)); } } #if defined(DEBUG_MODEL_LOADING) @@ -527,7 +531,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (!m_isServer) LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store m_llModelInfo = LLModelInfo(); - emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename())); + emit modelLoadingError(u"Could not find file for model %1"_s.arg(modelInfo.filename())); } if (m_llModelInfo.model) { @@ -542,7 +546,8 @@ bool ChatLLM::isModelLoaded() const return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); } -std::string remove_leading_whitespace(const std::string& input) { +std::string remove_leading_whitespace(const std::string& input) +{ auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { return !std::isspace(c); }); @@ -553,7 +558,8 @@ std::string remove_leading_whitespace(const std::string& input) { return std::string(first_non_whitespace, input.end()); } -std::string trim_whitespace(const std::string& input) { +std::string trim_whitespace(const std::string& input) +{ auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { return !std::isspace(c); }); @@ -706,11 +712,15 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString } // Augment the prompt template with the results if any - QList docsContext; - if (!databaseResults.isEmpty()) - docsContext.append("### Context:"); - for (const ResultInfo &info : databaseResults) - docsContext.append(info.text); + QString docsContext; + if (!databaseResults.isEmpty()) { + QStringList results; + for (const ResultInfo &info : databaseResults) + results << u"Collection: %1\nPath: %2\nSnippet: %3"_s.arg(info.collection, info.path, info.text); + + // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template + docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); + } int n_threads = MySettings::globalInstance()->threadCount(); @@ -738,7 +748,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->start(); if (!docsContext.isEmpty()) { auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response - m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx); + m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); @@ -836,7 +846,7 @@ void ChatLLM::generateName() auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1); LLModel::PromptContext ctx = m_ctx; - m_llModelInfo.model->prompt("Describe the above conversation in three words or less.", + m_llModelInfo.model->prompt("Describe the above conversation in seven words or less.", promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); std::string trimmed = trim_whitespace(m_nameResponse); if (trimmed != m_nameResponse) { diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 67791164..fde97f1c 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -22,6 +22,8 @@ #include #include +using namespace Qt::Literals::StringLiterals; + class QDataStream; enum LLModelType { @@ -68,7 +70,7 @@ private Q_SLOTS: void handleTimeout() { m_elapsed += m_time.restart(); - emit report(QString("%1 tokens/sec").arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2)); + emit report(u"%1 tokens/sec"_s.arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2)); } private: diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 2de88744..fa4f5506 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -1,6 +1,8 @@ #ifndef CHATMODEL_H #define CHATMODEL_H +#include "database.h" + #include #include #include @@ -26,17 +28,18 @@ struct ChatItem Q_PROPERTY(bool stopped MEMBER stopped) Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) - Q_PROPERTY(QString references MEMBER references) - Q_PROPERTY(QList referencesContext MEMBER referencesContext) + Q_PROPERTY(QList sources MEMBER sources) + Q_PROPERTY(QList consolidatedSources MEMBER consolidatedSources) public: + // TODO: Maybe we should include the model name here as well as timestamp? int id = 0; QString name; QString value; QString prompt; QString newResponse; - QString references; - QList referencesContext; + QList sources; + QList consolidatedSources; bool currentResponse = false; bool stopped = false; bool thumbsUpState = false; @@ -62,8 +65,8 @@ public: StoppedRole, ThumbsUpStateRole, ThumbsDownStateRole, - ReferencesRole, - ReferencesContextRole + SourcesRole, + ConsolidatedSourcesRole }; int rowCount(const QModelIndex &parent = QModelIndex()) const override @@ -97,10 +100,10 @@ public: return item.thumbsUpState; case ThumbsDownStateRole: return item.thumbsDownState; - case ReferencesRole: - return item.references; - case ReferencesContextRole: - return item.referencesContext; + case SourcesRole: + return QVariant::fromValue(item.sources); + case ConsolidatedSourcesRole: + return QVariant::fromValue(item.consolidatedSources); } return QVariant(); @@ -118,8 +121,8 @@ public: roles[StoppedRole] = "stopped"; roles[ThumbsUpStateRole] = "thumbsUpState"; roles[ThumbsDownStateRole] = "thumbsDownState"; - roles[ReferencesRole] = "references"; - roles[ReferencesContextRole] = "referencesContext"; + roles[SourcesRole] = "sources"; + roles[ConsolidatedSourcesRole] = "consolidatedSources"; return roles; } @@ -196,19 +199,28 @@ public: } } - Q_INVOKABLE void updateReferences(int index, const QString &references, const QList &referencesContext) + QList consolidateSources(const QList &sources) { + QMap groupedData; + for (const ResultInfo &info : sources) { + if (groupedData.contains(info.file)) { + groupedData[info.file].text += "\n---\n" + info.text; + } else { + groupedData[info.file] = info; + } + } + QList consolidatedSources = groupedData.values(); + return consolidatedSources; + } + + Q_INVOKABLE void updateSources(int index, const QList &sources) { if (index < 0 || index >= m_chatItems.size()) return; ChatItem &item = m_chatItems[index]; - if (item.references != references) { - item.references = references; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ReferencesRole}); - } - if (item.referencesContext != referencesContext) { - item.referencesContext = referencesContext; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ReferencesContextRole}); - } + item.sources = sources; + item.consolidatedSources = consolidateSources(sources); + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); } Q_INVOKABLE void updateThumbsUpState(int index, bool b) @@ -259,9 +271,56 @@ public: stream << c.stopped; stream << c.thumbsUpState; stream << c.thumbsDownState; - if (version > 2) { - stream << c.references; - stream << c.referencesContext; + if (version > 7) { + stream << c.sources.size(); + for (const ResultInfo &info : c.sources) { + Q_ASSERT(!info.file.isEmpty()); + stream << info.collection; + stream << info.path; + stream << info.file; + stream << info.title; + stream << info.author; + stream << info.date; + stream << info.text; + stream << info.page; + stream << info.from; + stream << info.to; + } + } else if (version > 2) { + QList references; + QList referencesContext; + int validReferenceNumber = 1; + for (const ResultInfo &info : c.sources) { + if (info.file.isEmpty()) + continue; + + QString reference; + { + QTextStream stream(&reference); + stream << (validReferenceNumber++) << ". "; + if (!info.title.isEmpty()) + stream << "\"" << info.title << "\". "; + if (!info.author.isEmpty()) + stream << "By " << info.author << ". "; + if (!info.date.isEmpty()) + stream << "Date: " << info.date << ". "; + stream << "In " << info.file << ". "; + if (info.page != -1) + stream << "Page " << info.page << ". "; + if (info.from != -1) { + stream << "Lines " << info.from; + if (info.to != -1) + stream << "-" << info.to; + stream << ". "; + } + stream << "[Context](context://" << validReferenceNumber - 1 << ")"; + } + references.append(reference); + referencesContext.append(info.text); + } + + stream << references.join("\n"); + stream << referencesContext; } } return stream.status() == QDataStream::Ok; @@ -282,9 +341,109 @@ public: stream >> c.stopped; stream >> c.thumbsUpState; stream >> c.thumbsDownState; - if (version > 2) { - stream >> c.references; - stream >> c.referencesContext; + if (version > 7) { + qsizetype count; + stream >> count; + QList sources; + for (int i = 0; i < count; ++i) { + ResultInfo info; + stream >> info.collection; + stream >> info.path; + stream >> info.file; + stream >> info.title; + stream >> info.author; + stream >> info.date; + stream >> info.text; + stream >> info.page; + stream >> info.from; + stream >> info.to; + sources.append(info); + } + c.sources = sources; + c.consolidatedSources = consolidateSources(sources); + }else if (version > 2) { + QString references; + QList referencesContext; + stream >> references; + stream >> referencesContext; + + if (!references.isEmpty()) { + QList sources; + QList referenceList = references.split("\n"); + + // Ignore empty lines and those that begin with "---" which is no longer used + for (auto it = referenceList.begin(); it != referenceList.end();) { + if (it->trimmed().isEmpty() || it->trimmed().startsWith("---")) + it = referenceList.erase(it); + else + ++it; + } + + Q_ASSERT(referenceList.size() == referencesContext.size()); + for (int j = 0; j < referenceList.size(); ++j) { + QString reference = referenceList[j]; + QString context = referencesContext[j]; + ResultInfo info; + QTextStream refStream(&reference); + QString dummy; + int validReferenceNumber; + refStream >> validReferenceNumber >> dummy; + // Extract title (between quotes) + if (reference.contains("\"")) { + int startIndex = reference.indexOf('"') + 1; + int endIndex = reference.indexOf('"', startIndex); + info.title = reference.mid(startIndex, endIndex - startIndex); + } + + // Extract author (after "By " and before the next period) + if (reference.contains("By ")) { + int startIndex = reference.indexOf("By ") + 3; + int endIndex = reference.indexOf('.', startIndex); + info.author = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract date (after "Date: " and before the next period) + if (reference.contains("Date: ")) { + int startIndex = reference.indexOf("Date: ") + 6; + int endIndex = reference.indexOf('.', startIndex); + info.date = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract file name (after "In " and before the "[Context]") + if (reference.contains("In ") && reference.contains(". [Context]")) { + int startIndex = reference.indexOf("In ") + 3; + int endIndex = reference.indexOf(". [Context]", startIndex); + info.file = reference.mid(startIndex, endIndex - startIndex).trimmed(); + } + + // Extract page number (after "Page " and before the next space) + if (reference.contains("Page ")) { + int startIndex = reference.indexOf("Page ") + 5; + int endIndex = reference.indexOf(' ', startIndex); + if (endIndex == -1) endIndex = reference.length(); + info.page = reference.mid(startIndex, endIndex - startIndex).toInt(); + } + + // Extract lines (after "Lines " and before the next space or hyphen) + if (reference.contains("Lines ")) { + int startIndex = reference.indexOf("Lines ") + 6; + int endIndex = reference.indexOf(' ', startIndex); + if (endIndex == -1) endIndex = reference.length(); + int hyphenIndex = reference.indexOf('-', startIndex); + if (hyphenIndex != -1 && hyphenIndex < endIndex) { + info.from = reference.mid(startIndex, hyphenIndex - startIndex).toInt(); + info.to = reference.mid(hyphenIndex + 1, endIndex - hyphenIndex - 1).toInt(); + } else { + info.from = reference.mid(startIndex, endIndex - startIndex).toInt(); + } + } + info.text = context; + sources.append(info); + } + + c.sources = sources; + c.consolidatedSources = consolidateSources(sources); + } } beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); m_chatItems.append(c); diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index b92f0667..8deb1175 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -1,262 +1,365 @@ #include "database.h" -#include "embeddings.h" -#include "modellist.h" #include "mysettings.h" -#include "network.h" -#include +#include + #include #include #include +#include #include #include #include #include #include #include -#include #include #include #include #include #include #include +#include #include +#include #include +#include #include #include +using namespace Qt::Literals::StringLiterals; +namespace us = unum::usearch; + //#define DEBUG //#define DEBUG_EXAMPLE -#define LOCALDOCS_VERSION 1 +namespace { -const auto INSERT_CHUNK_SQL = QLatin1String(R"( - insert into chunks(document_id, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to) - values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - )"); +/* QFile that checks input for binary data. If seen, it fails the read and returns true + * for binarySeen(). */ +class BinaryDetectingFile: public QFile { +public: + using QFile::QFile; -const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"( - insert into chunks_fts(document_id, chunk_id, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to) - values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - )"); + bool binarySeen() const { return m_binarySeen; } -const auto DELETE_CHUNKS_SQL = QLatin1String(R"( - delete from chunks WHERE document_id = ?; - )"); - -const auto DELETE_CHUNKS_FTS_SQL = QLatin1String(R"( - delete from chunks_fts WHERE document_id = ?; - )"); - -const auto CHUNKS_SQL = QLatin1String(R"( - create table chunks(document_id integer, chunk_id integer primary key autoincrement, chunk_text varchar, - file varchar, title varchar, author varchar, subject varchar, keywords varchar, - page integer, line_from integer, line_to integer); - )"); - -const auto FTS_CHUNKS_SQL = QLatin1String(R"( - create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to, tokenize="trigram"); - )"); - -const auto SELECT_CHUNKS_BY_DOCUMENT_SQL = QLatin1String(R"( - select chunk_id from chunks WHERE document_id = ?; - )"); - -const auto SELECT_CHUNKS_SQL = QLatin1String(R"( - select chunks.chunk_id, documents.document_time, - chunks.chunk_text, chunks.file, chunks.title, chunks.author, chunks.page, - chunks.line_from, chunks.line_to - from chunks - join documents ON chunks.document_id = documents.id - join folders ON documents.folder_id = folders.id - join collections ON folders.id = collections.folder_id - where chunks.chunk_id in (%1) and collections.collection_name in (%2); -)"); - -const auto SELECT_NGRAM_SQL = QLatin1String(R"( - select chunks_fts.chunk_id, documents.document_time, - chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page, - chunks_fts.line_from, chunks_fts.line_to - from chunks_fts - join documents ON chunks_fts.document_id = documents.id - join folders ON documents.folder_id = folders.id - join collections ON folders.id = collections.folder_id - where chunks_fts match ? and collections.collection_name in (%1) - order by bm25(chunks_fts) - limit %2; - )"); - -bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, - const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, - int page, int from, int to, int *chunk_id) -{ - { - if (!q.prepare(INSERT_CHUNK_SQL)) - return false; - q.addBindValue(document_id); - q.addBindValue(chunk_text); - q.addBindValue(file); - q.addBindValue(title); - q.addBindValue(author); - q.addBindValue(subject); - q.addBindValue(keywords); - q.addBindValue(page); - q.addBindValue(from); - q.addBindValue(to); - if (!q.exec()) - return false; +protected: + qint64 readData(char *data, qint64 maxSize) override { + qint64 res = QFile::readData(data, maxSize); + return checkData(data, res); } - if (!q.exec("select last_insert_rowid();")) + + qint64 readLineData(char *data, qint64 maxSize) override { + qint64 res = QFile::readLineData(data, maxSize); + return checkData(data, res); + } + +private: + qint64 checkData(const char *data, qint64 size) { + Q_ASSERT(!isTextModeEnabled()); // We need raw bytes from the underlying QFile + if (size != -1 && !m_binarySeen) { + for (qint64 i = 0; i < size; i++) { + /* Control characters we should never see in plain text: + * 0x00 NUL - 0x06 ACK + * 0x0E SO - 0x1A SUB + * 0x1C FS - 0x1F US */ + auto c = static_cast(data[i]); + if (c < 0x07 || (c >= 0x0E && c < 0x1B) || (c >= 0x1C && c < 0x20)) { + m_binarySeen = true; + break; + } + } + } + return m_binarySeen ? -1 : size; + } + + bool m_binarySeen = false; +}; + +} // namespace + +static int s_batchSize = 100; + +static const QString INIT_DB_SQL[] = { + // automatically free unused disk space + u"pragma auto_vacuum = FULL;"_s, + // create tables + uR"( + create table chunks( + id integer primary key autoincrement, + document_id integer not null, + chunk_text text not null, + file text not null, + title text, + author text, + subject text, + keywords text, + page integer, + line_from integer, + line_to integer, + words integer default 0 not null, + tokens integer default 0 not null, + foreign key(document_id) references documents(id) + ); + )"_s, uR"( + create table collections( + id integer primary key, + name text unique not null, + start_update_time integer, + last_update_time integer, + embedding_model text + ); + )"_s, uR"( + create table folders( + id integer primary key autoincrement, + path text unique not null + ); + )"_s, uR"( + create table collection_items( + collection_id integer not null, + folder_id integer not null, + foreign key(collection_id) references collections(id) + foreign key(folder_id) references folders(id), + unique(collection_id, folder_id) + ); + )"_s, uR"( + create table documents( + id integer primary key, + folder_id integer not null, + document_time integer not null, + document_path text unique not null, + foreign key(folder_id) references folders(id) + ); + )"_s, uR"( + create table embeddings( + model text not null, + folder_id integer not null, + chunk_id integer not null, + embedding blob not null, + primary key(model, folder_id, chunk_id), + foreign key(folder_id) references folders(id), + foreign key(chunk_id) references chunks(id), + unique(model, chunk_id) + ); + )"_s, +}; + +static const QString INSERT_CHUNK_SQL = uR"( + insert into chunks(document_id, chunk_text, + file, title, author, subject, keywords, page, line_from, line_to, words) + values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + returning id; + )"_s; + +static const QString DELETE_CHUNKS_SQL[] = { + uR"( + delete from embeddings + where chunk_id in ( + select id from chunks where document_id = ? + ); + )"_s, uR"( + delete from chunks where document_id = ?; + )"_s, +}; + +static const QString SELECT_CHUNKS_BY_DOCUMENT_SQL = uR"( + select id from chunks WHERE document_id = ?; + )"_s; + +static const QString SELECT_CHUNKS_SQL = uR"( + select c.id, d.document_time, d.document_path, c.chunk_text, c.file, c.title, c.author, c.page, c.line_from, c.line_to, co.name + from chunks c + join documents d on d.id = c.document_id + join folders f on f.id = d.folder_id + join collection_items ci on ci.folder_id = f.id + join collections co on co.id = ci.collection_id + where c.id in (%1); +)"_s; + +static const QString SELECT_UNCOMPLETED_CHUNKS_SQL = uR"( + select co.name, co.embedding_model, c.id, d.folder_id, c.chunk_text + from chunks c + join documents d on d.id = c.document_id + join folders f on f.id = d.folder_id + join collection_items ci on ci.folder_id = f.id + join collections co on co.id = ci.collection_id and co.embedding_model is not null + where not exists( + select 1 + from embeddings e + where e.chunk_id = c.id and e.model = co.embedding_model + ); + )"_s; + +static const QString SELECT_COUNT_CHUNKS_SQL = uR"( + select count(c.id) + from chunks c + join documents d on d.id = c.document_id + where d.folder_id = ?; + )"_s; + +static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, + int page, int from, int to, int words, int *chunk_id) +{ + if (!q.prepare(INSERT_CHUNK_SQL)) return false; - if (!q.next()) + q.addBindValue(document_id); + q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + q.addBindValue(page); + q.addBindValue(from); + q.addBindValue(to); + q.addBindValue(words); + if (!q.exec() || !q.next()) return false; *chunk_id = q.value(0).toInt(); - { - if (!q.prepare(INSERT_CHUNK_FTS_SQL)) + return true; +} + +static bool removeChunksByDocumentId(QSqlQuery &q, int document_id) +{ + for (const auto &cmd: DELETE_CHUNKS_SQL) { + if (!q.prepare(cmd)) return false; q.addBindValue(document_id); - q.addBindValue(*chunk_id); - q.addBindValue(chunk_text); - q.addBindValue(file); - q.addBindValue(title); - q.addBindValue(author); - q.addBindValue(subject); - q.addBindValue(keywords); - q.addBindValue(page); - q.addBindValue(from); - q.addBindValue(to); if (!q.exec()) return false; } return true; } -bool removeChunksByDocumentId(QSqlQuery &q, int document_id) +#define NAMED_PAIR(name, typea, a, typeb, b) \ + struct name { typea a; typeb b; }; \ + static bool operator==(const name &x, const name &y) { return x.a == y.a && x.b == y.b; } \ + static size_t qHash(const name &x, size_t seed) { return qHashMulti(seed, x.a, x.b); } + +// struct compared by embedding key, can be extended with additional unique data +NAMED_PAIR(EmbeddingKey, QString, embedding_model, int, chunk_id) + +namespace { + struct IncompleteChunk: EmbeddingKey { int folder_id; QString text; }; +} // namespace + +static bool selectAllUncompletedChunks(QSqlQuery &q, QHash &chunks) { - { - if (!q.prepare(DELETE_CHUNKS_SQL)) - return false; - q.addBindValue(document_id); - if (!q.exec()) - return false; + if (!q.exec(SELECT_UNCOMPLETED_CHUNKS_SQL)) + return false; + while (q.next()) { + QString collection = q.value(0).toString(); + IncompleteChunk ic { + /*embedding_model =*/ q.value(1).toString(), + /*chunk_id =*/ q.value(2).toInt(), + /*folder_id =*/ q.value(3).toInt(), + /*text =*/ q.value(4).toString(), + }; + chunks[ic] << collection; } - - { - if (!q.prepare(DELETE_CHUNKS_FTS_SQL)) - return false; - q.addBindValue(document_id); - if (!q.exec()) - return false; - } - return true; } -QStringList generateGrams(const QString &input, int N) +static bool selectCountChunks(QSqlQuery &q, int folder_id, int &count) { - // Remove common English punctuation using QRegularExpression - static QRegularExpression punctuation(R"([.,;:!?'"()\-])"); - QString cleanedInput = input; - cleanedInput = cleanedInput.remove(punctuation); - - // Split the cleaned input into words using whitespace - static QRegularExpression spaces("\\s+"); - QStringList words = cleanedInput.split(spaces, Qt::SkipEmptyParts); - N = qMin(words.size(), N); - - // Generate all possible N-grams - QStringList ngrams; - for (int i = 0; i < words.size() - (N - 1); ++i) { - QStringList currentNgram; - for (int j = 0; j < N; ++j) { - currentNgram.append("\"" + words[i + j] + "\""); - } - ngrams.append("NEAR(" + currentNgram.join(" ") + ", " + QString::number(N) + ")"); + if (!q.prepare(SELECT_COUNT_CHUNKS_SQL)) + return false; + q.addBindValue(folder_id); + if (!q.exec()) + return false; + if (!q.next()) { + count = 0; + return false; } - return ngrams; + count = q.value(0).toInt(); + return true; } -bool selectChunk(QSqlQuery &q, const QList &collection_names, const std::vector &chunk_ids, int retrievalSize) +static bool selectChunk(QSqlQuery &q, const QList &chunk_ids, int retrievalSize) { QString chunk_ids_str = QString::number(chunk_ids[0]); for (size_t i = 1; i < chunk_ids.size(); ++i) chunk_ids_str += "," + QString::number(chunk_ids[i]); - const QString collection_names_str = collection_names.join("', '"); - const QString formatted_query = SELECT_CHUNKS_SQL.arg(chunk_ids_str).arg("'" + collection_names_str + "'"); + const QString formatted_query = SELECT_CHUNKS_SQL.arg(chunk_ids_str); if (!q.prepare(formatted_query)) return false; return q.exec(); } -bool selectChunk(QSqlQuery &q, const QList &collection_names, const QString &chunk_text, int retrievalSize) -{ - static QRegularExpression spaces("\\s+"); - const int N_WORDS = chunk_text.split(spaces).size(); - for (int N = N_WORDS; N > 2; N--) { - // first try trigrams - QList text = generateGrams(chunk_text, N); - QString orText = text.join(" OR "); - const QString collection_names_str = collection_names.join("', '"); - const QString formatted_query = SELECT_NGRAM_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize)); - if (!q.prepare(formatted_query)) - return false; - q.addBindValue(orText); - bool success = q.exec(); - if (!success) return false; - if (q.next()) { -#if defined(DEBUG) - qDebug() << "hit on" << N << "before" << chunk_text << "after" << orText; -#endif - q.previous(); - return true; - } - } - return true; -} +static const QString INSERT_COLLECTION_SQL = uR"( + insert into collections(name, start_update_time, last_update_time, embedding_model) + values(?, ?, ?, ?) + returning id; + )"_s; -const auto INSERT_COLLECTION_SQL = QLatin1String(R"( - insert into collections(collection_name, folder_id) values(?, ?); - )"); +static const QString DELETE_COLLECTION_SQL = uR"( + delete from collections where name = ? and folder_id = ?; + )"_s; -const auto DELETE_COLLECTION_SQL = QLatin1String(R"( - delete from collections where collection_name = ? and folder_id = ?; - )"); +static const QString SELECT_FOLDERS_FROM_COLLECTIONS_SQL = uR"( + select f.id, f.path + from collections c + join collection_items ci on ci.collection_id = c.id + join folders f on ci.folder_id = f.id + where c.name = ?; + )"_s; -const auto COLLECTIONS_SQL = QLatin1String(R"( - create table collections(collection_name varchar, folder_id integer, unique(collection_name, folder_id)); - )"); - -const auto SELECT_FOLDERS_FROM_COLLECTIONS_SQL = QLatin1String(R"( - select folder_id from collections where collection_name = ?; - )"); - -const auto SELECT_COLLECTIONS_FROM_FOLDER_SQL = QLatin1String(R"( - select collection_name from collections where folder_id = ?; - )"); - -const auto SELECT_COLLECTIONS_SQL = QLatin1String(R"( +static const QString SELECT_COLLECTIONS_SQL_V1 = uR"( select c.collection_name, f.folder_path, f.id from collections c join folders f on c.folder_id = f.id order by c.collection_name asc, f.folder_path asc; - )"); + )"_s; -bool addCollection(QSqlQuery &q, const QString &collection_name, int folder_id) +static const QString SELECT_COLLECTIONS_SQL_V2 = uR"( + select c.id, c.name, f.path, f.id, c.start_update_time, c.last_update_time, c.embedding_model + from collections c + join collection_items ci on ci.collection_id = c.id + join folders f on ci.folder_id = f.id + order by c.name asc, f.path asc; + )"_s; + +static const QString SELECT_COLLECTION_BY_NAME_SQL = uR"( + select id, name, start_update_time, last_update_time, embedding_model + from collections c + where name = ?; + )"_s; + +static const QString SET_COLLECTION_EMBEDDING_MODEL_SQL = uR"( + update collections + set embedding_model = ? + where name = ?; + )"_s; + +static const QString UPDATE_START_UPDATE_TIME_SQL = uR"( + update collections set start_update_time = ? where id = ?; +)"_s; + +static const QString UPDATE_LAST_UPDATE_TIME_SQL = uR"( + update collections set last_update_time = ? where id = ?; +)"_s; + +static bool addCollection(QSqlQuery &q, const QString &collection_name, const QDateTime &start_update, + const QDateTime &last_update, const QString &embedding_model, CollectionItem &item) { if (!q.prepare(INSERT_COLLECTION_SQL)) return false; q.addBindValue(collection_name); - q.addBindValue(folder_id); - return q.exec(); + q.addBindValue(start_update); + q.addBindValue(last_update); + q.addBindValue(embedding_model); + if (!q.exec() || !q.next()) + return false; + item.collection_id = q.value(0).toInt(); + item.collection = collection_name; + item.embeddingModel = embedding_model; + return true; } -bool removeCollection(QSqlQuery &q, const QString &collection_name, int folder_id) +static bool removeCollection(QSqlQuery &q, const QString &collection_name, int folder_id) { if (!q.prepare(DELETE_COLLECTION_SQL)) return false; @@ -265,70 +368,171 @@ bool removeCollection(QSqlQuery &q, const QString &collection_name, int folder_i return q.exec(); } -bool selectFoldersFromCollection(QSqlQuery &q, const QString &collection_name, QList *folderIds) { +static bool selectFoldersFromCollection(QSqlQuery &q, const QString &collection_name, QList> *folders) +{ if (!q.prepare(SELECT_FOLDERS_FROM_COLLECTIONS_SQL)) return false; q.addBindValue(collection_name); if (!q.exec()) return false; while (q.next()) - folderIds->append(q.value(0).toInt()); + folders->append({q.value(0).toInt(), q.value(1).toString()}); return true; } -bool selectCollectionsFromFolder(QSqlQuery &q, int folder_id, QList *collections) { - if (!q.prepare(SELECT_COLLECTIONS_FROM_FOLDER_SQL)) - return false; - q.addBindValue(folder_id); - if (!q.exec()) - return false; - while (q.next()) - collections->append(q.value(0).toString()); - return true; -} - -bool selectAllFromCollections(QSqlQuery &q, QList *collections) { - if (!q.prepare(SELECT_COLLECTIONS_SQL)) - return false; - if (!q.exec()) - return false; +static QList sqlExtractCollections(QSqlQuery &q, bool with_folder = false, int version = LOCALDOCS_VERSION) +{ + QList collections; while (q.next()) { CollectionItem i; - i.collection = q.value(0).toString(); - i.folder_path = q.value(1).toString(); - i.folder_id = q.value(2).toInt(); + int idx = 0; + if (version >= 2) + i.collection_id = q.value(idx++).toInt(); + i.collection = q.value(idx++).toString(); + if (with_folder) { + i.folder_path = q.value(idx++).toString(); + i.folder_id = q.value(idx++).toInt(); + } i.indexing = false; i.installed = true; - collections->append(i); + + if (version >= 2) { + bool ok; + const qint64 start_update = q.value(idx++).toLongLong(&ok); + if (ok) i.startUpdate = QDateTime::fromMSecsSinceEpoch(start_update); + const qint64 last_update = q.value(idx++).toLongLong(&ok); + if (ok) i.lastUpdate = QDateTime::fromMSecsSinceEpoch(last_update); + + i.embeddingModel = q.value(idx++).toString(); + } + if (i.embeddingModel.isNull()) { + // unknown embedding model -> need to re-index + i.forceIndexing = true; + } + + collections << i; } + return collections; +} + +static bool selectAllFromCollections(QSqlQuery &q, QList *collections, int version = LOCALDOCS_VERSION) +{ + + switch (version) { + case 1: + if (!q.prepare(SELECT_COLLECTIONS_SQL_V1)) + return false; + break; + case 2: + if (!q.prepare(SELECT_COLLECTIONS_SQL_V2)) + return false; + break; + default: + Q_UNREACHABLE(); + return false; + } + + if (!q.exec()) + return false; + *collections = sqlExtractCollections(q, true, version); return true; } -const auto INSERT_FOLDERS_SQL = QLatin1String(R"( - insert into folders(folder_path) values(?); - )"); +static bool selectCollectionByName(QSqlQuery &q, const QString &name, std::optional &collection) +{ + if (!q.prepare(SELECT_COLLECTION_BY_NAME_SQL)) + return false; + q.addBindValue(name); + if (!q.exec()) + return false; + QList collections = sqlExtractCollections(q); + Q_ASSERT(collections.count() <= 1); + collection.reset(); + if (!collections.isEmpty()) + collection = collections.first(); + return true; +} -const auto DELETE_FOLDERS_SQL = QLatin1String(R"( +static bool setCollectionEmbeddingModel(QSqlQuery &q, const QString &collection_name, const QString &embedding_model) +{ + if (!q.prepare(SET_COLLECTION_EMBEDDING_MODEL_SQL)) + return false; + q.addBindValue(embedding_model); + q.addBindValue(collection_name); + return q.exec(); +} + +static bool updateStartUpdateTime(QSqlQuery &q, int id, qint64 update_time) +{ + if (!q.prepare(UPDATE_START_UPDATE_TIME_SQL)) + return false; + q.addBindValue(update_time); + q.addBindValue(id); + return q.exec(); +} + +static bool updateLastUpdateTime(QSqlQuery &q, int id, qint64 update_time) +{ + if (!q.prepare(UPDATE_LAST_UPDATE_TIME_SQL)) + return false; + q.addBindValue(update_time); + q.addBindValue(id); + return q.exec(); +} + +static const QString INSERT_FOLDERS_SQL = uR"( + insert into folders(path) values(?); + )"_s; + +static const QString DELETE_FOLDERS_SQL = uR"( delete from folders where id = ?; - )"); + )"_s; -const auto SELECT_FOLDERS_FROM_PATH_SQL = QLatin1String(R"( - select id from folders where folder_path = ?; - )"); +static const QString SELECT_FOLDERS_FROM_PATH_SQL = uR"( + select id from folders where path = ?; + )"_s; -const auto SELECT_FOLDERS_FROM_ID_SQL = QLatin1String(R"( - select folder_path from folders where id = ?; - )"); +static const QString GET_FOLDER_EMBEDDING_MODEL_SQL = uR"( + select co.embedding_model + from collections co + join collection_items ci on ci.collection_id = co.id + where ci.folder_id = ?; + )"_s; -const auto SELECT_ALL_FOLDERPATHS_SQL = QLatin1String(R"( - select folder_path from folders; - )"); +static const QString SELECT_ALL_FOLDERPATHS_SQL = uR"( + select path from folders; + )"_s; -const auto FOLDERS_SQL = QLatin1String(R"( - create table folders(id integer primary key, folder_path varchar unique); - )"); +static const QString FOLDER_REMOVE_ALL_DOCS_SQL[] = { + uR"( + delete from embeddings + where chunk_id in ( + select c.id + from chunks c + join documents d on d.id = c.document_id + join folders f on f.id = d.folder_id + where f.path = ? + ); + )"_s, uR"( + delete from chunks + where document_id in ( + select d.id + from documents d + join folders f on f.id = d.folder_id + where f.path = ? + ); + )"_s, uR"( + delete from documents + where id in ( + select d.id + from documents d + join folders f on f.id = d.folder_id + where f.path = ? + ); + )"_s, +}; -bool addFolderToDB(QSqlQuery &q, const QString &folder_path, int *folder_id) +static bool addFolderToDB(QSqlQuery &q, const QString &folder_path, int *folder_id) { if (!q.prepare(INSERT_FOLDERS_SQL)) return false; @@ -339,14 +543,16 @@ bool addFolderToDB(QSqlQuery &q, const QString &folder_path, int *folder_id) return true; } -bool removeFolderFromDB(QSqlQuery &q, int folder_id) { +static bool removeFolderFromDB(QSqlQuery &q, int folder_id) +{ if (!q.prepare(DELETE_FOLDERS_SQL)) return false; q.addBindValue(folder_id); return q.exec(); } -bool selectFolder(QSqlQuery &q, const QString &folder_path, int *id) { +static bool selectFolder(QSqlQuery &q, const QString &folder_path, int *id) +{ if (!q.prepare(SELECT_FOLDERS_FROM_PATH_SQL)) return false; q.addBindValue(folder_path); @@ -358,19 +564,21 @@ bool selectFolder(QSqlQuery &q, const QString &folder_path, int *id) { return true; } -bool selectFolder(QSqlQuery &q, int id, QString *folder_path) { - if (!q.prepare(SELECT_FOLDERS_FROM_ID_SQL)) +static bool sqlGetFolderEmbeddingModel(QSqlQuery &q, int id, QString &embedding_model) +{ + if (!q.prepare(GET_FOLDER_EMBEDDING_MODEL_SQL)) return false; q.addBindValue(id); - if (!q.exec()) + if (!q.exec() || !q.next()) return false; + // FIXME(jared): there may be more than one if a folder is shared between collections Q_ASSERT(q.size() < 2); - if (q.next()) - *folder_path = q.value(0).toString(); + embedding_model = q.value(0).toString(); return true; } -bool selectAllFolderPaths(QSqlQuery &q, QList *folder_paths) { +static bool selectAllFolderPaths(QSqlQuery &q, QList *folder_paths) +{ if (!q.prepare(SELECT_ALL_FOLDERPATHS_SQL)) return false; if (!q.exec()) @@ -380,35 +588,96 @@ bool selectAllFolderPaths(QSqlQuery &q, QList *folder_paths) { return true; } -const auto INSERT_DOCUMENTS_SQL = QLatin1String(R"( +static bool sqlRemoveDocsByFolderPath(QSqlQuery &q, const QString &path) +{ + for (const auto &cmd: FOLDER_REMOVE_ALL_DOCS_SQL) { + if (!q.prepare(cmd)) + return false; + q.addBindValue(path); + if (!q.exec()) + return false; + } + return true; +} + +static const QString INSERT_COLLECTION_ITEM_SQL = uR"( + insert into collection_items(collection_id, folder_id) + values(?, ?) + on conflict do nothing; +)"_s; + +static const QString DELETE_COLLECTION_FOLDER_SQL = uR"( + delete from collection_items + where collection_id = (select id from collections where name = :name) and folder_id = :folder_id + returning (select count(*) from collection_items where folder_id = :folder_id); +)"_s; + +static const QString PRUNE_COLLECTIONS_SQL = uR"( + delete from collections + where id not in (select collection_id from collection_items); +)"_s; + +// 0 = already exists, 1 = added, -1 = error +static int addCollectionItem(QSqlQuery &q, int collection_id, int folder_id) +{ + if (!q.prepare(INSERT_COLLECTION_ITEM_SQL)) + return -1; + q.addBindValue(collection_id); + q.addBindValue(folder_id); + if (q.exec()) + return q.numRowsAffected(); + return -1; +} + +// returns the number of remaining references to the folder, or -1 on error +static int removeCollectionFolder(QSqlQuery &q, const QString &collection_name, int folder_id) +{ + if (!q.prepare(DELETE_COLLECTION_FOLDER_SQL)) + return -1; + q.bindValue(":name", collection_name); + q.bindValue(":folder_id", folder_id); + if (!q.exec() || !q.next()) + return -1; + return q.value(0).toInt(); +} + +static bool sqlPruneCollections(QSqlQuery &q) +{ + return q.exec(PRUNE_COLLECTIONS_SQL); +} + +static const QString INSERT_DOCUMENTS_SQL = uR"( insert into documents(folder_id, document_time, document_path) values(?, ?, ?); - )"); + )"_s; -const auto UPDATE_DOCUMENT_TIME_SQL = QLatin1String(R"( +static const QString UPDATE_DOCUMENT_TIME_SQL = uR"( update documents set document_time = ? where id = ?; - )"); + )"_s; -const auto DELETE_DOCUMENTS_SQL = QLatin1String(R"( +static const QString DELETE_DOCUMENTS_SQL = uR"( delete from documents where id = ?; - )"); + )"_s; -const auto DOCUMENTS_SQL = QLatin1String(R"( - create table documents(id integer primary key, folder_id integer, document_time integer, document_path varchar unique); - )"); - -const auto SELECT_DOCUMENT_SQL = QLatin1String(R"( +static const QString SELECT_DOCUMENT_SQL = uR"( select id, document_time from documents where document_path = ?; - )"); + )"_s; -const auto SELECT_DOCUMENTS_SQL = QLatin1String(R"( +static const QString SELECT_DOCUMENTS_SQL = uR"( select id from documents where folder_id = ?; - )"); + )"_s; -const auto SELECT_ALL_DOCUMENTS_SQL = QLatin1String(R"( +static const QString SELECT_ALL_DOCUMENTS_SQL = uR"( select id, document_path from documents; - )"); + )"_s; -bool addDocument(QSqlQuery &q, int folder_id, qint64 document_time, const QString &document_path, int *document_id) +static const QString SELECT_COUNT_STATISTICS_SQL = uR"( + select count(distinct d.id), sum(c.words), sum(c.tokens) + from documents d + left join chunks c on d.id = c.document_id + where d.folder_id = ?; + )"_s; + +static bool addDocument(QSqlQuery &q, int folder_id, qint64 document_time, const QString &document_path, int *document_id) { if (!q.prepare(INSERT_DOCUMENTS_SQL)) return false; @@ -421,14 +690,15 @@ bool addDocument(QSqlQuery &q, int folder_id, qint64 document_time, const QStrin return true; } -bool removeDocument(QSqlQuery &q, int document_id) { +static bool removeDocument(QSqlQuery &q, int document_id) +{ if (!q.prepare(DELETE_DOCUMENTS_SQL)) return false; q.addBindValue(document_id); return q.exec(); } -bool updateDocument(QSqlQuery &q, int id, qint64 document_time) +static bool updateDocument(QSqlQuery &q, int id, qint64 document_time) { if (!q.prepare(UPDATE_DOCUMENT_TIME_SQL)) return false; @@ -437,7 +707,8 @@ bool updateDocument(QSqlQuery &q, int id, qint64 document_time) return q.exec(); } -bool selectDocument(QSqlQuery &q, const QString &document_path, int *id, qint64 *document_time) { +static bool selectDocument(QSqlQuery &q, const QString &document_path, int *id, qint64 *document_time) +{ if (!q.prepare(SELECT_DOCUMENT_SQL)) return false; q.addBindValue(document_path); @@ -451,7 +722,8 @@ bool selectDocument(QSqlQuery &q, const QString &document_path, int *id, qint64 return true; } -bool selectDocuments(QSqlQuery &q, int folder_id, QList *documentIds) { +static bool selectDocuments(QSqlQuery &q, int folder_id, QList *documentIds) +{ if (!q.prepare(SELECT_DOCUMENTS_SQL)) return false; q.addBindValue(folder_id); @@ -462,107 +734,221 @@ bool selectDocuments(QSqlQuery &q, int folder_id, QList *documentIds) { return true; } -QSqlError initDb() +static bool selectCountStatistics(QSqlQuery &q, int folder_id, int *total_docs, int *total_words, int *total_tokens) { - QString dbPath = MySettings::globalInstance()->modelPath() - + QString("localdocs_v%1.db").arg(LOCALDOCS_VERSION); - QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE"); - db.setDatabaseName(dbPath); - - if (!db.open()) - return db.lastError(); - - QStringList tables = db.tables(); - if (tables.contains("chunks", Qt::CaseInsensitive)) - return QSqlError(); - - QSqlQuery q; - if (!q.exec(CHUNKS_SQL)) - return q.lastError(); - - if (!q.exec(FTS_CHUNKS_SQL)) - return q.lastError(); - - if (!q.exec(COLLECTIONS_SQL)) - return q.lastError(); - - if (!q.exec(FOLDERS_SQL)) - return q.lastError(); - - if (!q.exec(DOCUMENTS_SQL)) - return q.lastError(); - -#if defined(DEBUG_EXAMPLE) - // Add a folder - QString folder_path = "/example/folder"; - int folder_id; - if (!addFolderToDB(q, folder_path, &folder_id)) { - qDebug() << "Error adding folder:" << q.lastError().text(); - return q.lastError(); + if (!q.prepare(SELECT_COUNT_STATISTICS_SQL)) + return false; + q.addBindValue(folder_id); + if (!q.exec()) + return false; + if (q.next()) { + *total_docs = q.value(0).toInt(); + *total_words = q.value(1).toInt(); + *total_tokens = q.value(2).toInt(); } - - // Add a collection - QString collection_name = "Example Collection"; - if (!addCollection(q, collection_name, folder_id)) { - qDebug() << "Error adding collection:" << q.lastError().text(); - return q.lastError(); - } - - CollectionItem i; - i.collection = collection_name; - i.folder_path = folder_path; - i.folder_id = folder_id; - emit addCollectionItem(i, false); - - // Add a document - int document_time = 123456789; - int document_id; - QString document_path = "/example/folder/document1.txt"; - if (!addDocument(q, folder_id, document_time, document_path, &document_id)) { - qDebug() << "Error adding document:" << q.lastError().text(); - return q.lastError(); - } - - // Add chunks to the document - QString chunk_text1 = "This is an example chunk."; - QString chunk_text2 = "Another example chunk."; - QString embedding_path = "/example/embeddings/embedding1.bin"; - QString file = "document1.txt"; - QString title; - QString author; - QString subject; - QString keywords; - int page = -1; - int from = -1; - int to = -1;; - int embedding_id = 1; - - if (!addChunk(q, document_id, 1, chunk_text1, file, title, author, subject, keywords, page, from, to, embedding_id, embedding_path) || - !addChunk(q, document_id, 2, chunk_text2, file, title, author, subject, keywords, page, from, to, embedding_id, embedding_path)) { - qDebug() << "Error adding chunks:" << q.lastError().text(); - return q.lastError(); - } - - // Perform a search - QList collection_names = {collection_name}; - QString search_text = "example"; - if (!selectChunk(q, collection_names, search_text, 3)) { - qDebug() << "Error selecting chunks:" << q.lastError().text(); - return q.lastError(); - } -#endif - - return QSqlError(); + return true; } -Database::Database(int chunkSize) +// insert embedding only if still needed +static const QString INSERT_EMBEDDING_SQL = uR"( + insert into embeddings(model, folder_id, chunk_id, embedding) + select :model, d.folder_id, :chunk_id, :embedding + from chunks c + join documents d on d.id = c.document_id + join folders f on f.id = d.folder_id + join collection_items ci on ci.folder_id = f.id + join collections co on co.id = ci.collection_id + where co.embedding_model = :model and c.id = :chunk_id + limit 1; +)"_s; + +static const QString GET_COLLECTION_EMBEDDINGS_SQL = uR"( + select e.chunk_id, e.embedding + from embeddings e + join collections co on co.embedding_model = e.model + join collection_items ci on ci.folder_id = e.folder_id and ci.collection_id = co.id + where co.name in ('%1'); +)"_s; + +static const QString GET_CHUNK_FILE_SQL = uR"( + select file from chunks where id = ?; +)"_s; + +namespace { + struct Embedding { QString model; int folder_id; int chunk_id; QByteArray data; }; + struct EmbeddingStat { QString lastFile; int nAdded; int nSkipped; }; +} // namespace + +NAMED_PAIR(EmbeddingFolder, QString, embedding_model, int, folder_id) + +static bool sqlAddEmbeddings(QSqlQuery &q, const QList &embeddings, QHash &embeddingStats) +{ + if (!q.prepare(INSERT_EMBEDDING_SQL)) + return false; + + // insert embedding if needed + for (const auto &e: embeddings) { + q.bindValue(":model", e.model); + q.bindValue(":chunk_id", e.chunk_id); + q.bindValue(":embedding", e.data); + if (!q.exec()) + return false; + + auto &stat = embeddingStats[{ e.model, e.folder_id }]; + if (q.numRowsAffected()) { + stat.nAdded++; // embedding added + } else { + stat.nSkipped++; // embedding no longer needed + } + } + + if (!q.prepare(GET_CHUNK_FILE_SQL)) + return false; + + // populate statistics for each collection item + for (const auto &e: embeddings) { + auto &stat = embeddingStats[{ e.model, e.folder_id }]; + if (stat.nAdded && stat.lastFile.isNull()) { + q.addBindValue(e.chunk_id); + if (!q.exec() || !q.next()) + return false; + stat.lastFile = q.value(0).toString(); + } + } + + return true; +} + +void Database::transaction() +{ + bool ok = m_db.transaction(); + Q_ASSERT(ok); +} + +void Database::commit() +{ + bool ok = m_db.commit(); + Q_ASSERT(ok); +} + +void Database::rollback() +{ + bool ok = m_db.rollback(); + Q_ASSERT(ok); +} + +bool Database::hasContent() +{ + return m_db.tables().contains("chunks", Qt::CaseInsensitive); +} + +int Database::openDatabase(const QString &modelPath, bool create, int ver) +{ + if (m_db.isOpen()) + m_db.close(); + auto dbPath = u"%1/localdocs_v%2.db"_s.arg(modelPath).arg(ver); + if (!create && !QFileInfo::exists(dbPath)) + return 0; + m_db.setDatabaseName(dbPath); + if (!m_db.open()) { + qWarning() << "ERROR: opening db" << m_db.lastError(); + return -1; + } + return hasContent(); +} + +bool Database::openLatestDb(const QString &modelPath, QList &oldCollections) +{ + /* + * Support upgrade path from older versions: + * + * 1. Detect and load dbPath with older versions + * 2. Provide versioned SQL select statements + * 3. Upgrade the tables to the new version + * 4. By default mark all collections of older versions as force indexing and present to the user + * the an 'update' button letting them know a breaking change happened and that the collection + * will need to be indexed again + */ + + int dbVer; + for (dbVer = LOCALDOCS_VERSION;; dbVer--) { + if (dbVer < LOCALDOCS_MIN_VER) return true; // create a new db + int res = openDatabase(modelPath, false, dbVer); + if (res == 1) break; // found one with content + if (res == -1) return false; // error + } + + if (dbVer == LOCALDOCS_VERSION) return true; // already up-to-date + + // If we're upgrading, then we need to do a select on the current version of the collections table, + // then create the new one and populate the collections table and mark them as needing forced + // indexing + +#if defined(DEBUG) + qDebug() << "Older localdocs version found" << dbVer << "upgrade to" << LOCALDOCS_VERSION; +#endif + + // Select the current collections which will be marked to force indexing + QSqlQuery q(m_db); + if (!selectAllFromCollections(q, &oldCollections, dbVer)) { + qWarning() << "ERROR: Could not open select old collections" << q.lastError(); + return false; + } + + m_db.close(); + return true; +} + +bool Database::initDb(const QString &modelPath, const QList &oldCollections) +{ + if (!m_db.isOpen()) { + int res = openDatabase(modelPath); + if (res == 1) return true; // already populated + if (res == -1) return false; // error + } else if (hasContent()) { + return true; // already populated + } + + transaction(); + + QSqlQuery q(m_db); + for (const auto &cmd: INIT_DB_SQL) { + if (!q.exec(cmd)) { + qWarning() << "ERROR: failed to create tables" << q.lastError(); + rollback(); + return false; + } + } + + /* These are collection items that came from an older version of localdocs which + * require forced indexing that should only be done when the user has explicitly asked + * for them to be indexed again */ + for (const CollectionItem &item : oldCollections) { + if (!addFolder(item.collection, item.folder_path, QString())) { + qWarning() << "ERROR: failed to add previous collections to new database"; + rollback(); + return false; + } + } + + commit(); + return true; +} + +Database::Database(int chunkSize, QStringList extensions) : QObject(nullptr) , m_chunkSize(chunkSize) + , m_scannedFileExtensions(std::move(extensions)) , m_scanTimer(new QTimer(this)) , m_watcher(new QFileSystemWatcher(this)) , m_embLLM(new EmbeddingLLM) - , m_embeddings(new Embeddings(this)) + , m_databaseValid(true) { + m_db = QSqlDatabase::database(QSqlDatabase::defaultConnection, false); + if (!m_db.isValid()) + m_db = QSqlDatabase::addDatabase("QSQLITE"); + Q_ASSERT(m_db.isValid()); + moveToThread(&m_dbThread); m_dbThread.setObjectName("database"); m_dbThread.start(); @@ -575,139 +961,227 @@ Database::~Database() delete m_embLLM; } -void Database::scheduleNext(int folder_id, size_t countForFolder) +void Database::setStartUpdateTime(CollectionItem &item) { - emit updateCurrentDocsToIndex(folder_id, countForFolder); + QSqlQuery q(m_db); + const qint64 update_time = QDateTime::currentMSecsSinceEpoch(); + if (!updateStartUpdateTime(q, item.collection_id, update_time)) + qWarning() << "Database ERROR: failed to set start update time:" << q.lastError(); + else + item.startUpdate = QDateTime::fromMSecsSinceEpoch(update_time); +} + +void Database::setLastUpdateTime(CollectionItem &item) +{ + QSqlQuery q(m_db); + const qint64 update_time = QDateTime::currentMSecsSinceEpoch(); + if (!updateLastUpdateTime(q, item.collection_id, update_time)) + qWarning() << "Database ERROR: failed to set last update time:" << q.lastError(); + else + item.lastUpdate = QDateTime::fromMSecsSinceEpoch(update_time); +} + +CollectionItem Database::guiCollectionItem(int folder_id) const +{ + Q_ASSERT(m_collectionMap.contains(folder_id)); + return m_collectionMap.value(folder_id); +} + +void Database::updateGuiForCollectionItem(const CollectionItem &item) +{ + m_collectionMap.insert(item.folder_id, item); + emit requestUpdateGuiForCollectionItem(item); +} + +void Database::addGuiCollectionItem(const CollectionItem &item) +{ + m_collectionMap.insert(item.folder_id, item); + emit requestAddGuiCollectionItem(item); +} + +void Database::removeGuiFolderById(const QString &collection, int folder_id) +{ + emit requestRemoveGuiFolderById(collection, folder_id); +} + +void Database::guiCollectionListUpdated(const QList &collectionList) +{ + for (const auto &i : collectionList) + m_collectionMap.insert(i.folder_id, i); + emit requestGuiCollectionListUpdated(collectionList); +} + +void Database::updateFolderToIndex(int folder_id, size_t countForFolder, bool sendChunks) +{ + CollectionItem item = guiCollectionItem(folder_id); + item.currentDocsToIndex = countForFolder; if (!countForFolder) { - updateFolderStatus(folder_id, FolderStatus::Complete); - emit updateInstalled(folder_id, true); - } - if (m_docsToScan.isEmpty()) { - m_scanTimer->stop(); - updateIndexingStatus(); + if (sendChunks && !m_chunkList.isEmpty()) + sendChunkList(); // send any remaining embedding chunks to llm + item.indexing = false; + item.installed = true; + + // Set the last update if we are done + if (item.startUpdate > item.lastUpdate && item.currentEmbeddingsToIndex == 0) + setLastUpdateTime(item); } + updateGuiForCollectionItem(item); } void Database::handleDocumentError(const QString &errorMessage, int document_id, const QString &document_path, const QSqlError &error) { - qWarning() << errorMessage << document_id << document_path << error.text(); + qWarning() << errorMessage << document_id << document_path << error; } -size_t Database::chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &file, - const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, - int maxChunks) +size_t Database::chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &embedding_model, + const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, + int page, int maxChunks) { int charCount = 0; - int line_from = -1; - int line_to = -1; + // TODO: implement line_from/line_to + constexpr int line_from = -1; + constexpr int line_to = -1; QList words; int chunks = 0; + int addedWords = 0; - QVector chunkList; - - while (!stream.atEnd()) { + for (;;) { QString word; stream >> word; + if (stream.status() && !stream.atEnd()) + return -1; charCount += word.length(); - words.append(word); - if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) { - const QString chunk = words.join(" "); - QSqlQuery q; - int chunk_id = 0; - if (!addChunk(q, - document_id, - chunk, - file, - title, - author, - subject, - keywords, - page, - line_from, - line_to, - &chunk_id - )) { - qWarning() << "ERROR: Could not insert chunk into db" << q.lastError(); + if (!word.isEmpty()) + words.append(word); + if (stream.status() || charCount + words.size() - 1 >= m_chunkSize) { + if (!words.isEmpty()) { + const QString chunk = words.join(" "); + QSqlQuery q(m_db); + int chunk_id = 0; + if (!addChunk(q, + document_id, + chunk, + file, + title, + author, + subject, + keywords, + page, + line_from, + line_to, + words.size(), + &chunk_id + )) { + qWarning() << "ERROR: Could not insert chunk into db" << q.lastError(); + } + + addedWords += words.size(); + + EmbeddingChunk toEmbed; + toEmbed.model = embedding_model; + toEmbed.folder_id = folder_id; + toEmbed.chunk_id = chunk_id; + toEmbed.chunk = chunk; + appendChunk(toEmbed); + ++chunks; + + words.clear(); + charCount = 0; } -#if 1 - EmbeddingChunk toEmbed; - toEmbed.folder_id = folder_id; - toEmbed.chunk_id = chunk_id; - toEmbed.chunk = chunk; - chunkList << toEmbed; - if (chunkList.count() == 100) { - m_embLLM->generateAsyncEmbeddings(chunkList); - emit updateTotalEmbeddingsToIndex(folder_id, 100); - chunkList.clear(); - } -#else - const std::vector result = m_embLLM->generateEmbeddings(chunk); - if (!m_embeddings->add(result, chunk_id)) - qWarning() << "ERROR: Cannot add point to embeddings index"; -#endif - - ++chunks; - - words.clear(); - charCount = 0; - - if (maxChunks > 0 && chunks == maxChunks) + if (stream.status() || (maxChunks > 0 && chunks == maxChunks)) break; } } - if (!chunkList.isEmpty()) { - m_embLLM->generateAsyncEmbeddings(chunkList); - emit updateTotalEmbeddingsToIndex(folder_id, chunkList.count()); - chunkList.clear(); + if (chunks) { + CollectionItem item = guiCollectionItem(folder_id); + + // Set the start update if we haven't done so already + if (item.startUpdate <= item.lastUpdate && item.currentEmbeddingsToIndex == 0) + setStartUpdateTime(item); + + item.currentEmbeddingsToIndex += chunks; + item.totalEmbeddingsToIndex += chunks; + item.totalWords += addedWords; + updateGuiForCollectionItem(item); } return stream.pos(); } +void Database::appendChunk(const EmbeddingChunk &chunk) +{ + m_chunkList.reserve(s_batchSize); + m_chunkList.append(chunk); + if (m_chunkList.size() >= s_batchSize) + sendChunkList(); +} + +void Database::sendChunkList() +{ + m_embLLM->generateDocEmbeddingsAsync(m_chunkList); + m_chunkList.clear(); +} + void Database::handleEmbeddingsGenerated(const QVector &embeddings) { - if (embeddings.isEmpty()) - return; + Q_ASSERT(!embeddings.isEmpty()); - int folder_id = 0; - for (auto e : embeddings) { - folder_id = e.folder_id; - if (!m_embeddings->add(e.embedding, e.chunk_id)) - qWarning() << "ERROR: Cannot add point to embeddings index"; + QList sqlEmbeddings; + for (const auto &e: embeddings) { + auto data = QByteArray::fromRawData( + reinterpret_cast(e.embedding.data()), + e.embedding.size() * sizeof(e.embedding.front()) + ); + sqlEmbeddings.append({e.model, e.folder_id, e.chunk_id, std::move(data)}); + } + + transaction(); + + QSqlQuery q(m_db); + QHash stats; + if (!sqlAddEmbeddings(q, sqlEmbeddings, stats)) { + qWarning() << "Database ERROR: failed to add embeddings:" << q.lastError(); + return rollback(); + } + + commit(); + + // FIXME(jared): embedding counts are per-collectionitem, not per-folder + for (const auto &[key, stat]: std::as_const(stats).asKeyValueRange()) { + if (!m_collectionMap.contains(key.folder_id)) continue; + CollectionItem item = guiCollectionItem(key.folder_id); + item.currentEmbeddingsToIndex -= stat.nAdded + stat.nSkipped; + item.totalEmbeddingsToIndex -= stat.nSkipped; + if (!stat.lastFile.isNull()) + item.fileCurrentlyProcessing = stat.lastFile; + + // Set the last update if we are done + Q_ASSERT(item.startUpdate > item.lastUpdate); + if (!item.indexing && item.currentEmbeddingsToIndex == 0) + setLastUpdateTime(item); + + updateGuiForCollectionItem(item); } - emit updateCurrentEmbeddingsToIndex(folder_id, embeddings.count()); - m_embeddings->save(); } -void Database::handleErrorGenerated(int folder_id, const QString &error) +void Database::handleErrorGenerated(const QVector &chunks, const QString &error) { - emit updateError(folder_id, error); -} + /* FIXME(jared): errors are actually collection-specific because they are conditioned + * on the embedding model, but this sets the error on all collections for a given + * folder */ -void Database::removeEmbeddingsByDocumentId(int document_id) -{ - QSqlQuery q; + QSet folder_ids; + for (const auto &c: chunks) { folder_ids << c.folder_id; } - if (!q.prepare(SELECT_CHUNKS_BY_DOCUMENT_SQL)) { - qWarning() << "ERROR: Cannot prepare sql for select chunks by document" << q.lastError(); - return; + for (int fid: folder_ids) { + if (!m_collectionMap.contains(fid)) continue; + CollectionItem item = guiCollectionItem(fid); + item.error = error; + updateGuiForCollectionItem(item); } - - q.addBindValue(document_id); - - if (!q.exec()) { - qWarning() << "ERROR: Cannot exec sql for select chunks by document" << q.lastError(); - return; - } - - while (q.next()) { - const int chunk_id = q.value(0).toInt(); - m_embeddings->remove(chunk_id); - } - m_embeddings->save(); } size_t Database::countOfDocuments(int folder_id) const @@ -745,7 +1219,6 @@ void Database::removeFolderFromDocumentQueue(int folder_id) if (!m_docsToScan.contains(folder_id)) return; m_docsToScan.remove(folder_id); - emit removeFolderById(folder_id); } void Database::enqueueDocumentInternal(const DocumentInfo &info, bool prepend) @@ -764,22 +1237,36 @@ void Database::enqueueDocuments(int folder_id, const QVector &info for (int i = 0; i < infos.size(); ++i) enqueueDocumentInternal(infos[i]); const size_t count = countOfDocuments(folder_id); - emit updateCurrentDocsToIndex(folder_id, count); - emit updateTotalDocsToIndex(folder_id, count); + + CollectionItem item = guiCollectionItem(folder_id); + item.currentDocsToIndex = count; + item.totalDocsToIndex = count; const size_t bytes = countOfBytes(folder_id); - emit updateCurrentBytesToIndex(folder_id, bytes); - emit updateTotalBytesToIndex(folder_id, bytes); + item.currentBytesToIndex = bytes; + item.totalBytesToIndex = bytes; + updateGuiForCollectionItem(item); m_scanTimer->start(); } +void Database::scanQueueBatch() +{ + QElapsedTimer timer; + timer.start(); + + transaction(); + + // scan for up to 100ms or until we run out of documents + while (!m_docsToScan.isEmpty() && timer.elapsed() < 100) + scanQueue(); + + commit(); + + if (m_docsToScan.isEmpty()) + m_scanTimer->stop(); +} + void Database::scanQueue() { - if (m_docsToScan.isEmpty()) { - m_scanTimer->stop(); - updateIndexingStatus(); - return; - } - DocumentInfo info = dequeueDocument(); const size_t countForFolder = countOfDocuments(info.folder); const int folder_id = info.folder; @@ -789,22 +1276,21 @@ void Database::scanQueue() // If the doc has since been deleted or no longer readable, then we schedule more work and return // leaving the cleanup for the cleanup handler - if (!info.doc.exists() || !info.doc.isReadable()) { - return scheduleNext(folder_id, countForFolder); - } + if (!info.doc.exists() || !info.doc.isReadable()) + return updateFolderToIndex(folder_id, countForFolder); const qint64 document_time = info.doc.fileTime(QFile::FileModificationTime).toMSecsSinceEpoch(); const QString document_path = info.doc.canonicalFilePath(); const bool currentlyProcessing = info.currentlyProcessing; // Check and see if we already have this document - QSqlQuery q; + QSqlQuery q(m_db); int existing_id = -1; qint64 existing_time = -1; if (!selectDocument(q, document_path, &existing_id, &existing_time)) { handleDocumentError("ERROR: Cannot select document", existing_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + return updateFolderToIndex(folder_id, countForFolder); } // If we have the document, we need to compare the last modification time and if it is newer @@ -813,15 +1299,14 @@ void Database::scanQueue() Q_ASSERT(existing_time != -1); if (document_time == existing_time) { // No need to rescan, but we do have to schedule next - return scheduleNext(folder_id, countForFolder); - } else { - removeEmbeddingsByDocumentId(existing_id); - if (!removeChunksByDocumentId(q, existing_id)) { - handleDocumentError("ERROR: Cannot remove chunks of document", - existing_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); - } + return updateFolderToIndex(folder_id, countForFolder); } + if (!removeChunksByDocumentId(q, existing_id)) { + handleDocumentError("ERROR: Cannot remove chunks of document", + existing_id, document_path, q.lastError()); + return updateFolderToIndex(folder_id, countForFolder); + } + updateCollectionStatistics(); } // Update the document_time for an existing document, or add it for the first time now @@ -831,27 +1316,37 @@ void Database::scanQueue() if (!updateDocument(q, document_id, document_time)) { handleDocumentError("ERROR: Could not update document_time", document_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + return updateFolderToIndex(folder_id, countForFolder); } } else { if (!addDocument(q, folder_id, document_time, document_path, &document_id)) { handleDocumentError("ERROR: Could not add document", document_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + return updateFolderToIndex(folder_id, countForFolder); } + + CollectionItem item = guiCollectionItem(folder_id); + item.totalDocs += 1; + updateGuiForCollectionItem(item); } } - QSqlDatabase::database().transaction(); + // Get the embedding model for this folder + // FIXME(jared): there can be more than one since we allow one folder to be in multiple collections + QString embedding_model; + if (!sqlGetFolderEmbeddingModel(q, folder_id, embedding_model)) { + handleDocumentError("ERROR: Could not get embedding model", + document_id, document_path, q.lastError()); + return updateFolderToIndex(folder_id, countForFolder); + } + Q_ASSERT(document_id != -1); if (info.isPdf()) { - updateFolderStatus(folder_id, FolderStatus::Embedding, -1, info.currentPage == 0); - QPdfDocument doc; if (QPdfDocument::Error::None != doc.load(info.doc.canonicalFilePath())) { handleDocumentError("ERROR: Could not load pdf", document_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + return updateFolderToIndex(folder_id, countForFolder); } const size_t bytes = info.doc.size(); const size_t bytesPerPage = std::floor(bytes / doc.pageCount()); @@ -862,71 +1357,99 @@ void Database::scanQueue() const QPdfSelection selection = doc.getAllText(pageIndex); QString text = selection.text(); QTextStream stream(&text); - chunkStream(stream, info.folder, document_id, info.doc.fileName(), + chunkStream(stream, info.folder, document_id, embedding_model, info.doc.fileName(), doc.metaData(QPdfDocument::MetaDataField::Title).toString(), doc.metaData(QPdfDocument::MetaDataField::Author).toString(), doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), pageIndex + 1 ); - emit subtractCurrentBytesToIndex(info.folder, bytesPerPage); + CollectionItem item = guiCollectionItem(info.folder); + item.currentBytesToIndex -= bytesPerPage; + updateGuiForCollectionItem(item); if (info.currentPage < doc.pageCount()) { info.currentPage += 1; info.currentlyProcessing = true; enqueueDocumentInternal(info, true /*prepend*/); - return scheduleNext(folder_id, countForFolder + 1); - } else { - emit subtractCurrentBytesToIndex(info.folder, bytes - (bytesPerPage * doc.pageCount())); + return updateFolderToIndex(folder_id, countForFolder + 1); } - } else { - updateFolderStatus(folder_id, FolderStatus::Embedding, -1, info.currentPosition == 0); - QFile file(document_path); + item.currentBytesToIndex -= bytes - (bytesPerPage * doc.pageCount()); + updateGuiForCollectionItem(item); + } else { + BinaryDetectingFile file(document_path); if (!file.open(QIODevice::ReadOnly)) { handleDocumentError("ERROR: Cannot open file for scanning", existing_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + return updateFolderToIndex(folder_id, countForFolder); } + Q_ASSERT(!file.isSequential()); // we need to seek const size_t bytes = info.doc.size(); QTextStream stream(&file); const size_t byteIndex = info.currentPosition; - if (!stream.seek(byteIndex)) { - handleDocumentError("ERROR: Cannot seek to pos for scanning", - existing_id, document_path, q.lastError()); - return scheduleNext(folder_id, countForFolder); + if (byteIndex) { + /* Read the Unicode BOM to detect the encoding. Without this, QTextStream will + * always interpret the text as UTF-8 when byteIndex is nonzero. */ + stream.read(1); + + if (!stream.seek(byteIndex)) { + handleDocumentError("ERROR: Cannot seek to pos for scanning", + existing_id, document_path, q.lastError()); + return updateFolderToIndex(folder_id, countForFolder); + } } #if defined(DEBUG) qDebug() << "scanning byteIndex" << byteIndex << "of" << bytes << document_path; #endif - int pos = chunkStream(stream, info.folder, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, - QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 100 /*maxChunks*/); + int pos = chunkStream(stream, info.folder, document_id, embedding_model, info.doc.fileName(), + QString() /*title*/, QString() /*author*/, QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, + 100 /*maxChunks*/); + if (pos < 0) { + if (!file.binarySeen()) { + handleDocumentError(u"ERROR: Failed to read file (status %1)"_s.arg(stream.status()), + existing_id, document_path, q.lastError()); + return updateFolderToIndex(folder_id, countForFolder); + } + + /* When we see a binary file, we treat it like an empty file so we know not to + * scan it again. All existing chunks are removed, and in-progress embeddings + * are ignored when they complete. */ + + qInfo() << "LocalDocs: Ignoring file with binary data:" << document_path; + + // this will also ensure in-flight embeddings are ignored + if (!removeChunksByDocumentId(q, existing_id)) { + handleDocumentError("ERROR: Cannot remove chunks of document", + existing_id, document_path, q.lastError()); + } + updateCollectionStatistics(); + return updateFolderToIndex(folder_id, countForFolder); + } file.close(); const size_t bytesChunked = pos - byteIndex; - emit subtractCurrentBytesToIndex(info.folder, bytesChunked); + CollectionItem item = guiCollectionItem(info.folder); + item.currentBytesToIndex -= bytesChunked; + updateGuiForCollectionItem(item); if (info.currentPosition < bytes) { info.currentPosition = pos; info.currentlyProcessing = true; enqueueDocumentInternal(info, true /*prepend*/); - return scheduleNext(folder_id, countForFolder + 1); + return updateFolderToIndex(folder_id, countForFolder + 1); } } - QSqlDatabase::database().commit(); - return scheduleNext(folder_id, countForFolder); + + return updateFolderToIndex(folder_id, countForFolder); } -void Database::scanDocuments(int folder_id, const QString &folder_path, bool isNew) +void Database::scanDocuments(int folder_id, const QString &folder_path) { #if defined(DEBUG) qDebug() << "scanning folder for documents" << folder_path; #endif - static const QList extensions { "txt", "pdf", "md", "rst" }; - - QDir dir(folder_path); - Q_ASSERT(dir.exists()); - Q_ASSERT(dir.isReadable()); - QDirIterator it(folder_path, QDir::Readable | QDir::Files, QDirIterator::Subdirectories); + QDirIterator it(folder_path, QDir::Readable | QDir::Files | QDir::Dirs | QDir::NoDotAndDotDot, + QDirIterator::Subdirectories); QVector infos; while (it.hasNext()) { it.next(); @@ -936,7 +1459,7 @@ void Database::scanDocuments(int folder_id, const QString &folder_path, bool isN continue; } - if (!extensions.contains(fileInfo.suffix())) + if (!m_scannedFileExtensions.contains(fileInfo.suffix())) continue; DocumentInfo info; @@ -946,8 +1469,12 @@ void Database::scanDocuments(int folder_id, const QString &folder_path, bool isN } if (!infos.isEmpty()) { - updateFolderStatus(folder_id, FolderStatus::Started, infos.count(), false, isNew); + CollectionItem item = guiCollectionItem(folder_id); + item.indexing = true; + updateGuiForCollectionItem(item); enqueueDocuments(folder_id, infos); + } else { + updateFolderToIndex(folder_id, 0, false); } } @@ -956,101 +1483,259 @@ void Database::start() connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Database::directoryChanged); connect(m_embLLM, &EmbeddingLLM::embeddingsGenerated, this, &Database::handleEmbeddingsGenerated); connect(m_embLLM, &EmbeddingLLM::errorGenerated, this, &Database::handleErrorGenerated); - m_scanTimer->callOnTimeout(this, &Database::scanQueue); - if (!QSqlDatabase::drivers().contains("QSQLITE")) { - qWarning() << "ERROR: missing sqlite driver"; + m_scanTimer->callOnTimeout(this, &Database::scanQueueBatch); + + const QString modelPath = MySettings::globalInstance()->modelPath(); + QList oldCollections; + + if (!openLatestDb(modelPath, oldCollections)) { + m_databaseValid = false; + } else if (!initDb(modelPath, oldCollections)) { + m_databaseValid = false; } else { - QSqlError err = initDb(); - if (err.type() != QSqlError::NoError) - qWarning() << "ERROR: initializing db" << err.text(); + //cleanDB(); + addCurrentFolders(); } - if (m_embeddings->fileExists() && !m_embeddings->load()) - qWarning() << "ERROR: Could not load embeddings"; - - int nAdded = addCurrentFolders(); - Network::globalInstance()->trackEvent("localdocs_startup", { {"doc_collections_total", nAdded} }); + if (!m_databaseValid) + emit databaseValidChanged(); } -int Database::addCurrentFolders() +void Database::addCurrentFolders() { #if defined(DEBUG) qDebug() << "addCurrentFolders"; #endif - QSqlQuery q; + QSqlQuery q(m_db); QList collections; if (!selectAllFromCollections(q, &collections)) { qWarning() << "ERROR: Cannot select collections" << q.lastError(); - return 0; + return; } - emit collectionListUpdated(collections); + guiCollectionListUpdated(collections); - int nAdded = 0; - for (const auto &i : collections) - nAdded += addFolder(i.collection, i.folder_path, true); + scheduleUncompletedEmbeddings(); - updateIndexingStatus(); + for (const auto &i : collections) { + if (!i.forceIndexing) { + addFolderToWatch(i.folder_path); + scanDocuments(i.folder_id, i.folder_path); + } + } - return collections.count() + nAdded; + updateCollectionStatistics(); } -bool Database::addFolder(const QString &collection, const QString &path, bool fromDb) +void Database::scheduleUncompletedEmbeddings() +{ + QHash chunkList; + QSqlQuery q(m_db); + if (!selectAllUncompletedChunks(q, chunkList)) { + qWarning() << "ERROR: Cannot select uncompleted chunks" << q.lastError(); + return; + } + + if (chunkList.isEmpty()) + return; + + // map of folder_id -> chunk count + QMap folderNChunks; + for (auto it = chunkList.keyBegin(), end = chunkList.keyEnd(); it != end; ++it) { + int folder_id = it->folder_id; + + if (folderNChunks.contains(folder_id)) continue; + int total = 0; + if (!selectCountChunks(q, folder_id, total)) { + qWarning() << "ERROR: Cannot count total chunks" << q.lastError(); + return; + } + folderNChunks.insert(folder_id, total); + } + + // map of (folder_id, collection) -> incomplete count + QMap, int> itemNIncomplete; + for (const auto &[chunk, collections]: std::as_const(chunkList).asKeyValueRange()) + for (const auto &collection: std::as_const(collections)) + itemNIncomplete[{ chunk.folder_id, collection }]++; + + for (const auto &[key, nIncomplete]: std::as_const(itemNIncomplete).asKeyValueRange()) { + const auto &[folder_id, collection] = key; + + /* FIXME(jared): this needs to be split by collection because different + * collections have different embedding models */ + int total = folderNChunks.value(folder_id); + CollectionItem item = guiCollectionItem(folder_id); + item.totalEmbeddingsToIndex = total; + item.currentEmbeddingsToIndex = nIncomplete; + updateGuiForCollectionItem(item); + } + + for (auto it = chunkList.keyBegin(), end = chunkList.keyEnd(); it != end;) { + QList batch; + for (; it != end && batch.size() < s_batchSize; ++it) + batch.append({ /*model*/ it->embedding_model, /*folder_id*/ it->folder_id, /*chunk_id*/ it->chunk_id, /*chunk*/ it->text }); + Q_ASSERT(!batch.isEmpty()); + m_embLLM->generateDocEmbeddingsAsync(batch); + } +} + +void Database::updateCollectionStatistics() +{ + QSqlQuery q(m_db); + QList collections; + if (!selectAllFromCollections(q, &collections)) { + qWarning() << "ERROR: Cannot select collections" << q.lastError(); + return; + } + + for (const auto &i: std::as_const(collections)) { + int total_docs = 0; + int total_words = 0; + int total_tokens = 0; + if (!selectCountStatistics(q, i.folder_id, &total_docs, &total_words, &total_tokens)) { + qWarning() << "ERROR: could not count statistics for folder" << q.lastError(); + } else { + CollectionItem item = guiCollectionItem(i.folder_id); + item.totalDocs = total_docs; + item.totalWords = total_words; + item.totalTokens = total_tokens; + updateGuiForCollectionItem(item); + } + } +} + +int Database::checkAndAddFolderToDB(const QString &path) { QFileInfo info(path); if (!info.exists() || !info.isReadable()) { qWarning() << "ERROR: Cannot add folder that doesn't exist or not readable" << path; - return false; + return -1; } - QSqlQuery q; + QSqlQuery q(m_db); int folder_id = -1; // See if the folder exists in the db if (!selectFolder(q, path, &folder_id)) { qWarning() << "ERROR: Cannot select folder from path" << path << q.lastError(); - return false; + return -1; } // Add the folder if (folder_id == -1 && !addFolderToDB(q, path, &folder_id)) { qWarning() << "ERROR: Cannot add folder to db with path" << path << q.lastError(); - return false; + return -1; } Q_ASSERT(folder_id != -1); + return folder_id; +} - // See if the folder has already been added to the collection - QList folders; +void Database::forceIndexing(const QString &collection, const QString &embedding_model) +{ + Q_ASSERT(!embedding_model.isNull()); + + QSqlQuery q(m_db); + QList> folders; if (!selectFoldersFromCollection(q, collection, &folders)) { qWarning() << "ERROR: Cannot select folders from collections" << collection << q.lastError(); + return; + } + + if (!setCollectionEmbeddingModel(q, collection, embedding_model)) { + qWarning().nospace() << "ERROR: Cannot set embedding model for collection " << collection << ": " + << q.lastError(); + return; + } + + for (const auto &folder: std::as_const(folders)) { + CollectionItem item = guiCollectionItem(folder.first); + item.embeddingModel = embedding_model; + item.forceIndexing = false; + updateGuiForCollectionItem(item); + addFolderToWatch(folder.second); + scanDocuments(folder.first, folder.second); + } +} + +void Database::forceRebuildFolder(const QString &path) +{ + QSqlQuery q(m_db); + int folder_id; + if (!selectFolder(q, path, &folder_id)) { + qWarning().nospace() << "Database ERROR: Cannot select folder from path " << path << ": " << q.lastError(); + return; + } + + Q_ASSERT(!m_docsToScan.contains(folder_id)); + + transaction(); + + if (!sqlRemoveDocsByFolderPath(q, path)) { + qWarning().nospace() << "Database ERROR: Cannot remove chunks for folder " << path << ": " << q.lastError(); + return rollback(); + } + + commit(); + + updateCollectionStatistics(); + + // We now have zero embeddings. Document progress will be updated by scanDocuments. + // FIXME(jared): this updates the folder, but these values should also depend on the collection + CollectionItem item = guiCollectionItem(folder_id); + item.currentEmbeddingsToIndex = item.totalEmbeddingsToIndex = 0; + updateGuiForCollectionItem(item); + + scanDocuments(folder_id, path); +} + +bool Database::addFolder(const QString &collection, const QString &path, const QString &embedding_model) +{ + // add the folder, if needed + const int folder_id = checkAndAddFolderToDB(path); + if (folder_id == -1) + return false; + + std::optional item; + QSqlQuery q(m_db); + if (!selectCollectionByName(q, collection, item)) { + qWarning().nospace() << "Database ERROR: Cannot select collection " << collection << ": " << q.lastError(); return false; } - bool added = false; - if (!folders.contains(folder_id)) { - if (!addCollection(q, collection, folder_id)) { - qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError(); + // add the collection, if needed + if (!item) { + item.emplace(); + if (!addCollection(q, collection, QDateTime() /*start_update*/, QDateTime() /*last_update*/, + embedding_model /*embedding_model*/, *item)) { + qWarning().nospace() << "ERROR: Cannot add collection " << collection << ": " << q.lastError(); return false; } - - CollectionItem i; - i.collection = collection; - i.folder_path = path; - i.folder_id = folder_id; - emit addCollectionItem(i, fromDb); - added = true; } - addFolderToWatch(path); - scanDocuments(folder_id, path, !fromDb); - - if (!fromDb) { - updateIndexingStatus(); + // link the folder and the collection, if needed + int res = addCollectionItem(q, item->collection_id, folder_id); + if (res < 0) { // error + qWarning().nospace() << "Database ERROR: Cannot add folder " << path << " to collection " << collection << ": " + << q.lastError(); + return false; } - return added; + // add the new collection item to the UI + if (res == 1) { // new item added + item->folder_path = path; + item->folder_id = folder_id; + addGuiCollectionItem(item.value()); + + // note: this is the existing embedding model if the collection was found + if (!item->embeddingModel.isNull()) { + addFolderToWatch(path); + scanDocuments(folder_id, path); + } + } + return true; } void Database::removeFolder(const QString &collection, const QString &path) @@ -1059,7 +1744,7 @@ void Database::removeFolder(const QString &collection, const QString &path) qDebug() << "removeFolder" << path; #endif - QSqlQuery q; + QSqlQuery q(m_db); int folder_id = -1; // See if the folder exists in the db @@ -1076,28 +1761,37 @@ void Database::removeFolder(const QString &collection, const QString &path) return; } - removeFolderInternal(collection, folder_id, path); + transaction(); + + if (removeFolderInternal(collection, folder_id, path)) { + commit(); + } else { + rollback(); + } } -void Database::removeFolderInternal(const QString &collection, int folder_id, const QString &path) +bool Database::removeFolderInternal(const QString &collection, int folder_id, const QString &path) { - // Determine if the folder is used by more than one collection - QSqlQuery q; - QList collections; - if (!selectCollectionsFromFolder(q, folder_id, &collections)) { - qWarning() << "ERROR: Cannot select collections from folder" << folder_id << q.lastError(); - return; + // Remove it from the collection + QSqlQuery q(m_db); + int nRemaining = removeCollectionFolder(q, collection, folder_id); + if (nRemaining == -1) { + qWarning().nospace() << "Database ERROR: Cannot remove collection " << collection << " from folder " + << folder_id << ": " << q.lastError(); + return false; + } + removeGuiFolderById(collection, folder_id); + + if (!sqlPruneCollections(q)) { + qWarning() << "Database ERROR: Cannot prune collections:" << q.lastError(); + return false; } - // Remove it from the collections - if (!removeCollection(q, collection, folder_id)) { - qWarning() << "ERROR: Cannot remove collection" << collection << folder_id << q.lastError(); - return; - } + // Keep folder if it is still referenced + if (nRemaining) + return true; - // If the folder is associated with more than one collection, then return - if (collections.count() > 1) - return; + // Remove the last reference to a folder // First remove all upcoming jobs associated with this folder removeFolderFromDocumentQueue(folder_id); @@ -1106,47 +1800,137 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co QList documentIds; if (!selectDocuments(q, folder_id, &documentIds)) { qWarning() << "ERROR: Cannot select documents" << folder_id << q.lastError(); - return; + return false; } // Remove all chunks and documents associated with this folder - for (int document_id : documentIds) { - removeEmbeddingsByDocumentId(document_id); + for (int document_id: std::as_const(documentIds)) { if (!removeChunksByDocumentId(q, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << q.lastError(); - return; + return false; } if (!removeDocument(q, document_id)) { qWarning() << "ERROR: Cannot remove document_id" << document_id << q.lastError(); - return; + return false; } } if (!removeFolderFromDB(q, folder_id)) { qWarning() << "ERROR: Cannot remove folder_id" << folder_id << q.lastError(); - return; + return false; } - emit removeFolderById(folder_id); - + m_collectionMap.remove(folder_id); removeFolderFromWatch(path); + return true; } -bool Database::addFolderToWatch(const QString &path) +void Database::addFolderToWatch(const QString &path) { #if defined(DEBUG) qDebug() << "addFolderToWatch" << path; #endif - return m_watcher->addPath(path); + // pre-check because addPath returns false for already watched paths + if (!m_watchedPaths.contains(path)) { + if (!m_watcher->addPath(path)) + qWarning() << "Database::addFolderToWatch: failed to watch" << path; + // add unconditionally to suppress repeated warnings + m_watchedPaths << path; + } } -bool Database::removeFolderFromWatch(const QString &path) +void Database::removeFolderFromWatch(const QString &path) { #if defined(DEBUG) qDebug() << "removeFolderFromWatch" << path; #endif - return m_watcher->removePath(path); + QDirIterator it(path, QDir::Readable | QDir::Dirs | QDir::NoDotAndDotDot, QDirIterator::Subdirectories); + QStringList children { path }; + while (it.hasNext()) + children.append(it.next()); + + m_watcher->removePaths(children); + m_watchedPaths -= QSet(children.begin(), children.end()); +} + +QList Database::searchEmbeddings(const std::vector &query, const QList &collections, int nNeighbors) +{ + constexpr int BATCH_SIZE = 2048; + + const int n_embd = query.size(); + const us::metric_punned_t metric(n_embd, us::metric_kind_t::ip_k); // inner product + + QSqlQuery q(m_db); + if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) { + qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError(); + return {}; + } + + us::executor_default_t executor(std::thread::hardware_concurrency()); + us::exact_search_t search; + + QList batchChunkIds; + QList batchEmbeddings; + batchChunkIds.reserve(BATCH_SIZE); + batchEmbeddings.reserve(BATCH_SIZE * n_embd); + + struct Result { int chunkId; us::distance_punned_t dist; }; + QList results; + + while (q.at() != QSql::AfterLastRow) { // batches + batchChunkIds.clear(); + batchEmbeddings.clear(); + + while (batchChunkIds.count() < BATCH_SIZE && q.next()) { // batch + batchChunkIds << q.value(0).toInt(); + batchEmbeddings.resize(batchEmbeddings.size() + n_embd); + QVariant embdCol = q.value(1); + if (embdCol.userType() != QMetaType::QByteArray) { + qWarning() << "Database ERROR: Expected embedding to be blob, got" << embdCol.userType(); + return {}; + } + auto *embd = static_cast(embdCol.constData()); + const int embd_stride = n_embd * sizeof(float); + if (embd->size() != embd_stride) { + qWarning() << "Database ERROR: Expected embedding to be" << embd_stride << "bytes, got" + << embd->size(); + return {}; + } + memcpy(&*(batchEmbeddings.end() - n_embd), embd->constData(), embd_stride); + } + + int nBatch = batchChunkIds.count(); + if (!nBatch) + break; + + // get top-k nearest neighbors of this batch + int kBatch = qMin(nNeighbors, nBatch); + us::exact_search_results_t batchResults = search( + (us::byte_t const *)batchEmbeddings.data(), nBatch, n_embd * sizeof(float), + (us::byte_t const *)query.data(), 1, n_embd * sizeof(float), + kBatch, metric + ); + + for (int i = 0; i < kBatch; ++i) { + auto offset = batchResults.at(0)[i].offset; + us::distance_punned_t distance = batchResults.at(0)[i].distance; + results.append({batchChunkIds[offset], distance}); + } + } + + // get top-k nearest neighbors of combined results + nNeighbors = qMin(nNeighbors, results.size()); + std::partial_sort( + results.begin(), results.begin() + nNeighbors, results.end(), + [](const Result &a, const Result &b) { return a.dist < b.dist; } + ); + + QList chunkIds; + chunkIds.reserve(nNeighbors); + for (int i = 0; i < nNeighbors; i++) + chunkIds << results[i].chunkId; + return chunkIds; } void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, @@ -1156,38 +1940,39 @@ void Database::retrieveFromDB(const QList &collections, const QString & qDebug() << "retrieveFromDB" << collections << text << retrievalSize; #endif - QSqlQuery q; - if (m_embeddings->isLoaded()) { - std::vector result = m_embLLM->generateEmbeddings(text); - if (result.empty()) { - qDebug() << "ERROR: generating embeddings returned a null result"; - return; - } - std::vector embeddings = m_embeddings->search(result, retrievalSize); - if (!selectChunk(q, collections, embeddings, retrievalSize)) { - qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); - return; - } - } else { - if (!selectChunk(q, collections, text, retrievalSize)) { - qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); - return; - } + std::vector queryEmbd = m_embLLM->generateQueryEmbedding(text); + if (queryEmbd.empty()) { + qDebug() << "ERROR: generating embeddings returned a null result"; + return; + } + + QList searchResults = searchEmbeddings(queryEmbd, collections, retrievalSize); + if (searchResults.isEmpty()) + return; + + QSqlQuery q(m_db); + if (!selectChunk(q, searchResults, retrievalSize)) { + qDebug() << "ERROR: selecting chunks:" << q.lastError(); + return; } while (q.next()) { #if defined(DEBUG) const int rowid = q.value(0).toInt(); #endif - const QString chunk_text = q.value(2).toString(); + const QString document_path = q.value(2).toString(); + const QString chunk_text = q.value(3).toString(); const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); - const QString file = q.value(3).toString(); - const QString title = q.value(4).toString(); - const QString author = q.value(5).toString(); - const int page = q.value(6).toInt(); - const int from =q.value(7).toInt(); - const int to =q.value(8).toInt(); + const QString file = q.value(4).toString(); + const QString title = q.value(5).toString(); + const QString author = q.value(6).toString(); + const int page = q.value(7).toInt(); + const int from = q.value(8).toInt(); + const int to = q.value(9).toInt(); + const QString collectionName = q.value(10).toString(); ResultInfo info; + info.collection = collectionName; + info.path = document_path; info.file = file; info.title = title; info.author = author; @@ -1204,47 +1989,57 @@ void Database::retrieveFromDB(const QList &collections, const QString & } } -void Database::cleanDB() +// FIXME This is very slow and non-interruptible and when we close the application and we're +// cleaning a large table this can cause the app to take forever to shut down. This would ideally be +// interruptible and we'd continue 'cleaning' when we restart +bool Database::cleanDB() { #if defined(DEBUG) qDebug() << "cleanDB"; #endif // Scan all folders in db to make sure they still exist - QSqlQuery q; + QSqlQuery q(m_db); QList collections; if (!selectAllFromCollections(q, &collections)) { qWarning() << "ERROR: Cannot select collections" << q.lastError(); - return; + return false; } - for (const auto &i : collections) { + transaction(); + + for (const auto &i: std::as_const(collections)) { // Find the path for the folder QFileInfo info(i.folder_path); if (!info.exists() || !info.isReadable()) { #if defined(DEBUG) qDebug() << "clean db removing folder" << i.folder_id << i.folder_path; #endif - removeFolderInternal(i.collection, i.folder_id, i.folder_path); + if (!removeFolderInternal(i.collection, i.folder_id, i.folder_path)) { + rollback(); + return false; + } } } // Scan all documents in db to make sure they still exist if (!q.prepare(SELECT_ALL_DOCUMENTS_SQL)) { qWarning() << "ERROR: Cannot prepare sql for select all documents" << q.lastError(); - return; + rollback(); + return false; } if (!q.exec()) { qWarning() << "ERROR: Cannot exec sql for select all documents" << q.lastError(); - return; + rollback(); + return false; } while (q.next()) { int document_id = q.value(0).toInt(); QString document_path = q.value(1).toString(); QFileInfo info(document_path); - if (info.exists() && info.isReadable()) + if (info.exists() && info.isReadable() && m_scannedFileExtensions.contains(info.suffix())) continue; #if defined(DEBUG) @@ -1252,16 +2047,22 @@ void Database::cleanDB() #endif // Remove all chunks and documents that either don't exist or have become unreadable - QSqlQuery query; - removeEmbeddingsByDocumentId(document_id); + QSqlQuery query(m_db); if (!removeChunksByDocumentId(query, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); + rollback(); + return false; } if (!removeDocument(query, document_id)) { qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError(); + rollback(); + return false; } } + + commit(); + return true; } void Database::changeChunkSize(int chunkSize) @@ -1273,9 +2074,7 @@ void Database::changeChunkSize(int chunkSize) qDebug() << "changeChunkSize" << chunkSize; #endif - m_chunkSize = chunkSize; - - QSqlQuery q; + QSqlQuery q(m_db); // Scan all documents in db to make sure they still exist if (!q.prepare(SELECT_ALL_DOCUMENTS_SQL)) { qWarning() << "ERROR: Cannot prepare sql for select all documents" << q.lastError(); @@ -1287,20 +2086,51 @@ void Database::changeChunkSize(int chunkSize) return; } + transaction(); + while (q.next()) { int document_id = q.value(0).toInt(); // Remove all chunks and documents to change the chunk size - QSqlQuery query; - removeEmbeddingsByDocumentId(document_id); + QSqlQuery query(m_db); if (!removeChunksByDocumentId(query, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); + return rollback(); } if (!removeDocument(query, document_id)) { qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError(); + return rollback(); } } + + commit(); + + m_chunkSize = chunkSize; addCurrentFolders(); + updateCollectionStatistics(); +} + +void Database::changeFileExtensions(const QStringList &extensions) +{ +#if defined(DEBUG) + qDebug() << "changeFileExtensions"; +#endif + + m_scannedFileExtensions = extensions; + + cleanDB(); + + QSqlQuery q(m_db); + QList collections; + if (!selectAllFromCollections(q, &collections)) { + qWarning() << "ERROR: Cannot select collections" << q.lastError(); + return; + } + + for (const auto &i: std::as_const(collections)) { + if (!i.forceIndexing) + scanDocuments(i.folder_id, i.folder_path); + } } void Database::directoryChanged(const QString &path) @@ -1309,90 +2139,33 @@ void Database::directoryChanged(const QString &path) qDebug() << "directoryChanged" << path; #endif - QSqlQuery q; + // search for a collection that contains this folder (we watch subdirectories) int folder_id = -1; + QDir dir(path); + for (;;) { + QSqlQuery q(m_db); + if (!selectFolder(q, dir.path(), &folder_id)) { + qWarning() << "ERROR: Cannot select folder from path" << dir.path() << q.lastError(); + return; + } + if (folder_id != -1) + break; - // Lookup the folder_id in the db - if (!selectFolder(q, path, &folder_id)) { - qWarning() << "ERROR: Cannot select folder from path" << path << q.lastError(); - return; - } - - // If we don't have a folder_id in the db, then something bad has happened - Q_ASSERT(folder_id != -1); - if (folder_id == -1) { - qWarning() << "ERROR: Watched folder does not exist in db" << path; - m_watcher->removePath(path); - return; + // check next parent + if (!dir.cdUp()) { + if (!dir.isRoot()) break; // folder removed + Q_ASSERT(false); + qWarning() << "ERROR: Watched folder does not exist in db" << path; + m_watcher->removePath(path); + return; + } } // Clean the database - cleanDB(); + if (cleanDB()) + updateCollectionStatistics(); // Rescan the documents associated with the folder - scanDocuments(folder_id, path, false); - updateIndexingStatus(); -} - -void Database::updateIndexingStatus() { - Q_ASSERT(m_scanTimer->isActive() || m_docsToScan.isEmpty()); - if (!m_indexingTimer.isValid() && m_scanTimer->isActive()) { - Network::globalInstance()->trackEvent("localdocs_indexing_start"); - m_indexingTimer.start(); - } else if (m_indexingTimer.isValid() && !m_scanTimer->isActive()) { - qint64 durationMs = m_indexingTimer.elapsed(); - Network::globalInstance()->trackEvent("localdocs_indexing_complete", { {"$duration", durationMs / 1000.} }); - m_indexingTimer.invalidate(); - } -} - -void Database::updateFolderStatus(int folder_id, Database::FolderStatus status, int numDocs, bool atStart, bool isNew) { - FolderStatusRecord *lastRecord = nullptr; - if (m_foldersBeingIndexed.contains(folder_id)) { - lastRecord = &m_foldersBeingIndexed[folder_id]; - } - Q_ASSERT(lastRecord || status == FolderStatus::Started); - - switch (status) { - case FolderStatus::Started: - if (lastRecord == nullptr) { - // record timestamp but don't send an event yet - m_foldersBeingIndexed.insert(folder_id, { QDateTime::currentMSecsSinceEpoch(), isNew, numDocs }); - emit updateIndexing(folder_id, true); - } - break; - case FolderStatus::Embedding: - if (!lastRecord->docsChanged) { - Q_ASSERT(atStart); - // send start event with the original timestamp for folders that need updating - const auto *embeddingModels = ModelList::globalInstance()->installedEmbeddingModels(); - Network::globalInstance()->trackEvent("localdocs_folder_indexing", { - {"folder_id", folder_id}, - {"is_new_collection", lastRecord->isNew}, - {"document_count", lastRecord->numDocs}, - {"embedding_model", embeddingModels->defaultModelInfo().filename()}, - {"chunk_size", m_chunkSize}, - {"time", lastRecord->startTime}, - }); - } - lastRecord->docsChanged += atStart; - lastRecord->chunksRead++; - break; - case FolderStatus::Complete: - if (lastRecord->docsChanged) { - // send complete event for folders that were updated - qint64 durationMs = QDateTime::currentMSecsSinceEpoch() - lastRecord->startTime; - Network::globalInstance()->trackEvent("localdocs_folder_complete", { - {"folder_id", folder_id}, - {"is_new_collection", lastRecord->isNew}, - {"documents_total", lastRecord->numDocs}, - {"documents_changed", lastRecord->docsChanged}, - {"chunks_read", lastRecord->chunksRead}, - {"$duration", durationMs / 1000.}, - }); - } - m_foldersBeingIndexed.remove(folder_id); - emit updateIndexing(folder_id, false); - break; - } + if (folder_id != -1) + scanDocuments(folder_id, path); } diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index 34cf7681..0f716e14 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -3,28 +3,39 @@ #include "embllm.h" // IWYU pragma: keep -#include +#include #include +#include #include #include #include #include #include +#include +#include #include +#include #include #include -#include -#include #include -class EmbeddingLLM; -class Embeddings; +using namespace Qt::Literals::StringLiterals; + class QFileSystemWatcher; class QSqlError; class QTextStream; class QTimer; +/* Version 0: GPT4All v2.4.3, full-text search + * Version 1: GPT4All v2.5.3, embeddings in hsnwlib + * Version 2: GPT4All v3.0.0, embeddings in sqlite */ + +// minimum supported version +static const int LOCALDOCS_MIN_VER = 1; +// current version +static const int LOCALDOCS_VERSION = 2; + struct DocumentInfo { int folder; @@ -33,34 +44,82 @@ struct DocumentInfo size_t currentPosition = 0; bool currentlyProcessing = false; bool isPdf() const { - return doc.suffix() == QLatin1String("pdf"); + return doc.suffix() == u"pdf"_s; } }; struct ResultInfo { - QString file; // [Required] The name of the file, but not the full path - QString title; // [Optional] The title of the document - QString author; // [Optional] The author of the document - QString date; // [Required] The creation or the last modification date whichever is latest - QString text; // [Required] The text actually used in the augmented context - int page = -1; // [Optional] The page where the text was found - int from = -1; // [Optional] The line number where the text begins - int to = -1; // [Optional] The line number where the text ends + QString collection; // [Required] The name of the collection + QString path; // [Required] The full path + QString file; // [Required] The name of the file, but not the full path + QString title; // [Optional] The title of the document + QString author; // [Optional] The author of the document + QString date; // [Required] The creation or the last modification date whichever is latest + QString text; // [Required] The text actually used in the augmented context + int page = -1; // [Optional] The page where the text was found + int from = -1; // [Optional] The line number where the text begins + int to = -1; // [Optional] The line number where the text ends + + bool operator==(const ResultInfo &other) const { + return file == other.file && + title == other.title && + author == other.author && + date == other.date && + text == other.text && + page == other.page && + from == other.from && + to == other.to; + } + bool operator!=(const ResultInfo &other) const { + return !(*this == other); + } + + Q_GADGET + Q_PROPERTY(QString collection MEMBER collection) + Q_PROPERTY(QString path MEMBER path) + Q_PROPERTY(QString file MEMBER file) + Q_PROPERTY(QString title MEMBER title) + Q_PROPERTY(QString author MEMBER author) + Q_PROPERTY(QString date MEMBER date) + Q_PROPERTY(QString text MEMBER text) + Q_PROPERTY(int page MEMBER page) + Q_PROPERTY(int from MEMBER from) + Q_PROPERTY(int to MEMBER to) }; +Q_DECLARE_METATYPE(ResultInfo) + struct CollectionItem { + // -- Fields persisted to database -- + + int collection_id = -1; + int folder_id = -1; QString collection; QString folder_path; - int folder_id = -1; + QString embeddingModel; + + // -- Transient fields -- + bool installed = false; bool indexing = false; + bool forceIndexing = false; QString error; + + // progress int currentDocsToIndex = 0; int totalDocsToIndex = 0; size_t currentBytesToIndex = 0; size_t totalBytesToIndex = 0; size_t currentEmbeddingsToIndex = 0; size_t totalEmbeddingsToIndex = 0; + + // statistics + size_t totalDocs = 0; + size_t totalWords = 0; + size_t totalTokens = 0; + QDateTime startUpdate; + QDateTime lastUpdate; + QString fileCurrentlyProcessing; }; Q_DECLARE_METATYPE(CollectionItem) @@ -68,53 +127,55 @@ class Database : public QObject { Q_OBJECT public: - Database(int chunkSize); - virtual ~Database(); + Database(int chunkSize, QStringList extensions); + ~Database() override; + + bool isValid() const { return m_databaseValid; } public Q_SLOTS: void start(); - void scanQueue(); - void scanDocuments(int folder_id, const QString &folder_path, bool isNew); - bool addFolder(const QString &collection, const QString &path, bool fromDb); + void scanQueueBatch(); + void scanDocuments(int folder_id, const QString &folder_path); + void forceIndexing(const QString &collection, const QString &embedding_model); + void forceRebuildFolder(const QString &path); + bool addFolder(const QString &collection, const QString &path, const QString &embedding_model); void removeFolder(const QString &collection, const QString &path); void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); - void cleanDB(); void changeChunkSize(int chunkSize); + void changeFileExtensions(const QStringList &extensions); Q_SIGNALS: - void docsToScanChanged(); - void updateInstalled(int folder_id, bool b); - void updateIndexing(int folder_id, bool b); - void updateError(int folder_id, const QString &error); - void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); - void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); - void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); - void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); - void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); - void updateCurrentEmbeddingsToIndex(int folder_id, size_t currentBytesToIndex); - void updateTotalEmbeddingsToIndex(int folder_id, size_t totalBytesToIndex); - void addCollectionItem(const CollectionItem &item, bool fromDb); - void removeFolderById(int folder_id); - void collectionListUpdated(const QList &collectionList); + // Signals for the gui only + void requestUpdateGuiForCollectionItem(const CollectionItem &item); + void requestAddGuiCollectionItem(const CollectionItem &item); + void requestRemoveGuiFolderById(const QString &collection, int folder_id); + void requestGuiCollectionListUpdated(const QList &collectionList); + void databaseValidChanged(); private Q_SLOTS: void directoryChanged(const QString &path); - bool addFolderToWatch(const QString &path); - bool removeFolderFromWatch(const QString &path); - int addCurrentFolders(); + void addCurrentFolders(); void handleEmbeddingsGenerated(const QVector &embeddings); - void handleErrorGenerated(int folder_id, const QString &error); + void handleErrorGenerated(const QVector &chunks, const QString &error); private: - enum class FolderStatus { Started, Embedding, Complete }; - struct FolderStatusRecord { qint64 startTime; bool isNew; int numDocs, docsChanged, chunksRead; }; + void transaction(); + void commit(); + void rollback(); - void removeFolderInternal(const QString &collection, int folder_id, const QString &path); - size_t chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &file, - const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, - int maxChunks = -1); - void removeEmbeddingsByDocumentId(int document_id); - void scheduleNext(int folder_id, size_t countForFolder); + bool hasContent(); + // not found -> 0, , exists and has content -> 1, error -> -1 + int openDatabase(const QString &modelPath, bool create = true, int ver = LOCALDOCS_VERSION); + bool openLatestDb(const QString &modelPath, QList &oldCollections); + bool initDb(const QString &modelPath, const QList &oldCollections); + int checkAndAddFolderToDB(const QString &path); + bool removeFolderInternal(const QString &collection, int folder_id, const QString &path); + size_t chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &embedding_model, + const QString &file, const QString &title, const QString &author, const QString &subject, + const QString &keywords, int page, int maxChunks = -1); + void appendChunk(const EmbeddingChunk &chunk); + void sendChunkList(); + void updateFolderToIndex(int folder_id, size_t countForFolder, bool sendChunks = true); void handleDocumentError(const QString &errorMessage, int document_id, const QString &document_path, const QSqlError &error); size_t countOfDocuments(int folder_id) const; @@ -123,20 +184,37 @@ private: void removeFolderFromDocumentQueue(int folder_id); void enqueueDocumentInternal(const DocumentInfo &info, bool prepend = false); void enqueueDocuments(int folder_id, const QVector &infos); - void updateIndexingStatus(); - void updateFolderStatus(int folder_id, FolderStatus status, int numDocs = -1, bool atStart = false, bool isNew = false); + void scanQueue(); + bool cleanDB(); + void addFolderToWatch(const QString &path); + void removeFolderFromWatch(const QString &path); + QList searchEmbeddings(const std::vector &query, const QList &collections, int nNeighbors); + + void setStartUpdateTime(CollectionItem &item); + void setLastUpdateTime(CollectionItem &item); + + CollectionItem guiCollectionItem(int folder_id) const; + void updateGuiForCollectionItem(const CollectionItem &item); + void addGuiCollectionItem(const CollectionItem &item); + void removeGuiFolderById(const QString &collection, int folder_id); + void guiCollectionListUpdated(const QList &collectionList); + void scheduleUncompletedEmbeddings(); + void updateCollectionStatistics(); private: + QSqlDatabase m_db; int m_chunkSize; + QStringList m_scannedFileExtensions; QTimer *m_scanTimer; QMap> m_docsToScan; - QElapsedTimer m_indexingTimer; - QMap m_foldersBeingIndexed; QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; + QSet m_watchedPaths; EmbeddingLLM *m_embLLM; - Embeddings *m_embeddings; + QVector m_chunkList; + QHash m_collectionMap; // used only for tracking indexing/embedding progress + std::atomic m_databaseValid; }; #endif // DATABASE_H diff --git a/gpt4all-chat/download.cpp b/gpt4all-chat/download.cpp index 66037736..6a138698 100644 --- a/gpt4all-chat/download.cpp +++ b/gpt4all-chat/download.cpp @@ -31,6 +31,8 @@ #include #include +using namespace Qt::Literals::StringLiterals; + class MyDownload: public Download { }; Q_GLOBAL_STATIC(MyDownload, downloadInstance) Download *Download::globalInstance() @@ -48,15 +50,18 @@ Download::Download() &Download::handleHashAndSaveFinished, Qt::QueuedConnection); connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &Download::handleSslErrors); + updateLatestNews(); updateReleaseNotes(); m_startTime = QDateTime::currentDateTime(); } -static bool operator==(const ReleaseInfo& lhs, const ReleaseInfo& rhs) { +static bool operator==(const ReleaseInfo& lhs, const ReleaseInfo& rhs) +{ return lhs.version == rhs.version; } -static bool compareVersions(const QString &a, const QString &b) { +static bool compareVersions(const QString &a, const QString &b) +{ QStringList aParts = a.split('.'); QStringList bParts = b.split('.'); @@ -79,6 +84,8 @@ ReleaseInfo Download::releaseInfo() const const QString currentVersion = QCoreApplication::applicationVersion(); if (m_releaseMap.contains(currentVersion)) return m_releaseMap.value(currentVersion); + if (!m_releaseMap.empty()) + return m_releaseMap.last(); return ReleaseInfo(); } @@ -97,7 +104,6 @@ bool Download::isFirstStart(bool writeVersion) const auto *mySettings = MySettings::globalInstance(); QSettings settings; - settings.sync(); QString lastVersionStarted = settings.value("download/lastVersionStarted").toString(); bool first = lastVersionStarted != QCoreApplication::applicationVersion(); if (first && writeVersion) { @@ -105,7 +111,6 @@ bool Download::isFirstStart(bool writeVersion) const // let the user select these again settings.remove("network/usageStatsActive"); settings.remove("network/isActive"); - settings.sync(); emit mySettings->networkUsageStatsActiveChanged(); emit mySettings->networkIsActiveChanged(); } @@ -125,15 +130,26 @@ void Download::updateReleaseNotes() connect(jsonReply, &QNetworkReply::finished, this, &Download::handleReleaseJsonDownloadFinished); } +void Download::updateLatestNews() +{ + QUrl url("http://gpt4all.io/meta/latestnews.md"); + QNetworkRequest request(url); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *reply = m_networkManager.get(request); + connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, &Download::handleLatestNewsDownloadFinished); +} + void Download::downloadModel(const QString &modelFile) { QFile *tempFile = new QFile(ModelList::globalInstance()->incompleteDownloadPath(modelFile)); - QDateTime modTime = tempFile->fileTime(QFile::FileModificationTime); bool success = tempFile->open(QIODevice::WriteOnly | QIODevice::Append); qWarning() << "Opening temp file for writing:" << tempFile->fileName(); if (!success) { const QString error - = QString("ERROR: Could not open temp file: %1 %2").arg(tempFile->fileName()).arg(modelFile); + = u"ERROR: Could not open temp file: %1 %2"_s.arg(tempFile->fileName(), modelFile); qWarning() << error; clearRetry(modelFile); ModelList::globalInstance()->updateDataByFilename(modelFile, {{ ModelList::DownloadErrorRole, error }}); @@ -161,7 +177,7 @@ void Download::downloadModel(const QString &modelFile) Network::globalInstance()->trackEvent("download_started", { {"model", modelFile} }); QNetworkRequest request(url); request.setAttribute(QNetworkRequest::User, modelFile); - request.setRawHeader("range", QString("bytes=%1-").arg(tempFile->pos()).toUtf8()); + request.setRawHeader("range", u"bytes=%1-"_s.arg(tempFile->pos()).toUtf8()); QSslConfiguration conf = request.sslConfiguration(); conf.setPeerVerifyMode(QSslSocket::VerifyNone); request.setSslConfiguration(conf); @@ -176,8 +192,7 @@ void Download::downloadModel(const QString &modelFile) void Download::cancelDownload(const QString &modelFile) { - for (int i = 0; i < m_activeDownloads.size(); ++i) { - QNetworkReply *modelReply = m_activeDownloads.keys().at(i); + for (auto [modelReply, tempFile]: m_activeDownloads.asKeyValueRange()) { QUrl url = modelReply->request().url(); if (url.toString().endsWith(modelFile)) { Network::globalInstance()->trackEvent("download_canceled", { {"model", modelFile} }); @@ -189,7 +204,6 @@ void Download::cancelDownload(const QString &modelFile) modelReply->abort(); // Abort the download modelReply->deleteLater(); // Schedule the reply for deletion - QFile *tempFile = m_activeDownloads.value(modelReply); tempFile->deleteLater(); m_activeDownloads.remove(modelReply); @@ -308,6 +322,24 @@ void Download::parseReleaseJsonFile(const QByteArray &jsonData) emit releaseInfoChanged(); } +void Download::handleLatestNewsDownloadFinished() +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + if (reply->error() != QNetworkReply::NoError) { + qWarning() << "ERROR: network error occurred attempting to download latest news:" << reply->errorString(); + reply->deleteLater(); + return; + } + + QByteArray responseData = reply->readAll(); + m_latestNews = QString::fromUtf8(responseData); + reply->deleteLater(); + emit latestNewsChanged(); +} + bool Download::hasRetry(const QString &filename) const { return m_activeRetries.contains(filename); @@ -354,7 +386,7 @@ void Download::handleErrorOccurred(QNetworkReply::NetworkError code) clearRetry(modelFilename); const QString error - = QString("ERROR: Network error occurred attempting to download %1 code: %2 errorString %3") + = u"ERROR: Network error occurred attempting to download %1 code: %2 errorString %3"_s .arg(modelFilename) .arg(code) .arg(modelReply->errorString()); @@ -428,7 +460,7 @@ void HashAndSaveFile::hashAndSave(const QString &expectedHash, QCryptographicHas // Reopen the tempFile for hashing if (!tempFile->open(QIODevice::ReadOnly)) { const QString error - = QString("ERROR: Could not open temp file for hashing: %1 %2").arg(tempFile->fileName()).arg(modelFilename); + = u"ERROR: Could not open temp file for hashing: %1 %2"_s.arg(tempFile->fileName(), modelFilename); qWarning() << error; emit hashAndSaveFinished(false, error, tempFile, modelReply); return; @@ -440,10 +472,8 @@ void HashAndSaveFile::hashAndSave(const QString &expectedHash, QCryptographicHas if (hash.result().toHex() != expectedHash.toLatin1()) { tempFile->close(); const QString error - = QString("ERROR: Download error hash did not match: %1 != %2 for %3") - .arg(hash.result().toHex()) - .arg(expectedHash.toLatin1()) - .arg(modelFilename); + = u"ERROR: Download error hash did not match: %1 != %2 for %3"_s + .arg(hash.result().toHex(), expectedHash.toLatin1(), modelFilename); qWarning() << error; tempFile->remove(); emit hashAndSaveFinished(false, error, tempFile, modelReply); @@ -464,7 +494,7 @@ void HashAndSaveFile::hashAndSave(const QString &expectedHash, QCryptographicHas // Reopen the tempFile for copying if (!tempFile->open(QIODevice::ReadOnly)) { const QString error - = QString("ERROR: Could not open temp file at finish: %1 %2").arg(tempFile->fileName()).arg(modelFilename); + = u"ERROR: Could not open temp file at finish: %1 %2"_s.arg(tempFile->fileName(), modelFilename); qWarning() << error; emit hashAndSaveFinished(false, error, tempFile, modelReply); return; @@ -484,7 +514,7 @@ void HashAndSaveFile::hashAndSave(const QString &expectedHash, QCryptographicHas } else { QFile::FileError error = file.error(); const QString errorString - = QString("ERROR: Could not save model to location: %1 failed with code %1").arg(saveFilePath).arg(error); + = u"ERROR: Could not save model to location: %1 failed with code %1"_s.arg(saveFilePath).arg(error); qWarning() << errorString; tempFile->close(); emit hashAndSaveFinished(false, errorString, tempFile, modelReply); @@ -505,7 +535,7 @@ void Download::handleModelDownloadFinished() if (modelReply->error()) { const QString errorString - = QString("ERROR: Downloading failed with code %1 \"%2\"").arg(modelReply->error()).arg(modelReply->errorString()); + = u"ERROR: Downloading failed with code %1 \"%2\""_s.arg(modelReply->error()).arg(modelReply->errorString()); qWarning() << errorString; modelReply->deleteLater(); tempFile->deleteLater(); diff --git a/gpt4all-chat/download.h b/gpt4all-chat/download.h index 8c1bb021..7fcedd2a 100644 --- a/gpt4all-chat/download.h +++ b/gpt4all-chat/download.h @@ -52,12 +52,14 @@ class Download : public QObject Q_OBJECT Q_PROPERTY(bool hasNewerRelease READ hasNewerRelease NOTIFY hasNewerReleaseChanged) Q_PROPERTY(ReleaseInfo releaseInfo READ releaseInfo NOTIFY releaseInfoChanged) + Q_PROPERTY(QString latestNews READ latestNews NOTIFY latestNewsChanged) public: static Download *globalInstance(); ReleaseInfo releaseInfo() const; bool hasNewerRelease() const; + QString latestNews() const { return m_latestNews; } Q_INVOKABLE void downloadModel(const QString &modelFile); Q_INVOKABLE void cancelDownload(const QString &modelFile); Q_INVOKABLE void installModel(const QString &modelFile, const QString &apiKey); @@ -65,11 +67,13 @@ public: Q_INVOKABLE bool isFirstStart(bool writeVersion = false) const; public Q_SLOTS: + void updateLatestNews(); void updateReleaseNotes(); private Q_SLOTS: void handleSslErrors(QNetworkReply *reply, const QList &errors); void handleReleaseJsonDownloadFinished(); + void handleLatestNewsDownloadFinished(); void handleErrorOccurred(QNetworkReply::NetworkError code); void handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal); void handleModelDownloadFinished(); @@ -82,6 +86,7 @@ Q_SIGNALS: void hasNewerReleaseChanged(); void requestHashAndSave(const QString &hash, QCryptographicHash::Algorithm a, const QString &saveFilePath, QFile *tempFile, QNetworkReply *modelReply); + void latestNewsChanged(); private: void parseReleaseJsonFile(const QByteArray &jsonData); @@ -92,6 +97,7 @@ private: HashAndSaveFile *m_hashAndSave; QMap m_releaseMap; + QString m_latestNews; QNetworkAccessManager m_networkManager; QMap m_activeDownloads; QHash m_activeRetries; diff --git a/gpt4all-chat/embeddings.cpp b/gpt4all-chat/embeddings.cpp deleted file mode 100644 index 97890d97..00000000 --- a/gpt4all-chat/embeddings.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include "embeddings.h" - -#include "mysettings.h" - -#include "hnswlib/hnswlib.h" - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#define EMBEDDINGS_VERSION 0 - -const int s_dim = 384; // Dimension of the elements -const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff -const int s_M = 16; // Tightly connected with internal dimensionality of the data - // strongly affects the memory consumption - -Embeddings::Embeddings(QObject *parent) - : QObject(parent) - , m_space(nullptr) - , m_hnsw(nullptr) -{ - m_filePath = MySettings::globalInstance()->modelPath() - + QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION); -} - -Embeddings::~Embeddings() -{ - delete m_hnsw; - m_hnsw = nullptr; - delete m_space; - m_space = nullptr; -} - -bool Embeddings::load() -{ - QFileInfo info(m_filePath); - if (!info.exists()) { - qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath; - return false; - } - - if (!info.isReadable()) { - qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath; - return false; - } - - if (!info.isWritable()) { - qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath; - return false; - } - - try { - m_space = new hnswlib::InnerProductSpace(s_dim); - m_hnsw = new hnswlib::HierarchicalNSW(m_space, m_filePath.toStdString(), s_M, s_ef_construction); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not load hnswlib index:" << e.what(); - return false; - } - return isLoaded(); -} - -bool Embeddings::load(qint64 maxElements) -{ - try { - m_space = new hnswlib::InnerProductSpace(s_dim); - m_hnsw = new hnswlib::HierarchicalNSW(m_space, maxElements, s_M, s_ef_construction); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not create hnswlib index:" << e.what(); - return false; - } - return isLoaded(); -} - -bool Embeddings::save() -{ - if (!isLoaded()) - return false; - try { - m_hnsw->saveIndex(m_filePath.toStdString()); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not save hnswlib index:" << e.what(); - return false; - } - return true; -} - -bool Embeddings::isLoaded() const -{ - return m_hnsw != nullptr; -} - -bool Embeddings::fileExists() const -{ - QFileInfo info(m_filePath); - return info.exists(); -} - -bool Embeddings::resize(qint64 size) -{ - if (!isLoaded()) { - qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!"; - return false; - } - - Q_ASSERT(m_hnsw); - try { - m_hnsw->resizeIndex(size); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not resize hnswlib index:" << e.what(); - return false; - } - return true; -} - -bool Embeddings::add(const std::vector &embedding, qint64 label) -{ - if (!isLoaded()) { - bool success = load(500); - if (!success) { - qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!"; - return false; - } - } - - Q_ASSERT(m_hnsw); - if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) { - if (!resize(m_hnsw->max_elements_ + 500)) { - return false; - } - } - - if (embedding.empty()) - return false; - - try { - m_hnsw->addPoint(embedding.data(), label, false); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what(); - return false; - } - return true; -} - -void Embeddings::remove(qint64 label) -{ - if (!isLoaded()) { - qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!"; - return; - } - - Q_ASSERT(m_hnsw); - try { - m_hnsw->markDelete(label); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what(); - } -} - -void Embeddings::clear() -{ - delete m_hnsw; - m_hnsw = nullptr; - delete m_space; - m_space = nullptr; -} - -std::vector Embeddings::search(const std::vector &embedding, int K) -{ - if (!isLoaded()) - return {}; - - Q_ASSERT(m_hnsw); - std::priority_queue> result; - try { - result = m_hnsw->searchKnn(embedding.data(), K); - } catch (const std::exception &e) { - qWarning() << "ERROR: could not search hnswlib index:" << e.what(); - return {}; - } - - std::vector neighbors; - neighbors.reserve(K); - - while(!result.empty()) { - neighbors.push_back(result.top().second); - result.pop(); - } - - // Reverse the neighbors, as the top of the priority queue is the farthest neighbor. - std::reverse(neighbors.begin(), neighbors.end()); - - return neighbors; -} diff --git a/gpt4all-chat/embeddings.h b/gpt4all-chat/embeddings.h deleted file mode 100644 index 165e61c0..00000000 --- a/gpt4all-chat/embeddings.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef EMBEDDINGS_H -#define EMBEDDINGS_H - -#include -#include -#include - -#include - -namespace hnswlib { - class InnerProductSpace; - template class HierarchicalNSW; -} - -class Embeddings : public QObject -{ - Q_OBJECT -public: - Embeddings(QObject *parent); - virtual ~Embeddings(); - - bool load(); - bool load(qint64 maxElements); - bool save(); - bool isLoaded() const; - bool fileExists() const; - bool resize(qint64 size); - - // Adds the embedding and returns the label used - bool add(const std::vector &embedding, qint64 label); - - // Removes the embedding at label by marking it as unused - void remove(qint64 label); - - // Clears the embeddings - void clear(); - - // Performs a nearest neighbor search of the embeddings and returns a vector of labels - // for the K nearest neighbors of the given embedding - std::vector search(const std::vector &embedding, int K); - -private: - QString m_filePath; - hnswlib::InnerProductSpace *m_space; - hnswlib::HierarchicalNSW *m_hnsw; -}; - -#endif // EMBEDDINGS_H diff --git a/gpt4all-chat/embllm.cpp b/gpt4all-chat/embllm.cpp index fe6fb4cb..b16b1616 100644 --- a/gpt4all-chat/embllm.cpp +++ b/gpt4all-chat/embllm.cpp @@ -1,6 +1,7 @@ #include "embllm.h" #include "modellist.h" +#include "mysettings.h" #include "../gpt4all-backend/llmodel.h" @@ -13,8 +14,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -24,16 +25,20 @@ #include #include -#include #include +using namespace Qt::Literals::StringLiterals; + +static const QString EMBEDDING_MODEL_NAME = u"nomic-embed-text-v1.5"_s; +static const QString LOCAL_EMBEDDING_MODEL = u"nomic-embed-text-v1.5.f16.gguf"_s; + EmbeddingLLMWorker::EmbeddingLLMWorker() : QObject(nullptr) , m_networkManager(new QNetworkAccessManager(this)) - , m_model(nullptr) , m_stopGenerating(false) { moveToThread(&m_workerThread); + connect(this, &EmbeddingLLMWorker::requestAtlasQueryEmbedding, this, &EmbeddingLLMWorker::atlasQueryEmbeddingRequested); connect(this, &EmbeddingLLMWorker::finished, &m_workerThread, &QThread::quit, Qt::DirectConnection); m_workerThread.setObjectName("embedding"); m_workerThread.start(); @@ -58,44 +63,31 @@ void EmbeddingLLMWorker::wait() bool EmbeddingLLMWorker::loadModel() { - const EmbeddingModels *embeddingModels = ModelList::globalInstance()->installedEmbeddingModels(); - if (!embeddingModels->count()) - return false; + m_nomicAPIKey.clear(); + m_model = nullptr; - const ModelInfo defaultModel = embeddingModels->defaultModelInfo(); - - QString filePath = defaultModel.dirpath + defaultModel.filename(); - QFileInfo fileInfo(filePath); - if (!fileInfo.exists()) { - qWarning() << "WARNING: Could not load sbert because file does not exist"; - m_model = nullptr; - return false; + if (MySettings::globalInstance()->localDocsUseRemoteEmbed()) { + m_nomicAPIKey = MySettings::globalInstance()->localDocsNomicAPIKey(); + return true; } - auto filename = fileInfo.fileName(); - bool isNomic = filename.startsWith("gpt4all-nomic-") && filename.endsWith(".rmodel"); - if (isNomic) { - QFile file(filePath); - if (!file.open(QIODeviceBase::ReadOnly)) { - qWarning() << "failed to open" << filePath << ":" << file.errorString(); - m_model = nullptr; - return false; - } - QJsonDocument doc = QJsonDocument::fromJson(file.readAll()); - QJsonObject obj = doc.object(); - m_nomicAPIKey = obj["apiKey"].toString(); - file.close(); - return true; + QString filePath = u"%1/../resources/%2"_s.arg(QCoreApplication::applicationDirPath(), LOCAL_EMBEDDING_MODEL); + if (!QFileInfo::exists(filePath)) { + qWarning() << "WARNING: Local embedding model not found"; + return false; } try { m_model = LLModel::Implementation::construct(filePath.toStdString()); } catch (const std::exception &e) { qWarning() << "WARNING: Could not load embedding model:" << e.what(); - m_model = nullptr; return false; } + // FIXME(jared): the user may want this to take effect without having to restart + int n_threads = MySettings::globalInstance()->threadCount(); + m_model->setThreadCount(n_threads); + // NOTE: explicitly loads model on CPU to avoid GPU OOM // TODO(cebtenzzre): support GPU-accelerated embeddings bool success = m_model->loadModel(filePath.toStdString(), 2048, 0); @@ -115,31 +107,38 @@ bool EmbeddingLLMWorker::loadModel() return true; } -bool EmbeddingLLMWorker::hasModel() const +std::vector EmbeddingLLMWorker::generateQueryEmbedding(const QString &text) { - return m_model || !m_nomicAPIKey.isEmpty(); -} + { + QMutexLocker locker(&m_mutex); -bool EmbeddingLLMWorker::isNomic() const -{ - return !m_nomicAPIKey.isEmpty(); -} + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return {}; + } -// this function is always called for retrieval tasks -std::vector EmbeddingLLMWorker::generateSyncEmbedding(const QString &text) -{ - Q_ASSERT(!isNomic()); - std::vector embedding(m_model->embeddingSize()); - try { - m_model->embed({text.toStdString()}, embedding.data(), true); - } catch (const std::exception &e) { - qWarning() << "WARNING: LLModel::embed failed: " << e.what(); - return {}; + if (!isNomic()) { + std::vector embedding(m_model->embeddingSize()); + + try { + m_model->embed({text.toStdString()}, embedding.data(), true); + } catch (const std::exception &e) { + qWarning() << "WARNING: LLModel::embed failed:" << e.what(); + return {}; + } + + return embedding; + } } - return embedding; + + EmbeddingLLMWorker worker; + emit worker.requestAtlasQueryEmbedding(text); + worker.wait(); + return worker.lastResponse(); } -void EmbeddingLLMWorker::sendAtlasRequest(const QStringList &texts, const QString &taskType, QVariant userData) { +void EmbeddingLLMWorker::sendAtlasRequest(const QStringList &texts, const QString &taskType, const QVariant &userData) +{ QJsonObject root; root.insert("model", "nomic-embed-text-v1"); root.insert("texts", QJsonArray::fromStringList(texts)); @@ -148,7 +147,7 @@ void EmbeddingLLMWorker::sendAtlasRequest(const QStringList &texts, const QStrin QJsonDocument doc(root); QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text"); - const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed(); + const QString authorization = u"Bearer %1"_s.arg(m_nomicAPIKey).trimmed(); QNetworkRequest request(nomicUrl); request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); request.setRawHeader("Authorization", authorization.toUtf8()); @@ -158,50 +157,63 @@ void EmbeddingLLMWorker::sendAtlasRequest(const QStringList &texts, const QStrin connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished); } -// this function is always called for retrieval tasks -void EmbeddingLLMWorker::requestSyncEmbedding(const QString &text) +void EmbeddingLLMWorker::atlasQueryEmbeddingRequested(const QString &text) { - if (!hasModel() && !loadModel()) { - qWarning() << "WARNING: Could not load model for embeddings"; - return; - } + { + QMutexLocker locker(&m_mutex); + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return; + } - if (!isNomic()) { - qWarning() << "WARNING: Request to generate sync embeddings for local model invalid"; - return; - } + if (!isNomic()) { + qWarning() << "WARNING: Request to generate sync embeddings for local model invalid"; + return; + } - Q_ASSERT(hasModel()); + Q_ASSERT(hasModel()); + } sendAtlasRequest({text}, "search_query"); } -// this function is always called for storage into the database -void EmbeddingLLMWorker::requestAsyncEmbedding(const QVector &chunks) +void EmbeddingLLMWorker::docEmbeddingsRequested(const QVector &chunks) { if (m_stopGenerating) return; - if (!hasModel() && !loadModel()) { - qWarning() << "WARNING: Could not load model for embeddings"; - return; + bool isNomic; + { + QMutexLocker locker(&m_mutex); + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return; + } + + isNomic = this->isNomic(); } - if (m_nomicAPIKey.isEmpty()) { + if (!isNomic) { QVector results; results.reserve(chunks.size()); - for (auto c : chunks) { + for (const auto &c: chunks) { EmbeddingResult result; + result.model = c.model; result.folder_id = c.folder_id; result.chunk_id = c.chunk_id; // TODO(cebtenzzre): take advantage of batched embeddings result.embedding.resize(m_model->embeddingSize()); - try { - m_model->embed({c.chunk.toStdString()}, result.embedding.data(), false); - } catch (const std::exception &e) { - qWarning() << "WARNING: LLModel::embed failed:" << e.what(); - return; + + { + QMutexLocker locker(&m_mutex); + try { + m_model->embed({c.chunk.toStdString()}, result.embedding.data(), false); + } catch (const std::exception &e) { + qWarning() << "WARNING: LLModel::embed failed:" << e.what(); + return; + } } + results << result; } emit embeddingsGenerated(results); @@ -214,14 +226,15 @@ void EmbeddingLLMWorker::requestAsyncEmbedding(const QVector &ch sendAtlasRequest(texts, "search_document", QVariant::fromValue(chunks)); } -std::vector jsonArrayToVector(const QJsonArray &jsonArray) { +std::vector jsonArrayToVector(const QJsonArray &jsonArray) +{ std::vector result; - for (const QJsonValue &innerValue : jsonArray) { + for (const auto &innerValue: jsonArray) { if (innerValue.isArray()) { QJsonArray innerArray = innerValue.toArray(); result.reserve(result.size() + innerArray.size()); - for (const QJsonValue &value : innerArray) { + for (const auto &value: innerArray) { result.push_back(static_cast(value.toDouble())); } } @@ -230,7 +243,8 @@ std::vector jsonArrayToVector(const QJsonArray &jsonArray) { return result; } -QVector jsonArrayToEmbeddingResults(const QVector& chunks, const QJsonArray& embeddings) { +QVector jsonArrayToEmbeddingResults(const QVector& chunks, const QJsonArray& embeddings) +{ QVector results; if (chunks.size() != embeddings.size()) { @@ -243,10 +257,11 @@ QVector jsonArrayToEmbeddingResults(const QVector embeddingVector; - for (const QJsonValue& value : embeddingArray) + for (const auto &value: embeddingArray) embeddingVector.push_back(static_cast(value.toDouble())); EmbeddingResult result; + result.model = chunk.model; result.folder_id = chunk.folder_id; result.chunk_id = chunk.chunk_id; result.embedding = std::move(embeddingVector); @@ -267,10 +282,6 @@ void EmbeddingLLMWorker::handleFinished() if (retrievedData.isValid() && retrievedData.canConvert>()) chunks = retrievedData.value>(); - int folder_id = 0; - if (!chunks.isEmpty()) - folder_id = chunks.first().folder_id; - QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); Q_ASSERT(response.isValid()); bool ok; @@ -279,13 +290,13 @@ void EmbeddingLLMWorker::handleFinished() QString errorDetails; QString replyErrorString = reply->errorString().trimmed(); QByteArray replyContent = reply->readAll().trimmed(); - errorDetails = QString("ERROR: Nomic Atlas responded with error code \"%1\"").arg(code); + errorDetails = u"ERROR: Nomic Atlas responded with error code \"%1\""_s.arg(code); if (!replyErrorString.isEmpty()) - errorDetails += QString(". Error Details: \"%1\"").arg(replyErrorString); + errorDetails += u". Error Details: \"%1\""_s.arg(replyErrorString); if (!replyContent.isEmpty()) - errorDetails += QString(". Response Content: \"%1\"").arg(QString::fromUtf8(replyContent)); + errorDetails += u". Response Content: \"%1\""_s.arg(QString::fromUtf8(replyContent)); qWarning() << errorDetails; - emit errorGenerated(folder_id, errorDetails); + emit errorGenerated(chunks, errorDetails); return; } @@ -294,7 +305,7 @@ void EmbeddingLLMWorker::handleFinished() QJsonParseError err; QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); if (err.error != QJsonParseError::NoError) { - qWarning() << "ERROR: Couldn't parse Nomic Atlas response: " << jsonData << err.errorString(); + qWarning() << "ERROR: Couldn't parse Nomic Atlas response:" << jsonData << err.errorString(); return; } @@ -315,8 +326,8 @@ EmbeddingLLM::EmbeddingLLM() : QObject(nullptr) , m_embeddingWorker(new EmbeddingLLMWorker) { - connect(this, &EmbeddingLLM::requestAsyncEmbedding, m_embeddingWorker, - &EmbeddingLLMWorker::requestAsyncEmbedding, Qt::QueuedConnection); + connect(this, &EmbeddingLLM::requestDocEmbeddings, m_embeddingWorker, + &EmbeddingLLMWorker::docEmbeddingsRequested, Qt::QueuedConnection); connect(m_embeddingWorker, &EmbeddingLLMWorker::embeddingsGenerated, this, &EmbeddingLLM::embeddingsGenerated, Qt::QueuedConnection); connect(m_embeddingWorker, &EmbeddingLLMWorker::errorGenerated, this, @@ -329,26 +340,18 @@ EmbeddingLLM::~EmbeddingLLM() m_embeddingWorker = nullptr; } -std::vector EmbeddingLLM::generateEmbeddings(const QString &text) +QString EmbeddingLLM::model() { - if (!m_embeddingWorker->hasModel() && !m_embeddingWorker->loadModel()) { - qWarning() << "WARNING: Could not load model for embeddings"; - return {}; - } - - if (!m_embeddingWorker->isNomic()) { - return m_embeddingWorker->generateSyncEmbedding(text); - } - - EmbeddingLLMWorker worker; - connect(this, &EmbeddingLLM::requestSyncEmbedding, &worker, - &EmbeddingLLMWorker::requestSyncEmbedding, Qt::QueuedConnection); - emit requestSyncEmbedding(text); - worker.wait(); - return worker.lastResponse(); + return EMBEDDING_MODEL_NAME; } -void EmbeddingLLM::generateAsyncEmbeddings(const QVector &chunks) +// TODO(jared): embed using all necessary embedding models given collection +std::vector EmbeddingLLM::generateQueryEmbedding(const QString &text) { - emit requestAsyncEmbedding(chunks); + return m_embeddingWorker->generateQueryEmbedding(text); +} + +void EmbeddingLLM::generateDocEmbeddingsAsync(const QVector &chunks) +{ + emit requestDocEmbeddings(chunks); } diff --git a/gpt4all-chat/embllm.h b/gpt4all-chat/embllm.h index 06ec94d6..91376650 100644 --- a/gpt4all-chat/embllm.h +++ b/gpt4all-chat/embllm.h @@ -2,6 +2,7 @@ #define EMBLLM_H #include +#include #include #include #include @@ -16,6 +17,7 @@ class LLModel; class QNetworkAccessManager; struct EmbeddingChunk { + QString model; // TODO(jared): use to select model int folder_id; int chunk_id; QString chunk; @@ -24,6 +26,7 @@ struct EmbeddingChunk { Q_DECLARE_METATYPE(EmbeddingChunk) struct EmbeddingResult { + QString model; int folder_id; int chunk_id; std::vector embedding; @@ -33,32 +36,33 @@ class EmbeddingLLMWorker : public QObject { Q_OBJECT public: EmbeddingLLMWorker(); - virtual ~EmbeddingLLMWorker(); + ~EmbeddingLLMWorker() override; void wait(); std::vector lastResponse() const { return m_lastResponse; } bool loadModel(); - bool hasModel() const; - bool isNomic() const; + bool isNomic() const { return !m_nomicAPIKey.isEmpty(); } + bool hasModel() const { return isNomic() || m_model; } - std::vector generateSyncEmbedding(const QString &text); + std::vector generateQueryEmbedding(const QString &text); public Q_SLOTS: - void requestSyncEmbedding(const QString &text); - void requestAsyncEmbedding(const QVector &chunks); + void atlasQueryEmbeddingRequested(const QString &text); + void docEmbeddingsRequested(const QVector &chunks); Q_SIGNALS: + void requestAtlasQueryEmbedding(const QString &text); void embeddingsGenerated(const QVector &embeddings); - void errorGenerated(int folder_id, const QString &error); + void errorGenerated(const QVector &chunks, const QString &error); void finished(); private Q_SLOTS: void handleFinished(); private: - void sendAtlasRequest(const QStringList &texts, const QString &taskType, QVariant userData = {}); + void sendAtlasRequest(const QStringList &texts, const QString &taskType, const QVariant &userData = {}); QString m_nomicAPIKey; QNetworkAccessManager *m_networkManager; @@ -66,6 +70,7 @@ private: LLModel *m_model = nullptr; std::atomic m_stopGenerating; QThread m_workerThread; + QMutex m_mutex; // guards m_model and m_nomicAPIKey }; class EmbeddingLLM : public QObject @@ -73,20 +78,20 @@ class EmbeddingLLM : public QObject Q_OBJECT public: EmbeddingLLM(); - virtual ~EmbeddingLLM(); + ~EmbeddingLLM() override; + static QString model(); bool loadModel(); bool hasModel() const; public Q_SLOTS: - std::vector generateEmbeddings(const QString &text); // synchronous - void generateAsyncEmbeddings(const QVector &chunks); + std::vector generateQueryEmbedding(const QString &text); // synchronous + void generateDocEmbeddingsAsync(const QVector &chunks); Q_SIGNALS: - void requestSyncEmbedding(const QString &text); - void requestAsyncEmbedding(const QVector &chunks); + void requestDocEmbeddings(const QVector &chunks); void embeddingsGenerated(const QVector &embeddings); - void errorGenerated(int folder_id, const QString &error); + void errorGenerated(const QVector &chunks, const QString &error); private: EmbeddingLLMWorker *m_embeddingWorker; diff --git a/gpt4all-chat/hnswlib/bruteforce.h b/gpt4all-chat/hnswlib/bruteforce.h deleted file mode 100644 index 30b33ae9..00000000 --- a/gpt4all-chat/hnswlib/bruteforce.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -namespace hnswlib { -template -class BruteforceSearch : public AlgorithmInterface { - public: - char *data_; - size_t maxelements_; - size_t cur_element_count; - size_t size_per_element_; - - size_t data_size_; - DISTFUNC fstdistfunc_; - void *dist_func_param_; - std::mutex index_lock; - - std::unordered_map dict_external_to_internal; - - - BruteforceSearch(SpaceInterface *s) - : data_(nullptr), - maxelements_(0), - cur_element_count(0), - size_per_element_(0), - data_size_(0), - dist_func_param_(nullptr) { - } - - - BruteforceSearch(SpaceInterface *s, const std::string &location) - : data_(nullptr), - maxelements_(0), - cur_element_count(0), - size_per_element_(0), - data_size_(0), - dist_func_param_(nullptr) { - loadIndex(location, s); - } - - - BruteforceSearch(SpaceInterface *s, size_t maxElements) { - maxelements_ = maxElements; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxElements * size_per_element_); - if (data_ == nullptr) - throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); - cur_element_count = 0; - } - - - ~BruteforceSearch() { - free(data_); - } - - - void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { - int idx; - { - std::unique_lock lock(index_lock); - - auto search = dict_external_to_internal.find(label); - if (search != dict_external_to_internal.end()) { - idx = search->second; - } else { - if (cur_element_count >= maxelements_) { - throw std::runtime_error("The number of elements exceeds the specified limit\n"); - } - idx = cur_element_count; - dict_external_to_internal[label] = idx; - cur_element_count++; - } - } - memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); - memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); - } - - - void removePoint(labeltype cur_external) { - size_t cur_c = dict_external_to_internal[cur_external]; - - dict_external_to_internal.erase(cur_external); - - labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); - dict_external_to_internal[label] = cur_c; - memcpy(data_ + size_per_element_ * cur_c, - data_ + size_per_element_ * (cur_element_count-1), - data_size_+sizeof(labeltype)); - cur_element_count--; - } - - - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { - assert(k <= cur_element_count); - std::priority_queue> topResults; - if (cur_element_count == 0) return topResults; - for (int i = 0; i < k; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); - } - } - dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; - for (int i = k; i < cur_element_count; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - if (dist <= lastdist) { - labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); - } - if (topResults.size() > k) - topResults.pop(); - - if (!topResults.empty()) { - lastdist = topResults.top().first; - } - } - } - return topResults; - } - - - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, maxelements_); - writeBinaryPOD(output, size_per_element_); - writeBinaryPOD(output, cur_element_count); - - output.write(data_, maxelements_ * size_per_element_); - - output.close(); - } - - - void loadIndex(const std::string &location, SpaceInterface *s) { - std::ifstream input(location, std::ios::binary); - std::streampos position; - - readBinaryPOD(input, maxelements_); - readBinaryPOD(input, size_per_element_); - readBinaryPOD(input, cur_element_count); - - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxelements_ * size_per_element_); - if (data_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); - - input.read(data_, maxelements_ * size_per_element_); - - input.close(); - } -}; -} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/hnswalg.h b/gpt4all-chat/hnswlib/hnswalg.h deleted file mode 100644 index bef00170..00000000 --- a/gpt4all-chat/hnswlib/hnswalg.h +++ /dev/null @@ -1,1271 +0,0 @@ -#pragma once - -#include "visited_list_pool.h" -#include "hnswlib.h" -#include -#include -#include -#include -#include -#include - -namespace hnswlib { -typedef unsigned int tableint; -typedef unsigned int linklistsizeint; - -template -class HierarchicalNSW : public AlgorithmInterface { - public: - static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; - static const unsigned char DELETE_MARK = 0x01; - - size_t max_elements_{0}; - mutable std::atomic cur_element_count{0}; // current number of elements - size_t size_data_per_element_{0}; - size_t size_links_per_element_{0}; - mutable std::atomic num_deleted_{0}; // number of deleted elements - size_t M_{0}; - size_t maxM_{0}; - size_t maxM0_{0}; - size_t ef_construction_{0}; - size_t ef_{ 0 }; - - double mult_{0.0}, revSize_{0.0}; - int maxlevel_{0}; - - VisitedListPool *visited_list_pool_{nullptr}; - - // Locks operations with element by label value - mutable std::vector label_op_locks_; - - std::mutex global; - std::vector link_list_locks_; - - tableint enterpoint_node_{0}; - - size_t size_links_level0_{0}; - size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; - - char *data_level0_memory_{nullptr}; - char **linkLists_{nullptr}; - std::vector element_levels_; // keeps level of each element - - size_t data_size_{0}; - - DISTFUNC fstdistfunc_; - void *dist_func_param_{nullptr}; - - mutable std::mutex label_lookup_lock; // lock for label_lookup_ - std::unordered_map label_lookup_; - - std::default_random_engine level_generator_; - std::default_random_engine update_probability_generator_; - - mutable std::atomic metric_distance_computations{0}; - mutable std::atomic metric_hops{0}; - - bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions - - std::mutex deleted_elements_lock; // lock for deleted_elements - std::unordered_set deleted_elements; // contains internal ids of deleted elements - - - HierarchicalNSW(SpaceInterface *s) { - } - - - HierarchicalNSW( - SpaceInterface *s, - const std::string &location, - bool nmslib = false, - size_t max_elements = 0, - bool allow_replace_deleted = false) - : allow_replace_deleted_(allow_replace_deleted) { - loadIndex(location, s, max_elements); - } - - - HierarchicalNSW( - SpaceInterface *s, - size_t max_elements, - size_t M = 16, - size_t ef_construction = 200, - size_t random_seed = 100, - bool allow_replace_deleted = false) - : link_list_locks_(max_elements), - label_op_locks_(MAX_LABEL_OPERATION_LOCKS), - element_levels_(max_elements), - allow_replace_deleted_(allow_replace_deleted) { - max_elements_ = max_elements; - num_deleted_ = 0; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - M_ = M; - maxM_ = M_; - maxM0_ = M_ * 2; - ef_construction_ = std::max(ef_construction, M_); - ef_ = 10; - - level_generator_.seed(random_seed); - update_probability_generator_.seed(random_seed + 1); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); - offsetData_ = size_links_level0_; - label_offset_ = size_links_level0_ + data_size_; - offsetLevel0_ = 0; - - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); - - cur_element_count = 0; - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - // initializations for special treatment of the first node - enterpoint_node_ = -1; - maxlevel_ = -1; - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; - } - - - ~HierarchicalNSW() { - free(data_level0_memory_); - for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) - free(linkLists_[i]); - } - free(linkLists_); - delete visited_list_pool_; - } - - - struct CompareByFirst { - constexpr bool operator()(std::pair const& a, - std::pair const& b) const noexcept { - return a.first < b.first; - } - }; - - - void setEf(size_t ef) { - ef_ = ef; - } - - - inline std::mutex& getLabelOpMutex(labeltype label) const { - // calculate hash - size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); - return label_op_locks_[lock_id]; - } - - - inline labeltype getExternalLabel(tableint internal_id) const { - labeltype return_label; - memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); - return return_label; - } - - - inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); - } - - - inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); - } - - - inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); - } - - - int getRandomLevel(double reverse_size) { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; - return (int) r; - } - - size_t getMaxElements() { - return max_elements_; - } - - size_t getCurrentElementCount() { - return cur_element_count; - } - - size_t getDeletedCount() { - return num_deleted_; - } - - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidateSet; - - dist_t lowerBound; - if (!isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - top_candidates.emplace(dist, ep_id); - lowerBound = dist; - candidateSet.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidateSet.emplace(-lowerBound, ep_id); - } - visited_array[ep_id] = visited_array_tag; - - while (!candidateSet.empty()) { - std::pair curr_el_pair = candidateSet.top(); - if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { - break; - } - candidateSet.pop(); - - tableint curNodeNum = curr_el_pair.second; - - std::unique_lock lock(link_list_locks_[curNodeNum]); - - int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); -// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); - tableint *datal = (tableint *) (data + 1); -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); -#endif - - for (size_t j = 0; j < size; j++) { - tableint candidate_id = *(datal + j); -// if (candidate_id == 0) continue; -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); -#endif - if (visited_array[candidate_id] == visited_array_tag) continue; - visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); - - dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { - candidateSet.emplace(-dist1, candidate_id); -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); -#endif - - if (!isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist1, candidate_id); - - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - } - visited_list_pool_->releaseVisitedList(vl); - - return top_candidates; - } - - - template - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidate_set; - - dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_array[ep_id] = visited_array_tag; - - while (!candidate_set.empty()) { - std::pair current_node_pair = candidate_set.top(); - - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { - break; - } - candidate_set.pop(); - - tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); -// bool cur_node_deleted = isMarkedDeleted(current_node_id); - if (collect_metrics) { - metric_hops++; - metric_distance_computations+=size; - } - -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); - _mm_prefetch((char *) (data + 2), _MM_HINT_T0); -#endif - - for (size_t j = 1; j <= size; j++) { - int candidate_id = *(data + j); -// if (candidate_id == 0) continue; -#ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0); //////////// -#endif - if (!(visited_array[candidate_id] == visited_array_tag)) { - visited_array[candidate_id] = visited_array_tag; - - char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - - if (top_candidates.size() < ef || lowerBound > dist) { - candidate_set.emplace(-dist, candidate_id); -#ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_, /////////// - _MM_HINT_T0); //////////////////////// -#endif - - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) - top_candidates.emplace(dist, candidate_id); - - if (top_candidates.size() > ef) - top_candidates.pop(); - - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - } - } - - visited_list_pool_->releaseVisitedList(vl); - return top_candidates; - } - - - void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { - if (top_candidates.size() < M) { - return; - } - - std::priority_queue> queue_closest; - std::vector> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } - - while (queue_closest.size()) { - if (return_list.size() >= M) - break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - - for (std::pair second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), - dist_func_param_); - if (curdist < dist_to_query) { - good = false; - break; - } - } - if (good) { - return_list.push_back(curent_pair); - } - } - - for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); - } - } - - - linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - } - - - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - } - - - linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); - } - - - linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { - return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); - } - - - tableint mutuallyConnectNewElement( - const void *data_point, - tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level, - bool isUpdate) { - size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); - if (top_candidates.size() > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - - std::vector selectedNeighbors; - selectedNeighbors.reserve(M_); - while (top_candidates.size() > 0) { - selectedNeighbors.push_back(top_candidates.top().second); - top_candidates.pop(); - } - - tableint next_closest_entry_point = selectedNeighbors.back(); - - { - // lock only during the update - // because during the addition the lock for cur_c is already acquired - std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); - if (isUpdate) { - lock.lock(); - } - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); - - if (*ll_cur && !isUpdate) { - throw std::runtime_error("The newly inserted element should have blank link list"); - } - setListCount(ll_cur, selectedNeighbors.size()); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx] && !isUpdate) - throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - - data[idx] = selectedNeighbors[idx]; - } - } - - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); - - size_t sz_link_list_other = getListCount(ll_other); - - if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); - if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - - tableint *data = (tableint *) (ll_other + 1); - - bool is_cur_c_present = false; - if (isUpdate) { - for (size_t j = 0; j < sz_link_list_other; j++) { - if (data[j] == cur_c) { - is_cur_c_present = true; - break; - } - } - } - - // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. - if (!is_cur_c_present) { - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); - - for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); - } - - getNeighborsByHeuristic2(candidates, Mcurmax); - - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } - - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; - } - } - if (indx >= 0) { - data[indx] = cur_c; - } */ - } - } - } - - return next_closest_entry_point; - } - - - void resizeIndex(size_t new_max_elements) { - if (new_max_elements < cur_element_count) - throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); - - element_levels_.resize(new_max_elements); - - std::vector(new_max_elements).swap(link_list_locks_); - - // Reallocate base layer - char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - data_level0_memory_ = data_level0_memory_new; - - // Reallocate all other layers - char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); - if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); - linkLists_ = linkLists_new; - - max_elements_ = new_max_elements; - } - - - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); - writeBinaryPOD(output, label_offset_); - writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); - - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); - - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); - } - output.close(); - } - - - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { - std::ifstream input(location, std::ios::binary); - - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); - - // get file size: - input.seekg(0, input.end); - std::streampos total_filesize = input.tellg(); - input.seekg(0, input.beg); - - readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); - - size_t max_elements = max_elements_i; - if (max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); - readBinaryPOD(input, label_offset_); - readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); - - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); - - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - - auto pos = input.tellg(); - - /// Optional - check if index is ok: - input.seekg(cur_element_count * size_data_per_element_, input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if (input.tellg() < 0 || input.tellg() >= total_filesize) { - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } - - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize, input.cur); - } - } - - // throw exception if it either corrupted or old index - if (input.tellg() != total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); - - input.clear(); - /// Optional check end - - input.seekg(pos, input.beg); - - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector(max_elements).swap(link_list_locks_); - std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { - label_lookup_[getExternalLabel(i)] = i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); - } - } - - for (size_t i = 0; i < cur_element_count; i++) { - if (isMarkedDeleted(i)) { - num_deleted_ += 1; - if (allow_replace_deleted_) deleted_elements.insert(i); - } - } - - input.close(); - - return; - } - - - template - std::vector getDataByLabel(labeltype label) const { - // lock all operations with element by label - std::unique_lock lock_label(getLabelOpMutex(label)); - - std::unique_lock lock_table(label_lookup_lock); - auto search = label_lookup_.find(label); - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - lock_table.unlock(); - - char* data_ptrv = getDataByInternalId(internalId); - size_t dim = *((size_t *) dist_func_param_); - std::vector data; - data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { - data.push_back(*data_ptr); - data_ptr += 1; - } - return data; - } - - - /* - * Marks an element with the given label deleted, does NOT really change the current graph. - */ - void markDelete(labeltype label) { - // lock all operations with element by label - std::unique_lock lock_label(getLabelOpMutex(label)); - - std::unique_lock lock_table(label_lookup_lock); - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - lock_table.unlock(); - - markDeletedInternal(internalId); - } - - - /* - * Uses the last 16 bits of the memory for the linked list size to store the mark, - * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. - */ - void markDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (!isMarkedDeleted(internalId)) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - num_deleted_ += 1; - if (allow_replace_deleted_) { - std::unique_lock lock_deleted_elements(deleted_elements_lock); - deleted_elements.insert(internalId); - } - } else { - throw std::runtime_error("The requested to delete element is already deleted"); - } - } - - - /* - * Removes the deleted mark of the node, does NOT really change the current graph. - * - * Note: the method is not safe to use when replacement of deleted elements is enabled, - * because elements marked as deleted can be completely removed by addPoint - */ - void unmarkDelete(labeltype label) { - // lock all operations with element by label - std::unique_lock lock_label(getLabelOpMutex(label)); - - std::unique_lock lock_table(label_lookup_lock); - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - lock_table.unlock(); - - unmarkDeletedInternal(internalId); - } - - - - /* - * Remove the deleted mark of the node. - */ - void unmarkDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (isMarkedDeleted(internalId)) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; - *ll_cur &= ~DELETE_MARK; - num_deleted_ -= 1; - if (allow_replace_deleted_) { - std::unique_lock lock_deleted_elements(deleted_elements_lock); - deleted_elements.erase(internalId); - } - } else { - throw std::runtime_error("The requested to undelete element is not deleted"); - } - } - - - /* - * Checks the first 16 bits of the memory to see if the element is marked deleted. - */ - bool isMarkedDeleted(tableint internalId) const { - unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; - return *ll_cur & DELETE_MARK; - } - - - unsigned short int getListCount(linklistsizeint * ptr) const { - return *((unsigned short int *)ptr); - } - - - void setListCount(linklistsizeint * ptr, unsigned short int size) const { - *((unsigned short int*)(ptr))=*((unsigned short int *)&size); - } - - - /* - * Adds point. Updates the point if it is already in the index. - * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point - */ - void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { - if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { - throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); - } - - // lock all operations with element by label - std::unique_lock lock_label(getLabelOpMutex(label)); - if (!replace_deleted) { - addPoint(data_point, label, -1); - return; - } - // check if there is vacant place - tableint internal_id_replaced; - std::unique_lock lock_deleted_elements(deleted_elements_lock); - bool is_vacant_place = !deleted_elements.empty(); - if (is_vacant_place) { - internal_id_replaced = *deleted_elements.begin(); - deleted_elements.erase(internal_id_replaced); - } - lock_deleted_elements.unlock(); - - // if there is no vacant place then add or update point - // else add point to vacant place - if (!is_vacant_place) { - addPoint(data_point, label, -1); - } else { - // we assume that there are no concurrent operations on deleted element - labeltype label_replaced = getExternalLabel(internal_id_replaced); - setExternalLabel(internal_id_replaced, label); - - std::unique_lock lock_table(label_lookup_lock); - label_lookup_.erase(label_replaced); - label_lookup_[label] = internal_id_replaced; - lock_table.unlock(); - - unmarkDeletedInternal(internal_id_replaced); - updatePoint(data_point, internal_id_replaced, 1.0); - } - } - - - void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { - // update the feature vector associated with existing point with new vector - memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - - int maxLevelCopy = maxlevel_; - tableint entryPointCopy = enterpoint_node_; - // If point to be updated is entry point and graph just contains single element then just return. - if (entryPointCopy == internalId && cur_element_count == 1) - return; - - int elemLevel = element_levels_[internalId]; - std::uniform_real_distribution distribution(0.0, 1.0); - for (int layer = 0; layer <= elemLevel; layer++) { - std::unordered_set sCand; - std::unordered_set sNeigh; - std::vector listOneHop = getConnectionsWithLock(internalId, layer); - if (listOneHop.size() == 0) - continue; - - sCand.insert(internalId); - - for (auto&& elOneHop : listOneHop) { - sCand.insert(elOneHop); - - if (distribution(update_probability_generator_) > updateNeighborProbability) - continue; - - sNeigh.insert(elOneHop); - - std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); - for (auto&& elTwoHop : listTwoHop) { - sCand.insert(elTwoHop); - } - } - - for (auto&& neigh : sNeigh) { - // if (neigh == internalId) - // continue; - - std::priority_queue, std::vector>, CompareByFirst> candidates; - size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 - size_t elementsToKeep = std::min(ef_construction_, size); - for (auto&& cand : sCand) { - if (cand == neigh) - continue; - - dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); - if (candidates.size() < elementsToKeep) { - candidates.emplace(distance, cand); - } else { - if (distance < candidates.top().first) { - candidates.pop(); - candidates.emplace(distance, cand); - } - } - } - - // Retrieve neighbours using heuristic and set connections. - getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); - - { - std::unique_lock lock(link_list_locks_[neigh]); - linklistsizeint *ll_cur; - ll_cur = get_linklist_at_level(neigh, layer); - size_t candSize = candidates.size(); - setListCount(ll_cur, candSize); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < candSize; idx++) { - data[idx] = candidates.top().second; - candidates.pop(); - } - } - } - } - - repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); - } - - - void repairConnectionsForUpdate( - const void *dataPoint, - tableint entryPointInternalId, - tableint dataPointInternalId, - int dataPointLevel, - int maxLevel) { - tableint currObj = entryPointInternalId; - if (dataPointLevel < maxLevel) { - dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxLevel; level > dataPointLevel; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist_at_level(currObj, level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); -#endif - for (int i = 0; i < size; i++) { -#ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); -#endif - tableint cand = datal[i]; - dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - } - - if (dataPointLevel > maxLevel) - throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); - - for (int level = dataPointLevel; level >= 0; level--) { - std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( - currObj, dataPoint, level); - - std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; - while (topCandidates.size() > 0) { - if (topCandidates.top().second != dataPointInternalId) - filteredTopCandidates.push(topCandidates.top()); - - topCandidates.pop(); - } - - // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. - // To prevent self loops, the `topCandidates` is filtered and thus can be empty. - if (filteredTopCandidates.size() > 0) { - bool epDeleted = isMarkedDeleted(entryPointInternalId); - if (epDeleted) { - filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); - if (filteredTopCandidates.size() > ef_construction_) - filteredTopCandidates.pop(); - } - - currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); - } - } - } - - - std::vector getConnectionsWithLock(tableint internalId, int level) { - std::unique_lock lock(link_list_locks_[internalId]); - unsigned int *data = get_linklist_at_level(internalId, level); - int size = getListCount(data); - std::vector result(size); - tableint *ll = (tableint *) (data + 1); - memcpy(result.data(), ll, size * sizeof(tableint)); - return result; - } - - - tableint addPoint(const void *data_point, labeltype label, int level) { - tableint cur_c = 0; - { - // Checking if the element with the same label already exists - // if so, updating it *instead* of creating a new element. - std::unique_lock lock_table(label_lookup_lock); - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - tableint existingInternalId = search->second; - if (allow_replace_deleted_) { - if (isMarkedDeleted(existingInternalId)) { - throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); - } - } - lock_table.unlock(); - - if (isMarkedDeleted(existingInternalId)) { - unmarkDeletedInternal(existingInternalId); - } - updatePoint(data_point, existingInternalId, 1.0); - - return existingInternalId; - } - - if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); - } - - cur_c = cur_element_count; - cur_element_count++; - label_lookup_[label] = cur_c; - } - - std::unique_lock lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; - - element_levels_[cur_c] = curlevel; - - std::unique_lock templock(global); - int maxlevelcopy = maxlevel_; - if (curlevel <= maxlevelcopy) - templock.unlock(); - tableint currObj = enterpoint_node_; - tableint enterpoint_copy = enterpoint_node_; - - memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - - // Initialisation of the data and label - memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); - - if (curlevel) { - linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); - memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); - } - - if ((signed)currObj != -1) { - if (curlevel < maxlevelcopy) { - dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxlevelcopy; level > curlevel; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist(currObj, level); - int size = getListCount(data); - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - } - - bool epDeleted = isMarkedDeleted(enterpoint_copy); - for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { - if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); - - std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, data_point, level); - if (epDeleted) { - top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - } - currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); - } - } else { - // Do nothing for the first element - enterpoint_node_ = 0; - maxlevel_ = curlevel; - } - - // Releasing lock for the maximum level - if (curlevel > maxlevelcopy) { - enterpoint_node_ = cur_c; - maxlevel_ = curlevel; - } - return cur_c; - } - - - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { - std::priority_queue> result; - if (cur_element_count == 0) return result; - - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); - - for (int level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - - data = (unsigned int *) get_linklist(currObj, level); - int size = getListCount(data); - metric_hops++; - metric_distance_computations+=size; - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); - - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); - } else { - top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); - } - - while (top_candidates.size() > k) { - top_candidates.pop(); - } - while (top_candidates.size() > 0) { - std::pair rez = top_candidates.top(); - result.push(std::pair(rez.first, getExternalLabel(rez.second))); - top_candidates.pop(); - } - return result; - } - - - void checkIntegrity() { - int connections_checked = 0; - std::vector inbound_connections_num(cur_element_count, 0); - for (int i = 0; i < cur_element_count; i++) { - for (int l = 0; l <= element_levels_[i]; l++) { - linklistsizeint *ll_cur = get_linklist_at_level(i, l); - int size = getListCount(ll_cur); - tableint *data = (tableint *) (ll_cur + 1); - std::unordered_set s; - for (int j = 0; j < size; j++) { - assert(data[j] > 0); - assert(data[j] < cur_element_count); - assert(data[j] != i); - inbound_connections_num[data[j]]++; - s.insert(data[j]); - connections_checked++; - } - assert(s.size() == size); - } - } - if (cur_element_count > 1) { - int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; - for (int i=0; i < cur_element_count; i++) { - assert(inbound_connections_num[i] > 0); - min1 = std::min(inbound_connections_num[i], min1); - max1 = std::max(inbound_connections_num[i], max1); - } - std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; - } - std::cout << "integrity ok, checked " << connections_checked << " connections\n"; - } -}; -} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/hnswlib.h b/gpt4all-chat/hnswlib/hnswlib.h deleted file mode 100644 index fb7118fa..00000000 --- a/gpt4all-chat/hnswlib/hnswlib.h +++ /dev/null @@ -1,199 +0,0 @@ -#pragma once -#ifndef NO_MANUAL_VECTORIZATION -#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) -#define USE_SSE -#ifdef __AVX__ -#define USE_AVX -#ifdef __AVX512F__ -#define USE_AVX512 -#endif -#endif -#endif -#endif - -#if defined(USE_AVX) || defined(USE_SSE) -#ifdef _MSC_VER -#include -#include -void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { - __cpuidex(out, eax, ecx); -} -static __int64 xgetbv(unsigned int x) { - return _xgetbv(x); -} -#else -#include -#include -#include -static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { - __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); -} -static uint64_t xgetbv(unsigned int index) { - uint32_t eax, edx; - __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); - return ((uint64_t)edx << 32) | eax; -} -#endif - -#if defined(USE_AVX512) -#include -#endif - -#if defined(__GNUC__) -#define PORTABLE_ALIGN32 __attribute__((aligned(32))) -#define PORTABLE_ALIGN64 __attribute__((aligned(64))) -#else -#define PORTABLE_ALIGN32 __declspec(align(32)) -#define PORTABLE_ALIGN64 __declspec(align(64)) -#endif - -// Adapted from https://github.com/Mysticial/FeatureDetector -#define _XCR_XFEATURE_ENABLED_MASK 0 - -static bool AVXCapable() { - int cpuInfo[4]; - - // CPU support - cpuid(cpuInfo, 0, 0); - int nIds = cpuInfo[0]; - - bool HW_AVX = false; - if (nIds >= 0x00000001) { - cpuid(cpuInfo, 0x00000001, 0); - HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; - } - - // OS support - cpuid(cpuInfo, 1, 0); - - bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; - bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; - - bool avxSupported = false; - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); - avxSupported = (xcrFeatureMask & 0x6) == 0x6; - } - return HW_AVX && avxSupported; -} - -static bool AVX512Capable() { - if (!AVXCapable()) return false; - - int cpuInfo[4]; - - // CPU support - cpuid(cpuInfo, 0, 0); - int nIds = cpuInfo[0]; - - bool HW_AVX512F = false; - if (nIds >= 0x00000007) { // AVX512 Foundation - cpuid(cpuInfo, 0x00000007, 0); - HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; - } - - // OS support - cpuid(cpuInfo, 1, 0); - - bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; - bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; - - bool avx512Supported = false; - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); - avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; - } - return HW_AVX512F && avx512Supported; -} -#endif - -#include -#include -#include -#include - -namespace hnswlib { -typedef size_t labeltype; - -// This can be extended to store state for filtering (e.g. from a std::set) -class BaseFilterFunctor { - public: - virtual bool operator()(hnswlib::labeltype id) { return true; } -}; - -template -class pairGreater { - public: - bool operator()(const T& p1, const T& p2) { - return p1.first > p2.first; - } -}; - -template -static void writeBinaryPOD(std::ostream &out, const T &podRef) { - out.write((char *) &podRef, sizeof(T)); -} - -template -static void readBinaryPOD(std::istream &in, T &podRef) { - in.read((char *) &podRef, sizeof(T)); -} - -template -using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - -template -class SpaceInterface { - public: - // virtual void search(void *); - virtual size_t get_data_size() = 0; - - virtual DISTFUNC get_dist_func() = 0; - - virtual void *get_dist_func_param() = 0; - - virtual ~SpaceInterface() {} -}; - -template -class AlgorithmInterface { - public: - virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; - - virtual std::priority_queue> - searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; - - // Return k nearest neighbor in the order of closer fist - virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; - - virtual void saveIndex(const std::string &location) = 0; - virtual ~AlgorithmInterface(){ - } -}; - -template -std::vector> -AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - BaseFilterFunctor* isIdAllowed) const { - std::vector> result; - - // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k, isIdAllowed); - { - size_t sz = ret.size(); - result.resize(sz); - while (!ret.empty()) { - result[--sz] = ret.top(); - ret.pop(); - } - } - - return result; -} -} // namespace hnswlib - -#include "space_l2.h" -#include "space_ip.h" -#include "bruteforce.h" -#include "hnswalg.h" diff --git a/gpt4all-chat/hnswlib/space_ip.h b/gpt4all-chat/hnswlib/space_ip.h deleted file mode 100644 index 2b1c359e..00000000 --- a/gpt4all-chat/hnswlib/space_ip.h +++ /dev/null @@ -1,375 +0,0 @@ -#pragma once -#include "hnswlib.h" - -namespace hnswlib { - -static float -InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - float res = 0; - for (unsigned i = 0; i < qty; i++) { - res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; - } - return res; -} - -static float -InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { - return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); -} - -#if defined(USE_AVX) - -// Favor using AVX if available. -static float -InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m256 sum256 = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } - - __m128 v1, v2; - __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); - - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - return sum; -} - -static float -InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); -} - -#endif - -#if defined(USE_SSE) - -static float -InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return sum; -} - -static float -InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); -} - -#endif - - -#if defined(USE_AVX512) - -static float -InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN64 TmpRes[16]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - - - const float *pEnd1 = pVect1 + 16 * qty16; - - __m512 sum512 = _mm512_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - - __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - __m512 v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); - } - - _mm512_store_ps(TmpRes, sum512); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; - - return sum; -} - -static float -InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); -} - -#endif - -#if defined(USE_AVX) - -static float -InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - - - const float *pEnd1 = pVect1 + 16 * qty16; - - __m256 sum256 = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } - - _mm256_store_ps(TmpRes, sum256); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; - - return sum; -} - -static float -InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); -} - -#endif - -#if defined(USE_SSE) - -static float -InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - - const float *pEnd1 = pVect1 + 16 * qty16; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return sum; -} - -static float -InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); -} - -#endif - -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) -static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; -static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; -static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; -static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; - -static float -InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - return 1.0f - (res + res_tail); -} - -static float -InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; - - float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; - - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - - return 1.0f - (res + res_tail); -} -#endif - -class InnerProductSpace : public SpaceInterface { - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - - public: - InnerProductSpace(size_t dim) { - fstdistfunc_ = InnerProductDistance; -#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; - } else if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #elif defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #endif - #if defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; - InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; - } - #endif - - if (dim % 16 == 0) - fstdistfunc_ = InnerProductDistanceSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = InnerProductDistanceSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; -#endif - dim_ = dim; - data_size_ = dim * sizeof(float); - } - - size_t get_data_size() { - return data_size_; - } - - DISTFUNC get_dist_func() { - return fstdistfunc_; - } - - void *get_dist_func_param() { - return &dim_; - } - -~InnerProductSpace() {} -}; - -} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/space_l2.h b/gpt4all-chat/hnswlib/space_l2.h deleted file mode 100644 index 834d19f7..00000000 --- a/gpt4all-chat/hnswlib/space_l2.h +++ /dev/null @@ -1,324 +0,0 @@ -#pragma once -#include "hnswlib.h" - -namespace hnswlib { - -static float -L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - float res = 0; - for (size_t i = 0; i < qty; i++) { - float t = *pVect1 - *pVect2; - pVect1++; - pVect2++; - res += t * t; - } - return (res); -} - -#if defined(USE_AVX512) - -// Favor using AVX512 if available. -static float -L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN64 TmpRes[16]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m512 diff, v1, v2; - __m512 sum = _mm512_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - diff = _mm512_sub_ps(v1, v2); - // sum = _mm512_fmadd_ps(diff, diff, sum); - sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); - } - - _mm512_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + - TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + - TmpRes[13] + TmpRes[14] + TmpRes[15]; - - return (res); -} -#endif - -#if defined(USE_AVX) - -// Favor using AVX if available. -static float -L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m256 diff, v1, v2; - __m256 sum = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - } - - _mm256_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; -} - -#endif - -#if defined(USE_SSE) - -static float -L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; -} -#endif - -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) -static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; - -static float -L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); -} -#endif - - -#if defined(USE_SSE) -static float -L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - - size_t qty4 = qty >> 2; - - const float *pEnd1 = pVect1 + (qty4 << 2); - - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; -} - -static float -L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; - - float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; - - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - - return (res + res_tail); -} -#endif - -class L2Space : public SpaceInterface { - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - - public: - L2Space(size_t dim) { - fstdistfunc_ = L2Sqr; -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; - else if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #elif defined(USE_AVX) - if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #endif - - if (dim % 16 == 0) - fstdistfunc_ = L2SqrSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = L2SqrSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = L2SqrSIMD4ExtResiduals; -#endif - dim_ = dim; - data_size_ = dim * sizeof(float); - } - - size_t get_data_size() { - return data_size_; - } - - DISTFUNC get_dist_func() { - return fstdistfunc_; - } - - void *get_dist_func_param() { - return &dim_; - } - - ~L2Space() {} -}; - -static int -L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - int res = 0; - unsigned char *a = (unsigned char *) pVect1; - unsigned char *b = (unsigned char *) pVect2; - - qty = qty >> 2; - for (size_t i = 0; i < qty; i++) { - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); -} - -static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { - size_t qty = *((size_t*)qty_ptr); - int res = 0; - unsigned char* a = (unsigned char*)pVect1; - unsigned char* b = (unsigned char*)pVect2; - - for (size_t i = 0; i < qty; i++) { - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); -} - -class L2SpaceI : public SpaceInterface { - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - - public: - L2SpaceI(size_t dim) { - if (dim % 4 == 0) { - fstdistfunc_ = L2SqrI4x; - } else { - fstdistfunc_ = L2SqrI; - } - dim_ = dim; - data_size_ = dim * sizeof(unsigned char); - } - - size_t get_data_size() { - return data_size_; - } - - DISTFUNC get_dist_func() { - return fstdistfunc_; - } - - void *get_dist_func_param() { - return &dim_; - } - - ~L2SpaceI() {} -}; -} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/visited_list_pool.h b/gpt4all-chat/hnswlib/visited_list_pool.h deleted file mode 100644 index 2e201ec4..00000000 --- a/gpt4all-chat/hnswlib/visited_list_pool.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace hnswlib { -typedef unsigned short int vl_type; - -class VisitedList { - public: - vl_type curV; - vl_type *mass; - unsigned int numelements; - - VisitedList(int numelements1) { - curV = -1; - numelements = numelements1; - mass = new vl_type[numelements]; - } - - void reset() { - curV++; - if (curV == 0) { - memset(mass, 0, sizeof(vl_type) * numelements); - curV++; - } - } - - ~VisitedList() { delete[] mass; } -}; -/////////////////////////////////////////////////////////// -// -// Class for multi-threaded pool-management of VisitedLists -// -///////////////////////////////////////////////////////// - -class VisitedListPool { - std::deque pool; - std::mutex poolguard; - int numelements; - - public: - VisitedListPool(int initmaxpools, int numelements1) { - numelements = numelements1; - for (int i = 0; i < initmaxpools; i++) - pool.push_front(new VisitedList(numelements)); - } - - VisitedList *getFreeVisitedList() { - VisitedList *rez; - { - std::unique_lock lock(poolguard); - if (pool.size() > 0) { - rez = pool.front(); - pool.pop_front(); - } else { - rez = new VisitedList(numelements); - } - } - rez->reset(); - return rez; - } - - void releaseVisitedList(VisitedList *vl) { - std::unique_lock lock(poolguard); - pool.push_front(vl); - } - - ~VisitedListPool() { - while (pool.size()) { - VisitedList *rez = pool.front(); - pool.pop_front(); - delete rez; - } - } -}; -} // namespace hnswlib diff --git a/gpt4all-chat/icons/alt_logo.svg b/gpt4all-chat/icons/alt_logo.svg new file mode 100644 index 00000000..9ecfd3ac --- /dev/null +++ b/gpt4all-chat/icons/alt_logo.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gpt4all-chat/icons/antenna_1.svg b/gpt4all-chat/icons/antenna_1.svg new file mode 100644 index 00000000..bc82f574 --- /dev/null +++ b/gpt4all-chat/icons/antenna_1.svg @@ -0,0 +1,4 @@ + + + + diff --git a/gpt4all-chat/icons/antenna_2.svg b/gpt4all-chat/icons/antenna_2.svg new file mode 100644 index 00000000..0e025bb6 --- /dev/null +++ b/gpt4all-chat/icons/antenna_2.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/gpt4all-chat/icons/antenna_3.svg b/gpt4all-chat/icons/antenna_3.svg new file mode 100644 index 00000000..3d75bd0d --- /dev/null +++ b/gpt4all-chat/icons/antenna_3.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/gpt4all-chat/icons/changelog.svg b/gpt4all-chat/icons/changelog.svg new file mode 100644 index 00000000..9374a50f --- /dev/null +++ b/gpt4all-chat/icons/changelog.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/chat.svg b/gpt4all-chat/icons/chat.svg new file mode 100644 index 00000000..62a4eb14 --- /dev/null +++ b/gpt4all-chat/icons/chat.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/db.svg b/gpt4all-chat/icons/db.svg index 4b0d1082..fc0816f7 100644 --- a/gpt4all-chat/icons/db.svg +++ b/gpt4all-chat/icons/db.svg @@ -1,5 +1,3 @@ - - + + + diff --git a/gpt4all-chat/icons/discord.svg b/gpt4all-chat/icons/discord.svg new file mode 100644 index 00000000..822821be --- /dev/null +++ b/gpt4all-chat/icons/discord.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/edit.svg b/gpt4all-chat/icons/edit.svg index 9820173b..5a79a50e 100644 --- a/gpt4all-chat/icons/edit.svg +++ b/gpt4all-chat/icons/edit.svg @@ -1,5 +1,3 @@ - - + + + diff --git a/gpt4all-chat/icons/email.svg b/gpt4all-chat/icons/email.svg new file mode 100644 index 00000000..cf757ac7 --- /dev/null +++ b/gpt4all-chat/icons/email.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/file-md.svg b/gpt4all-chat/icons/file-md.svg new file mode 100644 index 00000000..adcb6d04 --- /dev/null +++ b/gpt4all-chat/icons/file-md.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/file-pdf.svg b/gpt4all-chat/icons/file-pdf.svg new file mode 100644 index 00000000..63fc1ae2 --- /dev/null +++ b/gpt4all-chat/icons/file-pdf.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/file-txt.svg b/gpt4all-chat/icons/file-txt.svg new file mode 100644 index 00000000..265ab5e7 --- /dev/null +++ b/gpt4all-chat/icons/file-txt.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/file.svg b/gpt4all-chat/icons/file.svg new file mode 100644 index 00000000..85a75443 --- /dev/null +++ b/gpt4all-chat/icons/file.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/github.svg b/gpt4all-chat/icons/github.svg new file mode 100644 index 00000000..81db0a32 --- /dev/null +++ b/gpt4all-chat/icons/github.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/gpt4all-chat/icons/globe.svg b/gpt4all-chat/icons/globe.svg new file mode 100644 index 00000000..e19e5fbc --- /dev/null +++ b/gpt4all-chat/icons/globe.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/home.svg b/gpt4all-chat/icons/home.svg new file mode 100644 index 00000000..4a682984 --- /dev/null +++ b/gpt4all-chat/icons/home.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/info.svg b/gpt4all-chat/icons/info.svg new file mode 100644 index 00000000..2c207ecb --- /dev/null +++ b/gpt4all-chat/icons/info.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/local-docs.svg b/gpt4all-chat/icons/local-docs.svg new file mode 100644 index 00000000..06031c34 --- /dev/null +++ b/gpt4all-chat/icons/local-docs.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/models.svg b/gpt4all-chat/icons/models.svg new file mode 100644 index 00000000..4e9b5306 --- /dev/null +++ b/gpt4all-chat/icons/models.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/nomic_logo.svg b/gpt4all-chat/icons/nomic_logo.svg new file mode 100644 index 00000000..c3bc1429 --- /dev/null +++ b/gpt4all-chat/icons/nomic_logo.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/gpt4all-chat/icons/notes.svg b/gpt4all-chat/icons/notes.svg new file mode 100644 index 00000000..a5378aa7 --- /dev/null +++ b/gpt4all-chat/icons/notes.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gpt4all-chat/icons/search.svg b/gpt4all-chat/icons/search.svg new file mode 100644 index 00000000..bae556fc --- /dev/null +++ b/gpt4all-chat/icons/search.svg @@ -0,0 +1,6 @@ + + + diff --git a/gpt4all-chat/icons/settings.svg b/gpt4all-chat/icons/settings.svg index 7542ea62..3d885b3f 100644 --- a/gpt4all-chat/icons/settings.svg +++ b/gpt4all-chat/icons/settings.svg @@ -1,46 +1,3 @@ - - - - - - - + + diff --git a/gpt4all-chat/icons/trash.svg b/gpt4all-chat/icons/trash.svg index b7c1a141..fa800005 100644 --- a/gpt4all-chat/icons/trash.svg +++ b/gpt4all-chat/icons/trash.svg @@ -1,5 +1,3 @@ - - + + + diff --git a/gpt4all-chat/icons/twitter.svg b/gpt4all-chat/icons/twitter.svg new file mode 100644 index 00000000..633d226c --- /dev/null +++ b/gpt4all-chat/icons/twitter.svg @@ -0,0 +1,3 @@ + + + diff --git a/gpt4all-chat/icons/you.svg b/gpt4all-chat/icons/you.svg new file mode 100644 index 00000000..4376bef8 --- /dev/null +++ b/gpt4all-chat/icons/you.svg @@ -0,0 +1,41 @@ + + + + + + diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp index 7eb9c4d0..dfbe7ca5 100644 --- a/gpt4all-chat/llm.cpp +++ b/gpt4all-chat/llm.cpp @@ -13,12 +13,14 @@ #include #include -#include - -#ifndef GPT4ALL_OFFLINE_INSTALLER +#ifdef GPT4ALL_OFFLINE_INSTALLER +# include +#else # include "network.h" #endif +using namespace Qt::Literals::StringLiterals; + class MyLLM: public LLM { }; Q_GLOBAL_STATIC(MyLLM, llmInstance) LLM *LLM::globalInstance() @@ -54,11 +56,11 @@ bool LLM::checkForUpdates() const Network::globalInstance()->trackEvent("check_for_updates"); #if defined(Q_OS_LINUX) - QString tool("maintenancetool"); + QString tool = u"maintenancetool"_s; #elif defined(Q_OS_WINDOWS) - QString tool("maintenancetool.exe"); + QString tool = u"maintenancetool.exe"_s; #elif defined(Q_OS_DARWIN) - QString tool("../../../maintenancetool.app/Contents/MacOS/maintenancetool"); + QString tool = u"../../../maintenancetool.app/Contents/MacOS/maintenancetool"_s; #endif QString fileName = QCoreApplication::applicationDirPath() diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index d0670618..0b69e834 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -1,6 +1,7 @@ #include "localdocs.h" #include "database.h" +#include "embllm.h" #include "mysettings.h" #include @@ -22,45 +23,37 @@ LocalDocs::LocalDocs() , m_database(nullptr) { connect(MySettings::globalInstance(), &MySettings::localDocsChunkSizeChanged, this, &LocalDocs::handleChunkSizeChanged); + connect(MySettings::globalInstance(), &MySettings::localDocsFileExtensionsChanged, this, &LocalDocs::handleFileExtensionsChanged); // Create the DB with the chunk size from settings - m_database = new Database(MySettings::globalInstance()->localDocsChunkSize()); + m_database = new Database(MySettings::globalInstance()->localDocsChunkSize(), + MySettings::globalInstance()->localDocsFileExtensions()); connect(this, &LocalDocs::requestStart, m_database, &Database::start, Qt::QueuedConnection); + connect(this, &LocalDocs::requestForceIndexing, m_database, + &Database::forceIndexing, Qt::QueuedConnection); + connect(this, &LocalDocs::forceRebuildFolder, m_database, + &Database::forceRebuildFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestAddFolder, m_database, &Database::addFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestRemoveFolder, m_database, &Database::removeFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestChunkSizeChange, m_database, &Database::changeChunkSize, Qt::QueuedConnection); + connect(this, &LocalDocs::requestFileExtensionsChange, m_database, + &Database::changeFileExtensions, Qt::QueuedConnection); + connect(m_database, &Database::databaseValidChanged, + this, &LocalDocs::databaseValidChanged, Qt::QueuedConnection); // Connections for modifying the model and keeping it updated with the database - connect(m_database, &Database::updateInstalled, - m_localDocsModel, &LocalDocsModel::updateInstalled, Qt::QueuedConnection); - connect(m_database, &Database::updateIndexing, - m_localDocsModel, &LocalDocsModel::updateIndexing, Qt::QueuedConnection); - connect(m_database, &Database::updateError, - m_localDocsModel, &LocalDocsModel::updateError, Qt::QueuedConnection); - connect(m_database, &Database::updateCurrentDocsToIndex, - m_localDocsModel, &LocalDocsModel::updateCurrentDocsToIndex, Qt::QueuedConnection); - connect(m_database, &Database::updateTotalDocsToIndex, - m_localDocsModel, &LocalDocsModel::updateTotalDocsToIndex, Qt::QueuedConnection); - connect(m_database, &Database::subtractCurrentBytesToIndex, - m_localDocsModel, &LocalDocsModel::subtractCurrentBytesToIndex, Qt::QueuedConnection); - connect(m_database, &Database::updateCurrentBytesToIndex, - m_localDocsModel, &LocalDocsModel::updateCurrentBytesToIndex, Qt::QueuedConnection); - connect(m_database, &Database::updateTotalBytesToIndex, - m_localDocsModel, &LocalDocsModel::updateTotalBytesToIndex, Qt::QueuedConnection); - connect(m_database, &Database::updateCurrentEmbeddingsToIndex, - m_localDocsModel, &LocalDocsModel::updateCurrentEmbeddingsToIndex, Qt::QueuedConnection); - connect(m_database, &Database::updateTotalEmbeddingsToIndex, - m_localDocsModel, &LocalDocsModel::updateTotalEmbeddingsToIndex, Qt::QueuedConnection); - connect(m_database, &Database::addCollectionItem, + connect(m_database, &Database::requestUpdateGuiForCollectionItem, + m_localDocsModel, &LocalDocsModel::updateCollectionItem, Qt::QueuedConnection); + connect(m_database, &Database::requestAddGuiCollectionItem, m_localDocsModel, &LocalDocsModel::addCollectionItem, Qt::QueuedConnection); - connect(m_database, &Database::removeFolderById, + connect(m_database, &Database::requestRemoveGuiFolderById, m_localDocsModel, &LocalDocsModel::removeFolderById, Qt::QueuedConnection); - connect(m_database, &Database::collectionListUpdated, + connect(m_database, &Database::requestGuiCollectionListUpdated, m_localDocsModel, &LocalDocsModel::collectionListUpdated, Qt::QueuedConnection); connect(qGuiApp, &QCoreApplication::aboutToQuit, this, &LocalDocs::aboutToQuit); @@ -76,16 +69,38 @@ void LocalDocs::addFolder(const QString &collection, const QString &path) { const QUrl url(path); const QString localPath = url.isLocalFile() ? url.toLocalFile() : path; - emit requestAddFolder(collection, localPath, false); + + const QString embedding_model = EmbeddingLLM::model(); + if (embedding_model.isEmpty()) { + qWarning() << "ERROR: We have no embedding model"; + return; + } + + emit requestAddFolder(collection, localPath, embedding_model); } void LocalDocs::removeFolder(const QString &collection, const QString &path) { - m_localDocsModel->removeCollectionPath(collection, path); emit requestRemoveFolder(collection, path); } +void LocalDocs::forceIndexing(const QString &collection) +{ + const QString embedding_model = EmbeddingLLM::model(); + if (embedding_model.isEmpty()) { + qWarning() << "ERROR: We have no embedding model"; + return; + } + + emit requestForceIndexing(collection, embedding_model); +} + void LocalDocs::handleChunkSizeChanged() { emit requestChunkSizeChange(MySettings::globalInstance()->localDocsChunkSize()); } + +void LocalDocs::handleFileExtensionsChanged() +{ + emit requestFileExtensionsChange(MySettings::globalInstance()->localDocsFileExtensions()); +} diff --git a/gpt4all-chat/localdocs.h b/gpt4all-chat/localdocs.h index ad694524..4a017fe2 100644 --- a/gpt4all-chat/localdocs.h +++ b/gpt4all-chat/localdocs.h @@ -1,16 +1,17 @@ #ifndef LOCALDOCS_H #define LOCALDOCS_H +#include "database.h" #include "localdocsmodel.h" // IWYU pragma: keep #include #include - -class Database; +#include class LocalDocs : public QObject { Q_OBJECT + Q_PROPERTY(bool databaseValid READ databaseValid NOTIFY databaseValidChanged) Q_PROPERTY(LocalDocsModel *localDocsModel READ localDocsModel NOTIFY localDocsModelChanged) public: @@ -20,19 +21,27 @@ public: Q_INVOKABLE void addFolder(const QString &collection, const QString &path); Q_INVOKABLE void removeFolder(const QString &collection, const QString &path); + Q_INVOKABLE void forceIndexing(const QString &collection); Database *database() const { return m_database; } + bool databaseValid() const { return m_database->isValid(); } + public Q_SLOTS: void handleChunkSizeChanged(); + void handleFileExtensionsChanged(); void aboutToQuit(); Q_SIGNALS: void requestStart(); - void requestAddFolder(const QString &collection, const QString &path, bool fromDb); + void requestForceIndexing(const QString &collection, const QString &embedding_model); + void forceRebuildFolder(const QString &path); + void requestAddFolder(const QString &collection, const QString &path, const QString &embedding_model); void requestRemoveFolder(const QString &collection, const QString &path); void requestChunkSizeChange(int chunkSize); + void requestFileExtensionsChange(const QStringList &extensions); void localDocsModelChanged(); + void databaseValidChanged(); private: LocalDocsModel *m_localDocsModel; diff --git a/gpt4all-chat/localdocsmodel.cpp b/gpt4all-chat/localdocsmodel.cpp index 9ce92e60..fba4c4ff 100644 --- a/gpt4all-chat/localdocsmodel.cpp +++ b/gpt4all-chat/localdocsmodel.cpp @@ -3,7 +3,9 @@ #include "localdocs.h" #include "network.h" +#include #include +#include #include #include @@ -12,6 +14,13 @@ LocalDocsCollectionsModel::LocalDocsCollectionsModel(QObject *parent) : QSortFilterProxyModel(parent) { setSourceModel(LocalDocs::globalInstance()->localDocsModel()); + + connect(LocalDocs::globalInstance()->localDocsModel(), + &LocalDocsModel::updatingChanged, this, &LocalDocsCollectionsModel::maybeTriggerUpdatingCountChanged); + connect(this, &LocalDocsCollectionsModel::rowsInserted, this, &LocalDocsCollectionsModel::countChanged); + connect(this, &LocalDocsCollectionsModel::rowsRemoved, this, &LocalDocsCollectionsModel::countChanged); + connect(this, &LocalDocsCollectionsModel::modelReset, this, &LocalDocsCollectionsModel::countChanged); + connect(this, &LocalDocsCollectionsModel::layoutChanged, this, &LocalDocsCollectionsModel::countChanged); } bool LocalDocsCollectionsModel::filterAcceptsRow(int sourceRow, @@ -26,11 +35,39 @@ void LocalDocsCollectionsModel::setCollections(const QList &collections { m_collections = collections; invalidateFilter(); + maybeTriggerUpdatingCountChanged(); +} + +int LocalDocsCollectionsModel::updatingCount() const +{ + return m_updatingCount; +} + +void LocalDocsCollectionsModel::maybeTriggerUpdatingCountChanged() +{ + int updatingCount = 0; + for (int row = 0; row < sourceModel()->rowCount(); ++row) { + QModelIndex index = sourceModel()->index(row, 0); + const QString collection = sourceModel()->data(index, LocalDocsModel::CollectionRole).toString(); + if (!m_collections.contains(collection)) + continue; + bool updating = sourceModel()->data(index, LocalDocsModel::UpdatingRole).toBool(); + if (updating) + ++updatingCount; + } + if (updatingCount != m_updatingCount) { + m_updatingCount = updatingCount; + emit updatingCountChanged(); + } } LocalDocsModel::LocalDocsModel(QObject *parent) : QAbstractListModel(parent) { + connect(this, &LocalDocsModel::rowsInserted, this, &LocalDocsModel::countChanged); + connect(this, &LocalDocsModel::rowsRemoved, this, &LocalDocsModel::countChanged); + connect(this, &LocalDocsModel::modelReset, this, &LocalDocsModel::countChanged); + connect(this, &LocalDocsModel::layoutChanged, this, &LocalDocsModel::countChanged); } int LocalDocsModel::rowCount(const QModelIndex &parent) const @@ -56,6 +93,8 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const return item.indexing; case ErrorRole: return item.error; + case ForceIndexingRole: + return item.forceIndexing; case CurrentDocsToIndexRole: return item.currentDocsToIndex; case TotalDocsToIndexRole: @@ -68,6 +107,22 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const return quint64(item.currentEmbeddingsToIndex); case TotalEmbeddingsToIndexRole: return quint64(item.totalEmbeddingsToIndex); + case TotalDocsRole: + return quint64(item.totalDocs); + case TotalWordsRole: + return quint64(item.totalWords); + case TotalTokensRole: + return quint64(item.totalTokens); + case StartUpdateRole: + return item.startUpdate; + case LastUpdateRole: + return item.lastUpdate; + case FileCurrentlyProcessingRole: + return item.fileCurrentlyProcessing; + case EmbeddingModelRole: + return item.embeddingModel; + case UpdatingRole: + return item.indexing || item.currentEmbeddingsToIndex != 0; } return QVariant(); @@ -81,103 +136,94 @@ QHash LocalDocsModel::roleNames() const roles[InstalledRole] = "installed"; roles[IndexingRole] = "indexing"; roles[ErrorRole] = "error"; + roles[ForceIndexingRole] = "forceIndexing"; roles[CurrentDocsToIndexRole] = "currentDocsToIndex"; roles[TotalDocsToIndexRole] = "totalDocsToIndex"; roles[CurrentBytesToIndexRole] = "currentBytesToIndex"; roles[TotalBytesToIndexRole] = "totalBytesToIndex"; roles[CurrentEmbeddingsToIndexRole] = "currentEmbeddingsToIndex"; roles[TotalEmbeddingsToIndexRole] = "totalEmbeddingsToIndex"; + roles[TotalDocsRole] = "totalDocs"; + roles[TotalWordsRole] = "totalWords"; + roles[TotalTokensRole] = "totalTokens"; + roles[StartUpdateRole] = "startUpdate"; + roles[LastUpdateRole] = "lastUpdate"; + roles[FileCurrentlyProcessingRole] = "fileCurrentlyProcessing"; + roles[EmbeddingModelRole] = "embeddingModel"; + roles[UpdatingRole] = "updating"; return roles; } -template -void LocalDocsModel::updateField(int folder_id, T value, - const std::function& updater, - const QVector& roles) +void LocalDocsModel::updateCollectionItem(const CollectionItem &item) { for (int i = 0; i < m_collectionList.size(); ++i) { - if (m_collectionList.at(i).folder_id != folder_id) + CollectionItem &stored = m_collectionList[i]; + if (stored.folder_id != item.folder_id) continue; - updater(m_collectionList[i], value); - emit dataChanged(this->index(i), this->index(i), roles); + QVector changed; + if (stored.folder_path != item.folder_path) + changed.append(FolderPathRole); + if (stored.installed != item.installed) + changed.append(InstalledRole); + if (stored.indexing != item.indexing) { + changed.append(IndexingRole); + changed.append(UpdatingRole); + } + if (stored.error != item.error) + changed.append(ErrorRole); + if (stored.forceIndexing != item.forceIndexing) + changed.append(ForceIndexingRole); + if (stored.currentDocsToIndex != item.currentDocsToIndex) + changed.append(CurrentDocsToIndexRole); + if (stored.totalDocsToIndex != item.totalDocsToIndex) + changed.append(TotalDocsToIndexRole); + if (stored.currentBytesToIndex != item.currentBytesToIndex) + changed.append(CurrentBytesToIndexRole); + if (stored.totalBytesToIndex != item.totalBytesToIndex) + changed.append(TotalBytesToIndexRole); + if (stored.currentEmbeddingsToIndex != item.currentEmbeddingsToIndex) { + changed.append(CurrentEmbeddingsToIndexRole); + changed.append(UpdatingRole); + } + if (stored.totalEmbeddingsToIndex != item.totalEmbeddingsToIndex) + changed.append(TotalEmbeddingsToIndexRole); + if (stored.totalDocs != item.totalDocs) + changed.append(TotalDocsRole); + if (stored.totalWords != item.totalWords) + changed.append(TotalWordsRole); + if (stored.totalTokens != item.totalTokens) + changed.append(TotalTokensRole); + if (stored.startUpdate != item.startUpdate) + changed.append(StartUpdateRole); + if (stored.lastUpdate != item.lastUpdate) + changed.append(LastUpdateRole); + if (stored.fileCurrentlyProcessing != item.fileCurrentlyProcessing) + changed.append(FileCurrentlyProcessingRole); + if (stored.embeddingModel != item.embeddingModel) + changed.append(EmbeddingModelRole); + + // preserve collection name as we ignore it for matching + QString collection = stored.collection; + stored = item; + stored.collection = collection; + + emit dataChanged(this->index(i), this->index(i), changed); + + if (changed.contains(UpdatingRole)) + emit updatingChanged(item.collection); } } -void LocalDocsModel::updateInstalled(int folder_id, bool b) -{ - updateField(folder_id, b, - [](CollectionItem& item, bool val) { item.installed = val; }, {InstalledRole}); -} - -void LocalDocsModel::updateIndexing(int folder_id, bool b) -{ - updateField(folder_id, b, - [](CollectionItem& item, bool val) { item.indexing = val; }, {IndexingRole}); -} - -void LocalDocsModel::updateError(int folder_id, const QString &error) -{ - updateField(folder_id, error, - [](CollectionItem& item, QString val) { item.error = val; }, {ErrorRole}); -} - -void LocalDocsModel::updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex) -{ - updateField(folder_id, currentDocsToIndex, - [](CollectionItem& item, size_t val) { item.currentDocsToIndex = val; }, {CurrentDocsToIndexRole}); -} - -void LocalDocsModel::updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex) -{ - updateField(folder_id, totalDocsToIndex, - [](CollectionItem& item, size_t val) { item.totalDocsToIndex = val; }, {TotalDocsToIndexRole}); -} - -void LocalDocsModel::subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes) -{ - updateField(folder_id, subtractedBytes, - [](CollectionItem& item, size_t val) { item.currentBytesToIndex -= val; }, {CurrentBytesToIndexRole}); -} - -void LocalDocsModel::updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex) -{ - updateField(folder_id, currentBytesToIndex, - [](CollectionItem& item, size_t val) { item.currentBytesToIndex = val; }, {CurrentBytesToIndexRole}); -} - -void LocalDocsModel::updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex) -{ - updateField(folder_id, totalBytesToIndex, - [](CollectionItem& item, size_t val) { item.totalBytesToIndex = val; }, {TotalBytesToIndexRole}); -} - -void LocalDocsModel::updateCurrentEmbeddingsToIndex(int folder_id, size_t currentEmbeddingsToIndex) -{ - updateField(folder_id, currentEmbeddingsToIndex, - [](CollectionItem& item, size_t val) { item.currentEmbeddingsToIndex += val; }, {CurrentEmbeddingsToIndexRole}); -} - -void LocalDocsModel::updateTotalEmbeddingsToIndex(int folder_id, size_t totalEmbeddingsToIndex) -{ - updateField(folder_id, totalEmbeddingsToIndex, - [](CollectionItem& item, size_t val) { item.totalEmbeddingsToIndex += val; }, {TotalEmbeddingsToIndexRole}); -} - -void LocalDocsModel::addCollectionItem(const CollectionItem &item, bool fromDb) +void LocalDocsModel::addCollectionItem(const CollectionItem &item) { beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size()); m_collectionList.append(item); endInsertRows(); - - if (!fromDb) { - Network::globalInstance()->trackEvent("doc_collection_add", { - {"collection_count", m_collectionList.count()}, - }); - } } -void LocalDocsModel::removeCollectionIf(std::function const &predicate) { +void LocalDocsModel::removeCollectionIf(std::function const &predicate) +{ for (int i = 0; i < m_collectionList.size();) { if (predicate(m_collectionList.at(i))) { beginRemoveRows(QModelIndex(), i, i); @@ -193,9 +239,11 @@ void LocalDocsModel::removeCollectionIf(std::function cons } } -void LocalDocsModel::removeFolderById(int folder_id) +void LocalDocsModel::removeFolderById(const QString &collection, int folder_id) { - removeCollectionIf([folder_id](const auto &c) { return c.folder_id == folder_id; }); + removeCollectionIf([collection, folder_id](const auto &c) { + return c.collection == collection && c.folder_id == folder_id; + }); } void LocalDocsModel::removeCollectionPath(const QString &name, const QString &path) diff --git a/gpt4all-chat/localdocsmodel.h b/gpt4all-chat/localdocsmodel.h index ed9fcff1..82b5f882 100644 --- a/gpt4all-chat/localdocsmodel.h +++ b/gpt4all-chat/localdocsmodel.h @@ -7,36 +7,46 @@ #include #include #include -#include #include #include #include #include -#include #include -#include #include class LocalDocsCollectionsModel : public QSortFilterProxyModel { Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged) public: explicit LocalDocsCollectionsModel(QObject *parent); public Q_SLOTS: + int count() const { return rowCount(); } void setCollections(const QList &collections); + int updatingCount() const; + +Q_SIGNALS: + void countChanged(); + void updatingCountChanged(); + +private Q_SLOT: + void maybeTriggerUpdatingCountChanged(); protected: bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; private: QList m_collections; + int m_updatingCount = 0; }; class LocalDocsModel : public QAbstractListModel { Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) public: enum Roles { @@ -45,43 +55,42 @@ public: InstalledRole, IndexingRole, ErrorRole, + ForceIndexingRole, CurrentDocsToIndexRole, TotalDocsToIndexRole, CurrentBytesToIndexRole, TotalBytesToIndexRole, CurrentEmbeddingsToIndexRole, - TotalEmbeddingsToIndexRole + TotalEmbeddingsToIndexRole, + TotalDocsRole, + TotalWordsRole, + TotalTokensRole, + StartUpdateRole, + LastUpdateRole, + FileCurrentlyProcessingRole, + EmbeddingModelRole, + UpdatingRole }; explicit LocalDocsModel(QObject *parent = nullptr); int rowCount(const QModelIndex & = QModelIndex()) const override; QVariant data(const QModelIndex &index, int role) const override; QHash roleNames() const override; + int count() const { return rowCount(); } public Q_SLOTS: - void updateInstalled(int folder_id, bool b); - void updateIndexing(int folder_id, bool b); - void updateError(int folder_id, const QString &error); - void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); - void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); - void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); - void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); - void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); - void updateCurrentEmbeddingsToIndex(int folder_id, size_t currentBytesToIndex); - void updateTotalEmbeddingsToIndex(int folder_id, size_t totalBytesToIndex); - void addCollectionItem(const CollectionItem &item, bool fromDb); - void removeFolderById(int folder_id); + void updateCollectionItem(const CollectionItem&); + void addCollectionItem(const CollectionItem &item); + void removeFolderById(const QString &collection, int folder_id); void removeCollectionPath(const QString &name, const QString &path); void collectionListUpdated(const QList &collectionList); -private: - template - void updateField(int folder_id, T value, - const std::function& updater, - const QVector& roles); - void removeCollectionIf(std::function const &predicate); +Q_SIGNALS: + void countChanged(); + void updatingChanged(const QString &collection); private: + void removeCollectionIf(std::function const &predicate); QList m_collectionList; }; diff --git a/gpt4all-chat/logger.cpp b/gpt4all-chat/logger.cpp index fc9fea89..6c730757 100644 --- a/gpt4all-chat/logger.cpp +++ b/gpt4all-chat/logger.cpp @@ -10,6 +10,8 @@ #include #include +using namespace Qt::Literals::StringLiterals; + class MyLogger: public Logger { }; Q_GLOBAL_STATIC(MyLogger, loggerInstance) Logger *Logger::globalInstance() @@ -61,7 +63,7 @@ void Logger::messageHandler(QtMsgType type, const QMessageLogContext &, const QS // Get time and date auto timestamp = QDateTime::currentDateTime().toString(); // Write message - const std::string out = QString("[%1] (%2): %4\n").arg(typeString, timestamp, msg).toStdString(); + const std::string out = u"[%1] (%2): %3\n"_s.arg(typeString, timestamp, msg).toStdString(); logger->m_file.write(out.c_str()); logger->m_file.flush(); std::cerr << out; diff --git a/gpt4all-chat/main.cpp b/gpt4all-chat/main.cpp index 1068162f..b7a8e9ec 100644 --- a/gpt4all-chat/main.cpp +++ b/gpt4all-chat/main.cpp @@ -15,9 +15,10 @@ #include #include #include +#include #include -#include #include +#include int main(int argc, char *argv[]) { @@ -25,6 +26,7 @@ int main(int argc, char *argv[]) QCoreApplication::setOrganizationDomain("gpt4all.io"); QCoreApplication::setApplicationName("GPT4All"); QCoreApplication::setApplicationVersion(APP_VERSION); + QSettings::setDefaultFormat(QSettings::IniFormat); Logger::globalInstance(); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index c9d0ad67..6d7fdafb 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -10,14 +10,15 @@ import download import modellist import network import gpt4all +import localdocs import mysettings Window { id: window width: 1920 height: 1080 - minimumWidth: 720 - minimumHeight: 480 + minimumWidth: 1280 + minimumHeight: 720 visible: true title: qsTr("GPT4All v") + Qt.application.version @@ -32,6 +33,128 @@ Window { id: theme } + Item { + Accessible.role: Accessible.Window + Accessible.name: title + } + + // Startup code + Component.onCompleted: { + startupDialogs(); + } + + Component.onDestruction: { + Network.trackEvent("session_end") + } + + Connections { + target: firstStartDialog + function onClosed() { + startupDialogs(); + } + } + + Connections { + target: Download + function onHasNewerReleaseChanged() { + startupDialogs(); + } + } + + property bool hasCheckedFirstStart: false + property bool hasShownSettingsAccess: false + + function startupDialogs() { + if (!LLM.compatHardware()) { + Network.trackEvent("noncompat_hardware") + errorCompatHardware.open(); + return; + } + + // check if we have access to settings and if not show an error + if (!hasShownSettingsAccess && !LLM.hasSettingsAccess()) { + errorSettingsAccess.open(); + hasShownSettingsAccess = true; + return; + } + + // check for first time start of this version + if (!hasCheckedFirstStart) { + if (Download.isFirstStart(/*writeVersion*/ true)) { + firstStartDialog.open(); + return; + } + + // send startup or opt-out now that the user has made their choice + Network.sendStartup() + // start localdocs + LocalDocs.requestStart() + + hasCheckedFirstStart = true + } + + // check for new version + if (Download.hasNewerRelease && !firstStartDialog.opened) { + newVersionDialog.open(); + return; + } + } + + PopupDialog { + id: errorCompatHardware + anchors.centerIn: parent + shouldTimeOut: false + shouldShowBusy: false + closePolicy: Popup.NoAutoClose + modal: true + text: qsTr("

Encountered an error starting up:


") + + qsTr("\"Incompatible hardware detected.\"") + + qsTr("

Unfortunately, your CPU does not meet the minimal requirements to run ") + + qsTr("this program. In particular, it does not support AVX intrinsics which this ") + + qsTr("program requires to successfully run a modern large language model. ") + + qsTr("The only solution at this time is to upgrade your hardware to a more modern CPU.") + + qsTr("

See here for more information: ") + + qsTr("https://en.wikipedia.org/wiki/Advanced_Vector_Extensions") + } + + PopupDialog { + id: errorSettingsAccess + anchors.centerIn: parent + shouldTimeOut: false + shouldShowBusy: false + modal: true + text: qsTr("

Encountered an error starting up:


") + + qsTr("\"Inability to access settings file.\"") + + qsTr("

Unfortunately, something is preventing the program from accessing ") + + qsTr("the settings file. This could be caused by incorrect permissions in the local ") + + qsTr("app config directory where the settings file is located. ") + + qsTr("Check out our discord channel for help.") + } + + StartupDialog { + id: firstStartDialog + anchors.centerIn: parent + } + + NewVersionDialog { + id: newVersionDialog + anchors.centerIn: parent + } + + Connections { + target: Network + function onHealthCheckFailed(code) { + healthCheckFailed.open(); + } + } + + PopupDialog { + id: healthCheckFailed + anchors.centerIn: parent + text: qsTr("Connection to datalake failed.") + font.pixelSize: theme.fontSizeLarge + } + property bool hasSaved: false PopupDialog { @@ -43,6 +166,18 @@ Window { font.pixelSize: theme.fontSizeLarge } + NetworkDialog { + id: networkDialog + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Network dialog") + Accessible.description: qsTr("opt-in to share feedback/conversations") + } + } + onClosing: function(close) { if (window.hasSaved) return; @@ -61,9 +196,440 @@ Window { } } - color: theme.black + color: theme.viewBarBackground - ChatView { - anchors.fill: parent + Rectangle { + id: viewBar + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.left: parent.left + width: MySettings.fontSize === "Small" ? 86 : 100 + color: theme.viewBarBackground + + ColumnLayout { + id: viewsLayout + anchors.top: parent.top + anchors.topMargin: 30 + anchors.horizontalCenter: parent.horizontalCenter + Layout.margins: 0 + spacing: 18 + + MyToolButton { + id: homeButton + backgroundColor: toggled ? theme.iconBackgroundViewBarHovered : theme.iconBackgroundViewBar + backgroundColorHovered: theme.iconBackgroundViewBarHovered + Layout.preferredWidth: 48 + Layout.preferredHeight: 48 + Layout.alignment: Qt.AlignCenter + toggledWidth: 0 + toggled: homeView.isShown() + toggledColor: theme.iconBackgroundViewBarToggled + imageWidth: 34 + imageHeight: 34 + source: "qrc:/gpt4all/icons/home.svg" + Accessible.name: qsTr("Home view") + Accessible.description: qsTr("Home view of application") + onClicked: { + homeView.show() + } + } + + Text { + Layout.topMargin: -20 + text: qsTr("Home") + font.pixelSize: theme.fontSizeLargeCapped + font.bold: true + color: homeButton.hovered ? homeButton.backgroundColorHovered : homeButton.backgroundColor + Layout.preferredWidth: 48 + horizontalAlignment: Text.AlignHCenter + TapHandler { + onTapped: function(eventPoint, button) { + homeView.show() + } + } + } + + MyToolButton { + id: chatButton + backgroundColor: toggled ? theme.iconBackgroundViewBarHovered : theme.iconBackgroundViewBar + backgroundColorHovered: theme.iconBackgroundViewBarHovered + Layout.preferredWidth: 48 + Layout.preferredHeight: 48 + Layout.alignment: Qt.AlignCenter + toggledWidth: 0 + toggled: chatView.isShown() + toggledColor: theme.iconBackgroundViewBarToggled + imageWidth: 34 + imageHeight: 34 + source: "qrc:/gpt4all/icons/chat.svg" + Accessible.name: qsTr("Chat view") + Accessible.description: qsTr("Chat view to interact with models") + onClicked: { + chatView.show() + } + } + + Text { + Layout.topMargin: -20 + text: qsTr("Chats") + font.pixelSize: theme.fontSizeLargeCapped + font.bold: true + color: chatButton.hovered ? chatButton.backgroundColorHovered : chatButton.backgroundColor + Layout.preferredWidth: 48 + horizontalAlignment: Text.AlignHCenter + TapHandler { + onTapped: function(eventPoint, button) { + chatView.show() + } + } + } + + MyToolButton { + id: modelsButton + backgroundColor: toggled ? theme.iconBackgroundViewBarHovered : theme.iconBackgroundViewBar + backgroundColorHovered: theme.iconBackgroundViewBarHovered + Layout.preferredWidth: 48 + Layout.preferredHeight: 48 + toggledWidth: 0 + toggled: modelsView.isShown() + toggledColor: theme.iconBackgroundViewBarToggled + imageWidth: 34 + imageHeight: 34 + source: "qrc:/gpt4all/icons/models.svg" + Accessible.name: qsTr("Models") + Accessible.description: qsTr("Models view for installed models") + onClicked: { + modelsView.show() + } + } + + Text { + Layout.topMargin: -20 + text: qsTr("Models") + font.pixelSize: theme.fontSizeLargeCapped + font.bold: true + color: modelsButton.hovered ? modelsButton.backgroundColorHovered : modelsButton.backgroundColor + Layout.preferredWidth: 48 + horizontalAlignment: Text.AlignHCenter + TapHandler { + onTapped: function(eventPoint, button) { + modelsView.show() + } + } + } + + MyToolButton { + id: localdocsButton + backgroundColor: toggled ? theme.iconBackgroundViewBarHovered : theme.iconBackgroundViewBar + backgroundColorHovered: theme.iconBackgroundViewBarHovered + Layout.preferredWidth: 48 + Layout.preferredHeight: 48 + toggledWidth: 0 + toggledColor: theme.iconBackgroundViewBarToggled + toggled: localDocsView.isShown() + imageWidth: 34 + imageHeight: 34 + source: "qrc:/gpt4all/icons/db.svg" + Accessible.name: qsTr("LocalDocs") + Accessible.description: qsTr("LocalDocs view to configure and use local docs") + onClicked: { + localDocsView.show() + } + } + + Text { + Layout.topMargin: -20 + text: qsTr("LocalDocs") + font.pixelSize: theme.fontSizeLargeCapped + font.bold: true + color: localdocsButton.hovered ? localdocsButton.backgroundColorHovered : localdocsButton.backgroundColor + Layout.preferredWidth: 48 + horizontalAlignment: Text.AlignHCenter + TapHandler { + onTapped: function(eventPoint, button) { + localDocsView.show() + } + } + } + + MyToolButton { + id: settingsButton + backgroundColor: toggled ? theme.iconBackgroundViewBarHovered : theme.iconBackgroundViewBar + backgroundColorHovered: theme.iconBackgroundViewBarHovered + Layout.preferredWidth: 48 + Layout.preferredHeight: 48 + toggledWidth: 0 + toggledColor: theme.iconBackgroundViewBarToggled + toggled: settingsView.isShown() + imageWidth: 34 + imageHeight: 34 + source: "qrc:/gpt4all/icons/settings.svg" + Accessible.name: qsTr("Settings") + Accessible.description: qsTr("Settings view for application configuration") + onClicked: { + settingsView.show(0 /*pageToDisplay*/) + } + } + + Text { + Layout.topMargin: -20 + text: qsTr("Settings") + font.pixelSize: theme.fontSizeLargeCapped + font.bold: true + color: settingsButton.hovered ? settingsButton.backgroundColorHovered : settingsButton.backgroundColor + Layout.preferredWidth: 48 + horizontalAlignment: Text.AlignHCenter + TapHandler { + onTapped: function(eventPoint, button) { + settingsView.show(0 /*pageToDisplay*/) + } + } + } + } + + ColumnLayout { + id: buttonsLayout + anchors.bottom: parent.bottom + anchors.margins: 0 + anchors.bottomMargin: 25 + anchors.horizontalCenter: parent.horizontalCenter + Layout.margins: 0 + spacing: 22 + + Rectangle { + Layout.alignment: Qt.AlignCenter + Layout.preferredWidth: image.width + Layout.preferredHeight: image.height + color: "transparent" + + Image { + id: image + anchors.centerIn: parent + sourceSize: Qt.size(60, 40) + fillMode: Image.PreserveAspectFit + mipmap: true + visible: false + source: "qrc:/gpt4all/icons/nomic_logo.svg" + } + + ColorOverlay { + anchors.fill: image + source: image + color: image.hovered ? theme.mutedDarkTextColorHovered : theme.mutedDarkTextColor + TapHandler { + onTapped: function(eventPoint, button) { + Qt.openUrlExternally("https://nomic.ai") + } + } + } + } + } + } + + Rectangle { + id: roundedFrame + z: 299 + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.left: viewBar.right + anchors.right: parent.right + anchors.topMargin: 15 + anchors.bottomMargin: 15 + anchors.rightMargin: 15 + radius: 15 + border.width: 1 + border.color: theme.dividerColor + color: "transparent" + clip: true + } + + RectangularGlow { + id: effect + anchors.fill: roundedFrame + glowRadius: 15 + spread: 0 + color: theme.dividerColor + cornerRadius: 10 + opacity: 0.5 + } + + StackLayout { + id: stackLayout + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.left: viewBar.right + anchors.right: parent.right + anchors.topMargin: 15 + anchors.bottomMargin: 15 + anchors.rightMargin: 15 + + layer.enabled: true + layer.effect: OpacityMask { + maskSource: Rectangle { + width: roundedFrame.width + height: roundedFrame.height + radius: 15 + } + } + + HomeView { + id: homeView + Layout.fillWidth: true + Layout.fillHeight: true + shouldShowFirstStart: !hasCheckedFirstStart + + function show() { + stackLayout.currentIndex = 0; + } + + function isShown() { + return stackLayout.currentIndex === 0 + } + + Connections { + target: homeView + function onChatViewRequested() { + chatView.show(); + } + function onLocalDocsViewRequested() { + localDocsView.show(); + } + function onAddModelViewRequested() { + addModelView.show(); + } + function onSettingsViewRequested(page) { + settingsView.show(page); + } + } + } + + ChatView { + id: chatView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { + stackLayout.currentIndex = 1; + } + + function isShown() { + return stackLayout.currentIndex === 1 + } + + Connections { + target: chatView + function onAddCollectionViewRequested() { + addCollectionView.show(); + } + function onAddModelViewRequested() { + addModelView.show(); + } + } + } + + ModelsView { + id: modelsView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { + stackLayout.currentIndex = 2; + // FIXME This expanded code should be removed and we should be changing the names of + // the classes here in ModelList for the proxy/filter models + ModelList.downloadableModels.expanded = true + } + + function isShown() { + return stackLayout.currentIndex === 2 + } + + Item { + Accessible.name: qsTr("Installed models") + Accessible.description: qsTr("View of installed models") + } + + Connections { + target: modelsView + function onAddModelViewRequested() { + addModelView.show(); + } + } + } + + LocalDocsView { + id: localDocsView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { + stackLayout.currentIndex = 3; + } + + function isShown() { + return stackLayout.currentIndex === 3 + } + + Connections { + target: localDocsView + function onAddCollectionViewRequested() { + addCollectionView.show(); + } + } + } + + SettingsView { + id: settingsView + Layout.fillWidth: true + Layout.fillHeight: true + + function show(page) { + settingsView.pageToDisplay = page; + stackLayout.currentIndex = 4; + } + + function isShown() { + return stackLayout.currentIndex === 4 + } + } + + AddCollectionView { + id: addCollectionView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { + stackLayout.currentIndex = 5; + } + function isShown() { + return stackLayout.currentIndex === 5 + } + + Connections { + target: addCollectionView + function onLocalDocsViewRequested() { + localDocsView.show(); + } + } + } + + AddModelView { + id: addModelView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { + stackLayout.currentIndex = 6; + } + function isShown() { + return stackLayout.currentIndex === 6 + } + + Connections { + target: addModelView + function onModelsViewRequested() { + modelsView.show(); + } + } + } } } diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 2fca1331..359f29a4 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -26,10 +26,11 @@ #include #include #include +#include #include #include -#include #include +#include #include #include @@ -38,12 +39,11 @@ #include #include +using namespace Qt::Literals::StringLiterals; + //#define USE_LOCAL_MODELSJSON -const char * const KNOWN_EMBEDDING_MODELS[] { - "all-MiniLM-L6-v2.gguf2.f16.gguf", - "gpt4all-nomic-embed-text-v1.rmodel", -}; +static const QStringList FILENAME_BLACKLIST { u"gpt4all-nomic-embed-text-v1.rmodel"_s }; QString ModelInfo::id() const { @@ -339,56 +339,32 @@ bool ModelInfo::shouldSaveMetadata() const return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/); } -EmbeddingModels::EmbeddingModels(QObject *parent, bool requireInstalled) - : QSortFilterProxyModel(parent) +QVariantMap ModelInfo::getFields() const { - m_requireInstalled = requireInstalled; - - connect(this, &EmbeddingModels::rowsInserted, this, &EmbeddingModels::countChanged); - connect(this, &EmbeddingModels::rowsRemoved, this, &EmbeddingModels::countChanged); - connect(this, &EmbeddingModels::modelReset, this, &EmbeddingModels::countChanged); - connect(this, &EmbeddingModels::layoutChanged, this, &EmbeddingModels::countChanged); -} - -bool EmbeddingModels::filterAcceptsRow(int sourceRow, - const QModelIndex &sourceParent) const -{ - QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); - bool isEmbeddingModel = sourceModel()->data(index, ModelList::IsEmbeddingModelRole).toBool(); - bool installed = sourceModel()->data(index, ModelList::InstalledRole).toBool(); - QString filename = sourceModel()->data(index, ModelList::FilenameRole).toString(); - auto &known = KNOWN_EMBEDDING_MODELS; - if (std::find(known, std::end(known), filename.toStdString()) == std::end(known)) - return false; // we are currently not prepared to support other embedding models - - return isEmbeddingModel && (!m_requireInstalled || installed); -} - -int EmbeddingModels::defaultModelIndex() const -{ - auto *sourceListModel = qobject_cast(sourceModel()); - if (!sourceListModel) return -1; - - int rows = sourceListModel->rowCount(); - for (int i = 0; i < rows; ++i) { - if (filterAcceptsRow(i, sourceListModel->index(i, 0).parent())) - return i; - } - - return -1; -} - -ModelInfo EmbeddingModels::defaultModelInfo() const -{ - auto *sourceListModel = qobject_cast(sourceModel()); - if (!sourceListModel) return ModelInfo(); - - int i = defaultModelIndex(); - if (i < 0) return ModelInfo(); - - QModelIndex sourceIndex = sourceListModel->index(i, 0); - auto id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString(); - return sourceListModel->modelInfo(id); + return { + { "filename", m_filename }, + { "description", m_description }, + { "url", m_url }, + { "quant", m_quant }, + { "type", m_type }, + { "isClone", m_isClone }, + { "isDiscovered", m_isDiscovered }, + { "likes", m_likes }, + { "downloads", m_downloads }, + { "recency", m_recency }, + { "temperature", m_temperature }, + { "topP", m_topP }, + { "minP", m_minP }, + { "topK", m_topK }, + { "maxLength", m_maxLength }, + { "promptBatchSize", m_promptBatchSize }, + { "contextLength", m_contextLength }, + { "gpuLayers", m_gpuLayers }, + { "repeatPenalty", m_repeatPenalty }, + { "repeatPenaltyTokens", m_repeatPenaltyTokens }, + { "promptTemplate", m_promptTemplate }, + { "systemPrompt", m_systemPrompt }, + }; } InstalledModels::InstalledModels(QObject *parent) @@ -424,13 +400,14 @@ DownloadableModels::DownloadableModels(QObject *parent) bool DownloadableModels::filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const { + // FIXME We can eliminate the 'expanded' code as the UI no longer uses this bool withinLimit = sourceRow < (m_expanded ? sourceModel()->rowCount() : m_limit); QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); bool isDownloadable = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty(); - bool isInstalled = !sourceModel()->data(index, ModelList::InstalledRole).toString().isEmpty(); - bool isIncomplete = !sourceModel()->data(index, ModelList::IncompleteRole).toString().isEmpty(); + bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool(); + bool isIncomplete = sourceModel()->data(index, ModelList::IncompleteRole).toBool(); bool isClone = sourceModel()->data(index, ModelList::IsCloneRole).toBool(); - return withinLimit && !isClone && (isDownloadable || isInstalled || isIncomplete); + return withinLimit && !isClone && !isInstalled && (isDownloadable || isIncomplete); } int DownloadableModels::count() const @@ -468,9 +445,7 @@ ModelList *ModelList::globalInstance() ModelList::ModelList() : QAbstractListModel(nullptr) - , m_embeddingModels(new EmbeddingModels(this, false /* all models */)) , m_installedModels(new InstalledModels(this)) - , m_installedEmbeddingModels(new EmbeddingModels(this, true /* installed models */)) , m_downloadableModels(new DownloadableModels(this)) , m_asyncModelRequestOngoing(false) , m_discoverLimit(20) @@ -480,9 +455,7 @@ ModelList::ModelList() , m_discoverResultsCompleted(0) , m_discoverInProgress(false) { - m_embeddingModels->setSourceModel(this); m_installedModels->setSourceModel(this); - m_installedEmbeddingModels->setSourceModel(this); m_downloadableModels->setSourceModel(this); connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory); @@ -552,17 +525,11 @@ const QList ModelList::userDefaultModelList() const return models; } -int ModelList::defaultEmbeddingModelIndex() const -{ - return embeddingModels()->defaultModelIndex(); -} - ModelInfo ModelList::defaultModelInfo() const { QMutexLocker locker(&m_mutex); QSettings settings; - settings.sync(); // The user default model can be set by the user in the settings dialog. The "default" user // default model is "Application default" which signals we should use the logic here. @@ -1153,7 +1120,7 @@ void ModelList::removeInternal(const ModelInfo &model) QString ModelList::uniqueModelName(const ModelInfo &model) const { QMutexLocker locker(&m_mutex); - QRegularExpression re("^(.*)~(\\d+)$"); + static const QRegularExpression re("^(.*)~(\\d+)$"); QRegularExpressionMatch match = re.match(model.name()); QString baseName; if (match.hasMatch()) @@ -1208,13 +1175,11 @@ void ModelList::updateModelsFromDirectory() it.next(); if (!it.fileInfo().isDir()) { QString filename = it.fileName(); - if (filename.endsWith(".txt") && (filename.startsWith("chatgpt-") || filename.startsWith("nomic-"))) { + if (filename.startsWith("chatgpt-") && filename.endsWith(".txt")) { QString apikey; QString modelname(filename); modelname.chop(4); // strip ".txt" extension - if (filename.startsWith("chatgpt-")) { - modelname.remove(0, 8); // strip "chatgpt-" prefix - } + modelname.remove(0, 8); // strip "chatgpt-" prefix QFile file(path + filename); if (file.open(QIODevice::ReadWrite)) { QTextStream in(&file); @@ -1227,7 +1192,7 @@ void ModelList::updateModelsFromDirectory() obj.insert("modelName", modelname); QJsonDocument doc(obj); - auto newfilename = QString("gpt4all-%1.rmodel").arg(modelname); + auto newfilename = u"gpt4all-%1.rmodel"_s.arg(modelname); QFile newfile(path + newfilename); if (newfile.open(QIODevice::ReadWrite)) { QTextStream out(&newfile); @@ -1241,46 +1206,41 @@ void ModelList::updateModelsFromDirectory() }; auto processDirectory = [&](const QString& path) { - QDirIterator it(path, QDirIterator::Subdirectories); + QDirIterator it(path, QDir::Files, QDirIterator::Subdirectories); while (it.hasNext()) { it.next(); - if (!it.fileInfo().isDir()) { - QString filename = it.fileName(); + QString filename = it.fileName(); + if (filename.startsWith("incomplete") || FILENAME_BLACKLIST.contains(filename)) + continue; + if (!filename.endsWith(".gguf") && !filename.endsWith(".rmodel")) + continue; - if ((filename.endsWith(".gguf") && !filename.startsWith("incomplete")) || filename.endsWith(".rmodel")) { + QVector modelsById; + { + QMutexLocker locker(&m_mutex); + for (ModelInfo *info : m_models) + if (info->filename() == filename) + modelsById.append(info->id()); + } - QString filePath = it.filePath(); - QFileInfo info(filePath); + if (modelsById.isEmpty()) { + if (!contains(filename)) + addModel(filename); + modelsById.append(filename); + } - if (!info.exists()) - continue; + QFileInfo info = it.fileInfo(); - QVector modelsById; - { - QMutexLocker locker(&m_mutex); - for (ModelInfo *info : m_models) - if (info->filename() == filename) - modelsById.append(info->id()); - } - - if (modelsById.isEmpty()) { - if (!contains(filename)) - addModel(filename); - modelsById.append(filename); - } - - for (const QString &id : modelsById) { - QVector> data { - { InstalledRole, true }, - { FilenameRole, filename }, - { OnlineRole, filename.endsWith(".rmodel") }, - { DirpathRole, info.dir().absolutePath() + "/" }, - { FilesizeRole, toFileSize(info.size()) }, - }; - updateData(id, data); - } - } + for (const QString &id : modelsById) { + QVector> data { + { InstalledRole, true }, + { FilenameRole, filename }, + { OnlineRole, filename.endsWith(".rmodel") }, + { DirpathRole, info.dir().absolutePath() + "/" }, + { FilesizeRole, toFileSize(info.size()) }, + }; + updateData(id, data); } } }; @@ -1299,9 +1259,9 @@ void ModelList::updateModelsFromDirectory() void ModelList::updateModelsFromJson() { #if defined(USE_LOCAL_MODELSJSON) - QUrl jsonUrl("file://" + QDir::homePath() + QString("/dev/large_language_models/gpt4all/gpt4all-chat/metadata/models%1.json").arg(MODELS_VERSION)); + QUrl jsonUrl("file://" + QDir::homePath() + u"/dev/large_language_models/gpt4all/gpt4all-chat/metadata/models%1.json"_s.arg(MODELS_VERSION)); #else - QUrl jsonUrl(QString("http://gpt4all.io/models/models%1.json").arg(MODELS_VERSION)); + QUrl jsonUrl(u"http://gpt4all.io/models/models%1.json"_s.arg(MODELS_VERSION)); #endif QNetworkRequest request(jsonUrl); QSslConfiguration conf = request.sslConfiguration(); @@ -1343,9 +1303,9 @@ void ModelList::updateModelsFromJsonAsync() emit asyncModelRequestOngoingChanged(); #if defined(USE_LOCAL_MODELSJSON) - QUrl jsonUrl("file://" + QDir::homePath() + QString("/dev/large_language_models/gpt4all/gpt4all-chat/metadata/models%1.json").arg(MODELS_VERSION)); + QUrl jsonUrl("file://" + QDir::homePath() + u"/dev/large_language_models/gpt4all/gpt4all-chat/metadata/models%1.json"_s.arg(MODELS_VERSION)); #else - QUrl jsonUrl(QString("http://gpt4all.io/models/models%1.json").arg(MODELS_VERSION)); + QUrl jsonUrl(u"http://gpt4all.io/models/models%1.json"_s.arg(MODELS_VERSION)); #endif QNetworkRequest request(jsonUrl); QSslConfiguration conf = request.sslConfiguration(); @@ -1383,7 +1343,7 @@ void ModelList::handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkErro if (!reply) return; - qWarning() << QString("ERROR: Modellist download failed with error code \"%1-%2\"") + qWarning() << u"ERROR: Modellist download failed with error code \"%1-%2\""_s .arg(code).arg(reply->errorString()); } @@ -1399,7 +1359,8 @@ void ModelList::updateDataForSettings() emit dataChanged(index(0, 0), index(m_models.size() - 1, 0)); } -static std::strong_ordering compareVersions(const QString &a, const QString &b) { +static std::strong_ordering compareVersions(const QString &a, const QString &b) +{ QStringList aParts = a.split('.'); QStringList bParts = b.split('.'); @@ -1450,8 +1411,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) QString versionRemoved = obj["removedIn"].toString(); QString url = obj["url"].toString(); QByteArray modelHash = obj["md5sum"].toString().toLatin1(); - bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true"); - bool disableGUI = obj.contains("disableGUI") && obj["disableGUI"] == QString("true"); + bool isDefault = obj.contains("isDefault") && obj["isDefault"] == u"true"_s; + bool disableGUI = obj.contains("disableGUI") && obj["disableGUI"] == u"true"_s; QString description = obj["description"].toString(); QString order = obj["order"].toString(); int ramrequired = obj["ramrequired"].toString().toInt(); @@ -1586,8 +1547,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) }; updateData(id, data); } - - const QString mistralDesc = tr("
  • Requires personal Mistral API key.
  • WARNING: Will send" + + const QString mistralDesc = tr("
    • Requires personal Mistral API key.
    • WARNING: Will send" " your chats to Mistral!
    • Your API key will be stored on disk
    • Will only be used" " to communicate with Mistral
    • You can apply for an API key" " here.
    • "); @@ -1642,7 +1603,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) }; updateData(id, data); } - + { const QString modelName = "Mistral Medium API"; const QString id = modelName; @@ -1668,38 +1629,6 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) }; updateData(id, data); } - - - { - const QString nomicEmbedDesc = tr("
      • For use with LocalDocs feature
      • " - "
      • Used for retrieval augmented generation (RAG)
      • " - "
      • Requires personal Nomic API key.
      • " - "
      • WARNING: Will send your localdocs to Nomic Atlas!
      • " - "
      • You can apply for an API key with Nomic Atlas.
      • "); - const QString modelName = "Nomic Embed"; - const QString id = modelName; - const QString modelFilename = "gpt4all-nomic-embed-text-v1.rmodel"; - if (contains(modelFilename)) - changeId(modelFilename, id); - if (!contains(id)) - addModel(id); - QVector> data { - { ModelList::NameRole, modelName }, - { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, - { ModelList::OnlineRole, true }, - { ModelList::IsEmbeddingModelRole, true }, - { ModelList::DescriptionRole, - tr("LocalDocs Nomic Atlas Embed
        ") + nomicEmbedDesc }, - { ModelList::RequiresVersionRole, "2.6.3" }, - { ModelList::OrderRole, "na" }, - { ModelList::RamrequiredRole, 0 }, - { ModelList::ParametersRole, "?" }, - { ModelList::QuantRole, "NA" }, - { ModelList::TypeRole, "Bert" }, - }; - updateData(id, data); - } } void ModelList::updateDiscoveredInstalled(const ModelInfo &info) @@ -1723,9 +1652,8 @@ void ModelList::updateDiscoveredInstalled(const ModelInfo &info) void ModelList::updateModelsFromSettings() { QSettings settings; - settings.sync(); QStringList groups = settings.childGroups(); - for (const QString g : groups) { + for (const QString &g: groups) { if (!g.startsWith("model-")) continue; @@ -1913,7 +1841,7 @@ void ModelList::discoverSearch(const QString &search) m_discoverNumberOfResults = 0; m_discoverResultsCompleted = 0; - discoverProgressChanged(); + emit discoverProgressChanged(); if (search.isEmpty()) { return; @@ -1922,9 +1850,10 @@ void ModelList::discoverSearch(const QString &search) m_discoverInProgress = true; emit discoverInProgressChanged(); - QStringList searchParams = search.split(QRegularExpression("\\s+")); // split by whitespace - QString searchString = QString("search=%1&").arg(searchParams.join('+')); - QString limitString = m_discoverLimit > 0 ? QString("limit=%1&").arg(m_discoverLimit) : QString(); + static const QRegularExpression wsRegex("\\s+"); + QStringList searchParams = search.split(wsRegex); // split by whitespace + QString searchString = u"search=%1&"_s.arg(searchParams.join('+')); + QString limitString = m_discoverLimit > 0 ? u"limit=%1&"_s.arg(m_discoverLimit) : QString(); QString sortString; switch (m_discoverSort) { @@ -1937,9 +1866,10 @@ void ModelList::discoverSearch(const QString &search) sortString = "sort=lastModified&"; break; } - QString directionString = !sortString.isEmpty() ? QString("direction=%1&").arg(m_discoverSortDirection) : QString(); + QString directionString = !sortString.isEmpty() ? u"direction=%1&"_s.arg(m_discoverSortDirection) : QString(); - QUrl hfUrl(QString("https://huggingface.co/api/models?filter=gguf&%1%2%3%4full=true&config=true").arg(searchString).arg(limitString).arg(sortString).arg(directionString)); + QUrl hfUrl(u"https://huggingface.co/api/models?filter=gguf&%1%2%3%4full=true&config=true"_s + .arg(searchString, limitString, sortString, directionString)); QNetworkRequest request(hfUrl); request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); @@ -1965,7 +1895,7 @@ void ModelList::handleDiscoveryErrorOccurred(QNetworkReply::NetworkError code) QNetworkReply *reply = qobject_cast(sender()); if (!reply) return; - qWarning() << QString("ERROR: Discovery failed with error code \"%1-%2\"") + qWarning() << u"ERROR: Discovery failed with error code \"%1-%2\""_s .arg(code).arg(reply->errorString()).toStdString(); } @@ -2005,7 +1935,7 @@ void ModelList::parseDiscoveryJsonFile(const QByteArray &jsonData) qWarning() << "ERROR: Couldn't parse: " << jsonData << err.errorString(); m_discoverNumberOfResults = 0; m_discoverResultsCompleted = 0; - discoverProgressChanged(); + emit discoverProgressChanged(); m_discoverInProgress = false; emit discoverInProgressChanged(); return; @@ -2045,7 +1975,7 @@ void ModelList::parseDiscoveryJsonFile(const QByteArray &jsonData) QString filename = file.second; ++m_discoverNumberOfResults; - QUrl url(QString("https://huggingface.co/%1/resolve/main/%2").arg(repo_id).arg(filename)); + QUrl url(u"https://huggingface.co/%1/resolve/main/%2"_s.arg(repo_id, filename)); QNetworkRequest request(url); request.setRawHeader("Accept-Encoding", "identity"); request.setAttribute(QNetworkRequest::RedirectPolicyAttribute, QNetworkRequest::ManualRedirectPolicy); @@ -2084,21 +2014,16 @@ void ModelList::handleDiscoveryItemFinished() QJsonObject config = obj["config"].toObject(); QString type = config["model_type"].toString(); - QByteArray repoCommitHeader = reply->rawHeader("X-Repo-Commit"); + // QByteArray repoCommitHeader = reply->rawHeader("X-Repo-Commit"); QByteArray linkedSizeHeader = reply->rawHeader("X-Linked-Size"); QByteArray linkedEtagHeader = reply->rawHeader("X-Linked-Etag"); // For some reason these seem to contain quotation marks ewww linkedEtagHeader.replace("\"", ""); linkedEtagHeader.replace("\'", ""); - QString locationHeader = reply->header(QNetworkRequest::LocationHeader).toString(); - - QString repoCommit = QString::fromUtf8(repoCommitHeader); - QString linkedSize = QString::fromUtf8(linkedSizeHeader); - QString linkedEtag = QString::fromUtf8(linkedEtagHeader); + // QString locationHeader = reply->header(QNetworkRequest::LocationHeader).toString(); QString modelFilename = reply->request().attribute(QNetworkRequest::UserMax).toString(); - QString modelFilesize = linkedSize; - modelFilesize = ModelList::toFileSize(modelFilesize.toULongLong()); + QString modelFilesize = ModelList::toFileSize(QString(linkedSizeHeader).toULongLong()); QString description = tr("Created by %1.
          " "
        • Published on %2." @@ -2155,6 +2080,6 @@ void ModelList::handleDiscoveryItemErrorOccurred(QNetworkReply::NetworkError cod if (!reply) return; - qWarning() << QString("ERROR: Discovery item failed with error code \"%1-%2\"") + qWarning() << u"ERROR: Discovery item failed with error code \"%1-%2\""_s .arg(code).arg(reply->errorString()).toStdString(); } diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 1209a7b3..a7df583c 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -20,6 +20,8 @@ #include #include +using namespace Qt::Literals::StringLiterals; + struct ModelInfo { Q_GADGET Q_PROPERTY(QString id READ id WRITE setId) @@ -169,6 +171,8 @@ public: bool shouldSaveMetadata() const; private: + QVariantMap getFields() const; + QString m_id; QString m_name; QString m_filename; @@ -199,28 +203,6 @@ private: }; Q_DECLARE_METATYPE(ModelInfo) -class EmbeddingModels : public QSortFilterProxyModel -{ - Q_OBJECT - Q_PROPERTY(int count READ count NOTIFY countChanged) -public: - EmbeddingModels(QObject *parent, bool requireInstalled); - int count() const { return rowCount(); } - - int defaultModelIndex() const; - ModelInfo defaultModelInfo() const; - -Q_SIGNALS: - void countChanged(); - void defaultModelIndexChanged(); - -protected: - bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; - -private: - bool m_requireInstalled; -}; - class InstalledModels : public QSortFilterProxyModel { Q_OBJECT @@ -269,8 +251,6 @@ class ModelList : public QAbstractListModel { Q_OBJECT Q_PROPERTY(int count READ count NOTIFY countChanged) - Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex) - Q_PROPERTY(EmbeddingModels* installedEmbeddingModels READ installedEmbeddingModels NOTIFY installedEmbeddingModelsChanged) Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged) Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged) Q_PROPERTY(QList userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged) @@ -408,7 +388,6 @@ public: Q_INVOKABLE void removeClone(const ModelInfo &model); Q_INVOKABLE void removeInstalled(const ModelInfo &model); ModelInfo defaultModelInfo() const; - int defaultEmbeddingModelIndex() const; void addModel(const QString &id); void changeId(const QString &oldId, const QString &newId); @@ -416,20 +395,18 @@ public: const QList exportModelList() const; const QList userDefaultModelList() const; - EmbeddingModels *embeddingModels() const { return m_embeddingModels; } - EmbeddingModels *installedEmbeddingModels() const { return m_installedEmbeddingModels; } InstalledModels *installedModels() const { return m_installedModels; } DownloadableModels *downloadableModels() const { return m_downloadableModels; } static inline QString toFileSize(quint64 sz) { if (sz < 1024) { - return QString("%1 bytes").arg(sz); + return u"%1 bytes"_s.arg(sz); } else if (sz < 1024 * 1024) { - return QString("%1 KB").arg(qreal(sz) / 1024, 0, 'g', 3); + return u"%1 KB"_s.arg(qreal(sz) / 1024, 0, 'g', 3); } else if (sz < 1024 * 1024 * 1024) { - return QString("%1 MB").arg(qreal(sz) / (1024 * 1024), 0, 'g', 3); + return u"%1 MB"_s.arg(qreal(sz) / (1024 * 1024), 0, 'g', 3); } else { - return QString("%1 GB").arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); + return u"%1 GB"_s.arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); } } @@ -455,7 +432,6 @@ public: Q_SIGNALS: void countChanged(); - void installedEmbeddingModelsChanged(); void installedModelsChanged(); void downloadableModelsChanged(); void userDefaultModelListChanged(); @@ -494,8 +470,6 @@ private: private: mutable QMutex m_mutex; QNetworkAccessManager m_networkManager; - EmbeddingModels *m_embeddingModels; - EmbeddingModels *m_installedEmbeddingModels; InstalledModels *m_installedModels; DownloadableModels *m_downloadableModels; QList m_models; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index c52ab638..69d28ace 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -8,40 +8,53 @@ #include #include #include -#include +#include +#include #include #include -#include #include #include +#include #include #include #include #include -static int default_threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency()); -static bool default_saveChatsContext = false; -static bool default_serverChat = false; -static QString default_userDefaultModel = "Application default"; -static bool default_forceMetal = false; -static QString default_lastVersionStarted = ""; -static int default_localDocsChunkSize = 256; -static QString default_chatTheme = "Dark"; -static QString default_fontSize = "Small"; -static int default_localDocsRetrievalSize = 3; -static bool default_localDocsShowReferences = true; -static QString default_networkAttribution = ""; -static bool default_networkIsActive = false; -static int default_networkPort = 4891; -static bool default_networkUsageStatsActive = false; -static QString default_device = "Auto"; +using namespace Qt::Literals::StringLiterals; + +namespace defaults { + +static const int threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency()); +static const bool forceMetal = false; +static const bool networkIsActive = false; +static const bool networkUsageStatsActive = false; +static const QString device = "Auto"; + +} // namespace defaults + +static const QVariantMap basicDefaults { + { "chatTheme", "Light" }, + { "fontSize", "Small" }, + { "lastVersionStarted", "" }, + { "networkPort", 4891, }, + { "saveChatsContext", false }, + { "serverChat", false }, + { "userDefaultModel", "Application default" }, + { "localdocs/chunkSize", 256 }, + { "localdocs/retrievalSize", 3 }, + { "localdocs/showReferences", true }, + { "localdocs/fileExtensions", QStringList { "txt", "pdf", "md", "rst" } }, + { "localdocs/useRemoteEmbed", false }, + { "localdocs/nomicAPIKey", "" }, + { "network/attribution", "" }, +}; static QString defaultLocalModelsPath() { QString localPath = QStandardPaths::writableLocation(QStandardPaths::AppLocalDataLocation) + "/"; - QString testWritePath = localPath + QString("test_write.txt"); + QString testWritePath = localPath + u"test_write.txt"_s; QString canonicalLocalPath = QFileInfo(localPath).canonicalFilePath() + "/"; QDir localDir(localPath); if (!localDir.exists()) { @@ -74,8 +87,6 @@ MySettings *MySettings::globalInstance() MySettings::MySettings() : QObject{nullptr} { - QSettings::setDefaultFormat(QSettings::IniFormat); - QVector deviceList{ "Auto" }; #if defined(Q_OS_MAC) && defined(__aarch64__) deviceList << "Metal"; @@ -88,722 +99,336 @@ MySettings::MySettings() setDeviceList(deviceList); } +QVariant MySettings::getBasicSetting(const QString &name) const +{ + return m_settings.value(name, basicDefaults.value(name)); +} + +void MySettings::setBasicSetting(const QString &name, const QVariant &value, std::optional signal) +{ + if (getBasicSetting(name) == value) + return; + + m_settings.setValue(name, value); + QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(signal.value_or(name)).toLatin1().constData()); +} + Q_INVOKABLE QVector MySettings::deviceList() const { return m_deviceList; } -void MySettings::setDeviceList(const QVector &deviceList) +void MySettings::setDeviceList(const QVector &value) { - m_deviceList = deviceList; + m_deviceList = value; emit deviceListChanged(); } -void MySettings::restoreModelDefaults(const ModelInfo &model) +void MySettings::restoreModelDefaults(const ModelInfo &info) { - setModelTemperature(model, model.m_temperature); - setModelTopP(model, model.m_topP); - setModelMinP(model, model.m_minP); - setModelTopK(model, model.m_topK);; - setModelMaxLength(model, model.m_maxLength); - setModelPromptBatchSize(model, model.m_promptBatchSize); - setModelContextLength(model, model.m_contextLength); - setModelGpuLayers(model, model.m_gpuLayers); - setModelRepeatPenalty(model, model.m_repeatPenalty); - setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens); - setModelPromptTemplate(model, model.m_promptTemplate); - setModelSystemPrompt(model, model.m_systemPrompt); + setModelTemperature(info, info.m_temperature); + setModelTopP(info, info.m_topP); + setModelMinP(info, info.m_minP); + setModelTopK(info, info.m_topK);; + setModelMaxLength(info, info.m_maxLength); + setModelPromptBatchSize(info, info.m_promptBatchSize); + setModelContextLength(info, info.m_contextLength); + setModelGpuLayers(info, info.m_gpuLayers); + setModelRepeatPenalty(info, info.m_repeatPenalty); + setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); + setModelPromptTemplate(info, info.m_promptTemplate); + setModelSystemPrompt(info, info.m_systemPrompt); } void MySettings::restoreApplicationDefaults() { - setChatTheme(default_chatTheme); - setFontSize(default_fontSize); - setDevice(default_device); - setThreadCount(default_threadCount); - setSaveChatsContext(default_saveChatsContext); - setServerChat(default_serverChat); - setNetworkPort(default_networkPort); + setChatTheme(basicDefaults.value("chatTheme").toString()); + setFontSize(basicDefaults.value("fontSize").toString()); + setDevice(defaults::device); + setThreadCount(defaults::threadCount); + setSaveChatsContext(basicDefaults.value("saveChatsContext").toBool()); + setServerChat(basicDefaults.value("serverChat").toBool()); + setNetworkPort(basicDefaults.value("networkPort").toInt()); setModelPath(defaultLocalModelsPath()); - setUserDefaultModel(default_userDefaultModel); - setForceMetal(default_forceMetal); + setUserDefaultModel(basicDefaults.value("userDefaultModel").toString()); + setForceMetal(defaults::forceMetal); } void MySettings::restoreLocalDocsDefaults() { - setLocalDocsChunkSize(default_localDocsChunkSize); - setLocalDocsRetrievalSize(default_localDocsRetrievalSize); - setLocalDocsShowReferences(default_localDocsShowReferences); + setLocalDocsChunkSize(basicDefaults.value("localdocs/chunkSize").toInt()); + setLocalDocsRetrievalSize(basicDefaults.value("localdocs/retrievalSize").toInt()); + setLocalDocsShowReferences(basicDefaults.value("localdocs/showReferences").toBool()); + setLocalDocsFileExtensions(basicDefaults.value("localdocs/fileExtensions").toStringList()); + setLocalDocsUseRemoteEmbed(basicDefaults.value("localdocs/useRemoteEmbed").toBool()); + setLocalDocsNomicAPIKey(basicDefaults.value("localdocs/nomicAPIKey").toString()); } -void MySettings::eraseModel(const ModelInfo &m) +void MySettings::eraseModel(const ModelInfo &info) { - QSettings settings; - settings.remove(QString("model-%1").arg(m.id())); - settings.sync(); + m_settings.remove(u"model-%1"_s.arg(info.id())); } -QString MySettings::modelName(const ModelInfo &m) const +QString MySettings::modelName(const ModelInfo &info) const { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/name", - !m.m_name.isEmpty() ? m.m_name : m.m_filename).toString(); + return m_settings.value(u"model-%1/name"_s.arg(info.id()), + !info.m_name.isEmpty() ? info.m_name : info.m_filename).toString(); } -void MySettings::setModelName(const ModelInfo &m, const QString &name, bool force) +void MySettings::setModelName(const ModelInfo &info, const QString &value, bool force) { - if ((modelName(m) == name || m.id().isEmpty()) && !force) + if ((modelName(info) == value || info.id().isEmpty()) && !force) return; - QSettings setting; - if ((m.m_name == name || m.m_filename == name) && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/name"); + if ((info.m_name == value || info.m_filename == value) && !info.shouldSaveMetadata()) + m_settings.remove(u"model-%1/name"_s.arg(info.id())); else - setting.setValue(QString("model-%1").arg(m.id()) + "/name", name); - setting.sync(); + m_settings.setValue(u"model-%1/name"_s.arg(info.id()), value); if (!force) - emit nameChanged(m); + emit nameChanged(info); } -QString MySettings::modelFilename(const ModelInfo &m) const +static QString modelSettingName(const ModelInfo &info, const QString &name) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/filename", m.m_filename).toString(); + return u"model-%1/%2"_s.arg(info.id(), name); } -void MySettings::setModelFilename(const ModelInfo &m, const QString &filename, bool force) +QVariant MySettings::getModelSetting(const QString &name, const ModelInfo &info) const { - if ((modelFilename(m) == filename || m.id().isEmpty()) && !force) + return m_settings.value(modelSettingName(info, name), info.getFields().value(name)); +} + +void MySettings::setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, + bool signal) +{ + if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value)) return; - QSettings setting; - if (m.m_filename == filename && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/filename"); + QString settingName = modelSettingName(info, name); + if (info.getFields().value(name) == value && !info.shouldSaveMetadata()) + m_settings.remove(settingName); else - setting.setValue(QString("model-%1").arg(m.id()) + "/filename", filename); - setting.sync(); - if (!force) - emit filenameChanged(m); + m_settings.setValue(settingName, value); + if (signal && !force) + QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(name).toLatin1().constData(), Q_ARG(ModelInfo, info)); } -QString MySettings::modelDescription(const ModelInfo &m) const +QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); } +QString MySettings::modelDescription (const ModelInfo &info) const { return getModelSetting("description", info).toString(); } +QString MySettings::modelUrl (const ModelInfo &info) const { return getModelSetting("url", info).toString(); } +QString MySettings::modelQuant (const ModelInfo &info) const { return getModelSetting("quant", info).toString(); } +QString MySettings::modelType (const ModelInfo &info) const { return getModelSetting("type", info).toString(); } +bool MySettings::modelIsClone (const ModelInfo &info) const { return getModelSetting("isClone", info).toBool(); } +bool MySettings::modelIsDiscovered (const ModelInfo &info) const { return getModelSetting("isDiscovered", info).toBool(); } +int MySettings::modelLikes (const ModelInfo &info) const { return getModelSetting("likes", info).toInt(); } +int MySettings::modelDownloads (const ModelInfo &info) const { return getModelSetting("downloads", info).toInt(); } +QDateTime MySettings::modelRecency (const ModelInfo &info) const { return getModelSetting("recency", info).toDateTime(); } +double MySettings::modelTemperature (const ModelInfo &info) const { return getModelSetting("temperature", info).toDouble(); } +double MySettings::modelTopP (const ModelInfo &info) const { return getModelSetting("topP", info).toDouble(); } +double MySettings::modelMinP (const ModelInfo &info) const { return getModelSetting("minP", info).toDouble(); } +int MySettings::modelTopK (const ModelInfo &info) const { return getModelSetting("topK", info).toInt(); } +int MySettings::modelMaxLength (const ModelInfo &info) const { return getModelSetting("maxLength", info).toInt(); } +int MySettings::modelPromptBatchSize (const ModelInfo &info) const { return getModelSetting("promptBatchSize", info).toInt(); } +int MySettings::modelContextLength (const ModelInfo &info) const { return getModelSetting("contextLength", info).toInt(); } +int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); } +double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); } +int MySettings::modelRepeatPenaltyTokens(const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } +QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } +QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } + +void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/description", m.m_description).toString(); + setModelSetting("filename", info, value, force, true); } -void MySettings::setModelDescription(const ModelInfo &m, const QString &d, bool force) +void MySettings::setModelDescription(const ModelInfo &info, const QString &value, bool force) { - if ((modelDescription(m) == d || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_description == d && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/description"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/description", d); - setting.sync(); + setModelSetting("description", info, value, force, true); } -QString MySettings::modelUrl(const ModelInfo &m) const +void MySettings::setModelUrl(const ModelInfo &info, const QString &value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/url", m.m_url).toString(); + setModelSetting("url", info, value, force); } -void MySettings::setModelUrl(const ModelInfo &m, const QString &u, bool force) +void MySettings::setModelQuant(const ModelInfo &info, const QString &value, bool force) { - if ((modelUrl(m) == u || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_url == u && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/url"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/url", u); - setting.sync(); + setModelSetting("quant", info, value, force); } -QString MySettings::modelQuant(const ModelInfo &m) const +void MySettings::setModelType(const ModelInfo &info, const QString &value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/quant", m.m_quant).toString(); + setModelSetting("type", info, value, force); } -void MySettings::setModelQuant(const ModelInfo &m, const QString &q, bool force) +void MySettings::setModelIsClone(const ModelInfo &info, bool value, bool force) { - if ((modelUrl(m) == q || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_quant == q && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/quant"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/quant", q); - setting.sync(); + setModelSetting("isClone", info, value, force); } -QString MySettings::modelType(const ModelInfo &m) const +void MySettings::setModelIsDiscovered(const ModelInfo &info, bool value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/type", m.m_type).toString(); + setModelSetting("isDiscovered", info, value, force); } -void MySettings::setModelType(const ModelInfo &m, const QString &t, bool force) +void MySettings::setModelLikes(const ModelInfo &info, int value, bool force) { - if ((modelType(m) == t || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_type == t && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/type"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/type", t); - setting.sync(); + setModelSetting("likes", info, value, force); } -bool MySettings::modelIsClone(const ModelInfo &m) const +void MySettings::setModelDownloads(const ModelInfo &info, int value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/isClone", m.m_isClone).toBool(); + setModelSetting("downloads", info, value, force); } -void MySettings::setModelIsClone(const ModelInfo &m, bool b, bool force) +void MySettings::setModelRecency(const ModelInfo &info, const QDateTime &value, bool force) { - if ((modelIsClone(m) == b || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_isClone == b && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/isClone"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/isClone", b); - setting.sync(); + setModelSetting("recency", info, value, force); } -bool MySettings::modelIsDiscovered(const ModelInfo &m) const +void MySettings::setModelTemperature(const ModelInfo &info, double value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/isDiscovered", m.m_isDiscovered).toBool(); + setModelSetting("temperature", info, value, force, true); } -void MySettings::setModelIsDiscovered(const ModelInfo &m, bool b, bool force) +void MySettings::setModelTopP(const ModelInfo &info, double value, bool force) { - if ((modelIsDiscovered(m) == b || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_isDiscovered == b && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/isDiscovered"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/isDiscovered", b); - setting.sync(); + setModelSetting("topP", info, value, force, true); } -int MySettings::modelLikes(const ModelInfo &m) const +void MySettings::setModelMinP(const ModelInfo &info, double value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/likes", m.m_likes).toInt(); + setModelSetting("minP", info, value, force, true); } -void MySettings::setModelLikes(const ModelInfo &m, int l, bool force) +void MySettings::setModelTopK(const ModelInfo &info, int value, bool force) { - if ((modelLikes(m) == l || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_likes == l && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/likes"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/likes", l); - setting.sync(); + setModelSetting("topK", info, value, force, true); } -int MySettings::modelDownloads(const ModelInfo &m) const +void MySettings::setModelMaxLength(const ModelInfo &info, int value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/downloads", m.m_downloads).toInt(); + setModelSetting("maxLength", info, value, force, true); } -void MySettings::setModelDownloads(const ModelInfo &m, int d, bool force) +void MySettings::setModelPromptBatchSize(const ModelInfo &info, int value, bool force) { - if ((modelDownloads(m) == d || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_downloads == d && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/downloads"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/downloads", d); - setting.sync(); + setModelSetting("promptBatchSize", info, value, force, true); } -QDateTime MySettings::modelRecency(const ModelInfo &m) const +void MySettings::setModelContextLength(const ModelInfo &info, int value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/recency", m.m_recency).toDateTime(); + setModelSetting("contextLength", info, value, force, true); } -void MySettings::setModelRecency(const ModelInfo &m, const QDateTime &r, bool force) +void MySettings::setModelGpuLayers(const ModelInfo &info, int value, bool force) { - if ((modelRecency(m) == r || m.id().isEmpty()) && !force) - return; - - QSettings setting; - if (m.m_recency == r && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/recency"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/recency", r); - setting.sync(); + setModelSetting("gpuLayers", info, value, force, true); } -double MySettings::modelTemperature(const ModelInfo &m) const +void MySettings::setModelRepeatPenalty(const ModelInfo &info, double value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/temperature", m.m_temperature).toDouble(); + setModelSetting("repeatPenalty", info, value, force, true); } -void MySettings::setModelTemperature(const ModelInfo &m, double t, bool force) +void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force) { - if (modelTemperature(m) == t && !force) - return; - - QSettings setting; - if (m.m_temperature == t && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/temperature"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/temperature", t); - setting.sync(); - if (!force) - emit temperatureChanged(m); + setModelSetting("repeatPenaltyTokens", info, value, force, true); } -double MySettings::modelTopP(const ModelInfo &m) const +void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble(); + setModelSetting("promptTemplate", info, value, force, true); } -double MySettings::modelMinP(const ModelInfo &m) const +void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) { - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/minP", m.m_minP).toDouble(); -} - -void MySettings::setModelTopP(const ModelInfo &m, double p, bool force) -{ - if (modelTopP(m) == p && !force) - return; - - QSettings setting; - if (m.m_topP == p && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/topP"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/topP", p); - setting.sync(); - if (!force) - emit topPChanged(m); -} - -void MySettings::setModelMinP(const ModelInfo &m, double p, bool force) -{ - if (modelMinP(m) == p && !force) - return; - - QSettings setting; - if (m.m_minP == p && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/minP"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/minP", p); - setting.sync(); - if (!force) - emit minPChanged(m); -} - -int MySettings::modelTopK(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/topK", m.m_topK).toInt(); -} - -void MySettings::setModelTopK(const ModelInfo &m, int k, bool force) -{ - if (modelTopK(m) == k && !force) - return; - - QSettings setting; - if (m.m_topK == k && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/topK"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/topK", k); - setting.sync(); - if (!force) - emit topKChanged(m); -} - -int MySettings::modelMaxLength(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/maxLength", m.m_maxLength).toInt(); -} - -void MySettings::setModelMaxLength(const ModelInfo &m, int l, bool force) -{ - if (modelMaxLength(m) == l && !force) - return; - - QSettings setting; - if (m.m_maxLength == l && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/maxLength"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/maxLength", l); - setting.sync(); - if (!force) - emit maxLengthChanged(m); -} - -int MySettings::modelPromptBatchSize(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/promptBatchSize", m.m_promptBatchSize).toInt(); -} - -void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force) -{ - if (modelPromptBatchSize(m) == s && !force) - return; - - QSettings setting; - if (m.m_promptBatchSize == s && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/promptBatchSize"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/promptBatchSize", s); - setting.sync(); - if (!force) - emit promptBatchSizeChanged(m); -} - -int MySettings::modelContextLength(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt(); -} - -void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force) -{ - if (modelContextLength(m) == l && !force) - return; - - QSettings setting; - if (m.m_contextLength == l && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/contextLength"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l); - setting.sync(); - if (!force) - emit contextLengthChanged(m); -} - -int MySettings::modelGpuLayers(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/gpuLayers", m.m_gpuLayers).toInt(); -} - -void MySettings::setModelGpuLayers(const ModelInfo &m, int l, bool force) -{ - if (modelGpuLayers(m) == l && !force) - return; - - QSettings setting; - if (m.m_gpuLayers == l && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/gpuLayers"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/gpuLayers", l); - setting.sync(); - if (!force) - emit gpuLayersChanged(m); -} - -double MySettings::modelRepeatPenalty(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/repeatPenalty", m.m_repeatPenalty).toDouble(); -} - -void MySettings::setModelRepeatPenalty(const ModelInfo &m, double p, bool force) -{ - if (modelRepeatPenalty(m) == p && !force) - return; - - QSettings setting; - if (m.m_repeatPenalty == p && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/repeatPenalty"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/repeatPenalty", p); - setting.sync(); - if (!force) - emit repeatPenaltyChanged(m); -} - -int MySettings::modelRepeatPenaltyTokens(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/repeatPenaltyTokens", m.m_repeatPenaltyTokens).toInt(); -} - -void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &m, int t, bool force) -{ - if (modelRepeatPenaltyTokens(m) == t && !force) - return; - - QSettings setting; - if (m.m_repeatPenaltyTokens == t && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/repeatPenaltyTokens"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/repeatPenaltyTokens", t); - setting.sync(); - if (!force) - emit repeatPenaltyTokensChanged(m); -} - -QString MySettings::modelPromptTemplate(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/promptTemplate", m.m_promptTemplate).toString(); -} - -void MySettings::setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force) -{ - if (modelPromptTemplate(m) == t && !force) - return; - - QSettings setting; - if (m.m_promptTemplate == t && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/promptTemplate"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/promptTemplate", t); - setting.sync(); - if (!force) - emit promptTemplateChanged(m); -} - -QString MySettings::modelSystemPrompt(const ModelInfo &m) const -{ - QSettings setting; - setting.sync(); - return setting.value(QString("model-%1").arg(m.id()) + "/systemPrompt", m.m_systemPrompt).toString(); -} - -void MySettings::setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force) -{ - if (modelSystemPrompt(m) == p && !force) - return; - - QSettings setting; - if (m.m_systemPrompt == p && !m.shouldSaveMetadata()) - setting.remove(QString("model-%1").arg(m.id()) + "/systemPrompt"); - else - setting.setValue(QString("model-%1").arg(m.id()) + "/systemPrompt", p); - setting.sync(); - if (!force) - emit systemPromptChanged(m); + setModelSetting("systemPrompt", info, value, force, true); } int MySettings::threadCount() const { - QSettings setting; - setting.sync(); - int c = setting.value("threadCount", default_threadCount).toInt(); + int c = m_settings.value("threadCount", defaults::threadCount).toInt(); // The old thread setting likely left many people with 0 in settings config file, which means // we should reset it to the default going forward if (c <= 0) - c = default_threadCount; + c = defaults::threadCount; c = std::max(c, 1); c = std::min(c, QThread::idealThreadCount()); return c; } -void MySettings::setThreadCount(int c) +void MySettings::setThreadCount(int value) { - if (threadCount() == c) + if (threadCount() == value) return; - c = std::max(c, 1); - c = std::min(c, QThread::idealThreadCount()); - QSettings setting; - setting.setValue("threadCount", c); - setting.sync(); + value = std::max(value, 1); + value = std::min(value, QThread::idealThreadCount()); + m_settings.setValue("threadCount", value); emit threadCountChanged(); } -bool MySettings::saveChatsContext() const +bool MySettings::saveChatsContext() const { return getBasicSetting("saveChatsContext" ).toBool(); } +bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); } +int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); } +QString MySettings::userDefaultModel() const { return getBasicSetting("userDefaultModel" ).toString(); } +QString MySettings::chatTheme() const { return getBasicSetting("chatTheme" ).toString(); } +QString MySettings::fontSize() const { return getBasicSetting("fontSize" ).toString(); } +QString MySettings::lastVersionStarted() const { return getBasicSetting("lastVersionStarted" ).toString(); } +int MySettings::localDocsChunkSize() const { return getBasicSetting("localdocs/chunkSize" ).toInt(); } +int MySettings::localDocsRetrievalSize() const { return getBasicSetting("localdocs/retrievalSize" ).toInt(); } +bool MySettings::localDocsShowReferences() const { return getBasicSetting("localdocs/showReferences").toBool(); } +QStringList MySettings::localDocsFileExtensions() const { return getBasicSetting("localdocs/fileExtensions").toStringList(); } +bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting("localdocs/useRemoteEmbed").toBool(); } +QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); } +QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); } + +void MySettings::setSaveChatsContext(bool value) { setBasicSetting("saveChatsContext", value); } +void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); } +void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); } +void MySettings::setUserDefaultModel(const QString &value) { setBasicSetting("userDefaultModel", value); } +void MySettings::setChatTheme(const QString &value) { setBasicSetting("chatTheme", value); } +void MySettings::setFontSize(const QString &value) { setBasicSetting("fontSize", value); } +void MySettings::setLastVersionStarted(const QString &value) { setBasicSetting("lastVersionStarted", value); } +void MySettings::setLocalDocsChunkSize(int value) { setBasicSetting("localdocs/chunkSize", value, "localDocsChunkSize"); } +void MySettings::setLocalDocsRetrievalSize(int value) { setBasicSetting("localdocs/retrievalSize", value, "localDocsRetrievalSize"); } +void MySettings::setLocalDocsShowReferences(bool value) { setBasicSetting("localdocs/showReferences", value, "localDocsShowReferences"); } +void MySettings::setLocalDocsFileExtensions(const QStringList &value) { setBasicSetting("localdocs/fileExtensions", value, "localDocsFileExtensions"); } +void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasicSetting("localdocs/useRemoteEmbed", value, "localDocsUseRemoteEmbed"); } +void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); } +void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); } + +QString MySettings::modelPath() { - QSettings setting; - setting.sync(); - return setting.value("saveChatsContext", default_saveChatsContext).toBool(); -} - -void MySettings::setSaveChatsContext(bool b) -{ - if (saveChatsContext() == b) - return; - - QSettings setting; - setting.setValue("saveChatsContext", b); - setting.sync(); - emit saveChatsContextChanged(); -} - -bool MySettings::serverChat() const -{ - QSettings setting; - setting.sync(); - return setting.value("serverChat", default_serverChat).toBool(); -} - -void MySettings::setServerChat(bool b) -{ - if (serverChat() == b) - return; - - QSettings setting; - setting.setValue("serverChat", b); - setting.sync(); - emit serverChatChanged(); -} - -int MySettings::networkPort() const -{ - QSettings setting; - setting.sync(); - return setting.value("networkPort", default_networkPort).toInt(); -} - -void MySettings::setNetworkPort(int c) -{ - if (networkPort() == c) - return; - - QSettings setting; - setting.setValue("networkPort", c); - setting.sync(); - emit networkPortChanged(); -} - -QString MySettings::modelPath() const -{ - QSettings setting; - setting.sync(); // We have to migrate the old setting because I changed the setting key recklessly in v2.4.11 // which broke a lot of existing installs - const bool containsOldSetting = setting.contains("modelPaths"); + const bool containsOldSetting = m_settings.contains("modelPaths"); if (containsOldSetting) { - const bool containsNewSetting = setting.contains("modelPath"); + const bool containsNewSetting = m_settings.contains("modelPath"); if (!containsNewSetting) - setting.setValue("modelPath", setting.value("modelPaths")); - setting.remove("modelPaths"); - setting.sync(); + m_settings.setValue("modelPath", m_settings.value("modelPaths")); + m_settings.remove("modelPaths"); } - return setting.value("modelPath", defaultLocalModelsPath()).toString(); + return m_settings.value("modelPath", defaultLocalModelsPath()).toString(); } -void MySettings::setModelPath(const QString &p) +void MySettings::setModelPath(const QString &value) { - QString filePath = (p.startsWith("file://") ? - QUrl(p).toLocalFile() : p); + QString filePath = (value.startsWith("file://") ? + QUrl(value).toLocalFile() : value); QString canonical = QFileInfo(filePath).canonicalFilePath() + "/"; if (modelPath() == canonical) return; - QSettings setting; - setting.setValue("modelPath", canonical); - setting.sync(); + m_settings.setValue("modelPath", canonical); emit modelPathChanged(); } -QString MySettings::userDefaultModel() const +QString MySettings::device() { - QSettings setting; - setting.sync(); - return setting.value("userDefaultModel", default_userDefaultModel).toString(); -} - -void MySettings::setUserDefaultModel(const QString &u) -{ - if (userDefaultModel() == u) - return; - - QSettings setting; - setting.setValue("userDefaultModel", u); - setting.sync(); - emit userDefaultModelChanged(); -} - -QString MySettings::chatTheme() const -{ - QSettings setting; - setting.sync(); - return setting.value("chatTheme", default_chatTheme).toString(); -} - -void MySettings::setChatTheme(const QString &u) -{ - if (chatTheme() == u) - return; - - QSettings setting; - setting.setValue("chatTheme", u); - setting.sync(); - emit chatThemeChanged(); -} - -QString MySettings::fontSize() const -{ - QSettings setting; - setting.sync(); - return setting.value("fontSize", default_fontSize).toString(); -} - -void MySettings::setFontSize(const QString &u) -{ - if (fontSize() == u) - return; - - QSettings setting; - setting.setValue("fontSize", u); - setting.sync(); - emit fontSizeChanged(); -} - -QString MySettings::device() const -{ - QSettings setting; - setting.sync(); - auto value = setting.value("device"); + auto value = m_settings.value("device"); if (!value.isValid()) - return default_device; + return defaults::device; auto device = value.toString(); if (!device.isEmpty()) { @@ -813,21 +438,18 @@ QString MySettings::device() const auto newName = QString::fromStdString(newNameStr); qWarning() << "updating device name:" << device << "->" << newName; device = newName; - setting.setValue("device", device); - setting.sync(); + m_settings.setValue("device", device); } } return device; } -void MySettings::setDevice(const QString &u) +void MySettings::setDevice(const QString &value) { - if (device() == u) + if (device() == value) return; - QSettings setting; - setting.setValue("device", u); - setting.sync(); + m_settings.setValue("device", value); emit deviceChanged(); } @@ -836,152 +458,48 @@ bool MySettings::forceMetal() const return m_forceMetal; } -void MySettings::setForceMetal(bool b) +void MySettings::setForceMetal(bool value) { - if (m_forceMetal == b) + if (m_forceMetal == value) return; - m_forceMetal = b; - emit forceMetalChanged(b); -} - -QString MySettings::lastVersionStarted() const -{ - QSettings setting; - setting.sync(); - return setting.value("lastVersionStarted", default_lastVersionStarted).toString(); -} - -void MySettings::setLastVersionStarted(const QString &v) -{ - if (lastVersionStarted() == v) - return; - - QSettings setting; - setting.setValue("lastVersionStarted", v); - setting.sync(); - emit lastVersionStartedChanged(); -} - -int MySettings::localDocsChunkSize() const -{ - QSettings setting; - setting.sync(); - return setting.value("localdocs/chunkSize", default_localDocsChunkSize).toInt(); -} - -void MySettings::setLocalDocsChunkSize(int s) -{ - if (localDocsChunkSize() == s) - return; - - QSettings setting; - setting.setValue("localdocs/chunkSize", s); - setting.sync(); - emit localDocsChunkSizeChanged(); -} - -int MySettings::localDocsRetrievalSize() const -{ - QSettings setting; - setting.sync(); - return setting.value("localdocs/retrievalSize", default_localDocsRetrievalSize).toInt(); -} - -void MySettings::setLocalDocsRetrievalSize(int s) -{ - if (localDocsRetrievalSize() == s) - return; - - QSettings setting; - setting.setValue("localdocs/retrievalSize", s); - setting.sync(); - emit localDocsRetrievalSizeChanged(); -} - -bool MySettings::localDocsShowReferences() const -{ - QSettings setting; - setting.sync(); - return setting.value("localdocs/showReferences", default_localDocsShowReferences).toBool(); -} - -void MySettings::setLocalDocsShowReferences(bool b) -{ - if (localDocsShowReferences() == b) - return; - - QSettings setting; - setting.setValue("localdocs/showReferences", b); - setting.sync(); - emit localDocsShowReferencesChanged(); -} - -QString MySettings::networkAttribution() const -{ - QSettings setting; - setting.sync(); - return setting.value("network/attribution", default_networkAttribution).toString(); -} - -void MySettings::setNetworkAttribution(const QString &a) -{ - if (networkAttribution() == a) - return; - - QSettings setting; - setting.setValue("network/attribution", a); - setting.sync(); - emit networkAttributionChanged(); + m_forceMetal = value; + emit forceMetalChanged(value); } bool MySettings::networkIsActive() const { - QSettings setting; - setting.sync(); - return setting.value("network/isActive", default_networkIsActive).toBool(); + return m_settings.value("network/isActive", defaults::networkIsActive).toBool(); } bool MySettings::isNetworkIsActiveSet() const { - QSettings setting; - setting.sync(); - return setting.value("network/isActive").isValid(); + return m_settings.value("network/isActive").isValid(); } -void MySettings::setNetworkIsActive(bool b) +void MySettings::setNetworkIsActive(bool value) { - QSettings setting; - setting.sync(); - auto cur = setting.value("network/isActive"); - if (!cur.isValid() || cur.toBool() != b) { - setting.setValue("network/isActive", b); - setting.sync(); + auto cur = m_settings.value("network/isActive"); + if (!cur.isValid() || cur.toBool() != value) { + m_settings.setValue("network/isActive", value); emit networkIsActiveChanged(); } } bool MySettings::networkUsageStatsActive() const { - QSettings setting; - setting.sync(); - return setting.value("network/usageStatsActive", default_networkUsageStatsActive).toBool(); + return m_settings.value("network/usageStatsActive", defaults::networkUsageStatsActive).toBool(); } bool MySettings::isNetworkUsageStatsActiveSet() const { - QSettings setting; - setting.sync(); - return setting.value("network/usageStatsActive").isValid(); + return m_settings.value("network/usageStatsActive").isValid(); } -void MySettings::setNetworkUsageStatsActive(bool b) +void MySettings::setNetworkUsageStatsActive(bool value) { - QSettings setting; - setting.sync(); - auto cur = setting.value("network/usageStatsActive"); - if (!cur.isValid() || cur.toBool() != b) { - setting.setValue("network/usageStatsActive", b); - setting.sync(); + auto cur = m_settings.value("network/usageStatsActive"); + if (!cur.isValid() || cur.toBool() != value) { + m_settings.setValue("network/usageStatsActive", value); emit networkUsageStatsActiveChanged(); } } diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 86b9183b..fc845990 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -5,10 +5,13 @@ #include #include +#include #include +#include #include #include +#include class MySettings : public QObject { @@ -25,6 +28,9 @@ class MySettings : public QObject Q_PROPERTY(int localDocsChunkSize READ localDocsChunkSize WRITE setLocalDocsChunkSize NOTIFY localDocsChunkSizeChanged) Q_PROPERTY(int localDocsRetrievalSize READ localDocsRetrievalSize WRITE setLocalDocsRetrievalSize NOTIFY localDocsRetrievalSizeChanged) Q_PROPERTY(bool localDocsShowReferences READ localDocsShowReferences WRITE setLocalDocsShowReferences NOTIFY localDocsShowReferencesChanged) + Q_PROPERTY(QStringList localDocsFileExtensions READ localDocsFileExtensions WRITE setLocalDocsFileExtensions NOTIFY localDocsFileExtensionsChanged) + Q_PROPERTY(bool localDocsUseRemoteEmbed READ localDocsUseRemoteEmbed WRITE setLocalDocsUseRemoteEmbed NOTIFY localDocsUseRemoteEmbedChanged) + Q_PROPERTY(QString localDocsNomicAPIKey READ localDocsNomicAPIKey WRITE setLocalDocsNomicAPIKey NOTIFY localDocsNomicAPIKeyChanged) Q_PROPERTY(QString networkAttribution READ networkAttribution WRITE setNetworkAttribution NOTIFY networkAttributionChanged) Q_PROPERTY(bool networkIsActive READ networkIsActive WRITE setNetworkIsActive NOTIFY networkIsActiveChanged) Q_PROPERTY(bool networkUsageStatsActive READ networkUsageStatsActive WRITE setNetworkUsageStatsActive NOTIFY networkUsageStatsActiveChanged) @@ -36,80 +42,80 @@ public: static MySettings *globalInstance(); // Restore methods - Q_INVOKABLE void restoreModelDefaults(const ModelInfo &model); + Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info); Q_INVOKABLE void restoreApplicationDefaults(); Q_INVOKABLE void restoreLocalDocsDefaults(); // Model/Character settings - void eraseModel(const ModelInfo &m); - QString modelName(const ModelInfo &m) const; - Q_INVOKABLE void setModelName(const ModelInfo &m, const QString &name, bool force = false); - QString modelFilename(const ModelInfo &m) const; - Q_INVOKABLE void setModelFilename(const ModelInfo &m, const QString &filename, bool force = false); + void eraseModel(const ModelInfo &info); + QString modelName(const ModelInfo &info) const; + Q_INVOKABLE void setModelName(const ModelInfo &info, const QString &name, bool force = false); + QString modelFilename(const ModelInfo &info) const; + Q_INVOKABLE void setModelFilename(const ModelInfo &info, const QString &filename, bool force = false); - QString modelDescription(const ModelInfo &m) const; - void setModelDescription(const ModelInfo &m, const QString &d, bool force = false); - QString modelUrl(const ModelInfo &m) const; - void setModelUrl(const ModelInfo &m, const QString &u, bool force = false); - QString modelQuant(const ModelInfo &m) const; - void setModelQuant(const ModelInfo &m, const QString &q, bool force = false); - QString modelType(const ModelInfo &m) const; - void setModelType(const ModelInfo &m, const QString &t, bool force = false); - bool modelIsClone(const ModelInfo &m) const; - void setModelIsClone(const ModelInfo &m, bool b, bool force = false); - bool modelIsDiscovered(const ModelInfo &m) const; - void setModelIsDiscovered(const ModelInfo &m, bool b, bool force = false); - int modelLikes(const ModelInfo &m) const; - void setModelLikes(const ModelInfo &m, int l, bool force = false); - int modelDownloads(const ModelInfo &m) const; - void setModelDownloads(const ModelInfo &m, int d, bool force = false); - QDateTime modelRecency(const ModelInfo &m) const; - void setModelRecency(const ModelInfo &m, const QDateTime &r, bool force = false); + QString modelDescription(const ModelInfo &info) const; + void setModelDescription(const ModelInfo &info, const QString &value, bool force = false); + QString modelUrl(const ModelInfo &info) const; + void setModelUrl(const ModelInfo &info, const QString &value, bool force = false); + QString modelQuant(const ModelInfo &info) const; + void setModelQuant(const ModelInfo &info, const QString &value, bool force = false); + QString modelType(const ModelInfo &info) const; + void setModelType(const ModelInfo &info, const QString &value, bool force = false); + bool modelIsClone(const ModelInfo &info) const; + void setModelIsClone(const ModelInfo &info, bool value, bool force = false); + bool modelIsDiscovered(const ModelInfo &info) const; + void setModelIsDiscovered(const ModelInfo &info, bool value, bool force = false); + int modelLikes(const ModelInfo &info) const; + void setModelLikes(const ModelInfo &info, int value, bool force = false); + int modelDownloads(const ModelInfo &info) const; + void setModelDownloads(const ModelInfo &info, int value, bool force = false); + QDateTime modelRecency(const ModelInfo &info) const; + void setModelRecency(const ModelInfo &info, const QDateTime &value, bool force = false); - double modelTemperature(const ModelInfo &m) const; - Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false); - double modelTopP(const ModelInfo &m) const; - Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false); - double modelMinP(const ModelInfo &m) const; - Q_INVOKABLE void setModelMinP(const ModelInfo &m, double p, bool force = false); - int modelTopK(const ModelInfo &m) const; - Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false); - int modelMaxLength(const ModelInfo &m) const; - Q_INVOKABLE void setModelMaxLength(const ModelInfo &m, int l, bool force = false); - int modelPromptBatchSize(const ModelInfo &m) const; - Q_INVOKABLE void setModelPromptBatchSize(const ModelInfo &m, int s, bool force = false); - double modelRepeatPenalty(const ModelInfo &m) const; - Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &m, double p, bool force = false); - int modelRepeatPenaltyTokens(const ModelInfo &m) const; - Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &m, int t, bool force = false); - QString modelPromptTemplate(const ModelInfo &m) const; - Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false); - QString modelSystemPrompt(const ModelInfo &m) const; - Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false); - int modelContextLength(const ModelInfo &m) const; - Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false); - int modelGpuLayers(const ModelInfo &m) const; - Q_INVOKABLE void setModelGpuLayers(const ModelInfo &m, int s, bool force = false); + double modelTemperature(const ModelInfo &info) const; + Q_INVOKABLE void setModelTemperature(const ModelInfo &info, double value, bool force = false); + double modelTopP(const ModelInfo &info) const; + Q_INVOKABLE void setModelTopP(const ModelInfo &info, double value, bool force = false); + double modelMinP(const ModelInfo &info) const; + Q_INVOKABLE void setModelMinP(const ModelInfo &info, double value, bool force = false); + int modelTopK(const ModelInfo &info) const; + Q_INVOKABLE void setModelTopK(const ModelInfo &info, int value, bool force = false); + int modelMaxLength(const ModelInfo &info) const; + Q_INVOKABLE void setModelMaxLength(const ModelInfo &info, int value, bool force = false); + int modelPromptBatchSize(const ModelInfo &info) const; + Q_INVOKABLE void setModelPromptBatchSize(const ModelInfo &info, int value, bool force = false); + double modelRepeatPenalty(const ModelInfo &info) const; + Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false); + int modelRepeatPenaltyTokens(const ModelInfo &info) const; + Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false); + QString modelPromptTemplate(const ModelInfo &info) const; + Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); + QString modelSystemPrompt(const ModelInfo &info) const; + Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); + int modelContextLength(const ModelInfo &info) const; + Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false); + int modelGpuLayers(const ModelInfo &info) const; + Q_INVOKABLE void setModelGpuLayers(const ModelInfo &info, int value, bool force = false); // Application settings int threadCount() const; - void setThreadCount(int c); + void setThreadCount(int value); bool saveChatsContext() const; - void setSaveChatsContext(bool b); + void setSaveChatsContext(bool value); bool serverChat() const; - void setServerChat(bool b); - QString modelPath() const; - void setModelPath(const QString &p); + void setServerChat(bool value); + QString modelPath(); + void setModelPath(const QString &value); QString userDefaultModel() const; - void setUserDefaultModel(const QString &u); + void setUserDefaultModel(const QString &value); QString chatTheme() const; - void setChatTheme(const QString &u); + void setChatTheme(const QString &value); QString fontSize() const; - void setFontSize(const QString &u); + void setFontSize(const QString &value); bool forceMetal() const; - void setForceMetal(bool b); - QString device() const; - void setDevice(const QString &u); + void setForceMetal(bool value); + QString device(); + void setDevice(const QString &value); int32_t contextLength() const; void setContextLength(int32_t value); int32_t gpuLayers() const; @@ -117,46 +123,53 @@ public: // Release/Download settings QString lastVersionStarted() const; - void setLastVersionStarted(const QString &v); + void setLastVersionStarted(const QString &value); // Localdocs settings int localDocsChunkSize() const; - void setLocalDocsChunkSize(int s); + void setLocalDocsChunkSize(int value); int localDocsRetrievalSize() const; - void setLocalDocsRetrievalSize(int s); + void setLocalDocsRetrievalSize(int value); bool localDocsShowReferences() const; - void setLocalDocsShowReferences(bool b); + void setLocalDocsShowReferences(bool value); + QStringList localDocsFileExtensions() const; + void setLocalDocsFileExtensions(const QStringList &value); + bool localDocsUseRemoteEmbed() const; + void setLocalDocsUseRemoteEmbed(bool value); + QString localDocsNomicAPIKey() const; + void setLocalDocsNomicAPIKey(const QString &value); // Network settings QString networkAttribution() const; - void setNetworkAttribution(const QString &a); + void setNetworkAttribution(const QString &value); bool networkIsActive() const; Q_INVOKABLE bool isNetworkIsActiveSet() const; - void setNetworkIsActive(bool b); + void setNetworkIsActive(bool value); bool networkUsageStatsActive() const; Q_INVOKABLE bool isNetworkUsageStatsActiveSet() const; - void setNetworkUsageStatsActive(bool b); + void setNetworkUsageStatsActive(bool value); int networkPort() const; - void setNetworkPort(int c); + void setNetworkPort(int value); QVector deviceList() const; - void setDeviceList(const QVector &deviceList); + void setDeviceList(const QVector &value); Q_SIGNALS: - void nameChanged(const ModelInfo &model); - void filenameChanged(const ModelInfo &model); - void temperatureChanged(const ModelInfo &model); - void topPChanged(const ModelInfo &model); - void minPChanged(const ModelInfo &model); - void topKChanged(const ModelInfo &model); - void maxLengthChanged(const ModelInfo &model); - void promptBatchSizeChanged(const ModelInfo &model); - void contextLengthChanged(const ModelInfo &model); - void gpuLayersChanged(const ModelInfo &model); - void repeatPenaltyChanged(const ModelInfo &model); - void repeatPenaltyTokensChanged(const ModelInfo &model); - void promptTemplateChanged(const ModelInfo &model); - void systemPromptChanged(const ModelInfo &model); + void nameChanged(const ModelInfo &info); + void filenameChanged(const ModelInfo &info); + void descriptionChanged(const ModelInfo &info); + void temperatureChanged(const ModelInfo &info); + void topPChanged(const ModelInfo &info); + void minPChanged(const ModelInfo &info); + void topKChanged(const ModelInfo &info); + void maxLengthChanged(const ModelInfo &info); + void promptBatchSizeChanged(const ModelInfo &info); + void contextLengthChanged(const ModelInfo &info); + void gpuLayersChanged(const ModelInfo &info); + void repeatPenaltyChanged(const ModelInfo &info); + void repeatPenaltyTokensChanged(const ModelInfo &info); + void promptTemplateChanged(const ModelInfo &info); + void systemPromptChanged(const ModelInfo &info); void threadCountChanged(); void saveChatsContextChanged(); void serverChatChanged(); @@ -169,6 +182,9 @@ Q_SIGNALS: void localDocsChunkSizeChanged(); void localDocsRetrievalSizeChanged(); void localDocsShowReferencesChanged(); + void localDocsFileExtensionsChanged(); + void localDocsUseRemoteEmbedChanged(); + void localDocsNomicAPIKeyChanged(); void networkAttributionChanged(); void networkIsActiveChanged(); void networkPortChanged(); @@ -178,6 +194,7 @@ Q_SIGNALS: void deviceListChanged(); private: + QSettings m_settings; bool m_forceMetal; QVector m_deviceList; @@ -185,6 +202,12 @@ private: explicit MySettings(); ~MySettings() {} friend class MyPrivateSettings; + + QVariant getBasicSetting(const QString &name) const; + void setBasicSetting(const QString &name, const QVariant &value, std::optional signal = std::nullopt); + QVariant getModelSetting(const QString &name, const ModelInfo &info) const; + void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, + bool signal = false); }; #endif // MYSETTINGS_H diff --git a/gpt4all-chat/network.cpp b/gpt4all-chat/network.cpp index b9f435a6..9b99683a 100644 --- a/gpt4all-chat/network.cpp +++ b/gpt4all-chat/network.cpp @@ -36,6 +36,8 @@ #include #include +using namespace Qt::Literals::StringLiterals; + //#define DEBUG static const char MIXPANEL_TOKEN[] = "ce362e568ddaee16ed243eaffb5860a2"; @@ -43,7 +45,8 @@ static const char MIXPANEL_TOKEN[] = "ce362e568ddaee16ed243eaffb5860a2"; #if defined(Q_OS_MAC) #include -static QString getCPUModel() { +static QString getCPUModel() +{ char buffer[256]; size_t bufferlen = sizeof(buffer); sysctlbyname("machdep.cpu.brand_string", &buffer, &bufferlen, NULL, 0); @@ -53,14 +56,16 @@ static QString getCPUModel() { #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #ifndef _MSC_VER -static void get_cpuid(int level, int *regs) { +static void get_cpuid(int level, int *regs) +{ asm volatile("cpuid" : "=a" (regs[0]), "=b" (regs[1]), "=c" (regs[2]), "=d" (regs[3]) : "0" (level) : "memory"); } #else #define get_cpuid(level, regs) __cpuid(regs, level) #endif -static QString getCPUModel() { +static QString getCPUModel() +{ int regs[12]; // EAX=800000000h: Get Highest Extended Function Implemented @@ -98,10 +103,8 @@ Network::Network() : QObject{nullptr} { QSettings settings; - settings.sync(); m_uniqueId = settings.value("uniqueId", generateUniqueId()).toString(); settings.setValue("uniqueId", m_uniqueId); - settings.sync(); m_sessionId = generateUniqueId(); // allow sendMixpanel to be called from any thread @@ -275,7 +278,7 @@ void Network::sendStartup() const auto *display = QGuiApplication::primaryScreen(); trackEvent("startup", { {"$screen_dpi", std::round(display->physicalDotsPerInch())}, - {"display", QString("%1x%2").arg(display->size().width()).arg(display->size().height())}, + {"display", u"%1x%2"_s.arg(display->size().width()).arg(display->size().height())}, {"ram", LLM::globalInstance()->systemTotalRAMInGB()}, {"cpu", getCPUModel()}, {"cpu_supports_avx2", LLModel::Implementation::cpuSupportsAVX2()}, diff --git a/gpt4all-chat/oscompat.cpp b/gpt4all-chat/oscompat.cpp new file mode 100644 index 00000000..a1edc8dd --- /dev/null +++ b/gpt4all-chat/oscompat.cpp @@ -0,0 +1,70 @@ +#include "oscompat.h" + +#include +#include +#include + +#ifdef Q_OS_WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif + +bool gpt4all_fsync(int fd) +{ +#if defined(Q_OS_WIN32) + HANDLE handle = HANDLE(_get_osfhandle(fd)); + if (handle == INVALID_HANDLE_VALUE) { + errno = EBADF; + return false; + } + + if (FlushFileBuffers(handle)) + return true; + + DWORD error = GetLastError(); + switch (error) { + case ERROR_ACCESS_DENIED: // read-only file + return true; + case ERROR_INVALID_HANDLE: // not a regular file + errno = EINVAL; + default: + errno = EIO; + } + + return false; +#elif defined(Q_OS_DARWIN) + return fcntl(fd, F_FULLFSYNC, 0) == 0; +#else + return fsync(fd) == 0; +#endif +} + +bool gpt4all_fdatasync(int fd) +{ +#if defined(Q_OS_WIN32) || defined(Q_OS_DARWIN) + return gpt4all_fsync(fd); +#else + return fdatasync(fd) == 0; +#endif +} + +bool gpt4all_syncdir(const QString &path) +{ +#if defined(Q_OS_WIN32) + (void)path; // cannot sync a directory on Windows + return true; +#else + int fd = open(path.toLocal8Bit().constData(), O_RDONLY | O_DIRECTORY); + if (fd == -1) return false; + bool ok = gpt4all_fdatasync(fd); + close(fd); + return ok; +#endif +} diff --git a/gpt4all-chat/oscompat.h b/gpt4all-chat/oscompat.h new file mode 100644 index 00000000..e4a8cfef --- /dev/null +++ b/gpt4all-chat/oscompat.h @@ -0,0 +1,7 @@ +#pragma once + +class QString; + +bool gpt4all_fsync(int fd); +bool gpt4all_fdatasync(int fd); +bool gpt4all_syncdir(const QString &path); diff --git a/gpt4all-chat/qml/AboutDialog.qml b/gpt4all-chat/qml/AboutDialog.qml deleted file mode 100644 index ca63f9e6..00000000 --- a/gpt4all-chat/qml/AboutDialog.qml +++ /dev/null @@ -1,101 +0,0 @@ -import QtCore -import QtQuick -import QtQuick.Controls -import QtQuick.Controls.Basic -import QtQuick.Layouts -import download -import network -import llm - -MyDialog { - id: abpoutDialog - anchors.centerIn: parent - modal: false - padding: 20 - width: 1024 - height: column.height + 40 - - Theme { - id: theme - } - - Column { - id: column - spacing: 20 - Item { - width: childrenRect.width - height: childrenRect.height - Image { - id: img - anchors.top: parent.top - anchors.left: parent.left - width: 60 - height: 60 - source: "qrc:/gpt4all/icons/logo.svg" - } - Text { - anchors.left: img.right - anchors.leftMargin: 30 - anchors.verticalCenter: img.verticalCenter - text: qsTr("About GPT4All") - color: theme.textColor - font.pixelSize: theme.fontSizeLarge - font.bold: true - } - } - - ScrollView { - clip: true - height: 200 - width: 1024 - 40 - ScrollBar.vertical.policy: ScrollBar.AlwaysOn - ScrollBar.horizontal.policy: ScrollBar.AlwaysOff - - MyTextArea { - id: welcome - width: 1024 - 40 - textFormat: TextEdit.MarkdownText - text: qsTr("### Release notes\n") - + Download.releaseInfo.notes - + qsTr("### Contributors\n") - + Download.releaseInfo.contributors - focus: false - readOnly: true - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("Release notes") - Accessible.description: qsTr("Release notes for this version") - } - } - - MySettingsLabel { - id: discordLink - width: parent.width - textFormat: Text.StyledText - wrapMode: Text.WordWrap - text: qsTr("Check out our discord channel https://discord.gg/4M2QFmTt2k") - font.pixelSize: theme.fontSizeLarge - onLinkActivated: { Qt.openUrlExternally("https://discord.gg/4M2QFmTt2k") } - color: theme.textColor - linkColor: theme.linkColor - - Accessible.role: Accessible.Link - Accessible.name: qsTr("Discord link") - } - - MySettingsLabel { - id: nomicProps - width: parent.width - textFormat: Text.StyledText - wrapMode: Text.WordWrap - text: qsTr("Thank you to Nomic AI and the community for contributing so much great data, code, ideas, and energy to the growing open source AI ecosystem!") - font.pixelSize: theme.fontSizeLarge - onLinkActivated: { Qt.openUrlExternally("https://home.nomic.ai") } - color: theme.textColor - linkColor: theme.linkColor - - Accessible.role: Accessible.Paragraph - Accessible.name: qsTr("Thank you blurb") - Accessible.description: qsTr("Contains embedded link to https://home.nomic.ai") - } - } -} diff --git a/gpt4all-chat/qml/AddCollectionView.qml b/gpt4all-chat/qml/AddCollectionView.qml new file mode 100644 index 00000000..d13585b3 --- /dev/null +++ b/gpt4all-chat/qml/AddCollectionView.qml @@ -0,0 +1,170 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import QtQuick.Dialogs +import Qt.labs.folderlistmodel +import Qt5Compat.GraphicalEffects +import llm +import chatlistmodel +import download +import modellist +import network +import gpt4all +import mysettings +import localdocs + +Rectangle { + id: addCollectionView + + Theme { + id: theme + } + + color: theme.viewBackground + signal localDocsViewRequested() + + ColumnLayout { + id: mainArea + anchors.left: parent.left + anchors.right: parent.right + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.margins: 30 + spacing: 50 + + RowLayout { + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop + spacing: 50 + + MyButton { + id: backButton + Layout.alignment: Qt.AlignTop | Qt.AlignLeft + text: qsTr("\u2190 Existing Collections") + + borderWidth: 0 + backgroundColor: theme.lighterButtonBackground + backgroundColorHovered: theme.lighterButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 8 + bottomPadding: 8 + textColor: theme.lighterButtonForeground + fontPixelSize: theme.fontSizeLarge + fontPixelBold: true + + onClicked: { + localDocsViewRequested() + } + } + } + + ColumnLayout { + id: root + Layout.alignment: Qt.AlignTop | Qt.AlignCenter + spacing: 50 + + property alias collection: collection.text + property alias folder_path: folderEdit.text + + FolderDialog { + id: folderDialog + title: qsTr("Please choose a directory") + } + + function openFolderDialog(currentFolder, onAccepted) { + folderDialog.currentFolder = currentFolder; + folderDialog.accepted.connect(function() { onAccepted(folderDialog.currentFolder); }); + folderDialog.open(); + } + + Text { + horizontalAlignment: Qt.AlignHCenter + text: qsTr("New Local Doc\nCollection") + font.pixelSize: theme.fontSizeBanner + color: theme.titleTextColor + } + + MyTextField { + id: collection + Layout.alignment: Qt.AlignCenter + Layout.minimumWidth: 400 + horizontalAlignment: Text.AlignJustify + color: theme.textColor + font.pixelSize: theme.fontSizeLarge + placeholderText: qsTr("Collection name...") + placeholderTextColor: theme.mutedTextColor + ToolTip.text: qsTr("Name of the collection to add (Required)") + ToolTip.visible: hovered + Accessible.role: Accessible.EditableText + Accessible.name: collection.text + Accessible.description: ToolTip.text + function showError() { + collection.placeholderTextColor = theme.textErrorColor + } + onTextChanged: { + collection.placeholderTextColor = theme.mutedTextColor + } + } + + RowLayout { + Layout.alignment: Qt.AlignCenter + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 + spacing: 10 + MyDirectoryField { + id: folderEdit + Layout.fillWidth: true + text: root.folder_path + placeholderText: qsTr("Folder path...") + font.pixelSize: theme.fontSizeLarge + placeholderTextColor: theme.mutedTextColor + ToolTip.text: qsTr("Folder path to documents (Required)") + ToolTip.visible: hovered + function showError() { + folderEdit.placeholderTextColor = theme.textErrorColor + } + onTextChanged: { + folderEdit.placeholderTextColor = theme.mutedTextColor + } + } + + MySettingsButton { + id: browseButton + text: qsTr("Browse") + onClicked: { + root.openFolderDialog(StandardPaths.writableLocation(StandardPaths.HomeLocation), function(selectedFolder) { + root.folder_path = selectedFolder + }) + } + } + } + + MyButton { + Layout.alignment: Qt.AlignCenter + Layout.minimumWidth: 400 + text: qsTr("Create Collection") + onClicked: { + var isError = false; + if (root.collection === "") { + isError = true; + collection.showError(); + } + if (root.folder_path === "" || !folderEdit.isValid) { + isError = true; + folderEdit.showError(); + } + if (isError) + return; + LocalDocs.addFolder(root.collection, root.folder_path) + root.collection = "" + root.folder_path = "" + collection.clear() + localDocsViewRequested() + } + } + } + } +} diff --git a/gpt4all-chat/qml/AddModelView.qml b/gpt4all-chat/qml/AddModelView.qml new file mode 100644 index 00000000..2762e136 --- /dev/null +++ b/gpt4all-chat/qml/AddModelView.qml @@ -0,0 +1,726 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import QtQuick.Dialogs +import Qt.labs.folderlistmodel +import Qt5Compat.GraphicalEffects +import llm +import chatlistmodel +import download +import modellist +import network +import gpt4all +import mysettings +import localdocs + +Rectangle { + id: addModelView + + Theme { + id: theme + } + + color: theme.viewBackground + signal modelsViewRequested() + + PopupDialog { + id: downloadingErrorPopup + anchors.centerIn: parent + shouldTimeOut: false + } + + ColumnLayout { + id: mainArea + anchors.left: parent.left + anchors.right: parent.right + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.margins: 30 + spacing: 50 + + ColumnLayout { + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop + spacing: 50 + + MyButton { + id: backButton + Layout.alignment: Qt.AlignTop | Qt.AlignLeft + text: qsTr("\u2190 Existing Models") + + borderWidth: 0 + backgroundColor: theme.lighterButtonBackground + backgroundColorHovered: theme.lighterButtonBackgroundHovered + backgroundRadius: 5 + padding: 15 + topPadding: 8 + bottomPadding: 8 + textColor: theme.lighterButtonForeground + fontPixelSize: theme.fontSizeLarge + fontPixelBold: true + + onClicked: { + modelsViewRequested() + } + } + + Text { + id: welcome + text: qsTr("Explore Models") + font.pixelSize: theme.fontSizeBanner + color: theme.titleTextColor + } + + RowLayout { + Layout.fillWidth: true + Layout.alignment: Qt.AlignCenter + Layout.margins: 0 + spacing: 10 + MyTextField { + id: discoverField + property string textBeingSearched: "" + readOnly: ModelList.discoverInProgress + Layout.alignment: Qt.AlignCenter + Layout.fillWidth: true + Layout.preferredHeight: 90 + font.pixelSize: theme.fontSizeLarger + placeholderText: qsTr("Discover and download models by keyword search...") + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + Accessible.description: qsTr("Text field for discovering and filtering downloadable models") + Connections { + target: ModelList + function onDiscoverInProgressChanged() { + if (ModelList.discoverInProgress) { + discoverField.textBeingSearched = discoverField.text; + discoverField.text = qsTr("Searching \u00B7 ") + discoverField.textBeingSearched; + } else { + discoverField.text = discoverField.textBeingSearched; + discoverField.textBeingSearched = ""; + } + } + } + background: ProgressBar { + id: discoverProgressBar + indeterminate: ModelList.discoverInProgress && ModelList.discoverProgress === 0.0 + value: ModelList.discoverProgress + background: Rectangle { + color: theme.controlBackground + border.color: theme.controlBorder + radius: 10 + } + contentItem: Item { + Rectangle { + visible: ModelList.discoverInProgress + anchors.bottom: parent.bottom + width: discoverProgressBar.visualPosition * parent.width + height: 10 + radius: 2 + color: theme.progressForeground + } + } + } + + Keys.onReturnPressed: (event)=> { + if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) + event.accepted = false; + else { + editingFinished(); + sendDiscovery() + } + } + function sendDiscovery() { + ModelList.downloadableModels.discoverAndFilter(discoverField.text); + } + RowLayout { + spacing: 0 + anchors.right: discoverField.right + anchors.verticalCenter: discoverField.verticalCenter + anchors.rightMargin: 15 + visible: !ModelList.discoverInProgress + MyMiniButton { + id: clearDiscoverButton + backgroundColor: theme.textColor + backgroundColorHovered: theme.iconBackgroundDark + visible: discoverField.text !== "" + contentItem: Text { + color: clearDiscoverButton.hovered ? theme.iconBackgroundDark : theme.textColor + text: "\u2715" + font.pixelSize: theme.fontSizeLarge + } + onClicked: { + discoverField.text = "" + discoverField.sendDiscovery() // should clear results + } + } + MyMiniButton { + backgroundColor: theme.textColor + backgroundColorHovered: theme.iconBackgroundDark + source: "qrc:/gpt4all/icons/settings.svg" + onClicked: { + discoveryTools.visible = !discoveryTools.visible + } + } + MyMiniButton { + id: sendButton + enabled: !ModelList.discoverInProgress + backgroundColor: theme.textColor + backgroundColorHovered: theme.iconBackgroundDark + source: "qrc:/gpt4all/icons/send_message.svg" + Accessible.name: qsTr("Initiate model discovery and filtering") + Accessible.description: qsTr("Triggers discovery and filtering of models") + onClicked: { + discoverField.sendDiscovery() + } + } + } + } + } + + RowLayout { + id: discoveryTools + Layout.fillWidth: true + Layout.alignment: Qt.AlignCenter + Layout.margins: 0 + spacing: 20 + visible: false + MyComboBox { + id: comboSort + model: [qsTr("Default"), qsTr("Likes"), qsTr("Downloads"), qsTr("Recent")] + currentIndex: ModelList.discoverSort + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + rightPadding: 30 + color: theme.textColor + text: { + return qsTr("Sort by: ") + comboSort.displayText + } + font.pixelSize: theme.fontSizeLarger + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + onActivated: function (index) { + ModelList.discoverSort = index; + } + } + MyComboBox { + id: comboSortDirection + model: [qsTr("Asc"), qsTr("Desc")] + currentIndex: { + if (ModelList.discoverSortDirection === 1) + return 0 + else + return 1; + } + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + rightPadding: 30 + color: theme.textColor + text: { + return qsTr("Sort dir: ") + comboSortDirection.displayText + } + font.pixelSize: theme.fontSizeLarger + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + onActivated: function (index) { + if (index === 0) + ModelList.discoverSortDirection = 1; + else + ModelList.discoverSortDirection = -1; + } + } + MyComboBox { + id: comboLimit + model: ["5", "10", "20", "50", "100", qsTr("None")] + currentIndex: { + if (ModelList.discoverLimit === 5) + return 0; + else if (ModelList.discoverLimit === 10) + return 1; + else if (ModelList.discoverLimit === 20) + return 2; + else if (ModelList.discoverLimit === 50) + return 3; + else if (ModelList.discoverLimit === 100) + return 4; + else if (ModelList.discoverLimit === -1) + return 5; + } + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + rightPadding: 30 + color: theme.textColor + text: { + return qsTr("Limit: ") + comboLimit.displayText + } + font.pixelSize: theme.fontSizeLarger + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + onActivated: function (index) { + switch (index) { + case 0: + ModelList.discoverLimit = 5; break; + case 1: + ModelList.discoverLimit = 10; break; + case 2: + ModelList.discoverLimit = 20; break; + case 3: + ModelList.discoverLimit = 50; break; + case 4: + ModelList.discoverLimit = 100; break; + case 5: + ModelList.discoverLimit = -1; break; + } + } + } + } + } + + Label { + visible: !ModelList.downloadableModels.count && !ModelList.asyncModelRequestOngoing + Layout.fillWidth: true + Layout.fillHeight: true + horizontalAlignment: Qt.AlignHCenter + verticalAlignment: Qt.AlignVCenter + text: qsTr("Network error: could not retrieve http://gpt4all.io/models/models3.json") + font.pixelSize: theme.fontSizeLarge + color: theme.mutedTextColor + } + + MyBusyIndicator { + visible: !ModelList.downloadableModels.count && ModelList.asyncModelRequestOngoing + running: ModelList.asyncModelRequestOngoing + Accessible.role: Accessible.Animation + Layout.alignment: Qt.AlignCenter + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the models request is ongoing") + } + + ScrollView { + id: scrollView + ScrollBar.vertical.policy: ScrollBar.AsNeeded + Layout.fillWidth: true + Layout.fillHeight: true + clip: true + + ListView { + id: modelListView + model: ModelList.downloadableModels + boundsBehavior: Flickable.StopAtBounds + spacing: 30 + + delegate: Rectangle { + id: delegateItem + width: modelListView.width + height: childrenRect.height + 60 + color: theme.conversationBackground + radius: 10 + border.width: 1 + border.color: theme.controlBorder + + ColumnLayout { + anchors.top: parent.top + anchors.left: parent.left + anchors.right: parent.right + anchors.margins: 30 + + Text { + Layout.fillWidth: true + Layout.alignment: Qt.AlignLeft + text: name + elide: Text.ElideRight + color: theme.titleTextColor + font.pixelSize: theme.fontSizeLargest + font.bold: true + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Model file") + Accessible.description: qsTr("Model file to be downloaded") + } + + + Rectangle { + Layout.fillWidth: true + height: 1 + color: theme.dividerColor + } + + RowLayout { + Layout.topMargin: 10 + Layout.fillWidth: true + Text { + id: descriptionText + text: description + font.pixelSize: theme.fontSizeLarge + Layout.fillWidth: true + wrapMode: Text.WordWrap + textFormat: Text.StyledText + color: theme.textColor + linkColor: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Description") + Accessible.description: qsTr("File description") + onLinkActivated: Qt.openUrlExternally(link) + } + + // FIXME Need to overhaul design here which must take into account + // features not present in current figma including: + // * Ability to cancel a current download + // * Ability to resume a download + // * The presentation of an error if encountered + // * Whether to show already installed models + // * Install of remote models with API keys + // * The presentation of the progress bar + Rectangle { + id: actionBox + width: childrenRect.width + 20 + color: "transparent" + border.width: 1 + border.color: theme.dividerColor + radius: 10 + Layout.rightMargin: 20 + Layout.bottomMargin: 20 + Layout.minimumHeight: childrenRect.height + 20 + Layout.alignment: Qt.AlignRight | Qt.AlignTop + + ColumnLayout { + spacing: 0 + MySettingsButton { + id: downloadButton + text: isDownloading ? qsTr("Cancel") : isIncomplete ? qsTr("Resume") : qsTr("Download") + font.pixelSize: theme.fontSizeLarge + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + visible: !isOnline && !installed && !calcHash && downloadError === "" + Accessible.description: qsTr("Stop/restart/start the download") + onClicked: { + if (!isDownloading) { + Download.downloadModel(filename); + } else { + Download.cancelDownload(filename); + } + } + } + + MySettingsDestructiveButton { + id: removeButton + text: qsTr("Remove") + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + visible: installed || downloadError !== "" + Accessible.description: qsTr("Remove model from filesystem") + onClicked: { + Download.removeModel(filename); + } + } + + MySettingsButton { + id: installButton + visible: !installed && isOnline + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + text: qsTr("Install") + font.pixelSize: theme.fontSizeLarge + onClicked: { + if (apiKey.text === "") + apiKey.showError(); + else + Download.installModel(filename, apiKey.text); + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Install") + Accessible.description: qsTr("Install online model") + } + + ColumnLayout { + spacing: 0 + Label { + Layout.topMargin: 20 + Layout.leftMargin: 20 + visible: downloadError !== "" + textFormat: Text.StyledText + text: "" + + qsTr("Error") + + "" + color: theme.textColor + font.pixelSize: theme.fontSizeLarge + linkColor: theme.textErrorColor + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Describes an error that occurred when downloading") + onLinkActivated: { + downloadingErrorPopup.text = downloadError; + downloadingErrorPopup.open(); + } + } + + Label { + visible: LLM.systemTotalRAMInGB() < ramrequired + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.maximumWidth: 300 + textFormat: Text.StyledText + text: qsTr("WARNING: Not recommended for your hardware.") + + qsTr(" Model requires more memory (") + ramrequired + + qsTr(" GB) than your system has available (") + + LLM.systemTotalRAMInGBString() + ")." + color: theme.textErrorColor + font.pixelSize: theme.fontSizeLarge + wrapMode: Text.WordWrap + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Error for incompatible hardware") + onLinkActivated: { + downloadingErrorPopup.text = downloadError; + downloadingErrorPopup.open(); + } + } + } + + ColumnLayout { + visible: isDownloading && !calcHash + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + spacing: 20 + + ProgressBar { + id: itemProgressBar + Layout.fillWidth: true + width: 200 + value: bytesReceived / bytesTotal + background: Rectangle { + implicitHeight: 45 + color: theme.progressBackground + radius: 3 + } + contentItem: Item { + implicitHeight: 40 + + Rectangle { + width: itemProgressBar.visualPosition * parent.width + height: parent.height + radius: 2 + color: theme.progressForeground + } + } + Accessible.role: Accessible.ProgressBar + Accessible.name: qsTr("Download progressBar") + Accessible.description: qsTr("Shows the progress made in the download") + } + + Label { + id: speedLabel + color: theme.textColor + Layout.alignment: Qt.AlignRight + text: speed + font.pixelSize: theme.fontSizeLarge + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Download speed") + Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second") + } + } + + RowLayout { + visible: calcHash + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.maximumWidth: 200 + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + clip: true + + Label { + id: calcHashLabel + color: theme.textColor + text: qsTr("Calculating...") + font.pixelSize: theme.fontSizeLarge + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file hash is being calculated") + } + + MyBusyIndicator { + id: busyCalcHash + running: calcHash + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the file hash is being calculated") + } + } + + MyTextField { + id: apiKey + visible: !installed && isOnline + Layout.topMargin: 20 + Layout.leftMargin: 20 + Layout.minimumWidth: 200 + Layout.alignment: Qt.AlignTop | Qt.AlignHCenter + wrapMode: Text.WrapAnywhere + function showError() { + apiKey.placeholderTextColor = theme.textErrorColor + } + onTextChanged: { + apiKey.placeholderTextColor = theme.mutedTextColor + } + placeholderText: qsTr("enter $API_KEY") + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + Accessible.description: qsTr("Whether the file hash is being calculated") + } + } + } + } + + Item { + Layout.minimumWidth: childrenRect.width + Layout.minimumHeight: childrenRect.height + Layout.bottomMargin: 10 + RowLayout { + id: paramRow + anchors.centerIn: parent + ColumnLayout { + Layout.topMargin: 10 + Layout.bottomMargin: 10 + Layout.leftMargin: 20 + Layout.rightMargin: 20 + Text { + text: qsTr("File size") + font.pixelSize: theme.fontSizeSmaller + color: theme.mutedDarkTextColor + } + Text { + text: filesize + color: theme.textColor + font.pixelSize: theme.fontSizeSmaller + font.bold: true + } + } + Rectangle { + width: 1 + Layout.fillHeight: true + color: theme.dividerColor + } + ColumnLayout { + Layout.topMargin: 10 + Layout.bottomMargin: 10 + Layout.leftMargin: 20 + Layout.rightMargin: 20 + Text { + text: qsTr("RAM required") + font.pixelSize: theme.fontSizeSmaller + color: theme.mutedDarkTextColor + } + Text { + text: ramrequired + qsTr(" GB") + color: theme.textColor + font.pixelSize: theme.fontSizeSmaller + font.bold: true + } + } + Rectangle { + width: 1 + Layout.fillHeight: true + color: theme.dividerColor + } + ColumnLayout { + Layout.topMargin: 10 + Layout.bottomMargin: 10 + Layout.leftMargin: 20 + Layout.rightMargin: 20 + Text { + text: qsTr("Parameters") + font.pixelSize: theme.fontSizeSmaller + color: theme.mutedDarkTextColor + } + Text { + text: parameters + color: theme.textColor + font.pixelSize: theme.fontSizeSmaller + font.bold: true + } + } + Rectangle { + width: 1 + Layout.fillHeight: true + color: theme.dividerColor + } + ColumnLayout { + Layout.topMargin: 10 + Layout.bottomMargin: 10 + Layout.leftMargin: 20 + Layout.rightMargin: 20 + Text { + text: qsTr("Quant") + font.pixelSize: theme.fontSizeSmaller + color: theme.mutedDarkTextColor + } + Text { + text: quant + color: theme.textColor + font.pixelSize: theme.fontSizeSmaller + font.bold: true + } + } + Rectangle { + width: 1 + Layout.fillHeight: true + color: theme.dividerColor + } + ColumnLayout { + Layout.topMargin: 10 + Layout.bottomMargin: 10 + Layout.leftMargin: 20 + Layout.rightMargin: 20 + Text { + text: qsTr("Type") + font.pixelSize: theme.fontSizeSmaller + color: theme.mutedDarkTextColor + } + Text { + text: type + color: theme.textColor + font.pixelSize: theme.fontSizeSmaller + font.bold: true + } + } + } + + Rectangle { + color: "transparent" + anchors.fill: paramRow + border.color: theme.dividerColor + border.width: 1 + radius: 10 + } + } + + Rectangle { + Layout.fillWidth: true + height: 1 + color: theme.dividerColor + } + } + } + } + } + } +} diff --git a/gpt4all-chat/qml/ApplicationSettings.qml b/gpt4all-chat/qml/ApplicationSettings.qml index d2dc2150..f58f7e9d 100644 --- a/gpt4all-chat/qml/ApplicationSettings.qml +++ b/gpt4all-chat/qml/ApplicationSettings.qml @@ -7,31 +7,97 @@ import QtQuick.Dialogs import modellist import mysettings import network +import llm MySettingsTab { onRestoreDefaultsClicked: { MySettings.restoreApplicationDefaults(); } title: qsTr("Application") + + NetworkDialog { + id: networkDialog + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Network dialog") + Accessible.description: qsTr("opt-in to share feedback/conversations") + } + } + + Dialog { + id: checkForUpdatesError + anchors.centerIn: parent + modal: false + padding: 20 + Text { + horizontalAlignment: Text.AlignJustify + text: qsTr("ERROR: Update system could not find the MaintenanceTool used
          + to check for updates!

          + Did you install this application using the online installer? If so,
          + the MaintenanceTool executable should be located one directory
          + above where this application resides on your filesystem.

          + If you can't start it manually, then I'm afraid you'll have to
          + reinstall.") + color: theme.textErrorColor + font.pixelSize: theme.fontSizeLarge + Accessible.role: Accessible.Dialog + Accessible.name: text + Accessible.description: qsTr("Error dialog") + } + background: Rectangle { + anchors.fill: parent + color: theme.containerBackground + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + } + contentItem: GridLayout { id: applicationSettingsTabInner columns: 3 - rowSpacing: 10 + rowSpacing: 30 columnSpacing: 10 + + ColumnLayout { + Layout.row: 0 + Layout.column: 0 + Layout.columnSpan: 3 + Layout.fillWidth: true + spacing: 10 + Label { + color: theme.styledTextColor + font.pixelSize: theme.fontSizeLarge + font.bold: true + text: "General" + } + + Rectangle { + Layout.fillWidth: true + height: 2 + color: theme.settingsDivider + } + } + MySettingsLabel { id: themeLabel text: qsTr("Theme") + helpText: qsTr("Customize the colors of GPT4All") Layout.row: 1 Layout.column: 0 } MyComboBox { id: themeBox Layout.row: 1 - Layout.column: 1 - Layout.columnSpan: 1 + Layout.column: 2 Layout.minimumWidth: 200 + Layout.maximumWidth: 200 Layout.fillWidth: false - model: ["Dark", "Light", "LegacyDark"] + Layout.alignment: Qt.AlignRight + model: [qsTr("Dark"), qsTr("Light"), qsTr("LegacyDark")] Accessible.role: Accessible.ComboBox Accessible.name: qsTr("Color theme") Accessible.description: qsTr("Color theme for the chat client to use") @@ -54,16 +120,18 @@ MySettingsTab { MySettingsLabel { id: fontLabel text: qsTr("Font Size") + helpText: qsTr("How big your font is displayed") Layout.row: 2 Layout.column: 0 } MyComboBox { id: fontBox Layout.row: 2 - Layout.column: 1 - Layout.columnSpan: 1 - Layout.minimumWidth: 100 + Layout.column: 2 + Layout.minimumWidth: 200 + Layout.maximumWidth: 200 Layout.fillWidth: false + Layout.alignment: Qt.AlignRight model: ["Small", "Medium", "Large"] Accessible.role: Accessible.ComboBox Accessible.name: qsTr("Font size") @@ -87,16 +155,18 @@ MySettingsTab { MySettingsLabel { id: deviceLabel text: qsTr("Device") + helpText: qsTr("The hardware device used to load the model") Layout.row: 3 Layout.column: 0 } MyComboBox { id: deviceBox Layout.row: 3 - Layout.column: 1 - Layout.columnSpan: 1 - Layout.minimumWidth: 350 + Layout.column: 2 + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 Layout.fillWidth: false + Layout.alignment: Qt.AlignRight model: MySettings.deviceList Accessible.role: Accessible.ComboBox Accessible.name: qsTr("Device") @@ -123,16 +193,17 @@ MySettingsTab { MySettingsLabel { id: defaultModelLabel text: qsTr("Default model") + helpText: qsTr("The preferred default model") Layout.row: 4 Layout.column: 0 } MyComboBox { id: comboBox Layout.row: 4 - Layout.column: 1 - Layout.columnSpan: 2 - Layout.minimumWidth: 350 - Layout.fillWidth: true + Layout.column: 2 + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 + Layout.alignment: Qt.AlignRight model: ModelList.userDefaultModelList Accessible.role: Accessible.ComboBox Accessible.name: qsTr("Default model") @@ -156,45 +227,96 @@ MySettingsTab { MySettingsLabel { id: modelPathLabel text: qsTr("Download path") + helpText: qsTr("The download folder for models") Layout.row: 5 Layout.column: 0 } - MyDirectoryField { - id: modelPathDisplayField - text: MySettings.modelPath - font.pixelSize: theme.fontSizeLarge - implicitWidth: 300 + + RowLayout { Layout.row: 5 - Layout.column: 1 - Layout.fillWidth: true - ToolTip.text: qsTr("Path where model files will be downloaded to") - ToolTip.visible: hovered - Accessible.role: Accessible.ToolTip - Accessible.name: modelPathDisplayField.text - Accessible.description: ToolTip.text - onEditingFinished: { - if (isValid) { - MySettings.modelPath = modelPathDisplayField.text - } else { - text = MySettings.modelPath + Layout.column: 2 + Layout.alignment: Qt.AlignRight + Layout.minimumWidth: 400 + Layout.maximumWidth: 400 + spacing: 10 + MyDirectoryField { + id: modelPathDisplayField + text: MySettings.modelPath + font.pixelSize: theme.fontSizeLarge + implicitWidth: 300 + Layout.fillWidth: true + ToolTip.text: qsTr("Path where model files will be downloaded to") + ToolTip.visible: hovered + Accessible.role: Accessible.ToolTip + Accessible.name: modelPathDisplayField.text + Accessible.description: ToolTip.text + onEditingFinished: { + if (isValid) { + MySettings.modelPath = modelPathDisplayField.text + } else { + text = MySettings.modelPath + } + } + } + MySettingsButton { + text: qsTr("Browse") + Accessible.description: qsTr("Choose where to save model files") + onClicked: { + openFolderDialog("file://" + MySettings.modelPath, function(selectedFolder) { + MySettings.modelPath = selectedFolder + }) } } } - MySettingsButton { - Layout.row: 5 + + MySettingsLabel { + id: dataLakeLabel + text: qsTr("Opensource Datalake") + helpText: qsTr("Send your data to the GPT4All Open Source Datalake.") + Layout.row: 6 + Layout.column: 0 + } + MyCheckBox { + id: dataLakeBox + Layout.row: 6 Layout.column: 2 - text: qsTr("Browse") - Accessible.description: qsTr("Choose where to save model files") + Layout.alignment: Qt.AlignRight + checked: MySettings.networkIsActive onClicked: { - openFolderDialog("file://" + MySettings.modelPath, function(selectedFolder) { - MySettings.modelPath = selectedFolder - }) + if (MySettings.networkIsActive) { + MySettings.networkIsActive = false + } else + networkDialog.open() + } + ToolTip.text: qsTr("Reveals a dialogue where you can opt-in for sharing data over network") + ToolTip.visible: hovered + } + + ColumnLayout { + Layout.row: 7 + Layout.column: 0 + Layout.columnSpan: 3 + Layout.fillWidth: true + spacing: 10 + Label { + color: theme.styledTextColor + font.pixelSize: theme.fontSizeLarge + font.bold: true + text: "Advanced" + } + + Rectangle { + Layout.fillWidth: true + height: 2 + color: theme.settingsDivider } } + MySettingsLabel { id: nThreadsLabel text: qsTr("CPU Threads") - Layout.row: 6 + helpText: qsTr("Number of CPU threads for inference and embedding") + Layout.row: 8 Layout.column: 0 } MyTextField { @@ -203,8 +325,11 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Amount of processing threads to use bounded by 1 and number of logical processors") ToolTip.visible: hovered - Layout.row: 6 - Layout.column: 1 + Layout.alignment: Qt.AlignRight + Layout.row: 8 + Layout.column: 2 + Layout.minimumWidth: 200 + Layout.maximumWidth: 200 validator: IntValidator { bottom: 1 } @@ -223,14 +348,16 @@ MySettingsTab { } MySettingsLabel { id: saveChatsContextLabel - text: qsTr("Save chats context to disk") - Layout.row: 7 + text: qsTr("Save chat context") + helpText: qsTr("Save chat context to disk") + Layout.row: 9 Layout.column: 0 } MyCheckBox { id: saveChatsContextBox - Layout.row: 7 - Layout.column: 1 + Layout.row: 9 + Layout.column: 2 + Layout.alignment: Qt.AlignRight checked: MySettings.saveChatsContext onClicked: { MySettings.saveChatsContext = !MySettings.saveChatsContext @@ -241,13 +368,15 @@ MySettingsTab { MySettingsLabel { id: serverChatLabel text: qsTr("Enable API server") - Layout.row: 8 + helpText: qsTr("A local http server running on local port") + Layout.row: 10 Layout.column: 0 } MyCheckBox { id: serverChatBox - Layout.row: 8 - Layout.column: 1 + Layout.row: 10 + Layout.column: 2 + Layout.alignment: Qt.AlignRight checked: MySettings.serverChat onClicked: { MySettings.serverChat = !MySettings.serverChat @@ -257,8 +386,9 @@ MySettingsTab { } MySettingsLabel { id: serverPortLabel - text: qsTr("API Server Port (Requires restart):") - Layout.row: 9 + text: qsTr("API Server Port:") + helpText: qsTr("A local port to run the server (Requires restart") + Layout.row: 11 Layout.column: 0 } MyTextField { @@ -268,8 +398,11 @@ MySettingsTab { font.pixelSize: theme.fontSizeLarge ToolTip.text: qsTr("Api server port. WARNING: You need to restart the application for it to take effect") ToolTip.visible: hovered - Layout.row: 9 - Layout.column: 1 + Layout.row: 11 + Layout.column: 2 + Layout.minimumWidth: 200 + Layout.maximumWidth: 200 + Layout.alignment: Qt.AlignRight validator: IntValidator { bottom: 1 } @@ -286,58 +419,53 @@ MySettingsTab { Accessible.name: serverPortField.text Accessible.description: ToolTip.text } - Rectangle { - Layout.row: 10 - Layout.column: 0 - Layout.columnSpan: 3 - Layout.fillWidth: true - height: 3 - color: theme.accentColor - } - } - advancedSettings: GridLayout { - columns: 3 - rowSpacing: 10 - columnSpacing: 10 - Rectangle { - Layout.row: 2 - Layout.column: 0 - Layout.fillWidth: true - Layout.columnSpan: 3 - height: 3 - color: theme.accentColor - } + MySettingsLabel { id: gpuOverrideLabel text: qsTr("Force Metal (macOS+arm)") - Layout.row: 1 + Layout.row: 13 Layout.column: 0 } - RowLayout { - Layout.row: 1 - Layout.column: 1 - Layout.columnSpan: 2 - MyCheckBox { - id: gpuOverrideBox - checked: MySettings.forceMetal - onClicked: { - MySettings.forceMetal = !MySettings.forceMetal - } + MyCheckBox { + id: gpuOverrideBox + Layout.row: 13 + Layout.column: 2 + Layout.alignment: Qt.AlignRight + checked: MySettings.forceMetal + onClicked: { + MySettings.forceMetal = !MySettings.forceMetal } + ToolTip.text: qsTr("WARNING: On macOS with arm (M1+) this setting forces usage of the GPU. Can cause crashes if the model requires more RAM than the system supports. Because of crash possibility the setting will not persist across restarts of the application. This has no effect on non-macs or intel.") + ToolTip.visible: hovered + } - Item { - Layout.fillWidth: true - Layout.alignment: Qt.AlignTop - Layout.minimumHeight: warningLabel.height - MySettingsLabel { - id: warningLabel - width: parent.width - color: theme.textErrorColor - wrapMode: Text.WordWrap - text: qsTr("WARNING: On macOS with arm (M1+) this setting forces usage of the GPU. Can cause crashes if the model requires more RAM than the system supports. Because of crash possibility the setting will not persist across restarts of the application. This has no effect on non-macs or intel.") - } + MySettingsLabel { + id: updatesLabel + text: qsTr("Check for updates") + helpText: qsTr("Click to see if an update to the application is available"); + Layout.row: 14 + Layout.column: 0 + } + + MySettingsButton { + Layout.row: 14 + Layout.column: 2 + Layout.alignment: Qt.AlignRight + text: qsTr("Updates"); + onClicked: { + if (!LLM.checkForUpdates()) + checkForUpdatesError.open() } } + + Rectangle { + Layout.row: 15 + Layout.column: 0 + Layout.columnSpan: 3 + Layout.fillWidth: true + height: 2 + color: theme.settingsDivider + } } } diff --git a/gpt4all-chat/qml/ChatDrawer.qml b/gpt4all-chat/qml/ChatDrawer.qml index 893fbc73..f107ef8f 100644 --- a/gpt4all-chat/qml/ChatDrawer.qml +++ b/gpt4all-chat/qml/ChatDrawer.qml @@ -16,23 +16,33 @@ Rectangle { id: theme } - signal downloadClicked - signal aboutClicked + color: theme.viewBackground - color: theme.containerBackground + Rectangle { + id: borderRight + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.right: parent.right + width: 2 + color: theme.dividerColor + } Item { - anchors.fill: parent - anchors.margins: 10 + anchors.top: parent.top + anchors.bottom: parent.bottom + anchors.left: parent.left + anchors.right: borderRight.left Accessible.role: Accessible.Pane Accessible.name: qsTr("Drawer") Accessible.description: qsTr("Main navigation drawer") - MyButton { + MySettingsButton { id: newChat + anchors.top: parent.top anchors.left: parent.left anchors.right: parent.right + anchors.margins: 20 font.pixelSize: theme.fontSizeLarger topPadding: 20 bottomPadding: 20 @@ -45,20 +55,31 @@ Rectangle { } } + Rectangle { + id: divider + anchors.top: newChat.bottom + anchors.margins: 20 + anchors.topMargin: 15 + anchors.left: parent.left + anchors.right: parent.right + height: 1 + color: theme.dividerColor + } + ScrollView { anchors.left: parent.left anchors.right: parent.right - anchors.rightMargin: -10 - anchors.topMargin: 10 - anchors.top: newChat.bottom - anchors.bottom: checkForUpdatesButton.top - anchors.bottomMargin: 10 + anchors.topMargin: 15 + anchors.top: divider.bottom + anchors.bottom: parent.bottom + anchors.bottomMargin: 15 ScrollBar.vertical.policy: ScrollBar.AlwaysOff clip: true ListView { id: conversationList anchors.fill: parent + anchors.leftMargin: 10 anchors.rightMargin: 10 model: ChatListModel @@ -71,6 +92,33 @@ Rectangle { anchors.bottom: conversationList.bottom } + Component { + id: sectionHeading + Rectangle { + width: ListView.view.width + height: childrenRect.height + color: "transparent" + property bool isServer: ChatListModel.get(parent.index) && ChatListModel.get(parent.index).isServer + visible: !isServer || MySettings.serverChat + + required property string section + + Text { + leftPadding: 10 + rightPadding: 10 + topPadding: 15 + bottomPadding: 5 + text: parent.section + color: theme.styledTextColor + font.pixelSize: theme.fontSizeLarge + } + } + } + + section.property: "section" + section.criteria: ViewSection.FullString + section.delegate: sectionHeading + delegate: Rectangle { id: chatRectangle width: conversationList.width @@ -80,21 +128,25 @@ Rectangle { property bool trashQuestionDisplayed: false visible: !isServer || MySettings.serverChat z: isCurrent ? 199 : 1 - color: index % 2 === 0 ? theme.darkContrast : theme.lightContrast + color: isCurrent ? theme.selectedBackground : "transparent" border.width: isCurrent - border.color: chatName.readOnly ? theme.assistantColor : theme.userColor + border.color: theme.dividerColor + radius: 10 + TextField { id: chatName anchors.left: parent.left anchors.right: buttons.left - color: theme.textColor - padding: 15 + color: theme.styledTextColor + topPadding: 15 + bottomPadding: 15 focus: false readOnly: true wrapMode: Text.NoWrap hoverEnabled: false // Disable hover events on the TextArea selectByMouse: false // Disable text selection in the TextArea font.pixelSize: theme.fontSizeLarge + font.bold: true text: readOnly ? metrics.elidedText : name horizontalAlignment: TextInput.AlignLeft opacity: trashQuestionDisplayed ? 0.5 : 1.0 @@ -103,7 +155,7 @@ Rectangle { font: chatName.font text: name elide: Text.ElideRight - elideWidth: chatName.width - 40 + elideWidth: chatName.width - 15 } background: Rectangle { color: "transparent" @@ -240,45 +292,5 @@ Rectangle { Accessible.description: qsTr("List of chats in the drawer dialog") } } - - MyButton { - id: checkForUpdatesButton - anchors.left: parent.left - anchors.right: parent.right - anchors.bottom: downloadButton.top - anchors.bottomMargin: 10 - text: qsTr("Updates") - font.pixelSize: theme.fontSizeLarge - Accessible.description: qsTr("Launch an external application that will check for updates to the installer") - onClicked: { - if (!LLM.checkForUpdates()) - checkForUpdatesError.open() - } - } - - MyButton { - id: downloadButton - anchors.left: parent.left - anchors.right: parent.right - anchors.bottom: aboutButton.top - anchors.bottomMargin: 10 - text: qsTr("Downloads") - Accessible.description: qsTr("Launch a dialog to download new models") - onClicked: { - downloadClicked() - } - } - - MyButton { - id: aboutButton - anchors.left: parent.left - anchors.right: parent.right - anchors.bottom: parent.bottom - text: qsTr("About") - Accessible.description: qsTr("Launch a dialog to show the about page") - onClicked: { - aboutClicked() - } - } } } diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index d72e651e..2d614042 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -23,38 +23,10 @@ Rectangle { property var currentChat: ChatListModel.currentChat property var chatModel: currentChat.chatModel + signal addCollectionViewRequested() + signal addModelViewRequested() - color: theme.black - - // Startup code - Component.onCompleted: { - startupDialogs(); - } - - Component.onDestruction: { - Network.trackEvent("session_end") - } - - Connections { - target: firstStartDialog - function onClosed() { - startupDialogs(); - } - } - - Connections { - target: downloadNewModels - function onClosed() { - startupDialogs(); - } - } - - Connections { - target: Download - function onHasNewerReleaseChanged() { - startupDialogs(); - } - } + color: theme.viewBackground Connections { target: currentChat @@ -71,126 +43,26 @@ Rectangle { } } - property bool hasShownModelDownload: false - property bool hasCheckedFirstStart: false - property bool hasShownSettingsAccess: false - - function startupDialogs() { - if (!LLM.compatHardware()) { - Network.trackEvent("noncompat_hardware") - errorCompatHardware.open(); - return; - } - - // check if we have access to settings and if not show an error - if (!hasShownSettingsAccess && !LLM.hasSettingsAccess()) { - errorSettingsAccess.open(); - hasShownSettingsAccess = true; - return; - } - - // check for first time start of this version - if (!hasCheckedFirstStart) { - if (Download.isFirstStart(/*writeVersion*/ true)) { - firstStartDialog.open(); - return; - } - - // send startup or opt-out now that the user has made their choice - Network.sendStartup() - // start localdocs - LocalDocs.requestStart() - - hasCheckedFirstStart = true - } - - // check for any current models and if not, open download dialog once - if (!hasShownModelDownload && ModelList.installedModels.count === 0 && !firstStartDialog.opened) { - downloadNewModels.open(); - hasShownModelDownload = true; - return; - } - - // check for new version - if (Download.hasNewerRelease && !firstStartDialog.opened && !downloadNewModels.opened) { - newVersionDialog.open(); - return; - } - } - function currentModelName() { return ModelList.modelInfo(currentChat.modelInfo.id).name; } - PopupDialog { - id: errorCompatHardware - anchors.centerIn: parent - shouldTimeOut: false - shouldShowBusy: false - closePolicy: Popup.NoAutoClose - modal: true - text: qsTr("

          Encountered an error starting up:


          ") - + qsTr("\"Incompatible hardware detected.\"") - + qsTr("

          Unfortunately, your CPU does not meet the minimal requirements to run ") - + qsTr("this program. In particular, it does not support AVX intrinsics which this ") - + qsTr("program requires to successfully run a modern large language model. ") - + qsTr("The only solution at this time is to upgrade your hardware to a more modern CPU.") - + qsTr("

          See here for more information: ") - + qsTr("https://en.wikipedia.org/wiki/Advanced_Vector_Extensions") - } - - PopupDialog { - id: errorSettingsAccess - anchors.centerIn: parent - shouldTimeOut: false - shouldShowBusy: false - modal: true - text: qsTr("

          Encountered an error starting up:


          ") - + qsTr("\"Inability to access settings file.\"") - + qsTr("

          Unfortunately, something is preventing the program from accessing ") - + qsTr("the settings file. This could be caused by incorrect permissions in the local ") - + qsTr("app config directory where the settings file is located. ") - + qsTr("Check out our discord channel for help.") - } - - StartupDialog { - id: firstStartDialog - anchors.centerIn: parent - } - - NewVersionDialog { - id: newVersionDialog - anchors.centerIn: parent - } - - AboutDialog { - id: aboutDialog - anchors.centerIn: parent - width: Math.min(1024, window.width - (window.width * .2)) - height: Math.min(600, window.height - (window.height * .2)) - } - - Item { - Accessible.role: Accessible.Window - Accessible.name: title - } - PopupDialog { id: modelLoadingErrorPopup anchors.centerIn: parent shouldTimeOut: false text: qsTr("

          Encountered an error loading model:


          ") - + "\"" + currentChat.modelLoadingError + "\"" - + qsTr("

          Model loading failures can happen for a variety of reasons, but the most common " - + "causes include a bad file format, an incomplete or corrupted download, the wrong file " - + "type, not enough system RAM or an incompatible model type. Here are some suggestions for resolving the problem:" - + "