diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..dd84ea78 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,38 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Desktop (please complete the following information):** + - OS: [e.g. iOS] + - Browser [e.g. chrome, safari] + - Version [e.g. 22] + +**Smartphone (please complete the following information):** + - Device: [e.g. iPhone6] + - OS: [e.g. iOS8.1] + - Browser [e.g. stock browser, safari] + - Version [e.g. 22] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..bbcbbe7d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.gitignore b/.gitignore index 1a84c28c..49708f47 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,9 @@ cython_debug/ .vscode *.bin -.DS_Store \ No newline at end of file +.DS_Store + +# gpt4all-chat +CMakeLists.txt.user +gpt4all-chat/meta/* +gpt4all-chat/models/* diff --git a/.gitmodules b/.gitmodules index 371af62e..eb06ee48 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "peft"] - path = peft - url = https://github.com/huggingface/peft.git +[submodule "llama.cpp"] + path = gpt4all-chat/llmodel/llama.cpp + url = https://github.com/manyoso/llama.cpp.git diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt new file mode 100644 index 00000000..fbe018b9 --- /dev/null +++ b/gpt4all-chat/CMakeLists.txt @@ -0,0 +1,225 @@ +cmake_minimum_required(VERSION 3.16) + +if(APPLE) + option(BUILD_UNIVERSAL "Build a Universal binary on macOS" OFF) + if(BUILD_UNIVERSAL) + # Build a Universal binary on macOS + # This requires that the found Qt library is compiled as Universal binaries. + set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE) + else() + # Build for the host architecture on macOS + set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE) + endif() +endif() + +set(APP_VERSION_MAJOR 2) +set(APP_VERSION_MINOR 4) +set(APP_VERSION_PATCH 2) +set(APP_VERSION "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") + +# Include the binary directory for the generated header file +include_directories("${CMAKE_CURRENT_BINARY_DIR}") + +project(gpt4all VERSION ${APP_VERSION} LANGUAGES CXX C) + +set(CMAKE_AUTOMOC ON) +set(CMAKE_AUTORCC ON) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +option(GPT4ALL_LOCALHOST OFF "Build installer for localhost repo") +option(GPT4ALL_AVX_ONLY OFF "Build for avx only") +option(GPT4ALL_OFFLINE_INSTALLER "Build an offline installer" OFF) + +# Generate a header file with the version number +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/cmake/config.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/config.h" +) + +find_package(Qt6 6.2 COMPONENTS Core Quick QuickDialogs2 Svg REQUIRED) + +# Get the Qt6Core target properties +get_target_property(Qt6Core_INCLUDE_DIRS Qt6::Core INTERFACE_INCLUDE_DIRECTORIES) +get_target_property(Qt6Core_LIBRARY_RELEASE Qt6::Core LOCATION_RELEASE) + +# Find the qmake binary +find_program(QMAKE_EXECUTABLE NAMES qmake qmake6 PATHS ${Qt6Core_INCLUDE_DIRS}/../.. NO_DEFAULT_PATH) + +# Get the Qt 6 root directory +get_filename_component(Qt6_ROOT_DIR "${Qt6Core_LIBRARY_RELEASE}" DIRECTORY) +get_filename_component(Qt6_ROOT_DIR "${Qt6_ROOT_DIR}/.." ABSOLUTE) + +message(STATUS "qmake binary: ${QMAKE_EXECUTABLE}") +message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}") + +add_subdirectory(llmodel) + +set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +qt_add_executable(chat + main.cpp + chat.h chat.cpp + chatllm.h chatllm.cpp + chatmodel.h chatlistmodel.h chatlistmodel.cpp + download.h download.cpp + network.h network.cpp + llm.h llm.cpp + sysinfo.h +) + +qt_add_qml_module(chat + URI gpt4all + VERSION 1.0 + QML_FILES + main.qml + qml/ChatDrawer.qml + qml/ModelDownloaderDialog.qml + qml/NetworkDialog.qml + qml/NewVersionDialog.qml + qml/ThumbsDownDialog.qml + qml/SettingsDialog.qml + qml/StartupDialog.qml + qml/PopupDialog.qml + qml/AboutDialog.qml + qml/Theme.qml + RESOURCES + icons/send_message.svg + icons/stop_generating.svg + icons/regenerate.svg + icons/copy.svg + icons/settings.svg + icons/edit.svg + icons/trash.svg + icons/network.svg + icons/thumbs_up.svg + icons/thumbs_down.svg + icons/logo.svg + icons/logo-32.png + icons/logo-48.png + icons/favicon.ico + icons/favicon.icns +) + +set_target_properties(chat PROPERTIES + MACOSX_BUNDLE_GUI_IDENTIFIER gpt4all + MACOSX_BUNDLE_BUNDLE_VERSION ${PROJECT_VERSION} + MACOSX_BUNDLE_SHORT_VERSION_STRING ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} + MACOSX_BUNDLE TRUE + WIN32_EXECUTABLE TRUE + MACOSX_BUNDLE_ICON_FILE "favicon.icns" +) + +if(${CMAKE_SYSTEM_NAME} MATCHES Darwin) + set_target_properties(chat PROPERTIES + OUTPUT_NAME gpt4all + ) +endif() + +target_compile_definitions(chat + PRIVATE $<$,$>:QT_QML_DEBUG>) +target_link_libraries(chat + PRIVATE Qt6::Quick Qt6::Svg) +target_link_libraries(chat + PRIVATE llmodel) + +set(COMPONENT_NAME_MAIN ${PROJECT_NAME}) +set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install) + +if(NOT (CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "arm64")) + add_executable(test_hw test_hw.cpp) + install(TARGETS test_hw DESTINATION bin COMPONENT ${COMPONENT_NAME_MAIN}) +endif() + +install(TARGETS chat DESTINATION bin COMPONENT ${COMPONENT_NAME_MAIN}) +install(TARGETS llmodel DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN}) +install(TARGETS llama DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN}) + +set(CPACK_GENERATOR "IFW") +set(CPACK_VERBATIM_VARIABLES YES) +set(CPACK_IFW_VERBOSE ON) + +if(${CMAKE_SYSTEM_NAME} MATCHES Linux) + set(LINUXDEPLOYQT "$ENV{HOME}/dev/linuxdeployqt/build/tools/linuxdeployqt/linuxdeployqt") + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/cmake/deploy-qt-linux.cmake.in" + "${CMAKE_BINARY_DIR}/cmake/deploy-qt-linux.cmake" @ONLY) + set(CPACK_PRE_BUILD_SCRIPTS ${CMAKE_BINARY_DIR}/cmake/deploy-qt-linux.cmake) + set(CPACK_IFW_ROOT "~/Qt/Tools/QtInstallerFramework/4.5") + set(CPACK_PACKAGE_FILE_NAME "${COMPONENT_NAME_MAIN}-installer-linux") + set(CPACK_IFW_TARGET_DIRECTORY "@HomeDir@/${COMPONENT_NAME_MAIN}") +elseif(${CMAKE_SYSTEM_NAME} MATCHES Windows) + find_program(WINDEPLOYQT windeployqt HINTS ${_qt_bin_dir}) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/cmake/deploy-qt-windows.cmake.in" + "${CMAKE_BINARY_DIR}/cmake/deploy-qt-windows.cmake" @ONLY) + set(CPACK_PRE_BUILD_SCRIPTS ${CMAKE_BINARY_DIR}/cmake/deploy-qt-windows.cmake) + set(CPACK_IFW_ROOT "C:/Qt/Tools/QtInstallerFramework/4.5") + set(CPACK_IFW_PACKAGE_ICON "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.ico") + set(CPACK_PACKAGE_FILE_NAME "${COMPONENT_NAME_MAIN}-installer-win64") + set(CPACK_IFW_TARGET_DIRECTORY "@HomeDir@\\${COMPONENT_NAME_MAIN}") +elseif(${CMAKE_SYSTEM_NAME} MATCHES Darwin) + find_program(MACDEPLOYQT macdeployqt HINTS ${_qt_bin_dir}) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/cmake/deploy-qt-mac.cmake.in" + "${CMAKE_BINARY_DIR}/cmake/deploy-qt-mac.cmake" @ONLY) + set(CPACK_PRE_BUILD_SCRIPTS ${CMAKE_BINARY_DIR}/cmake/deploy-qt-mac.cmake) + set(CPACK_IFW_ROOT "~/Qt/Tools/QtInstallerFramework/4.5") + set(CPACK_IFW_PACKAGE_ICON "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.icns") + set(CPACK_PACKAGE_FILE_NAME "${COMPONENT_NAME_MAIN}-installer-darwin") + set(CPACK_IFW_TARGET_DIRECTORY "@ApplicationsDir@/${COMPONENT_NAME_MAIN}") + set(CPACK_BUNDLE_NAME ${COMPONENT_NAME_MAIN}) + set(CPACK_BUNDLE_ICON "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.icns") +endif() + +set(CPACK_PACKAGE_INSTALL_DIRECTORY ${COMPONENT_NAME_MAIN}) +set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR}) +set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR}) +SET(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH}) +set(CPACK_PACKAGE_HOMEPAGE_URL "https://gpt4all.io") +set(CPACK_PACKAGE_ICON "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-48.png") +set(CPACK_RESOURCE_FILE_LICENSE ${CMAKE_CURRENT_SOURCE_DIR}/LICENSE) +set(CPACK_RESOURCE_FILE_README ${CMAKE_CURRENT_SOURCE_DIR}/README.md) +set(CPACK_PACKAGE_EXECUTABLES "GPT4All") +set(CPACK_CREATE_DESKTOP_LINKS "GPT4All") +set(CPACK_IFW_PACKAGE_NAME "GPT4All") +set(CPACK_IFW_PACKAGE_TITLE "GPT4All Installer") +set(CPACK_IFW_PACKAGE_PUBLISHER "Nomic, Inc.") +set(CPACK_IFW_PRODUCT_URL "https://gpt4all.io") +set(CPACK_IFW_PACKAGE_WIZARD_STYLE "Aero") +set(CPACK_IFW_PACKAGE_LOGO "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-48.png") +set(CPACK_IFW_PACKAGE_WINDOW_ICON "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-32.png") +set(CPACK_IFW_PACKAGE_WIZARD_SHOW_PAGE_LIST OFF) + +include(InstallRequiredSystemLibraries) +include(CPack) +include(CPackIFW) +cpack_add_component(${COMPONENT_NAME_MAIN} DOWNLOADED) +cpack_ifw_configure_component(${COMPONENT_NAME_MAIN} ESSENTIAL FORCED_INSTALLATION) +cpack_ifw_configure_component(${COMPONENT_NAME_MAIN} VERSION ${APP_VERSION}) +cpack_ifw_configure_component(${COMPONENT_NAME_MAIN} LICENSES "MIT LICENSE" ${CPACK_RESOURCE_FILE_LICENSE}) +cpack_ifw_configure_component(${COMPONENT_NAME_MAIN} SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/installerscript.qs") +cpack_ifw_configure_component(${COMPONENT_NAME_MAIN} REPLACES "gpt4all-chat") #Was used in very earliest prototypes + +if (GPT4ALL_LOCALHOST) + cpack_ifw_add_repository("GPT4AllRepository" URL "http://localhost/repository") +elseif(GPT4ALL_OFFLINE_INSTALLER) + cpack_ifw_add_repository("GPT4AllRepository" URL "file://${CMAKE_BINARY_DIR}/packages") +else() + if(${CMAKE_SYSTEM_NAME} MATCHES Linux) + if (GPT4ALL_AVX_ONLY) + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/avx_only/linux/repository") + else() + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/linux/repository") + endif() + elseif(${CMAKE_SYSTEM_NAME} MATCHES Windows) + #To sign the target on windows have to create a batch script add use it as a custom target and then use CPACK_IFW_EXTRA_TARGETS to set this extra target + if (GPT4ALL_AVX_ONLY) + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/avx_only/windows/repository") + else() + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/windows/repository") + endif() + elseif(${CMAKE_SYSTEM_NAME} MATCHES Darwin) + if (GPT4ALL_AVX_ONLY) + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/avx_only/mac/repository") + else() + cpack_ifw_add_repository("GPT4AllRepository" URL "https://gpt4all.io/installer_repos/mac/repository") + endif() + endif() +endif() diff --git a/gpt4all-chat/LICENSE b/gpt4all-chat/LICENSE new file mode 100644 index 00000000..09ca4546 --- /dev/null +++ b/gpt4all-chat/LICENSE @@ -0,0 +1,15 @@ +Copyright 2023 Nomic, Inc., Aaron Miller + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +ADDENDUM: + +Any LLM models that are loaded and used by the application are not themselves +subject to this license if indeed they are even copyrightable. The terms of +this license apply only to the application software and its accompanying +documentation and do not extend to any LLM models, whether created by the +author of the application or obtained from third-party sources. diff --git a/gpt4all-chat/README.md b/gpt4all-chat/README.md index d78f4611..4bbe6016 100644 --- a/gpt4all-chat/README.md +++ b/gpt4all-chat/README.md @@ -1,2 +1,78 @@ -# GPT4All Chat -This directory will contain the code to build out the QT chat GUI. +# gpt4all-chat + +Cross platform Qt based GUI for GPT4All versions with GPT-J as the base +model. NOTE: The model seen in the screenshot is actually a preview of a +new training run for GPT4All based on GPT-J. The GPT4All project is busy +at work getting ready to release this model including installers for all +three major OS's. In the meantime, you can try this UI out with the original +GPT-J model by following build instructions below. + +![image](https://user-images.githubusercontent.com/50458173/231464085-da9edff6-a593-410e-8f38-7513f75c8aab.png) + +## Install + +One click installers for macOS, Linux, and Windows at https://gpt4all.io + +## Features + +* Cross-platform (Linux, Windows, MacOSX) +* Fast CPU based inference using ggml for GPT-J based models +* The UI is made to look and feel like you've come to expect from a chatty gpt +* Check for updates so you can alway stay fresh with latest models +* Easy to install with precompiled binaries available for all three major desktop platforms +* Multi-modal - Ability to load more than one model and switch between them +* Supports both llama.cpp and gptj.cpp style models +* Model downloader in GUI featuring many popular open source models +* Settings dialog to change temp, top_p, top_k, threads, etc +* Copy your conversation to clipboard +* Check for updates to get the very latest GUI + +## Feature wishlist + +* Multi-chat - a list of current and past chats and the ability to save/delete/export and switch between +* Text to speech - have the AI response with voice +* Speech to text - give the prompt with your voice +* Python bindings +* Typescript bindings +* Plugin support for langchain other developer tools +* Save your prompt/responses to disk +* Upload prompt/respones manually/automatically to nomic.ai to aid future training runs +* Syntax highlighting support for programming languages, etc. +* REST API with a built-in webserver in the chat gui itself with a headless operation mode as well +* Advanced settings for changing temperature, topk, etc. (DONE) +* YOUR IDEA HERE + +## Building and running + +* Follow the visual instructions on the [build_and_run](build_and_run.md) page + +## Getting the latest + +If you've already checked out the source code and/or built the program make sure when you do a git fetch to get the latest changes and that you also do ```git submodule update --init --recursive``` to update the submodules. + +## Manual download of models +* https://gpt4all.io/models/ggml-mpt-7b-chat.bin (default) (md5sum 756249d3d6abe23bde3b1ae272628640) Current best non-commercially licensable chat model based on MPT and trained by Mosaic ML. +* https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin (default) (md5sum 81a09a0ddf89690372fc296ff7f625af) Current best commercially licensable model based on GPT-J and trained by Nomic AI on the latest curated GPT4All dataset. +* https://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin (md5sum 91f886b68fbce697e9a3cd501951e455) Current best non-commercially licensable model based on Llama 13b and trained by Nomic AI on the latest curated GPT4All dataset. +* https://gpt4all.io/models/ggml-gpt4all-j-v1.2-jazzy.bin (md5sum 879344aaa9d62fdccbda0be7a09e7976) An commercially licensable model based on GPT-J and trained by Nomic AI on the v2 GPT4All dataset. +* https://gpt4all.io/models/ggml-gpt4all-j-v1.1-breezy.bin (md5sum 61d48a82cb188cceb14ebb8082bfec37) An commercially licensable model based on GPT-J and trained by Nomic AI on the v1 GPT4All dataset. +* https://gpt4all.io/models/ggml-gpt4all-j.bin (md5sum 5b5a3f9b858d33b29b52b89692415595) An commercially licensable model based on GPT-J and trained by Nomic AI on the v0 GPT4All dataset. +* https://gpt4all.io/models/ggml-vicuna-7b-1.1-q4_2.bin (md5sum 29119f8fa11712704c6b22ac5ab792ea) An non-commercially licensable model based on Llama 7b and trained by teams from UC Berkeley, CMU, Stanford, MBZUAI, and UC San Diego. +* https://gpt4all.io/models/ggml-vicuna-13b-1.1-q4_2.bin (md5sum 95999b7b0699e2070af63bf5d34101a8) An non-commercially licensable model based on Llama 13b and trained by teams from UC Berkeley, CMU, Stanford, MBZUAI, and UC San Diego. +* https://gpt4all.io/models/ggml-wizardLM-7B.q4_2.bin (md5sum 99e6d129745a3f1fb1121abed747b05a) An non-commercially licensable model based on Llama 7b and trained by Microsoft and Peking University. +* https://gpt4all.io/models/ggml-stable-vicuna-13B.q4_2.bin (md5sum 6cb4ee297537c9133bddab9692879de0) An non-commercially licensable model based on Llama 13b and RLHF trained by Stable AI. +* https://gpt4all.io/models/ggml-mpt-7b-base.bin (md5sum 120c32a51d020066288df045ef5d52b9) A commercially licensable model base pre-trained by Mosaic ML. + +## Terminal Only Interface with no Qt dependency + +Check out https://github.com/kuvaus/LlamaGPTJ-chat which is using the llmodel backend so it is compliant with our ecosystem and all models downloaded above should work with it. + +## Contributing + +* Pull requests welcome. See the feature wish list for ideas :) + + +## License +The source code of this chat interface is currently under a MIT license. The underlying GPT4All-j model is released under non-restrictive open-source Apache 2 License. + +The GPT4All-J license allows for users to use generated outputs as they see fit. Users take responsibility for ensuring their content meets applicable requirements for publication in a given context or region. diff --git a/gpt4all-chat/build_and_run.md b/gpt4all-chat/build_and_run.md new file mode 100644 index 00000000..e111d0b4 --- /dev/null +++ b/gpt4all-chat/build_and_run.md @@ -0,0 +1,57 @@ +# Install Qt 6.x and setup/build gpt4all-chat from source + +Depending upon your operating system, there are many ways that Qt is distributed. +Here is the recommended method for getting the Qt dependency installed to setup and build +gpt4all-chat from source. + +## Create a [Qt account](https://login.qt.io/register) + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/d1e44cab-4245-4144-a91c-7b02267df2b2) + +## Go to the Qt open source [download page](https://www.qt.io/download-qt-installer-oss) + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/d68f5f45-cca3-4fe9-acf4-cabdcb95f669) + +## Start the installer and sign in + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/899b1422-51ae-4bb5-acc9-b9027a8e9b19) + +## After some screens about license, select custom + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/2290031a-fdb0-4f47-a7f1-d77ad5451068) + +## Select the following + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/c6e999e5-cc8a-4dfc-8065-b59139e8c7ae) + +NOTE: This is for macOS. For Linux it is similar, but you need ming64 for Windows, not the MSVC install + +## Open up QtCreator + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/a34978f4-a220-459c-af66-e901d7ccd7bb) + +## Clone the git repo for gpt4all-chat + +``` +git clone --recurse-submodules https://github.com/nomic-ai/gpt4all-chat +``` + +## Open the gpt4all-chat project in QtCreator + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/3d3e2743-2a1d-43d6-9e55-62f7f4306de7) + +NOTE: File->Open File or Project and navigate to the gpt4all-chat repo and choose the CMakeLists.txt + +## Configure project + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/44d5aafb-a95d-434b-ba2a-a3138c0e49a0) + +## Build project + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/43cd7b42-32f0-4efa-9612-d51f85637103) + +## Run project + +![image](https://github.com/nomic-ai/gpt4all-chat/assets/10168/611ea795-bdcd-4feb-a466-eb1c2e936e7e) + + diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp new file mode 100644 index 00000000..2a6b941f --- /dev/null +++ b/gpt4all-chat/chat.cpp @@ -0,0 +1,274 @@ +#include "chat.h" +#include "llm.h" +#include "network.h" +#include "download.h" + +Chat::Chat(QObject *parent) + : QObject(parent) + , m_id(Network::globalInstance()->generateUniqueId()) + , m_name(tr("New Chat")) + , m_chatModel(new ChatModel(this)) + , m_responseInProgress(false) + , m_creationDate(QDateTime::currentSecsSinceEpoch()) + , m_llmodel(new ChatLLM(this)) +{ + // Should be in same thread + connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + + // Should be in different threads + connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::modelLoadingError, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); + + connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); + connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); + connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); + connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); + connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection); + connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection); + connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); + + // The following are blocking operations and will block the gui thread, therefore must be fast + // to respond to + connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection); + connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); + connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); +} + +void Chat::reset() +{ + stopGenerating(); + // Erase our current on disk representation as we're completely resetting the chat along with id + LLM::globalInstance()->chatListModel()->removeChatFile(this); + emit resetContextRequested(); // blocking queued connection + m_id = Network::globalInstance()->generateUniqueId(); + emit idChanged(); + // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally + // an older chat that was reset for another purpose. Resetting this data will lead to the chat + // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' + // further down in the list. This might surprise the user. In the future, we me might get rid of + // the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown + // we effectively do a reset context. We *have* to do this right now when switching between different + // types of models. The only way to get rid of that would be a very long recalculate where we rebuild + // the context if we switch between different types of models. Probably the right way to fix this + // is to allow switching models but throwing up a dialog warning users if we switch between types + // of models that a long recalculation will ensue. + m_chatModel->clear(); +} + +bool Chat::isModelLoaded() const +{ + return m_llmodel->isModelLoaded(); +} + +void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) +{ + emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, + repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount()); +} + +void Chat::regenerateResponse() +{ + emit regenerateResponseRequested(); // blocking queued connection +} + +void Chat::stopGenerating() +{ + m_llmodel->stopGenerating(); +} + +QString Chat::response() const +{ + return m_llmodel->response(); +} + +void Chat::handleResponseChanged() +{ + const int index = m_chatModel->count() - 1; + m_chatModel->updateValue(index, response()); + emit responseChanged(); +} + +void Chat::responseStarted() +{ + m_responseInProgress = true; + emit responseInProgressChanged(); +} + +void Chat::responseStopped() +{ + m_responseInProgress = false; + emit responseInProgressChanged(); + if (m_llmodel->generatedName().isEmpty()) + emit generateNameRequested(); + if (chatModel()->count() < 3) + Network::globalInstance()->sendChatStarted(); +} + +QString Chat::modelName() const +{ + return m_llmodel->modelName(); +} + +void Chat::setModelName(const QString &modelName) +{ + // doesn't block but will unload old model and load new one which the gui can see through changes + // to the isModelLoaded property + emit modelNameChangeRequested(modelName); +} + +void Chat::newPromptResponsePair(const QString &prompt) +{ + m_chatModel->appendPrompt(tr("Prompt: "), prompt); + m_chatModel->appendResponse(tr("Response: "), prompt); + emit resetResponseRequested(); // blocking queued connection +} + +bool Chat::isRecalc() const +{ + return m_llmodel->isRecalc(); +} + +void Chat::loadDefaultModel() +{ + emit loadDefaultModelRequested(); +} + +void Chat::loadModel(const QString &modelName) +{ + emit loadModelRequested(modelName); +} + +void Chat::unloadModel() +{ + stopGenerating(); + emit unloadModelRequested(); +} + +void Chat::reloadModel() +{ + emit reloadModelRequested(m_savedModelName); +} + +void Chat::generatedNameChanged() +{ + // Only use the first three words maximum and remove newlines and extra spaces + QString gen = m_llmodel->generatedName().simplified(); + QStringList words = gen.split(' ', Qt::SkipEmptyParts); + int wordCount = qMin(3, words.size()); + m_name = words.mid(0, wordCount).join(' '); + emit nameChanged(); +} + +void Chat::handleRecalculating() +{ + Network::globalInstance()->sendRecalculatingContext(m_chatModel->count()); + emit recalcChanged(); +} + +void Chat::handleModelNameChanged() +{ + m_savedModelName = modelName(); + emit modelNameChanged(); +} + +bool Chat::serialize(QDataStream &stream, int version) const +{ + stream << m_creationDate; + stream << m_id; + stream << m_name; + stream << m_userName; + stream << m_savedModelName; + if (!m_llmodel->serialize(stream, version)) + return false; + if (!m_chatModel->serialize(stream, version)) + return false; + return stream.status() == QDataStream::Ok; +} + +bool Chat::deserialize(QDataStream &stream, int version) +{ + stream >> m_creationDate; + stream >> m_id; + emit idChanged(); + stream >> m_name; + stream >> m_userName; + emit nameChanged(); + stream >> m_savedModelName; + + // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so + // unfortunately, we cannot deserialize these + if (version < 2 && m_savedModelName.contains("gpt4all-j")) + return false; + + if (!m_llmodel->deserialize(stream, version)) + return false; + if (!m_chatModel->deserialize(stream, version)) + return false; + emit chatModelChanged(); + return stream.status() == QDataStream::Ok; +} + +QList Chat::modelList() const +{ + // Build a model list from exepath and from the localpath + QList list; + + QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + + { + QDir dir(exePath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = exePath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists()) { + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (localPath != exePath) { + QDir dir(localPath); + dir.setNameFilters(QStringList() << "ggml-*.bin"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = localPath + f; + QFileInfo info(filePath); + QString name = info.completeBaseName().remove(0, 5); + if (info.exists() && !list.contains(name)) { // don't allow duplicates + if (name == modelName()) + list.prepend(name); + else + list.append(name); + } + } + } + + if (list.isEmpty()) { + if (exePath != localPath) { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath << "nor" << localPath; + } else { + qWarning() << "ERROR: Could not find any applicable models in" + << exePath; + } + return QList(); + } + + return list; +} diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h new file mode 100644 index 00000000..4ec97ee6 --- /dev/null +++ b/gpt4all-chat/chat.h @@ -0,0 +1,106 @@ +#ifndef CHAT_H +#define CHAT_H + +#include +#include +#include + +#include "chatllm.h" +#include "chatmodel.h" + +class Chat : public QObject +{ + Q_OBJECT + Q_PROPERTY(QString id READ id NOTIFY idChanged) + Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged) + Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) + Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) + Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) + QML_ELEMENT + QML_UNCREATABLE("Only creatable from c++!") + +public: + explicit Chat(QObject *parent = nullptr); + + QString id() const { return m_id; } + QString name() const { return m_userName.isEmpty() ? m_name : m_userName; } + void setName(const QString &name) + { + m_userName = name; + emit nameChanged(); + } + ChatModel *chatModel() { return m_chatModel; } + + Q_INVOKABLE void reset(); + Q_INVOKABLE bool isModelLoaded() const; + Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); + Q_INVOKABLE void regenerateResponse(); + Q_INVOKABLE void stopGenerating(); + Q_INVOKABLE void newPromptResponsePair(const QString &prompt); + + QString response() const; + bool responseInProgress() const { return m_responseInProgress; } + QString modelName() const; + void setModelName(const QString &modelName); + bool isRecalc() const; + + void loadDefaultModel(); + void loadModel(const QString &modelName); + void unloadModel(); + void reloadModel(); + + qint64 creationDate() const { return m_creationDate; } + bool serialize(QDataStream &stream, int version) const; + bool deserialize(QDataStream &stream, int version); + + QList modelList() const; + +Q_SIGNALS: + void idChanged(); + void nameChanged(); + void chatModelChanged(); + void isModelLoadedChanged(); + void responseChanged(); + void responseInProgressChanged(); + void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); + void regenerateResponseRequested(); + void resetResponseRequested(); + void resetContextRequested(); + void modelNameChangeRequested(const QString &modelName); + void modelNameChanged(); + void recalcChanged(); + void loadDefaultModelRequested(); + void loadModelRequested(const QString &modelName); + void unloadModelRequested(); + void reloadModelRequested(const QString &modelName); + void generateNameRequested(); + void modelListChanged(); + void modelLoadingError(const QString &error); + +private Q_SLOTS: + void handleResponseChanged(); + void responseStarted(); + void responseStopped(); + void generatedNameChanged(); + void handleRecalculating(); + void handleModelNameChanged(); + +private: + QString m_id; + QString m_name; + QString m_userName; + QString m_savedModelName; + ChatModel *m_chatModel; + bool m_responseInProgress; + qint64 m_creationDate; + ChatLLM *m_llmodel; +}; + +#endif // CHAT_H diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp new file mode 100644 index 00000000..3fd2246f --- /dev/null +++ b/gpt4all-chat/chatlistmodel.cpp @@ -0,0 +1,246 @@ +#include "chatlistmodel.h" +#include "download.h" + +#include +#include + +#define CHAT_FORMAT_MAGIC 0xF5D553CC +#define CHAT_FORMAT_VERSION 2 + +ChatListModel::ChatListModel(QObject *parent) + : QAbstractListModel(parent) + , m_newChat(nullptr) + , m_dummyChat(nullptr) + , m_currentChat(nullptr) + , m_shouldSaveChats(false) +{ + addDummyChat(); + + ChatsRestoreThread *thread = new ChatsRestoreThread; + connect(thread, &ChatsRestoreThread::chatRestored, this, &ChatListModel::restoreChat); + connect(thread, &ChatsRestoreThread::finished, this, &ChatListModel::chatsRestoredFinished); + connect(thread, &ChatsRestoreThread::finished, thread, &QObject::deleteLater); + thread->start(); +} + +bool ChatListModel::shouldSaveChats() const +{ + return m_shouldSaveChats; +} + +void ChatListModel::setShouldSaveChats(bool b) +{ + if (m_shouldSaveChats == b) + return; + m_shouldSaveChats = b; + emit shouldSaveChatsChanged(); +} + +void ChatListModel::removeChatFile(Chat *chat) const +{ + const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); + QFile file(savePath + "/gpt4all-" + chat->id() + ".chat"); + if (!file.exists()) + return; + bool success = file.remove(); + if (!success) + qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName(); +} + +void ChatListModel::saveChats() const +{ + if (!m_shouldSaveChats) + return; + + QElapsedTimer timer; + timer.start(); + const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); + for (Chat *chat : m_chats) { + QString fileName = "gpt4all-" + chat->id() + ".chat"; + QFile file(savePath + "/" + fileName); + bool success = file.open(QIODevice::WriteOnly); + if (!success) { + qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName(); + continue; + } + QDataStream out(&file); + + out << (quint32)CHAT_FORMAT_MAGIC; + out << (qint32)CHAT_FORMAT_VERSION; + out.setVersion(QDataStream::Qt_6_2); + + qDebug() << "serializing chat" << fileName; + if (!chat->serialize(out, CHAT_FORMAT_VERSION)) { + qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName(); + file.remove(); + } + file.close(); + } + qint64 elapsedTime = timer.elapsed(); + qDebug() << "serializing chats took:" << elapsedTime << "ms"; +} + +void ChatsRestoreThread::run() +{ + QElapsedTimer timer; + timer.start(); + struct FileInfo { + bool oldFile; + qint64 creationDate; + QString file; + }; + QList files; + { + // Look for any files in the original spot which was the settings config directory + QSettings settings; + QFileInfo settingsInfo(settings.fileName()); + QString settingsPath = settingsInfo.absolutePath(); + QDir dir(settingsPath); + dir.setNameFilters(QStringList() << "gpt4all-*.chat"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = settingsPath + "/" + f; + QFile file(filePath); + bool success = file.open(QIODevice::ReadOnly); + if (!success) { + qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName(); + continue; + } + QDataStream in(&file); + FileInfo info; + info.oldFile = true; + info.file = filePath; + in >> info.creationDate; + files.append(info); + file.close(); + } + } + { + const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); + QDir dir(savePath); + dir.setNameFilters(QStringList() << "gpt4all-*.chat"); + QStringList fileNames = dir.entryList(); + for (QString f : fileNames) { + QString filePath = savePath + "/" + f; + QFile file(filePath); + bool success = file.open(QIODevice::ReadOnly); + if (!success) { + qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName(); + continue; + } + QDataStream in(&file); + // Read and check the header + quint32 magic; + in >> magic; + if (magic != CHAT_FORMAT_MAGIC) { + qWarning() << "ERROR: Chat file has bad magic:" << file.fileName(); + continue; + } + + // Read the version + qint32 version; + in >> version; + if (version < 1) { + qWarning() << "ERROR: Chat file has non supported version:" << file.fileName(); + continue; + } + + if (version <= 1) + in.setVersion(QDataStream::Qt_6_2); + + FileInfo info; + info.oldFile = false; + info.file = filePath; + in >> info.creationDate; + files.append(info); + file.close(); + } + } + std::sort(files.begin(), files.end(), [](const FileInfo &a, const FileInfo &b) { + return a.creationDate > b.creationDate; + }); + + for (FileInfo &f : files) { + QFile file(f.file); + bool success = file.open(QIODevice::ReadOnly); + if (!success) { + qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName(); + continue; + } + QDataStream in(&file); + + qint32 version = 0; + if (!f.oldFile) { + // Read and check the header + quint32 magic; + in >> magic; + if (magic != CHAT_FORMAT_MAGIC) { + qWarning() << "ERROR: Chat file has bad magic:" << file.fileName(); + continue; + } + + // Read the version + in >> version; + if (version < 1) { + qWarning() << "ERROR: Chat file has non supported version:" << file.fileName(); + continue; + } + + if (version <= 1) + in.setVersion(QDataStream::Qt_6_2); + } + + qDebug() << "deserializing chat" << f.file; + + Chat *chat = new Chat; + chat->moveToThread(qApp->thread()); + if (!chat->deserialize(in, version)) { + qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); + file.remove(); + } else { + emit chatRestored(chat); + } + if (f.oldFile) + file.remove(); // No longer storing in this directory + file.close(); + } + + qint64 elapsedTime = timer.elapsed(); + qDebug() << "deserializing chats took:" << elapsedTime << "ms"; +} + +void ChatListModel::restoreChat(Chat *chat) +{ + chat->setParent(this); + connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged); + connect(chat, &Chat::modelLoadingError, this, &ChatListModel::handleModelLoadingError); + + if (m_dummyChat) { + beginResetModel(); + m_chats = QList({chat}); + setCurrentChat(chat); + delete m_dummyChat; + m_dummyChat = nullptr; + endResetModel(); + } else { + beginInsertRows(QModelIndex(), m_chats.size(), m_chats.size()); + m_chats.append(chat); + endInsertRows(); + } +} + +void ChatListModel::chatsRestoredFinished() +{ + if (m_dummyChat) { + beginResetModel(); + Chat *dummy = m_dummyChat; + m_dummyChat = nullptr; + m_chats.clear(); + addChat(); + delete dummy; + endResetModel(); + } + + if (m_chats.isEmpty()) + addChat(); +} diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h new file mode 100644 index 00000000..c695e05d --- /dev/null +++ b/gpt4all-chat/chatlistmodel.h @@ -0,0 +1,233 @@ +#ifndef CHATLISTMODEL_H +#define CHATLISTMODEL_H + +#include +#include "chat.h" + +class ChatsRestoreThread : public QThread +{ + Q_OBJECT +public: + void run() override; + +Q_SIGNALS: + void chatRestored(Chat *chat); +}; + +class ChatListModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) + Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) + +public: + explicit ChatListModel(QObject *parent = nullptr); + + enum Roles { + IdRole = Qt::UserRole + 1, + NameRole + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_chats.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_chats.size()) + return QVariant(); + + const Chat *item = m_chats.at(index.row()); + switch (role) { + case IdRole: + return item->id(); + case NameRole: + return item->name(); + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[IdRole] = "id"; + roles[NameRole] = "name"; + return roles; + } + + bool shouldSaveChats() const; + void setShouldSaveChats(bool b); + + Q_INVOKABLE void addChat() + { + // Don't add a new chat if we already have one + if (m_newChat || m_dummyChat) + return; + + // Create a new chat pointer and connect it to determine when it is populated + m_newChat = new Chat(this); + connect(m_newChat->chatModel(), &ChatModel::countChanged, + this, &ChatListModel::newChatCountChanged); + connect(m_newChat, &Chat::nameChanged, + this, &ChatListModel::nameChanged); + + beginInsertRows(QModelIndex(), 0, 0); + m_chats.prepend(m_newChat); + endInsertRows(); + emit countChanged(); + setCurrentChat(m_newChat); + } + + Q_INVOKABLE void addDummyChat() + { + // Create a new dummy chat pointer and don't connect it + m_dummyChat = new Chat(this); + beginInsertRows(QModelIndex(), 0, 0); + m_chats.prepend(m_dummyChat); + endInsertRows(); + emit countChanged(); + m_currentChat = m_dummyChat; + emit currentChatChanged(); + } + + void setNewChat(Chat* chat) + { + // Don't add a new chat if we already have one + if (m_newChat) + return; + + m_newChat = chat; + connect(m_newChat->chatModel(), &ChatModel::countChanged, + this, &ChatListModel::newChatCountChanged); + connect(m_newChat, &Chat::nameChanged, + this, &ChatListModel::nameChanged); + connect(m_newChat, &Chat::modelLoadingError, + this, &ChatListModel::handleModelLoadingError); + setCurrentChat(m_newChat); + } + + Q_INVOKABLE void removeChat(Chat* chat) + { + if (!m_chats.contains(chat)) { + qWarning() << "WARNING: Removing chat failed with id" << chat->id(); + return; + } + + removeChatFile(chat); + + if (chat == m_newChat) { + m_newChat->disconnect(this); + m_newChat = nullptr; + } + + const int index = m_chats.indexOf(chat); + if (m_chats.count() < 2) { + addChat(); + } else { + int nextIndex; + if (index == m_chats.count() - 1) + nextIndex = index - 1; + else + nextIndex = index + 1; + Chat *nextChat = get(nextIndex); + Q_ASSERT(nextChat); + setCurrentChat(nextChat); + } + + const int newIndex = m_chats.indexOf(chat); + beginRemoveRows(QModelIndex(), newIndex, newIndex); + m_chats.removeAll(chat); + endRemoveRows(); + delete chat; + } + + Chat *currentChat() const + { + return m_currentChat; + } + + void setCurrentChat(Chat *chat) + { + if (!m_chats.contains(chat)) { + qWarning() << "ERROR: Setting current chat failed with id" << chat->id(); + return; + } + + if (m_currentChat && m_currentChat->isModelLoaded()) + m_currentChat->unloadModel(); + + m_currentChat = chat; + if (!m_currentChat->isModelLoaded()) + m_currentChat->reloadModel(); + emit currentChatChanged(); + } + + Q_INVOKABLE Chat* get(int index) + { + if (index < 0 || index >= m_chats.size()) return nullptr; + return m_chats.at(index); + } + + int count() const { return m_chats.size(); } + + void removeChatFile(Chat *chat) const; + void saveChats() const; + void restoreChat(Chat *chat); + void chatsRestoredFinished(); + +Q_SIGNALS: + void countChanged(); + void currentChatChanged(); + void shouldSaveChatsChanged(); + +private Q_SLOTS: + void newChatCountChanged() + { + Q_ASSERT(m_newChat && m_newChat->chatModel()->count()); + m_newChat->chatModel()->disconnect(this); + m_newChat = nullptr; + } + + void nameChanged() + { + Chat *chat = qobject_cast(sender()); + if (!chat) + return; + + int row = m_chats.indexOf(chat); + if (row < 0 || row >= m_chats.size()) + return; + + QModelIndex index = createIndex(row, 0); + emit dataChanged(index, index, {NameRole}); + } + + void handleModelLoadingError(const QString &error) + { + Chat *chat = qobject_cast(sender()); + qWarning() << "ERROR:" << qPrintable(error) << "id" << chat->id(); + removeChat(chat); + } + + void printChats() + { + for (auto c : m_chats) { + qDebug() << c->name() + << (c == m_currentChat ? "currentChat: true" : "currentChat: false") + << (c == m_newChat ? "newChat: true" : "newChat: false"); + } + } + +private: + bool m_shouldSaveChats; + Chat* m_newChat; + Chat* m_dummyChat; + Chat* m_currentChat; + QList m_chats; +}; + +#endif // CHATITEMMODEL_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp new file mode 100644 index 00000000..2ffbc3c7 --- /dev/null +++ b/gpt4all-chat/chatllm.cpp @@ -0,0 +1,483 @@ +#include "chatllm.h" +#include "chat.h" +#include "download.h" +#include "network.h" +#include "llmodel/gptj.h" +#include "llmodel/llamamodel.h" +#include "llmodel/mpt.h" + +#include +#include +#include +#include +#include +#include +#include + +//#define DEBUG + +#define MPT_INTERNAL_STATE_VERSION 0 +#define GPTJ_INTERNAL_STATE_VERSION 0 +#define LLAMA_INTERNAL_STATE_VERSION 0 + +static QString modelFilePath(const QString &modelName) +{ + QString appPath = QCoreApplication::applicationDirPath() + + "/ggml-" + modelName + ".bin"; + QFileInfo infoAppPath(appPath); + if (infoAppPath.exists()) + return appPath; + + QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + + "/ggml-" + modelName + ".bin"; + + QFileInfo infoLocalPath(downloadPath); + if (infoLocalPath.exists()) + return downloadPath; + return QString(); +} + +ChatLLM::ChatLLM(Chat *parent) + : QObject{nullptr} + , m_llmodel(nullptr) + , m_promptResponseTokens(0) + , m_responseLogits(0) + , m_isRecalc(false) + , m_chat(parent) +{ + moveToThread(&m_llmThread); + connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); + connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); + connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + m_llmThread.setObjectName(m_chat->id()); + m_llmThread.start(); +} + +bool ChatLLM::loadDefaultModel() +{ + const QList models = m_chat->modelList(); + if (models.isEmpty()) { + // try again when we get a list of models + connect(Download::globalInstance(), &Download::modelListChanged, this, + &ChatLLM::loadDefaultModel, Qt::SingleShotConnection); + return false; + } + + QSettings settings; + settings.sync(); + // The user default model can be set by the user in the settings dialog. The "default" user + // default model is "Application default" which signals we should use the default model that was + // specified by the models.json file. + QString defaultModel = settings.value("userDefaultModel").toString(); + if (defaultModel.isEmpty() || !models.contains(defaultModel) || defaultModel == "Application default") + defaultModel = settings.value("defaultModel").toString(); + if (defaultModel.isEmpty() || !models.contains(defaultModel)) + defaultModel = models.first(); + return loadModel(defaultModel); +} + +bool ChatLLM::loadModel(const QString &modelName) +{ + if (isModelLoaded() && m_modelName == modelName) + return true; + + if (isModelLoaded()) { + resetContextPrivate(); + delete m_llmodel; + m_llmodel = nullptr; + emit isModelLoadedChanged(); + } + + bool isGPTJ = false; + bool isMPT = false; + QString filePath = modelFilePath(modelName); + QFileInfo info(filePath); + if (info.exists()) { + + auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + fin.seekg(0); + fin.close(); + isGPTJ = magic == 0x67676d6c; + isMPT = magic == 0x67676d6d; + if (isGPTJ) { + m_modelType = ModelType::GPTJ_; + m_llmodel = new GPTJ; + m_llmodel->loadModel(filePath.toStdString()); + } else if (isMPT) { + m_modelType = ModelType::MPT_; + m_llmodel = new MPT; + m_llmodel->loadModel(filePath.toStdString()); + } else { + m_modelType = ModelType::LLAMA_; + m_llmodel = new LLamaModel; + m_llmodel->loadModel(filePath.toStdString()); + } + + restoreState(); + +#if defined(DEBUG) + qDebug() << "chatllm modelLoadedChanged" << m_chat->id(); + fflush(stdout); +#endif + + emit isModelLoadedChanged(); + + static bool isFirstLoad = true; + if (isFirstLoad) { + emit sendStartup(); + isFirstLoad = false; + } else + emit sendModelLoaded(); + } else { + const QString error = QString("Could not find model %1").arg(modelName); + emit modelLoadingError(error); + } + + if (m_llmodel) + setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix + + return m_llmodel; +} + +bool ChatLLM::isModelLoaded() const +{ + return m_llmodel && m_llmodel->isModelLoaded(); +} + +void ChatLLM::regenerateResponse() +{ + m_ctx.n_past -= m_promptResponseTokens; + m_ctx.n_past = std::max(0, m_ctx.n_past); + // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? + m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end()); + m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end()); + m_promptResponseTokens = 0; + m_responseLogits = 0; + m_response = std::string(); + emit responseChanged(); +} + +void ChatLLM::resetResponse() +{ + m_promptResponseTokens = 0; + m_responseLogits = 0; + m_response = std::string(); + emit responseChanged(); +} + +void ChatLLM::resetContext() +{ + resetContextPrivate(); + emit sendResetContext(); +} + +void ChatLLM::resetContextPrivate() +{ + regenerateResponse(); + m_ctx = LLModel::PromptContext(); +} + +std::string remove_leading_whitespace(const std::string& input) { + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + return std::string(first_non_whitespace, input.end()); +} + +std::string trim_whitespace(const std::string& input) { + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { + return !std::isspace(c); + }).base(); + + return std::string(first_non_whitespace, last_non_whitespace); +} + +QString ChatLLM::response() const +{ + return QString::fromStdString(remove_leading_whitespace(m_response)); +} + +QString ChatLLM::modelName() const +{ + return m_modelName; +} + +void ChatLLM::setModelName(const QString &modelName) +{ + m_modelName = modelName; + emit modelNameChanged(); +} + +void ChatLLM::modelNameChangeRequested(const QString &modelName) +{ + if (!loadModel(modelName)) + qWarning() << "ERROR: Could not load model" << modelName; +} + +bool ChatLLM::handlePrompt(int32_t token) +{ + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt +#if defined(DEBUG) + qDebug() << "chatllm prompt process" << m_chat->id() << token; +#endif + ++m_promptResponseTokens; + return !m_stopGenerating; +} + +bool ChatLLM::handleResponse(int32_t token, const std::string &response) +{ +#if defined(DEBUG) + printf("%s", response.c_str()); + fflush(stdout); +#endif + + // check for error + if (token < 0) { + m_response.append(response); + emit responseChanged(); + return false; + } + + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptResponseTokens; + Q_ASSERT(!response.empty()); + m_response.append(response); + emit responseChanged(); + return !m_stopGenerating; +} + +bool ChatLLM::handleRecalculate(bool isRecalc) +{ + if (m_isRecalc != isRecalc) { + m_isRecalc = isRecalc; + emit recalcChanged(); + } + return !m_stopGenerating; +} + +bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, + float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads) +{ + if (!isModelLoaded()) + return false; + + QString instructPrompt = prompt_template.arg(prompt); + + m_stopGenerating = false; + auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1); + emit responseStarted(); + qint32 logitsBefore = m_ctx.logits.size(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llmodel->setThreadCount(n_threads); +#if defined(DEBUG) + printf("%s", qPrintable(instructPrompt)); + fflush(stdout); +#endif + m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); +#if defined(DEBUG) + printf("\n"); + fflush(stdout); +#endif + m_responseLogits += m_ctx.logits.size() - logitsBefore; + std::string trimmed = trim_whitespace(m_response); + if (trimmed != m_response) { + m_response = trimmed; + emit responseChanged(); + } + emit responseStopped(); + return true; +} + +void ChatLLM::unloadModel() +{ +#if defined(DEBUG) + qDebug() << "chatllm unloadModel" << m_chat->id(); +#endif + saveState(); + delete m_llmodel; + m_llmodel = nullptr; + emit isModelLoadedChanged(); +} + +void ChatLLM::reloadModel(const QString &modelName) +{ +#if defined(DEBUG) + qDebug() << "chatllm reloadModel" << m_chat->id(); +#endif + if (modelName.isEmpty()) { + loadDefaultModel(); + } else { + loadModel(modelName); + } +} + +void ChatLLM::generateName() +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded()) + return; + + QString instructPrompt("### Instruction:\n" + "Describe response above in three words.\n" + "### Response:\n"); + auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1); + LLModel::PromptContext ctx = m_ctx; +#if defined(DEBUG) + printf("%s", qPrintable(instructPrompt)); + fflush(stdout); +#endif + m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); +#if defined(DEBUG) + printf("\n"); + fflush(stdout); +#endif + std::string trimmed = trim_whitespace(m_nameResponse); + if (trimmed != m_nameResponse) { + m_nameResponse = trimmed; + emit generatedNameChanged(); + } +} + +void ChatLLM::handleChatIdChanged() +{ + m_llmThread.setObjectName(m_chat->id()); +} + +bool ChatLLM::handleNamePrompt(int32_t token) +{ + Q_UNUSED(token); + qt_noop(); + return true; +} + +bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) +{ + Q_UNUSED(token); + + m_nameResponse.append(response); + emit generatedNameChanged(); + QString gen = QString::fromStdString(m_nameResponse).simplified(); + QStringList words = gen.split(' ', Qt::SkipEmptyParts); + int wordCount = words.size(); + return words.size() <= 3; +} + +bool ChatLLM::handleNameRecalculate(bool isRecalc) +{ + Q_UNUSED(isRecalc); + Q_UNREACHABLE(); + return true; +} + +bool ChatLLM::serialize(QDataStream &stream, int version) +{ + if (version > 1) { + stream << m_modelType; + switch (m_modelType) { + case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break; + case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break; + case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break; + default: Q_UNREACHABLE(); + } + } + stream << response(); + stream << generatedName(); + stream << m_promptResponseTokens; + stream << m_responseLogits; + stream << m_ctx.n_past; + stream << quint64(m_ctx.logits.size()); + stream.writeRawData(reinterpret_cast(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float)); + stream << quint64(m_ctx.tokens.size()); + stream.writeRawData(reinterpret_cast(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int)); + saveState(); + QByteArray compressed = qCompress(m_state); + stream << compressed; +#if defined(DEBUG) + qDebug() << "chatllm serialize" << m_chat->id() << m_state.size(); +#endif + return stream.status() == QDataStream::Ok; +} + +bool ChatLLM::deserialize(QDataStream &stream, int version) +{ + if (version > 1) { + int internalStateVersion; + stream >> m_modelType; + stream >> internalStateVersion; // for future use + } + QString response; + stream >> response; + m_response = response.toStdString(); + QString nameResponse; + stream >> nameResponse; + m_nameResponse = nameResponse.toStdString(); + stream >> m_promptResponseTokens; + stream >> m_responseLogits; + stream >> m_ctx.n_past; + quint64 logitsSize; + stream >> logitsSize; + m_ctx.logits.resize(logitsSize); + stream.readRawData(reinterpret_cast(m_ctx.logits.data()), logitsSize * sizeof(float)); + quint64 tokensSize; + stream >> tokensSize; + m_ctx.tokens.resize(tokensSize); + stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + if (version > 0) { + QByteArray compressed; + stream >> compressed; + m_state = qUncompress(compressed); + } else { + stream >> m_state; + } +#if defined(DEBUG) + qDebug() << "chatllm deserialize" << m_chat->id(); +#endif + return stream.status() == QDataStream::Ok; +} + +void ChatLLM::saveState() +{ + if (!isModelLoaded()) + return; + + const size_t stateSize = m_llmodel->stateSize(); + m_state.resize(stateSize); +#if defined(DEBUG) + qDebug() << "chatllm saveState" << m_chat->id() << "size:" << m_state.size(); +#endif + m_llmodel->saveState(static_cast(reinterpret_cast(m_state.data()))); +} + +void ChatLLM::restoreState() +{ + if (!isModelLoaded() || m_state.isEmpty()) + return; + +#if defined(DEBUG) + qDebug() << "chatllm restoreState" << m_chat->id() << "size:" << m_state.size(); +#endif + m_llmodel->restoreState(static_cast(reinterpret_cast(m_state.data()))); + m_state.clear(); + m_state.resize(0); +} diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h new file mode 100644 index 00000000..bb488b16 --- /dev/null +++ b/gpt4all-chat/chatllm.h @@ -0,0 +1,100 @@ +#ifndef CHATLLM_H +#define CHATLLM_H + +#include +#include + +#include "llmodel/llmodel.h" + +class Chat; +class ChatLLM : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(QString response READ response NOTIFY responseChanged) + Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) + Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) + +public: + enum ModelType { + MPT_, + GPTJ_, + LLAMA_ + }; + + ChatLLM(Chat *parent); + + bool isModelLoaded() const; + void regenerateResponse(); + void resetResponse(); + void resetContext(); + + void stopGenerating() { m_stopGenerating = true; } + + QString response() const; + QString modelName() const; + + void setModelName(const QString &modelName); + + bool isRecalc() const { return m_isRecalc; } + + QString generatedName() const { return QString::fromStdString(m_nameResponse); } + + bool serialize(QDataStream &stream, int version); + bool deserialize(QDataStream &stream, int version); + +public Q_SLOTS: + bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, + int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, + int32_t n_threads); + bool loadDefaultModel(); + bool loadModel(const QString &modelName); + void modelNameChangeRequested(const QString &modelName); + void unloadModel(); + void reloadModel(const QString &modelName); + void generateName(); + void handleChatIdChanged(); + +Q_SIGNALS: + void isModelLoadedChanged(); + void modelLoadingError(const QString &error); + void responseChanged(); + void responseStarted(); + void responseStopped(); + void modelNameChanged(); + void recalcChanged(); + void sendStartup(); + void sendModelLoaded(); + void sendResetContext(); + void generatedNameChanged(); + void stateChanged(); + +private: + void resetContextPrivate(); + bool handlePrompt(int32_t token); + bool handleResponse(int32_t token, const std::string &response); + bool handleRecalculate(bool isRecalc); + bool handleNamePrompt(int32_t token); + bool handleNameResponse(int32_t token, const std::string &response); + bool handleNameRecalculate(bool isRecalc); + void saveState(); + void restoreState(); + +private: + LLModel::PromptContext m_ctx; + LLModel *m_llmodel; + std::string m_response; + std::string m_nameResponse; + quint32 m_promptResponseTokens; + quint32 m_responseLogits; + QString m_modelName; + ModelType m_modelType; + Chat *m_chat; + QByteArray m_state; + QThread m_llmThread; + std::atomic m_stopGenerating; + bool m_isRecalc; +}; + +#endif // CHATLLM_H diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h new file mode 100644 index 00000000..e3c01a9a --- /dev/null +++ b/gpt4all-chat/chatmodel.h @@ -0,0 +1,261 @@ +#ifndef CHATMODEL_H +#define CHATMODEL_H + +#include +#include +#include + +struct ChatItem +{ + Q_GADGET + Q_PROPERTY(int id MEMBER id) + Q_PROPERTY(QString name MEMBER name) + Q_PROPERTY(QString value MEMBER value) + Q_PROPERTY(QString prompt MEMBER prompt) + Q_PROPERTY(QString newResponse MEMBER newResponse) + Q_PROPERTY(bool currentResponse MEMBER currentResponse) + Q_PROPERTY(bool stopped MEMBER stopped) + Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) + Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) + +public: + int id = 0; + QString name; + QString value; + QString prompt; + QString newResponse; + bool currentResponse = false; + bool stopped = false; + bool thumbsUpState = false; + bool thumbsDownState = false; +}; +Q_DECLARE_METATYPE(ChatItem) + +class ChatModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + +public: + explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {} + + enum Roles { + IdRole = Qt::UserRole + 1, + NameRole, + ValueRole, + PromptRole, + NewResponseRole, + CurrentResponseRole, + StoppedRole, + ThumbsUpStateRole, + ThumbsDownStateRole + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_chatItems.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) + return QVariant(); + + const ChatItem &item = m_chatItems.at(index.row()); + switch (role) { + case IdRole: + return item.id; + case NameRole: + return item.name; + case ValueRole: + return item.value; + case PromptRole: + return item.prompt; + case NewResponseRole: + return item.newResponse; + case CurrentResponseRole: + return item.currentResponse; + case StoppedRole: + return item.stopped; + case ThumbsUpStateRole: + return item.thumbsUpState; + case ThumbsDownStateRole: + return item.thumbsDownState; + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[IdRole] = "id"; + roles[NameRole] = "name"; + roles[ValueRole] = "value"; + roles[PromptRole] = "prompt"; + roles[NewResponseRole] = "newResponse"; + roles[CurrentResponseRole] = "currentResponse"; + roles[StoppedRole] = "stopped"; + roles[ThumbsUpStateRole] = "thumbsUpState"; + roles[ThumbsDownStateRole] = "thumbsDownState"; + return roles; + } + + void appendPrompt(const QString &name, const QString &value) + { + ChatItem item; + item.name = name; + item.value = value; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(item); + endInsertRows(); + emit countChanged(); + } + + void appendResponse(const QString &name, const QString &prompt) + { + ChatItem item; + item.id = m_chatItems.count(); // This is only relevant for responses + item.name = name; + item.prompt = prompt; + item.currentResponse = true; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(item); + endInsertRows(); + emit countChanged(); + } + + Q_INVOKABLE void clear() + { + if (m_chatItems.isEmpty()) return; + + beginResetModel(); + m_chatItems.clear(); + endResetModel(); + emit countChanged(); + } + + Q_INVOKABLE ChatItem get(int index) + { + if (index < 0 || index >= m_chatItems.size()) return ChatItem(); + return m_chatItems.at(index); + } + + Q_INVOKABLE void updateCurrentResponse(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.currentResponse != b) { + item.currentResponse = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); + } + } + + Q_INVOKABLE void updateStopped(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.stopped != b) { + item.stopped = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); + } + } + + Q_INVOKABLE void updateValue(int index, const QString &value) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.value != value) { + item.value = value; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); + } + } + + Q_INVOKABLE void updateThumbsUpState(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.thumbsUpState != b) { + item.thumbsUpState = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole}); + } + } + + Q_INVOKABLE void updateThumbsDownState(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.thumbsDownState != b) { + item.thumbsDownState = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole}); + } + } + + Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.newResponse != newResponse) { + item.newResponse = newResponse; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); + } + } + + int count() const { return m_chatItems.size(); } + + bool serialize(QDataStream &stream, int version) const + { + stream << count(); + for (auto c : m_chatItems) { + stream << c.id; + stream << c.name; + stream << c.value; + stream << c.prompt; + stream << c.newResponse; + stream << c.currentResponse; + stream << c.stopped; + stream << c.thumbsUpState; + stream << c.thumbsDownState; + } + return stream.status() == QDataStream::Ok; + } + + bool deserialize(QDataStream &stream, int version) + { + int size; + stream >> size; + for (int i = 0; i < size; ++i) { + ChatItem c; + stream >> c.id; + stream >> c.name; + stream >> c.value; + stream >> c.prompt; + stream >> c.newResponse; + stream >> c.currentResponse; + stream >> c.stopped; + stream >> c.thumbsUpState; + stream >> c.thumbsDownState; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(c); + endInsertRows(); + } + emit countChanged(); + return stream.status() == QDataStream::Ok; + } + +Q_SIGNALS: + void countChanged(); + +private: + + QList m_chatItems; +}; + +#endif // CHATMODEL_H diff --git a/gpt4all-chat/cmake/config.h.in b/gpt4all-chat/cmake/config.h.in new file mode 100644 index 00000000..e578a82d --- /dev/null +++ b/gpt4all-chat/cmake/config.h.in @@ -0,0 +1,7 @@ +#ifndef CONFIG_H +#define CONFIG_H + +#define APP_VERSION "@APP_VERSION@" +#define GPT4ALL_AVX_ONLY "@GPT4ALL_AVX_ONLY@" + +#endif // CONFIG_H diff --git a/gpt4all-chat/cmake/deploy-qt-linux.cmake.in b/gpt4all-chat/cmake/deploy-qt-linux.cmake.in new file mode 100644 index 00000000..8c4240f1 --- /dev/null +++ b/gpt4all-chat/cmake/deploy-qt-linux.cmake.in @@ -0,0 +1,12 @@ +set(LINUXDEPLOYQT "@LINUXDEPLOYQT@") +set(COMPONENT_NAME_MAIN "@COMPONENT_NAME_MAIN@") +set(CMAKE_CURRENT_SOURCE_DIR "@CMAKE_CURRENT_SOURCE_DIR@") +set(DATA_DIR ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) +set(BIN_DIR ${DATA_DIR}/bin) +set(Qt6_ROOT_DIR "@Qt6_ROOT_DIR@") +set(ENV{LD_LIBRARY_PATH} "${BIN_DIR}:${Qt6_ROOT_DIR}/../lib/") +execute_process(COMMAND ${LINUXDEPLOYQT} ${BIN_DIR}/chat -qmldir=${CMAKE_CURRENT_SOURCE_DIR} -bundle-non-qt-libs -qmake=${Qt6_ROOT_DIR}/bin/qmake -verbose=2) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-32.png" + DESTINATION ${DATA_DIR}) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-48.png" + DESTINATION ${DATA_DIR}) diff --git a/gpt4all-chat/cmake/deploy-qt-mac.cmake.in b/gpt4all-chat/cmake/deploy-qt-mac.cmake.in new file mode 100644 index 00000000..f6ceb718 --- /dev/null +++ b/gpt4all-chat/cmake/deploy-qt-mac.cmake.in @@ -0,0 +1,16 @@ +set(MACDEPLOYQT "@MACDEPLOYQT@") +set(COMPONENT_NAME_MAIN "@COMPONENT_NAME_MAIN@") +set(CMAKE_CURRENT_SOURCE_DIR "@CMAKE_CURRENT_SOURCE_DIR@") +execute_process(COMMAND ${MACDEPLOYQT} ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin/gpt4all.app -qmldir=${CMAKE_CURRENT_SOURCE_DIR} -verbose=2) +file(COPY ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/lib/libllama.dylib + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin/gpt4all.app/Contents/Frameworks) +file(COPY ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/lib/libllmodel.dylib + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin/gpt4all.app/Contents/Frameworks) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.icns" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin/gpt4all.app/Contents/Resources) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-32.png" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-48.png" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.icns" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) diff --git a/gpt4all-chat/cmake/deploy-qt-windows.cmake.in b/gpt4all-chat/cmake/deploy-qt-windows.cmake.in new file mode 100644 index 00000000..80493951 --- /dev/null +++ b/gpt4all-chat/cmake/deploy-qt-windows.cmake.in @@ -0,0 +1,14 @@ +set(WINDEPLOYQT "@WINDEPLOYQT@") +set(COMPONENT_NAME_MAIN "@COMPONENT_NAME_MAIN@") +set(CMAKE_CURRENT_SOURCE_DIR "@CMAKE_CURRENT_SOURCE_DIR@") +execute_process(COMMAND ${WINDEPLOYQT} --qmldir ${CMAKE_CURRENT_SOURCE_DIR} ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin) +file(COPY ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/lib/libllama.dll + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin) +file(COPY ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/lib/libllmodel.dll + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data/bin) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-32.png" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/logo-48.png" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) +file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/icons/favicon.ico" + DESTINATION ${CPACK_TEMPORARY_INSTALL_DIRECTORY}/packages/${COMPONENT_NAME_MAIN}/data) diff --git a/gpt4all-chat/cmake/installerscript.qs b/gpt4all-chat/cmake/installerscript.qs new file mode 100644 index 00000000..841e1e52 --- /dev/null +++ b/gpt4all-chat/cmake/installerscript.qs @@ -0,0 +1,68 @@ +function Component() { +} + +var targetDirectory; +Component.prototype.beginInstallation = function() { + targetDirectory = installer.value("TargetDir"); +}; + +Component.prototype.createOperations = function() +{ + try { + // call the base create operations function + component.createOperations(); + if (systemInfo.productType === "windows") { + try { + var userProfile = installer.environmentVariable("USERPROFILE"); + installer.setValue("UserProfile", userProfile); + component.addOperation("CreateShortcut", + targetDirectory + "/bin/chat.exe", + "@UserProfile@/Desktop/GPT4All.lnk", + "workingDirectory=" + targetDirectory + "/bin", + "iconPath=" + targetDirectory + "/favicon.ico", + "iconId=0", "description=Open GPT4All"); + } catch (e) { + print("ERROR: creating desktop shortcut" + e); + } + component.addOperation("CreateShortcut", + targetDirectory + "/bin/chat.exe", + "@StartMenuDir@/GPT4All.lnk", + "workingDirectory=" + targetDirectory + "/bin", + "iconPath=" + targetDirectory + "/favicon.ico", + "iconId=0", "description=Open GPT4All"); + } else if (systemInfo.productType === "osx") { + var gpt4allAppPath = targetDirectory + "/bin/gpt4all.app"; + var symlinkPath = targetDirectory + "/../GPT4All.app"; + // Remove the symlink if it already exists + component.addOperation("Execute", "rm", "-f", symlinkPath); + // Create the symlink + component.addOperation("Execute", "ln", "-s", gpt4allAppPath, symlinkPath); + } else { // linux + var homeDir = installer.environmentVariable("HOME"); + if (!installer.fileExists(homeDir + "/Desktop/GPT4All.desktop")) { + component.addOperation("CreateDesktopEntry", + homeDir + "/Desktop/GPT4All.desktop", + "Type=Application\nTerminal=false\nExec=\"" + targetDirectory + + "/bin/chat\"\nName=GPT4All\nIcon=" + targetDirectory + + "/logo-48.png\nName[en_US]=GPT4All"); + } + } + } catch (e) { + print("ERROR: running post installscript.qs" + e); + } +} + +Component.prototype.createOperationsForArchive = function(archive) +{ + component.createOperationsForArchive(archive); + + if (systemInfo.productType === "osx") { + var uninstallTargetDirectory = installer.value("TargetDir"); + var symlinkPath = uninstallTargetDirectory + "/../GPT4All.app"; + + // Remove the symlink during uninstallation + if (installer.isUninstaller()) { + component.addOperation("Execute", "rm", "-f", symlinkPath, "UNDOEXECUTE"); + } + } +} diff --git a/gpt4all-chat/cmake/sign_dmg.py b/gpt4all-chat/cmake/sign_dmg.py new file mode 100644 index 00000000..c448db36 --- /dev/null +++ b/gpt4all-chat/cmake/sign_dmg.py @@ -0,0 +1,81 @@ +import os +import subprocess +import tempfile +import shutil +import click +from typing import Optional + +# Requires click +# pip install click + +# Example usage +# python sign_dmg.py --input-dmg /path/to/your/input.dmg --output-dmg /path/to/your/output.dmg --signing-identity "Developer ID Application: YOUR_NAME (TEAM_ID)" + +# NOTE: This script assumes that you have the necessary Developer ID Application certificate in your +# Keychain Access and that the codesign and hdiutil command-line tools are available on your system. + +@click.command() +@click.option('--input-dmg', required=True, help='Path to the input DMG file.') +@click.option('--output-dmg', required=True, help='Path to the output signed DMG file.') +@click.option('--sha1-hash', help='SHA-1 hash of the Developer ID Application certificate') +@click.option('--signing-identity', default=None, help='Common name of the Developer ID Application certificate') +def sign_dmg(input_dmg: str, output_dmg: str, signing_identity: Optional[str] = None, sha1_hash: Optional[str] = None) -> None: + if not signing_identity and not sha1_hash: + print("Error: Either --signing-identity or --sha1-hash must be provided.") + exit(1) + + # Mount the input DMG + mount_point = tempfile.mkdtemp() + subprocess.run(['hdiutil', 'attach', input_dmg, '-mountpoint', mount_point]) + + # Copy the contents of the DMG to a temporary folder + temp_dir = tempfile.mkdtemp() + shutil.copytree(mount_point, os.path.join(temp_dir, 'contents')) + subprocess.run(['hdiutil', 'detach', mount_point]) + + # Find the .app bundle in the temporary folder + app_bundle = None + for item in os.listdir(os.path.join(temp_dir, 'contents')): + if item.endswith('.app'): + app_bundle = os.path.join(temp_dir, 'contents', item) + break + + if not app_bundle: + print('No .app bundle found in the DMG.') + exit(1) + + # Sign the .app bundle + try: + subprocess.run([ + 'codesign', + '--deep', + '--force', + '--verbose', + '--options', 'runtime', + '--timestamp', + '--sign', sha1_hash or signing_identity, + app_bundle + ], check=True) + except subprocess.CalledProcessError as e: + print(f"Error during codesign: {e}") + # Clean up temporary directories + shutil.rmtree(temp_dir) + shutil.rmtree(mount_point) + exit(1) + + # Create a new DMG containing the signed .app bundle + subprocess.run([ + 'hdiutil', 'create', + '-volname', os.path.splitext(os.path.basename(input_dmg))[0], + '-srcfolder', os.path.join(temp_dir, 'contents'), + '-ov', + '-format', 'UDZO', + output_dmg + ]) + + # Clean up temporary directories + shutil.rmtree(temp_dir) + shutil.rmtree(mount_point) + +if __name__ == '__main__': + sign_dmg() diff --git a/gpt4all-chat/download.cpp b/gpt4all-chat/download.cpp new file mode 100644 index 00000000..736c8fa1 --- /dev/null +++ b/gpt4all-chat/download.cpp @@ -0,0 +1,600 @@ +#include "download.h" +#include "network.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class MyDownload: public Download { }; +Q_GLOBAL_STATIC(MyDownload, downloadInstance) +Download *Download::globalInstance() +{ + return downloadInstance(); +} + +Download::Download() + : QObject(nullptr) + , m_hashAndSave(new HashAndSaveFile) +{ + connect(this, &Download::requestHashAndSave, m_hashAndSave, + &HashAndSaveFile::hashAndSave, Qt::QueuedConnection); + connect(m_hashAndSave, &HashAndSaveFile::hashAndSaveFinished, this, + &Download::handleHashAndSaveFinished, Qt::QueuedConnection); + connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, + &Download::handleSslErrors); + connect(this, &Download::downloadLocalModelsPathChanged, this, &Download::updateModelList); + updateModelList(); + updateReleaseNotes(); + QSettings settings; + settings.sync(); + m_downloadLocalModelsPath = settings.value("modelPath", + defaultLocalModelsPath()).toString(); + m_startTime = QDateTime::currentDateTime(); +} + +bool operator==(const ModelInfo& lhs, const ModelInfo& rhs) { + return lhs.filename == rhs.filename && lhs.md5sum == rhs.md5sum; +} + +bool operator==(const ReleaseInfo& lhs, const ReleaseInfo& rhs) { + return lhs.version == rhs.version; +} + +bool compareVersions(const QString &a, const QString &b) { + QStringList aParts = a.split('.'); + QStringList bParts = b.split('.'); + + for (int i = 0; i < std::min(aParts.size(), bParts.size()); ++i) { + int aInt = aParts[i].toInt(); + int bInt = bParts[i].toInt(); + + if (aInt > bInt) { + return true; + } else if (aInt < bInt) { + return false; + } + } + + return aParts.size() > bParts.size(); +} + +QList Download::modelList() const +{ + // We make sure the default model is listed first + QList values = m_modelMap.values(); + ModelInfo defaultInfo; + ModelInfo bestGPTJInfo; + ModelInfo bestLlamaInfo; + ModelInfo bestMPTInfo; + QList filtered; + for (ModelInfo v : values) { + if (v.isDefault) + defaultInfo = v; + if (v.bestGPTJ) + bestGPTJInfo = v; + if (v.bestLlama) + bestLlamaInfo = v; + if (v.bestMPT) + bestMPTInfo = v; + filtered.append(v); + } + + Q_ASSERT(defaultInfo == bestGPTJInfo || defaultInfo == bestLlamaInfo || defaultInfo == bestMPTInfo); + + if (bestLlamaInfo.bestLlama) { + filtered.removeAll(bestLlamaInfo); + filtered.prepend(bestLlamaInfo); + } + + if (bestGPTJInfo.bestGPTJ) { + filtered.removeAll(bestGPTJInfo); + filtered.prepend(bestGPTJInfo); + } + + if (bestMPTInfo.bestMPT) { + filtered.removeAll(bestMPTInfo); + filtered.prepend(bestMPTInfo); + } + + return filtered; +} + +ReleaseInfo Download::releaseInfo() const +{ + const QString currentVersion = QCoreApplication::applicationVersion(); + if (m_releaseMap.contains(currentVersion)) + return m_releaseMap.value(currentVersion); + return ReleaseInfo(); +} + +bool Download::hasNewerRelease() const +{ + const QString currentVersion = QCoreApplication::applicationVersion(); + QList versions = m_releaseMap.keys(); + std::sort(versions.begin(), versions.end(), compareVersions); + if (versions.isEmpty()) + return false; + return compareVersions(versions.first(), currentVersion); +} + +QString Download::downloadLocalModelsPath() const { + return m_downloadLocalModelsPath; +} + +void Download::setDownloadLocalModelsPath(const QString &modelPath) { + QString filePath = (modelPath.startsWith("file://") ? + QUrl(modelPath).toLocalFile() : modelPath); + QString canonical = QFileInfo(filePath).canonicalFilePath() + "/"; + if (m_downloadLocalModelsPath != canonical) { + m_downloadLocalModelsPath = canonical; + emit downloadLocalModelsPathChanged(); + } +} + +bool Download::isFirstStart() const +{ + QSettings settings; + settings.sync(); + QString lastVersionStarted = settings.value("download/lastVersionStarted").toString(); + bool first = lastVersionStarted != QCoreApplication::applicationVersion(); + settings.setValue("download/lastVersionStarted", QCoreApplication::applicationVersion()); + settings.sync(); + return first; +} + +QString Download::incompleteDownloadPath(const QString &modelFile) { + QString downloadPath = downloadLocalModelsPath() + "incomplete-" + + modelFile; + return downloadPath; +} + +QString Download::defaultLocalModelsPath() const +{ + QString localPath = QStandardPaths::writableLocation(QStandardPaths::AppLocalDataLocation) + + "/"; + QString testWritePath = localPath + QString("test_write.txt"); + QString canonicalLocalPath = QFileInfo(localPath).canonicalFilePath() + "/"; + QDir localDir(localPath); + if (!localDir.exists()) { + if (!localDir.mkpath(localPath)) { + qWarning() << "ERROR: Local download directory can't be created:" << canonicalLocalPath; + return canonicalLocalPath; + } + } + + if (QFileInfo::exists(testWritePath)) + return canonicalLocalPath; + + QFile testWriteFile(testWritePath); + if (testWriteFile.open(QIODeviceBase::ReadWrite)) { + testWriteFile.close(); + return canonicalLocalPath; + } + + qWarning() << "ERROR: Local download path appears not writeable:" << canonicalLocalPath; + return canonicalLocalPath; +} + +void Download::updateModelList() +{ + QUrl jsonUrl("http://gpt4all.io/models/models.json"); + QNetworkRequest request(jsonUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *jsonReply = m_networkManager.get(request); + connect(jsonReply, &QNetworkReply::finished, this, &Download::handleModelsJsonDownloadFinished); +} + +void Download::updateReleaseNotes() +{ + QUrl jsonUrl("http://gpt4all.io/meta/release.json"); + QNetworkRequest request(jsonUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *jsonReply = m_networkManager.get(request); + connect(jsonReply, &QNetworkReply::finished, this, &Download::handleReleaseJsonDownloadFinished); +} + +void Download::downloadModel(const QString &modelFile) +{ + QFile *tempFile = new QFile(incompleteDownloadPath(modelFile)); + QDateTime modTime = tempFile->fileTime(QFile::FileModificationTime); + bool success = tempFile->open(QIODevice::WriteOnly | QIODevice::Append); + qWarning() << "Opening temp file for writing:" << tempFile->fileName(); + if (!success) { + qWarning() << "ERROR: Could not open temp file:" + << tempFile->fileName() << modelFile; + return; + } + size_t incomplete_size = tempFile->size(); + if (incomplete_size > 0) { + if (modTime < m_startTime) { + qWarning() << "File last modified before app started, rewinding by 1MB"; + if (incomplete_size >= 1024 * 1024) { + incomplete_size -= 1024 * 1024; + } else { + incomplete_size = 0; + } + } + tempFile->seek(incomplete_size); + } + + Network::globalInstance()->sendDownloadStarted(modelFile); + QNetworkRequest request("http://gpt4all.io/models/" + modelFile); + request.setRawHeader("range", QString("bytes=%1-").arg(incomplete_size).toUtf8()); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *modelReply = m_networkManager.get(request); + connect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress); + connect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished); + connect(modelReply, &QNetworkReply::readyRead, this, &Download::handleReadyRead); + m_activeDownloads.insert(modelReply, tempFile); +} + +void Download::cancelDownload(const QString &modelFile) +{ + for (int i = 0; i < m_activeDownloads.size(); ++i) { + QNetworkReply *modelReply = m_activeDownloads.keys().at(i); + QUrl url = modelReply->request().url(); + if (url.toString().endsWith(modelFile)) { + Network::globalInstance()->sendDownloadCanceled(modelFile); + + // Disconnect the signals + disconnect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress); + disconnect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished); + + modelReply->abort(); // Abort the download + modelReply->deleteLater(); // Schedule the reply for deletion + + QFile *tempFile = m_activeDownloads.value(modelReply); + tempFile->deleteLater(); + m_activeDownloads.remove(modelReply); + + // Emit downloadFinished signal for cleanup + emit downloadFinished(modelFile); + break; + } + } +} + +void Download::handleSslErrors(QNetworkReply *reply, const QList &errors) +{ + QUrl url = reply->request().url(); + for (auto e : errors) + qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url; +} + +void Download::handleModelsJsonDownloadFinished() +{ +#if 0 + QByteArray jsonData = QString("" + "[" + " {" + " \"md5sum\": \"61d48a82cb188cceb14ebb8082bfec37\"," + " \"filename\": \"ggml-gpt4all-j-v1.1-breezy.bin\"," + " \"filesize\": \"3785248281\"" + " }," + " {" + " \"md5sum\": \"879344aaa9d62fdccbda0be7a09e7976\"," + " \"filename\": \"ggml-gpt4all-j-v1.2-jazzy.bin\"," + " \"filesize\": \"3785248281\"," + " \"isDefault\": \"true\"" + " }," + " {" + " \"md5sum\": \"5b5a3f9b858d33b29b52b89692415595\"," + " \"filesize\": \"3785248281\"," + " \"filename\": \"ggml-gpt4all-j.bin\"" + " }" + "]" + ).toUtf8(); + printf("%s\n", jsonData.toStdString().c_str()); + fflush(stdout); +#else + QNetworkReply *jsonReply = qobject_cast(sender()); + if (!jsonReply) + return; + + QByteArray jsonData = jsonReply->readAll(); + jsonReply->deleteLater(); +#endif + parseModelsJsonFile(jsonData); +} + +void Download::parseModelsJsonFile(const QByteArray &jsonData) +{ + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); + if (err.error != QJsonParseError::NoError) { + qDebug() << "ERROR: Couldn't parse: " << jsonData << err.errorString(); + return; + } + + QString defaultModel; + QJsonArray jsonArray = document.array(); + const QString currentVersion = QCoreApplication::applicationVersion(); + + m_modelMap.clear(); + for (const QJsonValue &value : jsonArray) { + QJsonObject obj = value.toObject(); + + QString modelFilename = obj["filename"].toString(); + QString modelFilesize = obj["filesize"].toString(); + QString requires = obj["requires"].toString(); + QByteArray modelMd5sum = obj["md5sum"].toString().toLatin1().constData(); + bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true"); + bool bestGPTJ = obj.contains("bestGPTJ") && obj["bestGPTJ"] == QString("true"); + bool bestLlama = obj.contains("bestLlama") && obj["bestLlama"] == QString("true"); + bool bestMPT = obj.contains("bestMPT") && obj["bestMPT"] == QString("true"); + QString description = obj["description"].toString(); + + if (!requires.isEmpty() + && requires != currentVersion + && compareVersions(requires, currentVersion)) { + continue; + } + + if (isDefault) + defaultModel = modelFilename; + quint64 sz = modelFilesize.toULongLong(); + if (sz < 1024) { + modelFilesize = QString("%1 bytes").arg(sz); + } else if (sz < 1024 * 1024) { + modelFilesize = QString("%1 KB").arg(qreal(sz) / 1024, 0, 'g', 3); + } else if (sz < 1024 * 1024 * 1024) { + modelFilesize = QString("%1 MB").arg(qreal(sz) / (1024 * 1024), 0, 'g', 3); + } else { + modelFilesize = QString("%1 GB").arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); + } + + QString filePath = downloadLocalModelsPath() + modelFilename; + QFileInfo info(filePath); + ModelInfo modelInfo; + modelInfo.filename = modelFilename; + modelInfo.filesize = modelFilesize; + modelInfo.md5sum = modelMd5sum; + modelInfo.installed = info.exists(); + modelInfo.isDefault = isDefault; + modelInfo.bestGPTJ = bestGPTJ; + modelInfo.bestLlama = bestLlama; + modelInfo.bestMPT = bestMPT; + modelInfo.description = description; + modelInfo.requires = requires; + m_modelMap.insert(modelInfo.filename, modelInfo); + } + + // remove ggml- prefix and .bin suffix + Q_ASSERT(defaultModel.startsWith("ggml-")); + defaultModel = defaultModel.remove(0, 5); + Q_ASSERT(defaultModel.endsWith(".bin")); + defaultModel.chop(4); + + QSettings settings; + settings.sync(); + settings.setValue("defaultModel", defaultModel); + settings.sync(); + emit modelListChanged(); +} + +void Download::handleReleaseJsonDownloadFinished() +{ + QNetworkReply *jsonReply = qobject_cast(sender()); + if (!jsonReply) + return; + + QByteArray jsonData = jsonReply->readAll(); + jsonReply->deleteLater(); + parseReleaseJsonFile(jsonData); +} + +void Download::parseReleaseJsonFile(const QByteArray &jsonData) +{ + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); + if (err.error != QJsonParseError::NoError) { + qDebug() << "ERROR: Couldn't parse: " << jsonData << err.errorString(); + return; + } + + QJsonArray jsonArray = document.array(); + + m_releaseMap.clear(); + for (const QJsonValue &value : jsonArray) { + QJsonObject obj = value.toObject(); + + QString version = obj["version"].toString(); + QString notes = obj["notes"].toString(); + QString contributors = obj["contributors"].toString(); + ReleaseInfo releaseInfo; + releaseInfo.version = version; + releaseInfo.notes = notes; + releaseInfo.contributors = contributors; + m_releaseMap.insert(version, releaseInfo); + } + + emit hasNewerReleaseChanged(); + emit releaseInfoChanged(); +} + +void Download::handleErrorOccurred(QNetworkReply::NetworkError code) +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); + qWarning() << "ERROR: Network error occurred attempting to download" + << modelFilename + << "code:" << code + << "errorString" << modelReply->errorString(); + Network::globalInstance()->sendDownloadError(modelFilename, (int)code, modelReply->errorString()); + cancelDownload(modelFilename); +} + +void Download::handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal) +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + QFile *tempFile = m_activeDownloads.value(modelReply); + if (!tempFile) + return; + QString contentRange = modelReply->rawHeader("content-range"); + if (contentRange.contains("/")) { + QString contentTotalSize = contentRange.split("/").last(); + bytesTotal = contentTotalSize.toLongLong(); + } + + QString modelFilename = modelReply->url().fileName(); + emit downloadProgress(tempFile->pos(), bytesTotal, modelFilename); +} + +HashAndSaveFile::HashAndSaveFile() + : QObject(nullptr) +{ + moveToThread(&m_hashAndSaveThread); + m_hashAndSaveThread.setObjectName("hashandsave thread"); + m_hashAndSaveThread.start(); +} + +void HashAndSaveFile::hashAndSave(const QString &expectedHash, const QString &saveFilePath, + QFile *tempFile, QNetworkReply *modelReply) +{ + Q_ASSERT(!tempFile->isOpen()); + QString modelFilename = modelReply->url().fileName(); + + // Reopen the tempFile for hashing + if (!tempFile->open(QIODevice::ReadOnly)) { + qWarning() << "ERROR: Could not open temp file for hashing:" + << tempFile->fileName() << modelFilename; + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + QCryptographicHash hash(QCryptographicHash::Md5); + while(!tempFile->atEnd()) + hash.addData(tempFile->read(16384)); + if (hash.result().toHex() != expectedHash) { + tempFile->close(); + qWarning() << "ERROR: Download error MD5SUM did not match:" + << hash.result().toHex() + << "!=" << expectedHash << "for" << modelFilename; + tempFile->remove(); + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + // The file save needs the tempFile closed + tempFile->close(); + + // Attempt to *move* the verified tempfile into place - this should be atomic + // but will only work if the destination is on the same filesystem + if (tempFile->rename(saveFilePath)) { + emit hashAndSaveFinished(true, tempFile, modelReply); + return; + } + + // Reopen the tempFile for copying + if (!tempFile->open(QIODevice::ReadOnly)) { + qWarning() << "ERROR: Could not open temp file at finish:" + << tempFile->fileName() << modelFilename; + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } + + // Save the model file to disk + QFile file(saveFilePath); + if (file.open(QIODevice::WriteOnly)) { + QByteArray buffer; + while (!tempFile->atEnd()) { + buffer = tempFile->read(16384); + file.write(buffer); + } + file.close(); + tempFile->close(); + emit hashAndSaveFinished(true, tempFile, modelReply); + } else { + QFile::FileError error = file.error(); + qWarning() << "ERROR: Could not save model to location:" + << saveFilePath + << "failed with code" << error; + tempFile->close(); + emit hashAndSaveFinished(false, tempFile, modelReply); + return; + } +} + +void Download::handleModelDownloadFinished() +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); + QFile *tempFile = m_activeDownloads.value(modelReply); + m_activeDownloads.remove(modelReply); + + if (modelReply->error()) { + qWarning() << "ERROR: downloading:" << modelReply->errorString(); + modelReply->deleteLater(); + tempFile->deleteLater(); + emit downloadFinished(modelFilename); + return; + } + + // The hash and save needs the tempFile closed + tempFile->close(); + + // Notify that we are calculating hash + ModelInfo info = m_modelMap.value(modelFilename); + info.calcHash = true; + m_modelMap.insert(modelFilename, info); + emit modelListChanged(); + + const QString saveFilePath = downloadLocalModelsPath() + modelFilename; + emit requestHashAndSave(info.md5sum, saveFilePath, tempFile, modelReply); +} + +void Download::handleHashAndSaveFinished(bool success, + QFile *tempFile, QNetworkReply *modelReply) +{ + // The hash and save should send back with tempfile closed + Q_ASSERT(!tempFile->isOpen()); + QString modelFilename = modelReply->url().fileName(); + Network::globalInstance()->sendDownloadFinished(modelFilename, success); + + ModelInfo info = m_modelMap.value(modelFilename); + info.calcHash = false; + info.installed = success; + m_modelMap.insert(modelFilename, info); + emit modelListChanged(); + + modelReply->deleteLater(); + tempFile->deleteLater(); + emit downloadFinished(modelFilename); +} + +void Download::handleReadyRead() +{ + QNetworkReply *modelReply = qobject_cast(sender()); + if (!modelReply) + return; + + QString modelFilename = modelReply->url().fileName(); + QFile *tempFile = m_activeDownloads.value(modelReply); + QByteArray buffer; + while (!modelReply->atEnd()) { + buffer = modelReply->read(16384); + tempFile->write(buffer); + } + tempFile->flush(); +} diff --git a/gpt4all-chat/download.h b/gpt4all-chat/download.h new file mode 100644 index 00000000..638bae43 --- /dev/null +++ b/gpt4all-chat/download.h @@ -0,0 +1,136 @@ +#ifndef DOWNLOAD_H +#define DOWNLOAD_H + +#include +#include +#include +#include +#include +#include +#include + +struct ModelInfo { + Q_GADGET + Q_PROPERTY(QString filename MEMBER filename) + Q_PROPERTY(QString filesize MEMBER filesize) + Q_PROPERTY(QByteArray md5sum MEMBER md5sum) + Q_PROPERTY(bool calcHash MEMBER calcHash) + Q_PROPERTY(bool installed MEMBER installed) + Q_PROPERTY(bool isDefault MEMBER isDefault) + Q_PROPERTY(bool bestGPTJ MEMBER bestGPTJ) + Q_PROPERTY(bool bestLlama MEMBER bestLlama) + Q_PROPERTY(bool bestMPT MEMBER bestMPT) + Q_PROPERTY(QString description MEMBER description) + Q_PROPERTY(QString requires MEMBER requires) + +public: + QString filename; + QString filesize; + QByteArray md5sum; + bool calcHash = false; + bool installed = false; + bool isDefault = false; + bool bestGPTJ = false; + bool bestLlama = false; + bool bestMPT = false; + QString description; + QString requires; +}; +Q_DECLARE_METATYPE(ModelInfo) + +struct ReleaseInfo { + Q_GADGET + Q_PROPERTY(QString version MEMBER version) + Q_PROPERTY(QString notes MEMBER notes) + Q_PROPERTY(QString contributors MEMBER contributors) + +public: + QString version; + QString notes; + QString contributors; +}; + +class HashAndSaveFile : public QObject +{ + Q_OBJECT +public: + HashAndSaveFile(); + +public Q_SLOTS: + void hashAndSave(const QString &hash, const QString &saveFilePath, + QFile *tempFile, QNetworkReply *modelReply); + +Q_SIGNALS: + void hashAndSaveFinished(bool success, + QFile *tempFile, QNetworkReply *modelReply); + +private: + QThread m_hashAndSaveThread; +}; + +class Download : public QObject +{ + Q_OBJECT + Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) + Q_PROPERTY(bool hasNewerRelease READ hasNewerRelease NOTIFY hasNewerReleaseChanged) + Q_PROPERTY(ReleaseInfo releaseInfo READ releaseInfo NOTIFY releaseInfoChanged) + Q_PROPERTY(QString downloadLocalModelsPath READ downloadLocalModelsPath + WRITE setDownloadLocalModelsPath + NOTIFY downloadLocalModelsPathChanged) + +public: + static Download *globalInstance(); + + QList modelList() const; + ReleaseInfo releaseInfo() const; + bool hasNewerRelease() const; + Q_INVOKABLE void updateModelList(); + Q_INVOKABLE void updateReleaseNotes(); + Q_INVOKABLE void downloadModel(const QString &modelFile); + Q_INVOKABLE void cancelDownload(const QString &modelFile); + Q_INVOKABLE QString defaultLocalModelsPath() const; + Q_INVOKABLE QString downloadLocalModelsPath() const; + Q_INVOKABLE void setDownloadLocalModelsPath(const QString &modelPath); + Q_INVOKABLE bool isFirstStart() const; + +private Q_SLOTS: + void handleSslErrors(QNetworkReply *reply, const QList &errors); + void handleModelsJsonDownloadFinished(); + void handleReleaseJsonDownloadFinished(); + void handleErrorOccurred(QNetworkReply::NetworkError code); + void handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal); + void handleModelDownloadFinished(); + void handleHashAndSaveFinished(bool success, + QFile *tempFile, QNetworkReply *modelReply); + void handleReadyRead(); + +Q_SIGNALS: + void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile); + void downloadFinished(const QString &modelFile); + void modelListChanged(); + void releaseInfoChanged(); + void hasNewerReleaseChanged(); + void downloadLocalModelsPathChanged(); + void requestHashAndSave(const QString &hash, const QString &saveFilePath, + QFile *tempFile, QNetworkReply *modelReply); + +private: + void parseModelsJsonFile(const QByteArray &jsonData); + void parseReleaseJsonFile(const QByteArray &jsonData); + QString incompleteDownloadPath(const QString &modelFile); + + HashAndSaveFile *m_hashAndSave; + QMap m_modelMap; + QMap m_releaseMap; + QNetworkAccessManager m_networkManager; + QMap m_activeDownloads; + QString m_downloadLocalModelsPath; + QDateTime m_startTime; + +private: + explicit Download(); + ~Download() {} + friend class MyDownload; +}; + +#endif // DOWNLOAD_H diff --git a/gpt4all-chat/icons/copy.svg b/gpt4all-chat/icons/copy.svg new file mode 100644 index 00000000..5ab45b5b --- /dev/null +++ b/gpt4all-chat/icons/copy.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/gpt4all-chat/icons/edit.svg b/gpt4all-chat/icons/edit.svg new file mode 100644 index 00000000..9820173b --- /dev/null +++ b/gpt4all-chat/icons/edit.svg @@ -0,0 +1,5 @@ + + diff --git a/gpt4all-chat/icons/favicon.icns b/gpt4all-chat/icons/favicon.icns new file mode 100644 index 00000000..38638c7a Binary files /dev/null and b/gpt4all-chat/icons/favicon.icns differ diff --git a/gpt4all-chat/icons/favicon.ico b/gpt4all-chat/icons/favicon.ico new file mode 100644 index 00000000..d7c27c4b Binary files /dev/null and b/gpt4all-chat/icons/favicon.ico differ diff --git a/gpt4all-chat/icons/logo-1024.png b/gpt4all-chat/icons/logo-1024.png new file mode 100644 index 00000000..2fb7cbdc Binary files /dev/null and b/gpt4all-chat/icons/logo-1024.png differ diff --git a/gpt4all-chat/icons/logo-128.png b/gpt4all-chat/icons/logo-128.png new file mode 100644 index 00000000..81c52374 Binary files /dev/null and b/gpt4all-chat/icons/logo-128.png differ diff --git a/gpt4all-chat/icons/logo-16.png b/gpt4all-chat/icons/logo-16.png new file mode 100644 index 00000000..344ac9b1 Binary files /dev/null and b/gpt4all-chat/icons/logo-16.png differ diff --git a/gpt4all-chat/icons/logo-256.png b/gpt4all-chat/icons/logo-256.png new file mode 100644 index 00000000..291f3003 Binary files /dev/null and b/gpt4all-chat/icons/logo-256.png differ diff --git a/gpt4all-chat/icons/logo-32.png b/gpt4all-chat/icons/logo-32.png new file mode 100644 index 00000000..06628744 Binary files /dev/null and b/gpt4all-chat/icons/logo-32.png differ diff --git a/gpt4all-chat/icons/logo-48.png b/gpt4all-chat/icons/logo-48.png new file mode 100644 index 00000000..4d53f9f7 Binary files /dev/null and b/gpt4all-chat/icons/logo-48.png differ diff --git a/gpt4all-chat/icons/logo-512.png b/gpt4all-chat/icons/logo-512.png new file mode 100644 index 00000000..4f70fbb4 Binary files /dev/null and b/gpt4all-chat/icons/logo-512.png differ diff --git a/gpt4all-chat/icons/logo-64.png b/gpt4all-chat/icons/logo-64.png new file mode 100644 index 00000000..fbffe619 Binary files /dev/null and b/gpt4all-chat/icons/logo-64.png differ diff --git a/gpt4all-chat/icons/logo.svg b/gpt4all-chat/icons/logo.svg new file mode 100644 index 00000000..e7084ec0 --- /dev/null +++ b/gpt4all-chat/icons/logo.svg @@ -0,0 +1,14 @@ + + + + + + + GPT + 4All + diff --git a/gpt4all-chat/icons/network.svg b/gpt4all-chat/icons/network.svg new file mode 100644 index 00000000..266f13d6 --- /dev/null +++ b/gpt4all-chat/icons/network.svg @@ -0,0 +1 @@ + diff --git a/gpt4all-chat/icons/regenerate.svg b/gpt4all-chat/icons/regenerate.svg new file mode 100644 index 00000000..016e6a52 --- /dev/null +++ b/gpt4all-chat/icons/regenerate.svg @@ -0,0 +1 @@ + diff --git a/gpt4all-chat/icons/send_message.svg b/gpt4all-chat/icons/send_message.svg new file mode 100644 index 00000000..d8650b66 --- /dev/null +++ b/gpt4all-chat/icons/send_message.svg @@ -0,0 +1 @@ + diff --git a/gpt4all-chat/icons/settings.svg b/gpt4all-chat/icons/settings.svg new file mode 100644 index 00000000..7542ea62 --- /dev/null +++ b/gpt4all-chat/icons/settings.svg @@ -0,0 +1,46 @@ + + + + + + + + diff --git a/gpt4all-chat/icons/stop_generating.svg b/gpt4all-chat/icons/stop_generating.svg new file mode 100644 index 00000000..c627ac0e --- /dev/null +++ b/gpt4all-chat/icons/stop_generating.svg @@ -0,0 +1 @@ + diff --git a/gpt4all-chat/icons/thumbs_down.svg b/gpt4all-chat/icons/thumbs_down.svg new file mode 100644 index 00000000..b01a82d3 --- /dev/null +++ b/gpt4all-chat/icons/thumbs_down.svg @@ -0,0 +1,5 @@ + + diff --git a/gpt4all-chat/icons/thumbs_up.svg b/gpt4all-chat/icons/thumbs_up.svg new file mode 100644 index 00000000..cd5efcd2 --- /dev/null +++ b/gpt4all-chat/icons/thumbs_up.svg @@ -0,0 +1,5 @@ + + diff --git a/gpt4all-chat/icons/trash.svg b/gpt4all-chat/icons/trash.svg new file mode 100644 index 00000000..b7c1a141 --- /dev/null +++ b/gpt4all-chat/icons/trash.svg @@ -0,0 +1,5 @@ + + diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp new file mode 100644 index 00000000..e94c461b --- /dev/null +++ b/gpt4all-chat/llm.cpp @@ -0,0 +1,79 @@ +#include "llm.h" +#include "config.h" +#include "download.h" +#include "network.h" + +#include +#include +#include +#include +#include +#include +#include + +class MyLLM: public LLM { }; +Q_GLOBAL_STATIC(MyLLM, llmInstance) +LLM *LLM::globalInstance() +{ + return llmInstance(); +} + +LLM::LLM() + : QObject{nullptr} + , m_chatListModel(new ChatListModel(this)) + , m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) + , m_compatHardware(true) +{ + connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit, + this, &LLM::aboutToQuit); + +#if defined(__x86_64__) || defined(__i386__) + if (QString(GPT4ALL_AVX_ONLY) == "OFF") { + const bool avx(__builtin_cpu_supports("avx")); + const bool avx2(__builtin_cpu_supports("avx2")); + const bool fma(__builtin_cpu_supports("fma")); + m_compatHardware = avx && avx2 && fma; + emit compatHardwareChanged(); + } +#endif +} + +bool LLM::checkForUpdates() const +{ + Network::globalInstance()->sendCheckForUpdates(); + +#if defined(Q_OS_LINUX) + QString tool("maintenancetool"); +#elif defined(Q_OS_WINDOWS) + QString tool("maintenancetool.exe"); +#elif defined(Q_OS_DARWIN) + QString tool("../../../maintenancetool.app/Contents/MacOS/maintenancetool"); +#endif + + QString fileName = QCoreApplication::applicationDirPath() + + "/../" + tool; + if (!QFileInfo::exists(fileName)) { + qDebug() << "Couldn't find tool at" << fileName << "so cannot check for updates!"; + return false; + } + + return QProcess::startDetached(fileName); +} + +int32_t LLM::threadCount() const +{ + return m_threadCount; +} + +void LLM::setThreadCount(int32_t n_threads) +{ + if (n_threads <= 0) + n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + m_threadCount = n_threads; + emit threadCountChanged(); +} + +void LLM::aboutToQuit() +{ + m_chatListModel->saveChats(); +} diff --git a/gpt4all-chat/llm.h b/gpt4all-chat/llm.h new file mode 100644 index 00000000..ac12981d --- /dev/null +++ b/gpt4all-chat/llm.h @@ -0,0 +1,44 @@ +#ifndef LLM_H +#define LLM_H + +#include + +#include "chatlistmodel.h" + +class LLM : public QObject +{ + Q_OBJECT + Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool compatHardware READ compatHardware NOTIFY compatHardwareChanged) + +public: + static LLM *globalInstance(); + + ChatListModel *chatListModel() const { return m_chatListModel; } + int32_t threadCount() const; + void setThreadCount(int32_t n_threads); + bool compatHardware() const { return m_compatHardware; } + + Q_INVOKABLE bool checkForUpdates() const; + +Q_SIGNALS: + void chatListModelChanged(); + void threadCountChanged(); + void compatHardwareChanged(); + +private Q_SLOTS: + void aboutToQuit(); + +private: + ChatListModel *m_chatListModel; + int32_t m_threadCount; + bool m_compatHardware; + +private: + explicit LLM(); + ~LLM() {} + friend class MyLLM; +}; + +#endif // LLM_H diff --git a/gpt4all-chat/llmodel/CMakeLists.txt b/gpt4all-chat/llmodel/CMakeLists.txt new file mode 100644 index 00000000..704faccc --- /dev/null +++ b/gpt4all-chat/llmodel/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.16) + +if(APPLE) + option(BUILD_UNIVERSAL "Build a Universal binary on macOS" ON) + if(BUILD_UNIVERSAL) + # Build a Universal binary on macOS + # This requires that the found Qt library is compiled as Universal binaries. + set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE) + else() + # Build for the host architecture on macOS + set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE) + endif() +endif() + +# Include the binary directory for the generated header file +include_directories("${CMAKE_CURRENT_BINARY_DIR}") + +project(llmodel VERSION ${APP_VERSION} LANGUAGES CXX C) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE) +set(BUILD_SHARED_LIBS ON FORCE) + +set(CMAKE_VERBOSE_MAKEFILE ON) +if (GPT4ALL_AVX_ONLY) + set(LLAMA_AVX2 OFF CACHE BOOL "llama: enable AVX2" FORCE) + set(LLAMA_F16C OFF CACHE BOOL "llama: enable F16C" FORCE) + set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE) +endif() + +add_subdirectory(llama.cpp) + +add_library(llmodel + gptj.h gptj.cpp + llamamodel.h llamamodel.cpp + llama.cpp/examples/common.cpp + llmodel.h llmodel_c.h llmodel_c.cpp + mpt.h mpt.cpp + utils.h utils.cpp +) + +target_link_libraries(llmodel + PRIVATE llama) + +set(COMPONENT_NAME_MAIN ${PROJECT_NAME}) +set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install) diff --git a/gpt4all-chat/llmodel/gptj.cpp b/gpt4all-chat/llmodel/gptj.cpp new file mode 100644 index 00000000..837fd1f8 --- /dev/null +++ b/gpt4all-chat/llmodel/gptj.cpp @@ -0,0 +1,1102 @@ +#include "gptj.h" +#include "llama.cpp/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams (GPT-J 6B) +static const size_t MB = 1024*1024; + +struct gptj_hparams { + int32_t n_vocab = 50400; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 16; + int32_t n_layer = 28; + int32_t n_rot = 64; + int32_t f16 = 1; +}; + +struct gptj_layer { + // normalization + struct ggml_tensor * ln_1_g; + struct ggml_tensor * ln_1_b; + + // attention + struct ggml_tensor * c_attn_q_proj_w; + struct ggml_tensor * c_attn_k_proj_w; + struct ggml_tensor * c_attn_v_proj_w; + + struct ggml_tensor * c_attn_proj_w; + + // ff + struct ggml_tensor * c_mlp_fc_w; + struct ggml_tensor * c_mlp_fc_b; + + struct ggml_tensor * c_mlp_proj_w; + struct ggml_tensor * c_mlp_proj_b; +}; + +struct gptj_buffer { + uint8_t * addr = NULL; + size_t size = 0; + + void resize(size_t size) { + delete[] addr; + addr = new uint8_t[size]; + this->size = size; + } + + ~gptj_buffer() { + fflush(stdout); + delete[] addr; + } +}; + +struct gptj_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx = NULL; + + gptj_buffer buf; + + int n; // number of tokens currently in the cache + + ~gptj_kv_cache() { + if (ctx) { + ggml_free(ctx); + } + } +}; + +struct gptj_model { + gptj_hparams hparams; + + // normalization + struct ggml_tensor * ln_f_g; + struct ggml_tensor * ln_f_b; + + struct ggml_tensor * wte; // position embedding + + struct ggml_tensor * lmh_g; // language model head + struct ggml_tensor * lmh_b; // language model bias + + std::vector layers; + + // key + value memory + struct gptj_kv_cache kv_self; + + // + struct ggml_context * ctx; + std::map tensors; + + gptj_buffer buf; + + ~gptj_model() { + if (ctx) { + ggml_free(ctx); + } + } +}; + +static bool kv_cache_init( + const struct gptj_hparams & hparams, + struct gptj_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + + const int64_t n_mem = (int64_t)n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + + struct ggml_init_params params; + params.mem_size = cache.buf.size; + params.mem_buffer = cache.buf.addr; + params.no_alloc = false; + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +// load the model's weights from a stream +bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + // load hparams + { + auto & hparams = model.hparams; + + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: f16 = %d\n", __func__, hparams.f16); + } + + // load vocab + { + int32_t n_vocab = 0; + fin.read((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != model.hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *) &len, sizeof(len)); + + word.resize(len); + fin.read((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + case 2: wtype = GGML_TYPE_Q4_0; break; + case 3: wtype = GGML_TYPE_Q4_1; break; + case 5: wtype = GGML_TYPE_Q4_2; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return false; + } + } + + const ggml_type wtype2 = GGML_TYPE_F32; + + auto & ctx = model.ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b + + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte + + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g + ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b + + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_v_proj_w + + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w + + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b + + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b + + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v + + ctx_size += (5 + 10*n_layer)*256; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size, + .mem_buffer = NULL, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + model.layers.resize(n_layer); + + model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + model.lmh_g = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.lmh_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte; + + model.tensors["transformer.ln_f.weight"] = model.ln_f_g; + model.tensors["transformer.ln_f.bias"] = model.ln_f_b; + + model.tensors["lm_head.weight"] = model.lmh_g; + model.tensors["lm_head.bias"] = model.lmh_b; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); + layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); + + layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_g; + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.bias"] = layer.ln_1_b; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.q_proj.weight"] = layer.c_attn_q_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.k_proj.weight"] = layer.c_attn_k_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.v_proj.weight"] = layer.c_attn_v_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.out_proj.weight"] = layer.c_attn_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"] = layer.c_mlp_fc_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"] = layer.c_mlp_fc_b; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"] = layer.c_mlp_proj_b; + } + } + + // key + value memory + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + ggml_free(ctx); + return false; + } + + const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v); + printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lu, %lu], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return false; + } + + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } + + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + printf(" done\n"); + + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + } + + return true; +} + +// load the model's weights from a file path +bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) { + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + bool loaded = gptj_model_load(fname, fin, model, vocab); + fin.close(); + return loaded; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +// The GPT-J model requires about 16MB of memory per input token. +// +bool gptj_eval( + gptj_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + + const int d_key = n_embd/n_head; + + static size_t buf_size = 1024u*MB; + if (!model.buf.addr || model.buf.size < buf_size) + model.buf.resize(buf_size); + + if (mem_per_token > 0 && mem_per_token*N > model.buf.size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new); + + // reallocate + model.buf.resize(buf_size_new); + if (model.buf.addr == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size); + return false; + } + } + + struct ggml_init_params params = { + .mem_size = model.buf.size, + .mem_buffer = model.buf.addr, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = { .n_threads = n_threads }; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + // wte + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur; + + // norm + { + cur = ggml_norm(ctx0, inpL); + + // cur = ln_1_g*cur + ln_1_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_1_g, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); + } + + struct ggml_tensor * inpSA = cur; + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur); + + // store key and value to memory + { + struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), + n_past, n_rot, 0), + 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), + n_past, n_rot, 1), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor * V_trans = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].c_attn_proj_w, + cur); + } + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + { + // note here we pass inpSA instead of cur + cur = ggml_mul_mat(ctx0, + model.layers[il].c_mlp_fc_w, + inpSA); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), + cur); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ggml_mul_mat(ctx0, + model.layers[il].c_mlp_proj_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), + cur); + } + + // self-attention + FF + cur = ggml_add(ctx0, cur, inpFF); + + // input for next layer + inpL = ggml_add(ctx0, cur, inpL); + } + + // norm + { + inpL = ggml_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.ln_f_g, inpL), + inpL), + ggml_repeat(ctx0, model.ln_f_b, inpL)); + } + + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL); + + inpL = ggml_add(ctx0, + ggml_repeat(ctx0, model.lmh_b, inpL), + inpL); + } + + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +#define GPTJ_MAX_RNG_STATE 64*1024 + +size_t gptj_get_state_size(const gptj_model &model) +{ + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = GPTJ_MAX_RNG_STATE; + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = model.kv_self.buf.size; + const size_t s_total = ( + + s_rng_size + + s_rng + + s_kv_size + + s_kv_ntok + + s_kv + ); + fflush(stdout); + return s_total; +} + +size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, uint8_t *dest) +{ + uint8_t * out = dest; + fflush(stdout); + // copy rng + { + std::stringstream rng_ss; + rng_ss << rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[GPTJ_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, GPTJ_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); + memcpy(out, &rng_buf[0], GPTJ_MAX_RNG_STATE); out += GPTJ_MAX_RNG_STATE; + } + + // copy kv cache + { + const size_t kv_size = model.kv_self.buf.size; + const int kv_ntok = model.kv_self.n; + + memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); + memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); + + if (kv_size) { + memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size; + } + } + + const size_t written = out - dest; + const size_t expected = gptj_get_state_size(model); + assert(written == expected); + fflush(stdout); + return written; +} + +size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *src) +{ + const uint8_t * in = src; + + // set rng + { + size_t rng_size; + char rng_buf[GPTJ_MAX_RNG_STATE]; + + memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); + memcpy(&rng_buf[0], in, GPTJ_MAX_RNG_STATE); in += GPTJ_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> *rng; + + assert(rng_ss.fail() == false); + } + + // set kv cache + { + size_t kv_size; + int kv_ntok; + + memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); + memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + + if (kv_size) { + assert(model->kv_self.buf.size == kv_size); + + void * k_data = model->kv_self.k->data; // remember data pointers + void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + + memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size; + + model->kv_self.k->data = k_data; // restore correct data pointers + model->kv_self.v->data = v_data; + + } + + model->kv_self.n = kv_ntok; + } + + const size_t nread = in - src; + const size_t expected = gptj_get_state_size(*model); + assert(nread == expected); + fflush(stdout); + return nread; +} + +struct GPTJPrivate { + const std::string modelPath; + bool modelLoaded; + gpt_vocab vocab; + gptj_model *model = nullptr; + int64_t n_threads = 0; + size_t mem_per_token = 0; + std::mt19937 rng; +}; + +GPTJ::GPTJ() + : d_ptr(new GPTJPrivate) { + + d_ptr->model = new gptj_model; + d_ptr->modelLoaded = false; +} + +bool GPTJ::loadModel(const std::string &modelPath) { + std::mt19937 rng(time(NULL)); + d_ptr->rng = rng; + + auto fin = std::ifstream(modelPath, std::ios::binary); + + // load the model + if (!gptj_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) { + std::cerr << "GPT-J ERROR: failed to load model from " << modelPath; + return false; + } + + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->modelLoaded = true; + fflush(stdout); + return true; +} + +void GPTJ::setThreadCount(int32_t n_threads) { + d_ptr->n_threads = n_threads; +} + +int32_t GPTJ::threadCount() { + return d_ptr->n_threads; +} + +GPTJ::~GPTJ() +{ + delete d_ptr->model; +} + +bool GPTJ::isModelLoaded() const +{ + return d_ptr->modelLoaded; +} + +size_t GPTJ::stateSize() const +{ + return gptj_get_state_size(*d_ptr->model); +} + +size_t GPTJ::saveState(uint8_t *dest) const +{ + return gptj_copy_state_data(*d_ptr->model, d_ptr->rng, dest); +} + +size_t GPTJ::restoreState(const uint8_t *src) +{ + return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src); +} + +void GPTJ::prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &promptCtx) { + + if (!isModelLoaded()) { + std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; + return; + } + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + int64_t t_prompt_us = 0; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); + + // save the context size + promptCtx.n_ctx = d_ptr->model->hparams.n_ctx; + + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); + std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + return; + } + + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); + + // determine the required inference memory per token: + static bool initialized = false; + static std::vector p_instruct; + static std::vector r_instruct; + if (!initialized) { + gptj_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + d_ptr->mem_per_token); + initialized = true; + } + + // process the prompt in batches + size_t i = 0; + const int64_t t_start_prompt_us = ggml_time_us(); + while (i < embd_inp.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); + std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + + // Check if the context has run out... + if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + } + + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to process prompt\n"; + return; + } + + size_t tokens = batch_end - i; + for (size_t t = 0; t < tokens; ++t) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); + if (!promptCallback(batch.at(t))) + return; + } + promptCtx.n_past += batch.size(); + i = batch_end; + } + t_prompt_us += ggml_time_us() - t_start_prompt_us; + + int p_instructFound = 0; + int r_instructFound = 0; + + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; + + // predict next tokens + int32_t totalPredictions = 0; + for (int i = 0; i < promptCtx.n_predict; i++) { + + // sample next token + const int n_vocab = d_ptr->model->hparams.n_vocab; + gpt_vocab::id id = 0; + { + const int64_t t_start_sample_us = ggml_time_us(); + id = gpt_sample_top_k_top_p(d_ptr->vocab, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, + promptCtx.n_ctx, + promptCtx.logits, + promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty, + d_ptr->rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Check if the context has run out... + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "GPTJ: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); + } + + const int64_t t_start_predict_us = ggml_time_us(); + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to predict next token\n"; + return; + } + t_predict_us += ggml_time_us() - t_start_predict_us; + + promptCtx.n_past += 1; + // display text + ++totalPredictions; + + if (id == 50256 /*end of text*/) + goto stop_generating; + + const std::string str = d_ptr->vocab.id_to_token[id]; + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + str; + if (reversePrompts.find(completed) != reversePrompts.end()) { + goto stop_generating; + } + + // Check if it partially matches our reverse prompts and if so, cache + for (auto s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) + goto stop_generating; + } + cachedTokens.clear(); + } + +stop_generating: + +#if 0 + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n"; + std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n"; + std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n"; + std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n"; + std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n"; + fflush(stdout); + } +#endif + + return; +} + +void GPTJ::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + + if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPTJ ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == promptCtx.tokens.size()); + +stop_generating: + recalculate(false); +} diff --git a/gpt4all-chat/llmodel/gptj.h b/gpt4all-chat/llmodel/gptj.h new file mode 100644 index 00000000..3109c1da --- /dev/null +++ b/gpt4all-chat/llmodel/gptj.h @@ -0,0 +1,36 @@ +#ifndef GPTJ_H +#define GPTJ_H + +#include +#include +#include +#include "llmodel.h" + +class GPTJPrivate; +class GPTJ : public LLModel { +public: + GPTJ(); + ~GPTJ(); + + bool loadModel(const std::string &modelPath) override; + bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; + void prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() override; + +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override; + +private: + GPTJPrivate *d_ptr; +}; + +#endif // GPTJ_H diff --git a/gpt4all-chat/llmodel/llama.cpp b/gpt4all-chat/llmodel/llama.cpp new file mode 160000 index 00000000..03ceb39c --- /dev/null +++ b/gpt4all-chat/llmodel/llama.cpp @@ -0,0 +1 @@ +Subproject commit 03ceb39c1e729bed4ad1dfa16638a72f1843bf0c diff --git a/gpt4all-chat/llmodel/llamamodel.cpp b/gpt4all-chat/llmodel/llamamodel.cpp new file mode 100644 index 00000000..272633c7 --- /dev/null +++ b/gpt4all-chat/llmodel/llamamodel.cpp @@ -0,0 +1,260 @@ +#include "llamamodel.h" + +#include "llama.cpp/examples/common.h" +#include "llama.cpp/llama.h" +#include "llama.cpp/ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct LLamaPrivate { + const std::string modelPath; + bool modelLoaded; + llama_context *ctx = nullptr; + llama_context_params params; + int64_t n_threads = 0; +}; + +LLamaModel::LLamaModel() + : d_ptr(new LLamaPrivate) { + + d_ptr->modelLoaded = false; +} + +bool LLamaModel::loadModel(const std::string &modelPath) +{ + // load the model + d_ptr->params = llama_context_default_params(); + + gpt_params params; + d_ptr->params.n_ctx = 2048; + d_ptr->params.n_parts = params.n_parts; + d_ptr->params.seed = params.seed; + d_ptr->params.f16_kv = params.memory_f16; + d_ptr->params.use_mmap = params.use_mmap; + d_ptr->params.use_mlock = params.use_mlock; + + d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params); + if (!d_ptr->ctx) { + std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl; + return false; + } + + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->modelLoaded = true; + fflush(stderr); + return true; +} + +void LLamaModel::setThreadCount(int32_t n_threads) { + d_ptr->n_threads = n_threads; +} + +int32_t LLamaModel::threadCount() { + return d_ptr->n_threads; +} + +LLamaModel::~LLamaModel() +{ + llama_free(d_ptr->ctx); +} + +bool LLamaModel::isModelLoaded() const +{ + return d_ptr->modelLoaded; +} + +size_t LLamaModel::stateSize() const +{ + return llama_get_state_size(d_ptr->ctx); +} + +size_t LLamaModel::saveState(uint8_t *dest) const +{ + return llama_copy_state_data(d_ptr->ctx, dest); +} + +size_t LLamaModel::restoreState(const uint8_t *src) +{ + return llama_set_state_data(d_ptr->ctx, src); +} + +void LLamaModel::prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &promptCtx) { + + if (!isModelLoaded()) { + std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n"; + return; + } + + gpt_params params; + params.prompt = prompt; + + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + + // tokenize the prompt + auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false); + + // save the context size + promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); + + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed."); + std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + return; + } + + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); + + // number of tokens to keep when resetting context + params.n_keep = (int)embd_inp.size(); + + // process the prompt in batches + size_t i = 0; + const int64_t t_start_prompt_us = ggml_time_us(); + while (i < embd_inp.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); + std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + + // Check if the context has run out... + if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "LLAMA: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + } + + if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { + std::cerr << "LLAMA ERROR: Failed to process prompt\n"; + return; + } + + size_t tokens = batch_end - i; + for (size_t t = 0; t < tokens; ++t) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); + if (!promptCallback(batch.at(t))) + return; + } + promptCtx.n_past += batch.size(); + i = batch_end; + } + + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; + + // predict next tokens + int32_t totalPredictions = 0; + for (int i = 0; i < promptCtx.n_predict; i++) { + // sample next token + llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, + promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty); + + // Check if the context has run out... + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "LLAMA: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); + } + + if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { + std::cerr << "LLAMA ERROR: Failed to predict next token\n"; + return; + } + + promptCtx.n_past += 1; + // display text + ++totalPredictions; + if (id == llama_token_eos()) + return; + + const std::string str = llama_token_to_str(d_ptr->ctx, id); + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + str; + if (reversePrompts.find(completed) != reversePrompts.end()) { + return; + } + + // Check if it partially matches our reverse prompts and if so, cache + for (auto s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t))) + return; + } + cachedTokens.clear(); + } +} + +void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + + if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { + std::cerr << "LLAMA ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == promptCtx.tokens.size()); + +stop_generating: + recalculate(false); +} diff --git a/gpt4all-chat/llmodel/llamamodel.h b/gpt4all-chat/llmodel/llamamodel.h new file mode 100644 index 00000000..7f487803 --- /dev/null +++ b/gpt4all-chat/llmodel/llamamodel.h @@ -0,0 +1,36 @@ +#ifndef LLAMAMODEL_H +#define LLAMAMODEL_H + +#include +#include +#include +#include "llmodel.h" + +class LLamaPrivate; +class LLamaModel : public LLModel { +public: + LLamaModel(); + ~LLamaModel(); + + bool loadModel(const std::string &modelPath) override; + bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; + void prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() override; + +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override; + +private: + LLamaPrivate *d_ptr; +}; + +#endif // LLAMAMODEL_H \ No newline at end of file diff --git a/gpt4all-chat/llmodel/llmodel.h b/gpt4all-chat/llmodel/llmodel.h new file mode 100644 index 00000000..5e254ab5 --- /dev/null +++ b/gpt4all-chat/llmodel/llmodel.h @@ -0,0 +1,47 @@ +#ifndef LLMODEL_H +#define LLMODEL_H + +#include +#include +#include +#include + +class LLModel { +public: + explicit LLModel() {} + virtual ~LLModel() {} + + virtual bool loadModel(const std::string &modelPath) = 0; + virtual bool isModelLoaded() const = 0; + virtual size_t stateSize() const { return 0; } + virtual size_t saveState(uint8_t *dest) const { return 0; } + virtual size_t restoreState(const uint8_t *src) { return 0; } + struct PromptContext { + std::vector logits; // logits of current context + std::vector tokens; // current tokens in the context window + int32_t n_past = 0; // number of tokens in past conversation + int32_t n_ctx = 0; // number of tokens possible in context window + int32_t n_predict = 200; + int32_t top_k = 40; + float top_p = 0.9f; + float temp = 0.9f; + int32_t n_batch = 9; + float repeat_penalty = 1.10f; + int32_t repeat_last_n = 64; // last n tokens to penalize + float contextErase = 0.75f; // percent of context to erase if we exceed the context + // window + }; + virtual void prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx) = 0; + virtual void setThreadCount(int32_t n_threads) {} + virtual int32_t threadCount() { return 1; } + +protected: + virtual void recalculateContext(PromptContext &promptCtx, + std::function recalculate) = 0; +}; + +#endif // LLMODEL_H diff --git a/gpt4all-chat/llmodel/llmodel_c.cpp b/gpt4all-chat/llmodel/llmodel_c.cpp new file mode 100644 index 00000000..4361a900 --- /dev/null +++ b/gpt4all-chat/llmodel/llmodel_c.cpp @@ -0,0 +1,161 @@ +#include "llmodel_c.h" + +#include "gptj.h" +#include "llamamodel.h" +#include "mpt.h" + +struct LLModelWrapper { + LLModel *llModel = nullptr; + LLModel::PromptContext promptContext; +}; + +llmodel_model llmodel_gptj_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new GPTJ; + return reinterpret_cast(wrapper); +} + +void llmodel_gptj_destroy(llmodel_model gptj) +{ + LLModelWrapper *wrapper = reinterpret_cast(gptj); + delete wrapper->llModel; + delete wrapper; +} + +llmodel_model llmodel_mpt_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new MPT; + return reinterpret_cast(wrapper); +} + +void llmodel_mpt_destroy(llmodel_model mpt) +{ + LLModelWrapper *wrapper = reinterpret_cast(mpt); + delete wrapper->llModel; + delete wrapper; +} + +llmodel_model llmodel_llama_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new LLamaModel; + return reinterpret_cast(wrapper); +} + +void llmodel_llama_destroy(llmodel_model llama) +{ + LLModelWrapper *wrapper = reinterpret_cast(llama); + delete wrapper->llModel; + delete wrapper; +} + +bool llmodel_loadModel(llmodel_model model, const char *model_path) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->loadModel(model_path); +} + +bool llmodel_isModelLoaded(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->isModelLoaded(); +} + +uint64_t llmodel_get_state_size(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->stateSize(); +} + +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->saveState(dest); +} + +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->restoreState(src); +} + +// Wrapper functions for the C callbacks +bool prompt_wrapper(int32_t token_id, void *user_data) { + llmodel_prompt_callback callback = reinterpret_cast(user_data); + return callback(token_id); +} + +bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) { + llmodel_response_callback callback = reinterpret_cast(user_data); + return callback(token_id, response.c_str()); +} + +bool recalculate_wrapper(bool is_recalculating, void *user_data) { + llmodel_recalculate_callback callback = reinterpret_cast(user_data); + return callback(is_recalculating); +} + +void llmodel_prompt(llmodel_model model, const char *prompt, + llmodel_response_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_recalculate_callback recalculate_callback, + llmodel_prompt_context *ctx) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + + // Create std::function wrappers that call the C function pointers + std::function prompt_func = + std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast(prompt_callback)); + std::function response_func = + std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast(response_callback)); + std::function recalc_func = + std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast(recalculate_callback)); + + // Copy the C prompt context + wrapper->promptContext.n_past = ctx->n_past; + wrapper->promptContext.n_ctx = ctx->n_ctx; + wrapper->promptContext.n_predict = ctx->n_predict; + wrapper->promptContext.top_k = ctx->top_k; + wrapper->promptContext.top_p = ctx->top_p; + wrapper->promptContext.temp = ctx->temp; + wrapper->promptContext.n_batch = ctx->n_batch; + wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; + wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; + wrapper->promptContext.contextErase = ctx->context_erase; + + // Call the C++ prompt method + wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext); + + // Update the C context by giving access to the wrappers raw pointers to std::vector data + // which involves no copies + ctx->logits = wrapper->promptContext.logits.data(); + ctx->logits_size = wrapper->promptContext.logits.size(); + ctx->tokens = wrapper->promptContext.tokens.data(); + ctx->tokens_size = wrapper->promptContext.tokens.size(); + + // Update the rest of the C prompt context + ctx->n_past = wrapper->promptContext.n_past; + ctx->n_ctx = wrapper->promptContext.n_ctx; + ctx->n_predict = wrapper->promptContext.n_predict; + ctx->top_k = wrapper->promptContext.top_k; + ctx->top_p = wrapper->promptContext.top_p; + ctx->temp = wrapper->promptContext.temp; + ctx->n_batch = wrapper->promptContext.n_batch; + ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; + ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; + ctx->context_erase = wrapper->promptContext.contextErase; +} + +void llmodel_setThreadCount(llmodel_model model, int32_t n_threads) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + wrapper->llModel->setThreadCount(n_threads); +} + +int32_t llmodel_threadCount(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->threadCount(); +} diff --git a/gpt4all-chat/llmodel/llmodel_c.h b/gpt4all-chat/llmodel/llmodel_c.h new file mode 100644 index 00000000..f45bdd8d --- /dev/null +++ b/gpt4all-chat/llmodel/llmodel_c.h @@ -0,0 +1,172 @@ +#ifndef LLMODEL_C_H +#define LLMODEL_C_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Opaque pointer to the underlying model. + */ +typedef void *llmodel_model; + +/** + * llmodel_prompt_context structure for holding the prompt context. + * NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the + * raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined + * behavior. + */ +typedef struct { + float *logits; // logits of current context + size_t logits_size; // the size of the raw logits vector + int32_t *tokens; // current tokens in the context window + size_t tokens_size; // the size of the raw tokens vector + int32_t n_past; // number of tokens in past conversation + int32_t n_ctx; // number of tokens possible in context window + int32_t n_predict; // number of tokens to predict + int32_t top_k; // top k logits to sample from + float top_p; // nucleus sampling probability threshold + float temp; // temperature to adjust model's output distribution + int32_t n_batch; // number of predictions to generate in parallel + float repeat_penalty; // penalty factor for repeated tokens + int32_t repeat_last_n; // last n tokens to penalize + float context_erase; // percent of context to erase if we exceed the context window +} llmodel_prompt_context; + +/** + * Callback type for prompt processing. + * @param token_id The token id of the prompt. + * @return a bool indicating whether the model should keep processing. + */ +typedef bool (*llmodel_prompt_callback)(int32_t token_id); + +/** + * Callback type for response. + * @param token_id The token id of the response. + * @param response The response string. NOTE: a token_id of -1 indicates the string is an error string. + * @return a bool indicating whether the model should keep generating. + */ +typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); + +/** + * Callback type for recalculation of context. + * @param whether the model is recalculating the context. + * @return a bool indicating whether the model should keep generating. + */ +typedef bool (*llmodel_recalculate_callback)(bool is_recalculating); + +/** + * Create a GPTJ instance. + * @return A pointer to the GPTJ instance. + */ +llmodel_model llmodel_gptj_create(); + +/** + * Destroy a GPTJ instance. + * @param gptj A pointer to the GPTJ instance. + */ +void llmodel_gptj_destroy(llmodel_model gptj); + +/** + * Create a MPT instance. + * @return A pointer to the MPT instance. + */ +llmodel_model llmodel_mpt_create(); + +/** + * Destroy a MPT instance. + * @param gptj A pointer to the MPT instance. + */ +void llmodel_mpt_destroy(llmodel_model mpt); + +/** + * Create a LLAMA instance. + * @return A pointer to the LLAMA instance. + */ +llmodel_model llmodel_llama_create(); + +/** + * Destroy a LLAMA instance. + * @param llama A pointer to the LLAMA instance. + */ +void llmodel_llama_destroy(llmodel_model llama); + +/** + * Load a model from a file. + * @param model A pointer to the llmodel_model instance. + * @param model_path A string representing the path to the model file. + * @return true if the model was loaded successfully, false otherwise. + */ +bool llmodel_loadModel(llmodel_model model, const char *model_path); + +/** + * Check if a model is loaded. + * @param model A pointer to the llmodel_model instance. + * @return true if the model is loaded, false otherwise. + */ +bool llmodel_isModelLoaded(llmodel_model model); + +/** + * Get the size of the internal state of the model. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @return the size in bytes of the internal state of the model + */ +uint64_t llmodel_get_state_size(llmodel_model model); + +/** + * Saves the internal state of the model to the specified destination address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param dest A pointer to the destination. + * @return the number of bytes copied + */ +uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest); + +/** + * Restores the internal state of the model using data from the specified address. + * NOTE: This state data is specific to the type of model you have created. + * @param model A pointer to the llmodel_model instance. + * @param src A pointer to the src. + * @return the number of bytes read + */ +uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); + +/** + * Generate a response using the model. + * @param model A pointer to the llmodel_model instance. + * @param prompt A string representing the input prompt. + * @param prompt_callback A callback function for handling the processing of prompt. + * @param response_callback A callback function for handling the generated response. + * @param recalculate_callback A callback function for handling recalculation requests. + * @param ctx A pointer to the llmodel_prompt_context structure. + */ +void llmodel_prompt(llmodel_model model, const char *prompt, + llmodel_response_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_recalculate_callback recalculate_callback, + llmodel_prompt_context *ctx); + +/** + * Set the number of threads to be used by the model. + * @param model A pointer to the llmodel_model instance. + * @param n_threads The number of threads to be used. + */ +void llmodel_setThreadCount(llmodel_model model, int32_t n_threads); + +/** + * Get the number of threads currently being used by the model. + * @param model A pointer to the llmodel_model instance. + * @return The number of threads currently being used. + */ +int32_t llmodel_threadCount(llmodel_model model); + +#ifdef __cplusplus +} +#endif + +#endif // LLMODEL_C_H diff --git a/gpt4all-chat/llmodel/mpt.cpp b/gpt4all-chat/llmodel/mpt.cpp new file mode 100644 index 00000000..1a5ce612 --- /dev/null +++ b/gpt4all-chat/llmodel/mpt.cpp @@ -0,0 +1,1240 @@ +#include "mpt.h" +#include "llama.cpp/ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const size_t MB = 1024*1024; + +// default hparams (MPT 7B) +struct mpt_hparams { + int32_t n_vocab = 50432; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 32; + int32_t n_layer = 32; + float alibi_bias_max = 8; + float clip_qkv = 0; + int32_t expand = 4; + int32_t f16 = 1; +}; + +struct mpt_layer { + // normalization + struct ggml_tensor * norm_1_w; + struct ggml_tensor * norm_2_w; + + // attention + struct ggml_tensor * attn_Wqkv_w; + struct ggml_tensor * attn_out_proj_w; + + // ff + struct ggml_tensor * ffn_up_proj_w; + struct ggml_tensor * ffn_down_proj_w; +}; + +struct mpt_buffer { + uint8_t * addr = NULL; + size_t size = 0; + + void resize(size_t size) { + delete[] addr; + addr = new uint8_t[size]; + this->size = size; + } + + ~mpt_buffer() { + fflush(stdout); + delete[] addr; + } +}; + +struct mpt_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx = NULL; + + mpt_buffer buf; + + int n; // number of tokens currently in the cache + + ~mpt_kv_cache() { + if (ctx) { + ggml_free(ctx); + } + } +}; + +struct mpt_model { + mpt_hparams hparams; + + // normalization + struct ggml_tensor * norm_f_w; + + struct ggml_tensor * wte; // position embedding + + // mpt does weight tying + + std::vector layers; + + struct mpt_kv_cache kv_self; + struct ggml_context * ctx; + std::map tensors; + + + mpt_buffer buf; + + ~mpt_model() { + if (ctx) { + ggml_free(ctx); + } + } +}; + +static bool kv_cache_init( + const struct mpt_hparams & hparams, + struct mpt_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + + const int64_t n_mem = (int64_t)n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + + struct ggml_init_params params; + params.mem_size = cache.buf.size; + params.mem_buffer = cache.buf.addr; + params.no_alloc = false; + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +struct mpt_vocab { + using id = int32_t; + using token = std::string; + + std::map token_to_id; + std::map id_to_token; + std::vector special_tokens; + + void add_special_token(const std::string &token) { + special_tokens.push_back(token); + } +}; + +std::string regex_escape(const std::string &s) { + static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])"); + return std::regex_replace(s, metacharacters, "\\$&"); +} + +// load the model's weights from a stream +bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6d) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + // load hparams + { + auto & hparams = model.hparams; + + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max)); + fin.read((char *) &hparams.clip_qkv, sizeof(hparams.clip_qkv)); + fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max); + printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv); + printf("%s: ftype = %d\n", __func__, hparams.f16); + } + + // load vocab + { + int32_t n_vocab = model.hparams.n_vocab; + fin.read((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != model.hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *) &len, sizeof(len)); + bool special = false; + if (len & (1<<31)) { + len = len &~ (1<<31); + special = true; + } + + if (len > 0) { + word.resize(len); + fin.read((char *) word.data(), len); + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + + // TODO: this only kind-of works, the gpt_tokenize can still incorrectly + // tokenize special tokens + if(special) { + vocab.add_special_token(word); + } + } + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + case 2: wtype = GGML_TYPE_Q4_0; break; + case 3: wtype = GGML_TYPE_Q4_1; break; + case 5: wtype = GGML_TYPE_Q4_2; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return false; + } + } + + auto & ctx = model.ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int expand = hparams.expand; + + + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_w + + ctx_size += n_embd*n_vocab*ggml_type_sizef(GGML_TYPE_F32); // wte + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // norm_1_w + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // norm_2_w + + ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // attn_Wqkv_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // attn_out_proj_w + + ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_up_proj_w + ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_down_proj_w + + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v + + // TODO probably less now? + ctx_size += (5 + 10*n_layer)*256; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size, + .mem_buffer = NULL, + .no_alloc = false, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int expand = hparams.expand; + + model.layers.resize(n_layer); + + model.wte = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + model.norm_f_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte; + model.tensors["transformer.norm_f.weight"] = model.norm_f_w; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.norm_2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.attn_Wqkv_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd * 3); + layer.attn_out_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.ffn_up_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, expand*n_embd); + layer.ffn_down_proj_w = ggml_new_tensor_2d(ctx, wtype, expand*n_embd, n_embd); + + // map by name + model.tensors["transformer.blocks." + std::to_string(i) + ".norm_1.weight"] = layer.norm_1_w; + model.tensors["transformer.blocks." + std::to_string(i) + ".norm_2.weight"] = layer.norm_2_w; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.attn_Wqkv_w; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] = layer.attn_out_proj_w; + + model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.up_proj.weight"] = layer.ffn_up_proj_w; + model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.down_proj.weight"] = layer.ffn_down_proj_w; + } + } + + // key + value memory + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + ggml_free(ctx); + return false; + } + + const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v); + printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ttype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + printf(" done\n"); + + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + } + + return true; +} + +// load the model's weights from a file path +bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) { + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + bool loaded = mpt_model_load(fname, fin, model, vocab); + fin.close(); + return loaded; +} + +bool mpt_eval( + mpt_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int expand = hparams.expand; + + const int d_key = n_embd/n_head; + + static size_t buf_size = 256u*1024*1024; + static void * buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + .mem_size = buf_size, + .mem_buffer = buf, + .no_alloc = false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = { .n_threads = n_threads }; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + // wte + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd); + + for (int il = 0; il < n_layer; ++il) { + + struct ggml_tensor * inpSA = inpL; + struct ggml_tensor * cur = inpSA; + // self-attention + { + + // norm1 + cur = ggml_norm(ctx0, cur); + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].norm_1_w, cur), + cur); + // compute QKV + cur = ggml_mul_mat(ctx0, + model.layers[il].attn_Wqkv_w, + cur); + + // TODO: clip_qkv + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*ggml_element_size(cur)*n_embd)); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*ggml_element_size(cur)*n_embd)); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*ggml_element_size(cur)*n_embd)); + + // TODO: qk_ln? (seems to be False in MPT-7B configs) + { + Vcur = ggml_transpose(ctx0, Vcur); + + struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, model.kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(model.kv_self.v), + (il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, N), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + + // Alibi + struct ggml_tensor * KQ_scaled_biased = ggml_alibi(ctx0, ggml_cont(ctx0, KQ_scaled), n_past, n_head); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_biased, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor * V = + ggml_view_3d(ctx0, model.kv_self.v, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(model.kv_self.v), + n_ctx*ggml_element_size(model.kv_self.v)*n_embd/n_head, + il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].attn_out_proj_w, + cur); + } + + + // residual + struct ggml_tensor * resSA = ggml_add(ctx0, cur, inpSA); + // feed-forward network + { + cur = resSA; + // norm2 + cur = ggml_norm(ctx0, cur); + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].norm_2_w, cur), + cur); + // ffn + cur = ggml_mul_mat(ctx0, + model.layers[il].ffn_up_proj_w, + cur); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, + model.layers[il].ffn_down_proj_w, + cur); + + } + + // self-attention + FF + inpL = ggml_add(ctx0, cur, resSA); + } + + struct ggml_tensor * out = inpL; + // -> logits + { + out = ggml_norm(ctx0, out); + out = ggml_mul(ctx0, + ggml_repeat(ctx0, model.norm_f_w, out), + out); + out = ggml_mul_mat(ctx0, model.wte, out); + } + + + // run the computation + ggml_build_forward_expand(&gf, out); + ggml_graph_compute (ctx0, &gf); + + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +std::vector mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) { + // taken from stablelm example in ggml + // they both use the gpt-neox tokenizer + // not sure if this entirely right? + std::vector words; + + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.size() == 0) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + +std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text) { + // Generate the subpattern from the special_tokens vector if it's not empty + if (!vocab.special_tokens.empty()) { + std::vector out; + std::vector chunks; + std::string str = text; + std::string special_tokens_subpattern; + for (const auto &token : vocab.special_tokens) { + if (!special_tokens_subpattern.empty()) { + special_tokens_subpattern += "|"; + } + special_tokens_subpattern += regex_escape(token); + } + std::regex re(special_tokens_subpattern); + std::smatch m; + while (std::regex_search(str, m, re)) { + auto tok = vocab.token_to_id.find(m.str()); + if (tok != vocab.token_to_id.end()) { + auto tokid = tok->second; + auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix()); + out.insert(out.end(), pfxtoks.begin(), pfxtoks.end()); + out.push_back(tokid); + str = m.suffix(); + } + } + if (!str.empty()) { + auto tokrest = mpt_tokenize_inner(vocab, str); + out.insert(out.end(), tokrest.begin(), tokrest.end()); + } + return out; + } else { + return mpt_tokenize_inner(vocab, text); + } +} + +#define MPT_MAX_RNG_STATE 64*1024 + +size_t mpt_get_state_size(const mpt_model &model) +{ + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = MPT_MAX_RNG_STATE; + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = model.kv_self.buf.size; + const size_t s_total = ( + + s_rng_size + + s_rng + + s_kv_size + + s_kv_ntok + + s_kv + ); + fflush(stdout); + return s_total; +} + +size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint8_t *dest) +{ + uint8_t * out = dest; + fflush(stdout); + // copy rng + { + std::stringstream rng_ss; + rng_ss << rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[MPT_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, MPT_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); + memcpy(out, &rng_buf[0], MPT_MAX_RNG_STATE); out += MPT_MAX_RNG_STATE; + } + + // copy kv cache + { + const size_t kv_size = model.kv_self.buf.size; + const int kv_ntok = model.kv_self.n; + + memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); + memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); + + if (kv_size) { + memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size; + } + } + + const size_t written = out - dest; + const size_t expected = mpt_get_state_size(model); + assert(written == expected); + fflush(stdout); + return written; +} + +mpt_vocab::id mpt_sample_top_k_top_p( + const mpt_vocab & vocab, + const size_t actualVocabSize, + const int32_t * last_n_tokens_data, + int last_n_tokens_size, + const std::vector logits, + int top_k, + double top_p, + double temp, + float repeat_penalty, + std::mt19937 & rng) { + int n_logits = actualVocabSize; + + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); + const auto * plogits = logits.data() + logits.size() - n_logits; + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const float scale = 1.0f/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } + } + } + + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + + //printf("\n"); + //for (int i = 0; i < (int) probs.size(); i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + //} + //exit(0); + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; +} + +size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src) +{ + const uint8_t * in = src; + + // set rng + { + size_t rng_size; + char rng_buf[MPT_MAX_RNG_STATE]; + + memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); + memcpy(&rng_buf[0], in, MPT_MAX_RNG_STATE); in += MPT_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> *rng; + + assert(rng_ss.fail() == false); + } + + // set kv cache + { + size_t kv_size; + int kv_ntok; + + memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); + memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + + if (kv_size) { + assert(model->kv_self.buf.size == kv_size); + + void * k_data = model->kv_self.k->data; // remember data pointers + void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + + memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size; + + model->kv_self.k->data = k_data; // restore correct data pointers + model->kv_self.v->data = v_data; + + } + + model->kv_self.n = kv_ntok; + } + + const size_t nread = in - src; + const size_t expected = mpt_get_state_size(*model); + assert(nread == expected); + fflush(stdout); + return nread; +} + +struct MPTPrivate { + const std::string modelPath; + bool modelLoaded; + mpt_vocab vocab; + mpt_model *model = nullptr; + int64_t n_threads = 0; + size_t mem_per_token = 0; + std::mt19937 rng; + bool has_im_end = false; +}; + +MPT::MPT() + : d_ptr(new MPTPrivate) { + + d_ptr->model = new mpt_model; + d_ptr->modelLoaded = false; +} + +bool MPT::loadModel(const std::string &modelPath) { + std::mt19937 rng(time(NULL)); + d_ptr->rng = rng; + + auto fin = std::ifstream(modelPath, std::ios::binary); + + // load the model + if (!mpt_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) { + std::cerr << "GPT-J ERROR: failed to load model from " << modelPath; + return false; + } + + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->modelLoaded = true; + d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end(); + fflush(stdout); + return true; +} + +void MPT::setThreadCount(int32_t n_threads) { + d_ptr->n_threads = n_threads; +} + +int32_t MPT::threadCount() { + return d_ptr->n_threads; +} + +MPT::~MPT() +{ + delete d_ptr->model; +} + +bool MPT::isModelLoaded() const +{ + return d_ptr->modelLoaded; +} + +size_t MPT::stateSize() const +{ + return mpt_get_state_size(*d_ptr->model); +} + +size_t MPT::saveState(uint8_t *dest) const +{ + return mpt_copy_state_data(*d_ptr->model, d_ptr->rng, dest); +} + +size_t MPT::restoreState(const uint8_t *src) +{ + return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src); +} + +void MPT::prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &promptCtx) { + + if (!isModelLoaded()) { + std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; + return; + } + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + int64_t t_prompt_us = 0; + + // tokenize the prompt + std::vector embd_inp = mpt_tokenize(d_ptr->vocab, prompt); + + // save the context size + promptCtx.n_ctx = d_ptr->model->hparams.n_ctx; + + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); + std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + return; + } + + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); + + // determine the required inference memory per token: + static bool initialized = false; + static std::vector p_instruct; + static std::vector r_instruct; + if (!initialized) { + mpt_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, + d_ptr->mem_per_token); + initialized = true; + } + + // process the prompt in batches + size_t i = 0; + const int64_t t_start_prompt_us = ggml_time_us(); + while (i < embd_inp.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); + std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + + // Check if the context has run out... + if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "MPT: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + } + + if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to process prompt\n"; + return; + } + + size_t tokens = batch_end - i; + for (size_t t = 0; t < tokens; ++t) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); + if (!promptCallback(batch.at(t))) + return; + } + promptCtx.n_past += batch.size(); + i = batch_end; + } + t_prompt_us += ggml_time_us() - t_start_prompt_us; + + int p_instructFound = 0; + int r_instructFound = 0; + + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; + + // predict next tokens + int32_t totalPredictions = 0; + for (int i = 0; i < promptCtx.n_predict; i++) { + + // sample next token + const int n_vocab = d_ptr->model->hparams.n_vocab; + int id = 0; + { + const int64_t t_start_sample_us = ggml_time_us(); + id = mpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, + promptCtx.n_ctx, + promptCtx.logits, + promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty, + d_ptr->rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Check if the context has run out... + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { + const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; + // Erase the first percentage of context from the tokens... + std::cerr << "MPT: reached the end of the context window so resizing\n"; + promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); + promptCtx.n_past = promptCtx.tokens.size(); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); + } + + const int64_t t_start_predict_us = ggml_time_us(); + if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to predict next token\n"; + return; + } + t_predict_us += ggml_time_us() - t_start_predict_us; + + promptCtx.n_past += 1; + // display text + ++totalPredictions; + + // mpt-7b-chat has special token for end + if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"]) + goto stop_generating; + + if (id == 0 /*end of text*/) + goto stop_generating; + + const std::string str = d_ptr->vocab.id_to_token[id]; + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + str; + if (reversePrompts.find(completed) != reversePrompts.end()) { + goto stop_generating; + } + + // Check if it partially matches our reverse prompts and if so, cache + for (auto s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + if (!responseCallback(t, d_ptr->vocab.id_to_token[t])) + goto stop_generating; + } + cachedTokens.clear(); + } + +stop_generating: + +#if 0 + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n"; + std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n"; + std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n"; + std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n"; + std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n"; + fflush(stdout); + } +#endif + + return; +} + +void MPT::recalculateContext(PromptContext &promptCtx, std::function recalculate) +{ + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + + assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + + if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, + d_ptr->mem_per_token)) { + std::cerr << "MPT ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == promptCtx.tokens.size()); + +stop_generating: + recalculate(false); +} diff --git a/gpt4all-chat/llmodel/mpt.h b/gpt4all-chat/llmodel/mpt.h new file mode 100644 index 00000000..9e693f6a --- /dev/null +++ b/gpt4all-chat/llmodel/mpt.h @@ -0,0 +1,36 @@ +#ifndef MPT_H +#define MPT_H + +#include +#include +#include +#include "llmodel.h" + +class MPTPrivate; +class MPT : public LLModel { +public: + MPT(); + ~MPT(); + + bool loadModel(const std::string &modelPath) override; + bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; + void prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() override; + +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override; + +private: + MPTPrivate *d_ptr; +}; + +#endif // MPT_H diff --git a/gpt4all-chat/llmodel/scripts/convert_mpt_hf_to_ggml.py b/gpt4all-chat/llmodel/scripts/convert_mpt_hf_to_ggml.py new file mode 100644 index 00000000..981432fc --- /dev/null +++ b/gpt4all-chat/llmodel/scripts/convert_mpt_hf_to_ggml.py @@ -0,0 +1,175 @@ +# Convert Hugging Face fine-tuned bloom-like models to ggml format +# +# Usage: +# +# python3 models/convert-h5-to-ggml.py +# +# This script is similar to "convert-pt-to-ggml.py" +# + +import io +import os +import sys +import struct +import json +import code +import torch +import numpy as np + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BloomForCausalLM + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +if len(sys.argv) < 3: + print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]") + print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'") + print(" dir-output: directory where the output file will be written") + print(" use-f32: if present, use float32 instead of float16") + sys.exit(1) + +model_name = sys.argv[1] +dir_out = sys.argv[2] + +# make sure the output directory exists +os.makedirs(dir_out, exist_ok=True) + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] +ftype = 1 +if len(sys.argv) > 3: + ftype = 0 + +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) +hparams = config.to_dict() +print("Loading model: ", model_name) +model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True) +print("Model loaded: ", model_name) + + +fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin" +fout = open(fname_out, "wb") +vocab = tokenizer.vocab + +hparams["multiple_of"] = 1 +fout.write(struct.pack("i", 0x67676d6d)) # magic: ggml in hex +fout.write(struct.pack("i", hparams["vocab_size"])) +fout.write(struct.pack("i", hparams["max_seq_len"])) +fout.write(struct.pack("i", hparams["d_model"])) +fout.write(struct.pack("i", hparams["n_heads"])) +fout.write(struct.pack("i", hparams["n_layers"])) +# n_rot (unused) +fout.write(struct.pack("i", 0)) +fout.write(struct.pack("i", ftype)) + +# # Is this correct?? +# dot_token = tokenizer.encode(".")[0] +# write tokens to ggml file +fout.write(struct.pack("i", hparams["vocab_size"])) + +for i in range(hparams["vocab_size"]): + text = tokenizer.decode([i]).encode('utf-8') + fout.write(struct.pack("i", len(text))) + fout.write(text) + +list_vars = model.state_dict() +for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + + # we don't need these + if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"): + print(" Skipping variable: " + name) + continue + + if "Wqkv.weight" in name: + # chunk qkv + query, key, value = np.split(data, 3, axis=0) + + new_name = name.split("Wqkv.weight")[0] + + for (data, name) in [(query, new_name + "q_proj.weight"), (key, new_name + "k_proj.weight"), (value, new_name + "v_proj.weight")]: + print(f"Processing variable: {name} with shape: {data.shape}") + n_dims = len(data.shape); + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0; + if ftype != 0: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str); + + # data + data.tofile(fout) + + else: + + n_dims = len(data.shape); + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0; + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str); + + # data + data.tofile(fout) + +fout.close() + +print("Done. Output file: " + fname_out) +print("") \ No newline at end of file diff --git a/gpt4all-chat/llmodel/utils.cpp b/gpt4all-chat/llmodel/utils.cpp new file mode 100644 index 00000000..b9b653f5 --- /dev/null +++ b/gpt4all-chat/llmodel/utils.cpp @@ -0,0 +1,274 @@ +#include "utils.h" + +#include +#include + +void replace(std::string & str, const std::string & needle, const std::string & replacement) { + size_t pos = 0; + while ((pos = str.find(needle, pos)) != std::string::npos) { + str.replace(pos, needle.length(), replacement); + pos += replacement.length(); + } +} + +std::map json_parse(const std::string & fname) { + std::map result; + + // read file into string + std::string json; + { + std::ifstream ifs(fname); + if (!ifs) { + fprintf(stderr, "Failed to open %s\n", fname.c_str()); + exit(1); + } + + json = std::string((std::istreambuf_iterator(ifs)), + (std::istreambuf_iterator())); + } + + if (json[0] != '{') { + return result; + } + + // parse json + { + bool has_key = false; + bool in_token = false; + + std::string str_key = ""; + std::string str_val = ""; + + int n = json.size(); + for (int i = 1; i < n; ++i) { + if (!in_token) { + if (json[i] == ' ') continue; + if (json[i] == '"') { + in_token = true; + continue; + } + } else { + if (json[i] == '\\' && i+1 < n) { + if (has_key == false) { + str_key += json[i]; + } else { + str_val += json[i]; + } + ++i; + } else if (json[i] == '"') { + if (has_key == false) { + has_key = true; + ++i; + while (json[i] == ' ') ++i; + ++i; // : + while (json[i] == ' ') ++i; + if (json[i] != '\"') { + while (json[i] != ',' && json[i] != '}') { + str_val += json[i++]; + } + has_key = false; + } else { + in_token = true; + continue; + } + } else { + has_key = false; + } + + ::replace(str_key, "\\u0120", " " ); // \u0120 -> space + ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line + ::replace(str_key, "\\\"", "\""); // \\\" -> " + + try { + result[str_key] = std::stoi(str_val); + } catch (...) { + //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str()); + + } + str_key = ""; + str_val = ""; + in_token = false; + continue; + } + if (has_key == false) { + str_key += json[i]; + } else { + str_val += json[i]; + } + } + } + } + + return result; +} + +std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.size() == 0) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + +bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { + printf("%s: loading vocab from '%s'\n", __func__, fname.c_str()); + + vocab.token_to_id = ::json_parse(fname); + + for (const auto & kv : vocab.token_to_id) { + vocab.id_to_token[kv.second] = kv.first; + } + + printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size()); + + // print the vocabulary + //for (auto kv : vocab.token_to_id) { + // printf("'%s' -> %d\n", kv.first.data(), kv.second); + //} + + return true; +} + +gpt_vocab::id gpt_sample_top_k_top_p( + const gpt_vocab & vocab, + const int32_t * last_n_tokens_data, + int last_n_tokens_size, + const std::vector logits, + int top_k, + double top_p, + double temp, + float repeat_penalty, + std::mt19937 & rng) { + int n_logits = vocab.id_to_token.size(); + + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); + const auto * plogits = logits.data() + logits.size() - n_logits; + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const float scale = 1.0f/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } + } + } + + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + + //printf("\n"); + //for (int i = 0; i < (int) probs.size(); i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + //} + //exit(0); + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; +} diff --git a/gpt4all-chat/llmodel/utils.h b/gpt4all-chat/llmodel/utils.h new file mode 100644 index 00000000..90cfdd97 --- /dev/null +++ b/gpt4all-chat/llmodel/utils.h @@ -0,0 +1,85 @@ +// Various helper functions and utilities + +#pragma once + +#include +#include +#include +#include +#include + +// +// CLI argument parsing +// + +struct gpt_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_predict = 200; // new tokens to predict + + // sampling parameters + int32_t top_k = 40; + float top_p = 0.9f; + float temp = 0.9f; + + int32_t n_batch = 8; // batch size for prompt processing + + std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path + std::string prompt; +}; + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params); + +void gpt_print_usage(int argc, char ** argv, const gpt_params & params); + +std::string gpt_random_prompt(std::mt19937 & rng); + +// +// Vocab utils +// + +struct gpt_vocab { + using id = int32_t; + using token = std::string; + + std::map token_to_id; + std::map id_to_token; +}; + +void replace(std::string & str, const std::string & needle, const std::string & replacement); + +// poor-man's JSON parsing +std::map json_parse(const std::string & fname); + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text); + +// load the tokens from encoder.json +bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); + +// sample next token given probabilities for each embedding +// +// - consider only the top K tokens +// - from them, consider only the top tokens with cumulative probability > P +// +// TODO: not sure if this implementation is correct +// +gpt_vocab::id gpt_sample_top_k_top_p( + const gpt_vocab & vocab, + const int32_t * last_n_tokens_data, + int last_n_tokens_size, + const std::vector logits, + int top_k, + double top_p, + double temp, + float repeat_penalty, + std::mt19937 & rng); diff --git a/gpt4all-chat/main.cpp b/gpt4all-chat/main.cpp new file mode 100644 index 00000000..aece4026 --- /dev/null +++ b/gpt4all-chat/main.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + +#include +#include + +#include "llm.h" +#include "download.h" +#include "network.h" +#include "config.h" + +int main(int argc, char *argv[]) +{ + QCoreApplication::setOrganizationName("nomic.ai"); + QCoreApplication::setOrganizationDomain("gpt4all.io"); + QCoreApplication::setApplicationName("GPT4All"); + QCoreApplication::setApplicationVersion(APP_VERSION); + QSettings::setDefaultFormat(QSettings::IniFormat); + + QGuiApplication app(argc, argv); + QQmlApplicationEngine engine; + qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); + qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); + qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); + const QUrl url(u"qrc:/gpt4all/main.qml"_qs); + + QObject::connect(&engine, &QQmlApplicationEngine::objectCreated, + &app, [url](QObject *obj, const QUrl &objUrl) { + if (!obj && url == objUrl) + QCoreApplication::exit(-1); + }, Qt::QueuedConnection); + engine.load(url); + +#if 0 + QDirIterator it("qrc:", QDirIterator::Subdirectories); + while (it.hasNext()) { + qDebug() << it.next(); + } +#endif + + return app.exec(); +} diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml new file mode 100644 index 00000000..6ab92df0 --- /dev/null +++ b/gpt4all-chat/main.qml @@ -0,0 +1,894 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import llm +import download +import network + +Window { + id: window + width: 1280 + height: 720 + visible: true + title: qsTr("GPT4All v") + Qt.application.version + + Theme { + id: theme + } + + property var currentChat: LLM.chatListModel.currentChat + property var chatModel: currentChat.chatModel + + color: theme.textColor + + // Startup code + Component.onCompleted: { + if (!LLM.compatHardware) { + Network.sendNonCompatHardware(); + errorCompatHardware.open(); + } else + startupDialogs(); + } + + Connections { + target: firstStartDialog + function onClosed() { + startupDialogs(); + } + } + + Connections { + target: downloadNewModels + function onClosed() { + startupDialogs(); + } + } + + Connections { + target: Download + function onHasNewerReleaseChanged() { + startupDialogs(); + } + } + + Connections { + target: currentChat + function onResponseInProgressChanged() { + if (Network.isActive && !currentChat.responseInProgress) + Network.sendConversation(currentChat.id, getConversationJson()); + } + } + + function startupDialogs() { + // check for first time start of this version + if (Download.isFirstStart()) { + firstStartDialog.open(); + return; + } + + // check for any current models and if not, open download dialog + if (currentChat.modelList.length === 0 && !firstStartDialog.opened) { + downloadNewModels.open(); + return; + } + + // check for new version + if (Download.hasNewerRelease && !firstStartDialog.opened && !downloadNewModels.opened) { + newVersionDialog.open(); + return; + } + } + + PopupDialog { + id: errorCompatHardware + anchors.centerIn: parent + shouldTimeOut: false + shouldShowBusy: false + closePolicy: Popup.NoAutoClose + modal: true + text: qsTr("Incompatible hardware detected. Please try the avx-only installer on https://gpt4all.io") + } + + StartupDialog { + id: firstStartDialog + anchors.centerIn: parent + } + + NewVersionDialog { + id: newVersionDialog + anchors.centerIn: parent + } + + AboutDialog { + id: aboutDialog + anchors.centerIn: parent + } + + Item { + Accessible.role: Accessible.Window + Accessible.name: title + } + + Rectangle { + id: header + anchors.left: parent.left + anchors.right: parent.right + anchors.top: parent.top + height: 100 + color: theme.backgroundDarkest + + Item { + anchors.centerIn: parent + height: childrenRect.height + visible: currentChat.isModelLoaded + + Label { + id: modelLabel + color: theme.textColor + padding: 20 + font.pixelSize: theme.fontSizeLarger + text: "" + background: Rectangle { + color: theme.backgroundDarkest + } + horizontalAlignment: TextInput.AlignRight + } + + ComboBox { + id: comboBox + width: 350 + anchors.top: modelLabel.top + anchors.bottom: modelLabel.bottom + anchors.horizontalCenter: parent.horizontalCenter + font.pixelSize: theme.fontSizeLarge + spacing: 0 + model: currentChat.modelList + Accessible.role: Accessible.ComboBox + Accessible.name: qsTr("ComboBox for displaying/picking the current model") + Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model") + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + leftPadding: 10 + rightPadding: 10 + text: comboBox.displayText + font: comboBox.font + color: theme.textColor + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + delegate: ItemDelegate { + width: comboBox.width + contentItem: Text { + text: modelData + color: theme.textColor + font: comboBox.font + elide: Text.ElideRight + verticalAlignment: Text.AlignVCenter + } + background: Rectangle { + color: highlighted ? theme.backgroundLight : theme.backgroundDark + } + highlighted: comboBox.highlightedIndex === index + } + popup: Popup { + y: comboBox.height - 1 + width: comboBox.width + implicitHeight: contentItem.implicitHeight + padding: 0 + + contentItem: ListView { + clip: true + implicitHeight: contentHeight + model: comboBox.popup.visible ? comboBox.delegateModel : null + currentIndex: comboBox.highlightedIndex + ScrollIndicator.vertical: ScrollIndicator { } + } + + background: Rectangle { + color: theme.backgroundDark + } + } + + background: Rectangle { + color: theme.backgroundDark + } + + onActivated: { + currentChat.stopGenerating() + currentChat.reset(); + currentChat.modelName = comboBox.currentText + } + } + } + + BusyIndicator { + anchors.centerIn: parent + visible: !currentChat.isModelLoaded + running: !currentChat.isModelLoaded + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the model is loading") + } + } + + SettingsDialog { + id: settingsDialog + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + } + + Button { + id: drawerButton + anchors.left: parent.left + anchors.top: parent.top + anchors.topMargin: 30 + anchors.leftMargin: 30 + width: 40 + height: 40 + z: 200 + padding: 15 + + Accessible.role: Accessible.ButtonMenu + Accessible.name: qsTr("Hamburger button") + Accessible.description: qsTr("Hamburger button that reveals a drawer on the left of the application") + + background: Item { + anchors.centerIn: parent + width: 30 + height: 30 + + Rectangle { + id: bar1 + color: theme.backgroundLightest + width: parent.width + height: 6 + radius: 2 + antialiasing: true + } + + Rectangle { + id: bar2 + anchors.centerIn: parent + color: theme.backgroundLightest + width: parent.width + height: 6 + radius: 2 + antialiasing: true + } + + Rectangle { + id: bar3 + anchors.bottom: parent.bottom + color: theme.backgroundLightest + width: parent.width + height: 6 + radius: 2 + antialiasing: true + } + } + onClicked: { + drawer.visible = !drawer.visible + } + } + + NetworkDialog { + id: networkDialog + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Network dialog") + Accessible.description: qsTr("Dialog for opt-in to sharing feedback/conversations") + } + } + + Button { + id: networkButton + anchors.right: parent.right + anchors.top: parent.top + anchors.topMargin: 30 + anchors.rightMargin: 30 + width: 40 + height: 40 + z: 200 + padding: 15 + + Accessible.role: Accessible.Button + Accessible.name: qsTr("Network button") + Accessible.description: qsTr("Reveals a dialogue where you can opt-in for sharing data over network") + + background: Item { + anchors.fill: parent + Rectangle { + anchors.fill: parent + color: "transparent" + visible: Network.isActive + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + } + Image { + anchors.centerIn: parent + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/network.svg" + } + } + + onClicked: { + if (Network.isActive) { + Network.isActive = false + Network.sendNetworkToggled(false); + } else + networkDialog.open() + } + } + + Connections { + target: Network + function onHealthCheckFailed(code) { + healthCheckFailed.open(); + } + } + + Button { + id: settingsButton + anchors.right: networkButton.left + anchors.top: parent.top + anchors.topMargin: 30 + anchors.rightMargin: 30 + width: 40 + height: 40 + z: 200 + padding: 15 + + background: Item { + anchors.fill: parent + Image { + anchors.centerIn: parent + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/settings.svg" + } + } + + Accessible.role: Accessible.Button + Accessible.name: qsTr("Settings button") + Accessible.description: qsTr("Reveals a dialogue where you can change various settings") + + onClicked: { + settingsDialog.open() + } + } + + PopupDialog { + id: copyMessage + anchors.centerIn: parent + text: qsTr("Conversation copied to clipboard.") + } + + PopupDialog { + id: healthCheckFailed + anchors.centerIn: parent + text: qsTr("Connection to datalake failed.") + } + + PopupDialog { + id: recalcPopup + anchors.centerIn: parent + shouldTimeOut: false + shouldShowBusy: true + text: qsTr("Recalculating context.") + + Connections { + target: currentChat + function onRecalcChanged() { + if (currentChat.isRecalc) + recalcPopup.open() + else + recalcPopup.close() + } + } + } + + Button { + id: copyButton + anchors.right: settingsButton.left + anchors.top: parent.top + anchors.topMargin: 30 + anchors.rightMargin: 30 + width: 40 + height: 40 + z: 200 + padding: 15 + + Accessible.role: Accessible.Button + Accessible.name: qsTr("Copy button") + Accessible.description: qsTr("Copy the conversation to the clipboard") + + background: Item { + anchors.fill: parent + Image { + anchors.centerIn: parent + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/copy.svg" + } + } + + TextEdit{ + id: copyEdit + visible: false + } + + onClicked: { + var conversation = getConversation() + copyEdit.text = conversation + copyEdit.selectAll() + copyEdit.copy() + copyMessage.open() + } + } + + function getConversation() { + var conversation = ""; + for (var i = 0; i < chatModel.count; i++) { + var item = chatModel.get(i) + var string = item.name; + var isResponse = item.name === qsTr("Response: ") + string += chatModel.get(i).value + if (isResponse && item.stopped) + string += " " + string += "\n" + conversation += string + } + return conversation + } + + function getConversationJson() { + var str = "{\"conversation\": ["; + for (var i = 0; i < chatModel.count; i++) { + var item = chatModel.get(i) + var isResponse = item.name === qsTr("Response: ") + str += "{\"content\": "; + str += JSON.stringify(item.value) + str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\""; + if (isResponse && item.thumbsUpState !== item.thumbsDownState) + str += ", \"rating\": \"" + (item.thumbsUpState ? "positive" : "negative") + "\""; + if (isResponse && item.newResponse !== "") + str += ", \"edited_content\": " + JSON.stringify(item.newResponse); + if (isResponse && item.stopped) + str += ", \"stopped\": \"true\"" + if (!isResponse) + str += "}," + else + str += ((i < chatModel.count - 1) ? "}," : "}") + } + return str + "]}" + } + + Button { + id: resetContextButton + anchors.right: copyButton.left + anchors.top: parent.top + anchors.topMargin: 30 + anchors.rightMargin: 30 + width: 40 + height: 40 + z: 200 + padding: 15 + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Reset the context which erases current conversation") + + background: Item { + anchors.fill: parent + Image { + anchors.centerIn: parent + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/regenerate.svg" + } + } + + onClicked: { + Network.sendResetContext(chatModel.count) + currentChat.reset(); + } + } + + Dialog { + id: checkForUpdatesError + anchors.centerIn: parent + modal: false + opacity: 0.9 + padding: 20 + Text { + horizontalAlignment: Text.AlignJustify + text: qsTr("ERROR: Update system could not find the MaintenanceTool used
+ to check for updates!

+ Did you install this application using the online installer? If so,
+ the MaintenanceTool executable should be located one directory
+ above where this application resides on your filesystem.

+ If you can't start it manually, then I'm afraid you'll have to
+ reinstall.") + color: theme.textColor + Accessible.role: Accessible.Dialog + Accessible.name: text + Accessible.description: qsTr("Dialog indicating an error") + } + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + } + + ModelDownloaderDialog { + id: downloadNewModels + anchors.centerIn: parent + width: Math.min(1024, window.width - (window.width * .2)) + height: Math.min(600, window.height - (window.height * .2)) + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Download new models dialog") + Accessible.description: qsTr("Dialog for downloading new models") + } + } + + ChatDrawer { + id: drawer + y: header.height + width: 0.3 * window.width + height: window.height - y + onDownloadClicked: { + downloadNewModels.open() + } + onAboutClicked: { + aboutDialog.open() + } + } + + Rectangle { + id: conversation + color: theme.backgroundLight + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: parent.bottom + anchors.top: header.bottom + + ScrollView { + id: scrollView + anchors.left: parent.left + anchors.right: parent.right + anchors.top: parent.top + anchors.bottom: textInputView.top + anchors.bottomMargin: 30 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + + Rectangle { + anchors.fill: parent + color: theme.backgroundLighter + + ListView { + id: listView + anchors.fill: parent + model: chatModel + + Accessible.role: Accessible.List + Accessible.name: qsTr("List of prompt/response pairs") + Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model") + + delegate: TextArea { + text: value + width: listView.width + color: theme.textColor + wrapMode: Text.WordWrap + focus: false + readOnly: true + font.pixelSize: theme.fontSizeLarge + cursorVisible: currentResponse ? currentChat.responseInProgress : false + cursorPosition: text.length + background: Rectangle { + color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight + } + + Accessible.role: Accessible.Paragraph + Accessible.name: name + Accessible.description: name === qsTr("Response: ") ? "The response by the model" : "The prompt by the user" + + topPadding: 20 + bottomPadding: 20 + leftPadding: 100 + rightPadding: 100 + + BusyIndicator { + anchors.left: parent.left + anchors.leftMargin: 90 + anchors.top: parent.top + anchors.topMargin: 5 + visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress + running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress + + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the model is thinking") + } + + Rectangle { + anchors.left: parent.left + anchors.top: parent.top + anchors.leftMargin: 20 + anchors.topMargin: 20 + width: 30 + height: 30 + radius: 5 + color: name === qsTr("Response: ") ? theme.assistantColor : theme.userColor + + Text { + anchors.centerIn: parent + text: name === qsTr("Response: ") ? "R" : "P" + color: "white" + } + } + + ThumbsDownDialog { + id: thumbsDownDialog + property point globalPoint: mapFromItem(window, + window.width / 2 - width / 2, + window.height / 2 - height / 2) + x: globalPoint.x + y: globalPoint.y + property string text: value + response: newResponse === undefined || newResponse === "" ? text : newResponse + onAccepted: { + var responseHasChanged = response !== text && response !== newResponse + if (thumbsDownState && !thumbsUpState && !responseHasChanged) + return + + chatModel.updateNewResponse(index, response) + chatModel.updateThumbsUpState(index, false) + chatModel.updateThumbsDownState(index, true) + Network.sendConversation(currentChat.id, getConversationJson()); + } + } + + Column { + visible: name === qsTr("Response: ") && + (!currentResponse || !currentChat.responseInProgress) && Network.isActive + anchors.right: parent.right + anchors.rightMargin: 20 + anchors.top: parent.top + anchors.topMargin: 20 + spacing: 10 + + Item { + width: childrenRect.width + height: childrenRect.height + Button { + id: thumbsUp + width: 30 + height: 30 + opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 + background: Image { + anchors.fill: parent + source: "qrc:/gpt4all/icons/thumbs_up.svg" + } + onClicked: { + if (thumbsUpState && !thumbsDownState) + return + + chatModel.updateNewResponse(index, "") + chatModel.updateThumbsUpState(index, true) + chatModel.updateThumbsDownState(index, false) + Network.sendConversation(currentChat.id, getConversationJson()); + } + } + + Button { + id: thumbsDown + anchors.top: thumbsUp.top + anchors.topMargin: 10 + anchors.left: thumbsUp.right + anchors.leftMargin: 2 + width: 30 + height: 30 + checked: thumbsDownState + opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 + transform: [ + Matrix4x4 { + matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1) + }, + Translate { + x: thumbsDown.width + } + ] + background: Image { + anchors.fill: parent + source: "qrc:/gpt4all/icons/thumbs_down.svg" + } + onClicked: { + thumbsDownDialog.open() + } + } + } + } + } + + property bool shouldAutoScroll: true + property bool isAutoScrolling: false + + Connections { + target: currentChat + function onResponseChanged() { + if (listView.shouldAutoScroll) { + listView.isAutoScrolling = true + listView.positionViewAtEnd() + listView.isAutoScrolling = false + } + } + } + + onContentYChanged: { + if (!isAutoScrolling) + shouldAutoScroll = atYEnd + } + + Component.onCompleted: { + shouldAutoScroll = true + positionViewAtEnd() + } + + footer: Item { + id: bottomPadding + width: parent.width + height: 60 + } + } + } + } + + Button { + visible: chatModel.count + Image { + anchors.verticalCenter: parent.verticalCenter + anchors.left: parent.left + anchors.leftMargin: 15 + source: currentChat.responseInProgress ? "qrc:/gpt4all/icons/stop_generating.svg" : "qrc:/gpt4all/icons/regenerate.svg" + } + leftPadding: 50 + onClicked: { + var index = Math.max(0, chatModel.count - 1); + var listElement = chatModel.get(index); + + if (currentChat.responseInProgress) { + listElement.stopped = true + currentChat.stopGenerating() + } else { + currentChat.regenerateResponse() + if (chatModel.count) { + if (listElement.name === qsTr("Response: ")) { + chatModel.updateCurrentResponse(index, true); + chatModel.updateStopped(index, false); + chatModel.updateThumbsUpState(index, false); + chatModel.updateThumbsDownState(index, false); + chatModel.updateNewResponse(index, ""); + currentChat.prompt(listElement.prompt, settingsDialog.promptTemplate, + settingsDialog.maxLength, + settingsDialog.topK, settingsDialog.topP, + settingsDialog.temperature, + settingsDialog.promptBatchSize, + settingsDialog.repeatPenalty, + settingsDialog.repeatPenaltyTokens) + } + } + } + } + anchors.bottom: textInputView.top + anchors.horizontalCenter: textInputView.horizontalCenter + anchors.bottomMargin: 40 + padding: 15 + contentItem: Text { + text: currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") + color: theme.textColor + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Controls generation of the response") + } + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + } + + ScrollView { + id: textInputView + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: parent.bottom + anchors.margins: 30 + height: Math.min(contentHeight, 200) + + TextArea { + id: textInput + color: theme.textColor + padding: 20 + rightPadding: 40 + enabled: currentChat.isModelLoaded + wrapMode: Text.WordWrap + font.pixelSize: theme.fontSizeLarge + placeholderText: qsTr("Send a message...") + placeholderTextColor: theme.backgroundLightest + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + Accessible.description: qsTr("Textfield for sending messages/prompts to the model") + Keys.onReturnPressed: (event)=> { + if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) + event.accepted = false; + else { + editingFinished(); + sendMessage() + } + } + function sendMessage() { + if (textInput.text === "") + return + + currentChat.stopGenerating() + + if (chatModel.count) { + var index = Math.max(0, chatModel.count - 1); + var listElement = chatModel.get(index); + chatModel.updateCurrentResponse(index, false); + } + currentChat.newPromptResponsePair(textInput.text); + currentChat.prompt(textInput.text, settingsDialog.promptTemplate, + settingsDialog.maxLength, + settingsDialog.topK, + settingsDialog.topP, + settingsDialog.temperature, + settingsDialog.promptBatchSize, + settingsDialog.repeatPenalty, + settingsDialog.repeatPenaltyTokens) + textInput.text = "" + } + } + } + + Button { + anchors.right: textInputView.right + anchors.verticalCenter: textInputView.verticalCenter + anchors.rightMargin: 15 + width: 30 + height: 30 + + background: Image { + anchors.centerIn: parent + source: "qrc:/gpt4all/icons/send_message.svg" + } + + Accessible.role: Accessible.Button + Accessible.name: qsTr("Send the message button") + Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model") + + onClicked: { + textInput.sendMessage() + } + } + } +} diff --git a/gpt4all-chat/network.cpp b/gpt4all-chat/network.cpp new file mode 100644 index 00000000..d70bf7df --- /dev/null +++ b/gpt4all-chat/network.cpp @@ -0,0 +1,531 @@ +#include "network.h" +#include "llm.h" +#include "sysinfo.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#define DEBUG + +#if defined(Q_OS_MAC) +#include +std::string getCPUModel() { + char buffer[256]; + size_t bufferlen = sizeof(buffer); + sysctlbyname("machdep.cpu.brand_string", &buffer, &bufferlen, NULL, 0); + return std::string(buffer); +} +#endif + +class MyNetwork: public Network { }; +Q_GLOBAL_STATIC(MyNetwork, networkInstance) +Network *Network::globalInstance() +{ + return networkInstance(); +} + +Network::Network() + : QObject{nullptr} + , m_isActive(false) + , m_usageStatsActive(false) + , m_shouldSendStartup(false) +{ + QSettings settings; + settings.sync(); + m_uniqueId = settings.value("uniqueId", generateUniqueId()).toString(); + settings.setValue("uniqueId", m_uniqueId); + settings.sync(); + m_isActive = settings.value("network/isActive", false).toBool(); + if (m_isActive) + sendHealth(); + m_usageStatsActive = settings.value("network/usageStatsActive", false).toBool(); + if (m_usageStatsActive) + sendIpify(); + connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, + &Network::handleSslErrors); +} + +void Network::setActive(bool b) +{ + QSettings settings; + settings.setValue("network/isActive", b); + settings.sync(); + m_isActive = b; + emit activeChanged(); + if (m_isActive) + sendHealth(); +} + +void Network::setUsageStatsActive(bool b) +{ + QSettings settings; + settings.setValue("network/usageStatsActive", b); + settings.sync(); + m_usageStatsActive = b; + emit usageStatsActiveChanged(); + if (!m_usageStatsActive) + sendOptOut(); + else { + // model might be loaded already when user opt-in for first time + sendStartup(); + sendIpify(); + } +} + +QString Network::generateUniqueId() const +{ + return QUuid::createUuid().toString(QUuid::WithoutBraces); +} + +bool Network::packageAndSendJson(const QString &ingestId, const QString &json) +{ + if (!m_isActive) + return false; + + QJsonParseError err; + QJsonDocument doc = QJsonDocument::fromJson(json.toUtf8(), &err); + if (err.error != QJsonParseError::NoError) { + qDebug() << "Couldn't parse: " << json << err.errorString(); + return false; + } + + Q_ASSERT(doc.isObject()); + Q_ASSERT(LLM::globalInstance()->chatListModel()->currentChat()); + QJsonObject object = doc.object(); + object.insert("source", "gpt4all-chat"); + object.insert("agent_id", LLM::globalInstance()->chatListModel()->currentChat()->modelName()); + object.insert("submitter_id", m_uniqueId); + object.insert("ingest_id", ingestId); + + QSettings settings; + settings.sync(); + QString attribution = settings.value("network/attribution", QString()).toString(); + if (!attribution.isEmpty()) + object.insert("network/attribution", attribution); + + QJsonDocument newDoc; + newDoc.setObject(object); + +#if defined(DEBUG) + printf("%s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif + + QUrl jsonUrl("https://api.gpt4all.io/v1/ingest/chat"); + QNetworkRequest request(jsonUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QByteArray body(newDoc.toJson()); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + QNetworkReply *jsonReply = m_networkManager.post(request, body); + connect(jsonReply, &QNetworkReply::finished, this, &Network::handleJsonUploadFinished); + m_activeUploads.append(jsonReply); + return true; +} + +void Network::handleJsonUploadFinished() +{ + QNetworkReply *jsonReply = qobject_cast(sender()); + if (!jsonReply) + return; + + m_activeUploads.removeAll(jsonReply); + + QVariant response = jsonReply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok) + qWarning() << "ERROR: ingest invalid response."; + if (code != 200) { + qWarning() << "ERROR: ingest response != 200 code:" << code; + sendHealth(); + } + + QByteArray jsonData = jsonReply->readAll(); + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); + if (err.error != QJsonParseError::NoError) { + qDebug() << "ERROR: Couldn't parse: " << jsonData << err.errorString(); + return; + } + +#if defined(DEBUG) + printf("%s\n", qPrintable(document.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif + + jsonReply->deleteLater(); +} + +void Network::handleSslErrors(QNetworkReply *reply, const QList &errors) +{ + QUrl url = reply->request().url(); + for (auto e : errors) + qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url; +} + +void Network::sendOptOut() +{ + QJsonObject properties; + properties.insert("token", "ce362e568ddaee16ed243eaffb5860a2"); + properties.insert("time", QDateTime::currentSecsSinceEpoch()); + properties.insert("distinct_id", m_uniqueId); + properties.insert("$insert_id", generateUniqueId()); + + QJsonObject event; + event.insert("event", "opt_out"); + event.insert("properties", properties); + + QJsonArray array; + array.append(event); + + QJsonDocument doc; + doc.setArray(array); + sendMixpanel(doc.toJson(), true /*isOptOut*/); + +#if defined(DEBUG) + printf("%s %s\n", qPrintable("opt_out"), qPrintable(doc.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif +} + +void Network::sendModelLoaded() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("model_load"); +} + +void Network::sendResetContext(int conversationLength) +{ + if (!m_usageStatsActive) + return; + + KeyValue kv; + kv.key = QString("length"); + kv.value = QJsonValue(conversationLength); + sendMixpanelEvent("reset_context", QVector{kv}); +} + +void Network::sendStartup() +{ + if (!m_usageStatsActive) + return; + m_shouldSendStartup = true; + if (m_ipify.isEmpty()) + return; // when it completes it will send + sendMixpanelEvent("startup"); +} + +void Network::sendCheckForUpdates() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("check_for_updates"); +} + +void Network::sendModelDownloaderDialog() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("download_dialog"); +} + +void Network::sendDownloadStarted(const QString &model) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("model"); + kv.value = QJsonValue(model); + sendMixpanelEvent("download_started", QVector{kv}); +} + +void Network::sendDownloadCanceled(const QString &model) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("model"); + kv.value = QJsonValue(model); + sendMixpanelEvent("download_canceled", QVector{kv}); +} + +void Network::sendDownloadError(const QString &model, int code, const QString &errorString) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("model"); + kv.value = QJsonValue(model); + KeyValue kvCode; + kvCode.key = QString("code"); + kvCode.value = QJsonValue(code); + KeyValue kvError; + kvError.key = QString("error"); + kvError.value = QJsonValue(errorString); + sendMixpanelEvent("download_error", QVector{kv, kvCode, kvError}); +} + +void Network::sendDownloadFinished(const QString &model, bool success) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("model"); + kv.value = QJsonValue(model); + KeyValue kvSuccess; + kvSuccess.key = QString("success"); + kvSuccess.value = QJsonValue(success); + sendMixpanelEvent("download_finished", QVector{kv, kvSuccess}); +} + +void Network::sendSettingsDialog() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("settings_dialog"); +} + +void Network::sendNetworkToggled(bool isActive) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("isActive"); + kv.value = QJsonValue(isActive); + sendMixpanelEvent("network_toggled", QVector{kv}); +} + +void Network::sendSaveChatsToggled(bool isActive) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("isActive"); + kv.value = QJsonValue(isActive); + sendMixpanelEvent("savechats_toggled", QVector{kv}); +} + +void Network::sendNewChat(int count) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("number_of_chats"); + kv.value = QJsonValue(count); + sendMixpanelEvent("new_chat", QVector{kv}); +} + +void Network::sendRemoveChat() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("remove_chat"); +} + +void Network::sendRenameChat() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("rename_chat"); +} + +void Network::sendChatStarted() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("chat_started"); +} + +void Network::sendRecalculatingContext(int conversationLength) +{ + if (!m_usageStatsActive) + return; + + KeyValue kv; + kv.key = QString("length"); + kv.value = QJsonValue(conversationLength); + sendMixpanelEvent("recalc_context", QVector{kv}); +} + +void Network::sendNonCompatHardware() +{ + if (!m_usageStatsActive) + return; + sendMixpanelEvent("noncompat_hardware"); +} + +void Network::sendMixpanelEvent(const QString &ev, const QVector &values) +{ + if (!m_usageStatsActive) + return; + + Q_ASSERT(LLM::globalInstance()->chatListModel()->currentChat()); + QJsonObject properties; + properties.insert("token", "ce362e568ddaee16ed243eaffb5860a2"); + properties.insert("time", QDateTime::currentSecsSinceEpoch()); + properties.insert("distinct_id", m_uniqueId); + properties.insert("$insert_id", generateUniqueId()); + properties.insert("$os", QSysInfo::prettyProductName()); + if (!m_ipify.isEmpty()) + properties.insert("ip", m_ipify); + properties.insert("name", QCoreApplication::applicationName() + " v" + + QCoreApplication::applicationVersion()); + properties.insert("model", LLM::globalInstance()->chatListModel()->currentChat()->modelName()); + + // Some additional startup information + if (ev == "startup") { + const QSize display = QGuiApplication::primaryScreen()->size(); + properties.insert("display", QString("%1x%2").arg(display.width()).arg(display.height())); + properties.insert("ram", getSystemTotalRAM()); +#if defined(__x86_64__) || defined(__i386__) + properties.insert("avx", bool(__builtin_cpu_supports("avx"))); + properties.insert("avx2", bool(__builtin_cpu_supports("avx2"))); + properties.insert("fma", bool(__builtin_cpu_supports("fma"))); +#endif +#if defined(Q_OS_MAC) + properties.insert("cpu", QString::fromStdString(getCPUModel())); +#endif + } + + for (auto p : values) + properties.insert(p.key, p.value); + + QJsonObject event; + event.insert("event", ev); + event.insert("properties", properties); + + QJsonArray array; + array.append(event); + + QJsonDocument doc; + doc.setArray(array); + sendMixpanel(doc.toJson()); + +#if defined(DEBUG) + printf("%s %s\n", qPrintable(ev), qPrintable(doc.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif +} + +void Network::sendIpify() +{ + if (!m_usageStatsActive || !m_ipify.isEmpty()) + return; + + QUrl ipifyUrl("https://api.ipify.org"); + QNetworkRequest request(ipifyUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *reply = m_networkManager.get(request); + connect(reply, &QNetworkReply::finished, this, &Network::handleIpifyFinished); +} + +void Network::sendMixpanel(const QByteArray &json, bool isOptOut) +{ + if (!m_usageStatsActive && !isOptOut) + return; + + QUrl trackUrl("https://api.mixpanel.com/track"); + QNetworkRequest request(trackUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + QNetworkReply *trackReply = m_networkManager.post(request, json); + connect(trackReply, &QNetworkReply::finished, this, &Network::handleMixpanelFinished); +} + +void Network::handleIpifyFinished() +{ + Q_ASSERT(m_usageStatsActive); + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok) + qWarning() << "ERROR: ipify invalid response."; + if (code != 200) + qWarning() << "ERROR: ipify response != 200 code:" << code; + m_ipify = qPrintable(reply->readAll()); +#if defined(DEBUG) + printf("ipify finished %s\n", m_ipify.toLatin1().constData()); + fflush(stdout); +#endif + reply->deleteLater(); + + if (m_shouldSendStartup) + sendStartup(); +} + +void Network::handleMixpanelFinished() +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok) + qWarning() << "ERROR: track invalid response."; + if (code != 200) + qWarning() << "ERROR: track response != 200 code:" << code; +#if defined(DEBUG) + printf("mixpanel finished %s\n", qPrintable(reply->readAll())); + fflush(stdout); +#endif + reply->deleteLater(); +} + +bool Network::sendConversation(const QString &ingestId, const QString &conversation) +{ + return packageAndSendJson(ingestId, conversation); +} + +void Network::sendHealth() +{ + QUrl healthUrl("https://api.gpt4all.io/v1/health"); + QNetworkRequest request(healthUrl); + QSslConfiguration conf = request.sslConfiguration(); + conf.setPeerVerifyMode(QSslSocket::VerifyNone); + request.setSslConfiguration(conf); + QNetworkReply *healthReply = m_networkManager.get(request); + connect(healthReply, &QNetworkReply::finished, this, &Network::handleHealthFinished); +} + +void Network::handleHealthFinished() +{ + QNetworkReply *healthReply = qobject_cast(sender()); + if (!healthReply) + return; + + QVariant response = healthReply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok) + qWarning() << "ERROR: health invalid response."; + if (code != 200) { + qWarning() << "ERROR: health response != 200 code:" << code; + emit healthCheckFailed(code); + setActive(false); + } + healthReply->deleteLater(); +} diff --git a/gpt4all-chat/network.h b/gpt4all-chat/network.h new file mode 100644 index 00000000..1c9de2df --- /dev/null +++ b/gpt4all-chat/network.h @@ -0,0 +1,87 @@ +#ifndef NETWORK_H +#define NETWORK_H + +#include +#include +#include +#include + +struct KeyValue { + QString key; + QJsonValue value; +}; + +class Network : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool isActive READ isActive WRITE setActive NOTIFY activeChanged) + Q_PROPERTY(bool usageStatsActive READ usageStatsActive WRITE setUsageStatsActive NOTIFY usageStatsActiveChanged) + +public: + static Network *globalInstance(); + + bool isActive() const { return m_isActive; } + void setActive(bool b); + + bool usageStatsActive() const { return m_usageStatsActive; } + void setUsageStatsActive(bool b); + + Q_INVOKABLE QString generateUniqueId() const; + Q_INVOKABLE bool sendConversation(const QString &ingestId, const QString &conversation); + +Q_SIGNALS: + void activeChanged(); + void usageStatsActiveChanged(); + void healthCheckFailed(int code); + +public Q_SLOTS: + void sendOptOut(); + void sendModelLoaded(); + void sendStartup(); + void sendCheckForUpdates(); + Q_INVOKABLE void sendModelDownloaderDialog(); + Q_INVOKABLE void sendResetContext(int conversationLength); + void sendDownloadStarted(const QString &model); + void sendDownloadCanceled(const QString &model); + void sendDownloadError(const QString &model, int code, const QString &errorString); + void sendDownloadFinished(const QString &model, bool success); + Q_INVOKABLE void sendSettingsDialog(); + Q_INVOKABLE void sendNetworkToggled(bool active); + Q_INVOKABLE void sendSaveChatsToggled(bool active); + Q_INVOKABLE void sendNewChat(int count); + Q_INVOKABLE void sendRemoveChat(); + Q_INVOKABLE void sendRenameChat(); + Q_INVOKABLE void sendNonCompatHardware(); + void sendChatStarted(); + void sendRecalculatingContext(int conversationLength); + +private Q_SLOTS: + void handleIpifyFinished(); + void handleHealthFinished(); + void handleJsonUploadFinished(); + void handleSslErrors(QNetworkReply *reply, const QList &errors); + void handleMixpanelFinished(); + +private: + void sendHealth(); + void sendIpify(); + void sendMixpanelEvent(const QString &event, const QVector &values = QVector()); + void sendMixpanel(const QByteArray &json, bool isOptOut = false); + bool packageAndSendJson(const QString &ingestId, const QString &json); + +private: + bool m_shouldSendStartup; + bool m_isActive; + bool m_usageStatsActive; + QString m_ipify; + QString m_uniqueId; + QNetworkAccessManager m_networkManager; + QVector m_activeUploads; + +private: + explicit Network(); + ~Network() {} + friend class MyNetwork; +}; + +#endif // LLM_H diff --git a/gpt4all-chat/qml/AboutDialog.qml b/gpt4all-chat/qml/AboutDialog.qml new file mode 100644 index 00000000..35d1ee43 --- /dev/null +++ b/gpt4all-chat/qml/AboutDialog.qml @@ -0,0 +1,113 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: abpoutDialog + anchors.centerIn: parent + modal: false + opacity: 0.9 + padding: 20 + width: 1024 + height: column.height + 40 + + Theme { + id: theme + } + + Column { + id: column + spacing: 20 + Item { + width: childrenRect.width + height: childrenRect.height + Image { + id: img + anchors.top: parent.top + anchors.left: parent.left + width: 60 + height: 60 + source: "qrc:/gpt4all/icons/logo.svg" + } + Text { + anchors.left: img.right + anchors.leftMargin: 30 + anchors.verticalCenter: img.verticalCenter + text: qsTr("About GPT4All") + color: theme.textColor + } + } + + ScrollView { + clip: true + height: 200 + width: 1024 - 40 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + ScrollBar.horizontal.policy: ScrollBar.AlwaysOff + + TextArea { + id: welcome + wrapMode: Text.Wrap + width: 1024 - 40 + padding: 20 + textFormat: TextEdit.MarkdownText + text: qsTr("### Release notes\n") + + Download.releaseInfo.notes + + qsTr("### Contributors\n") + + Download.releaseInfo.contributors + color: theme.textColor + focus: false + readOnly: true + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Release notes") + Accessible.description: qsTr("Release notes for this version") + background: Rectangle { + color: theme.backgroundLight + radius: 10 + } + } + } + + Label { + id: discordLink + width: parent.width + textFormat: Text.RichText + wrapMode: Text.WordWrap + text: qsTr("Check out our discord channel https://discord.gg/4M2QFmTt2k") + onLinkActivated: { Qt.openUrlExternally("https://discord.gg/4M2QFmTt2k") } + color: theme.textColor + linkColor: theme.linkColor + + Accessible.role: Accessible.Link + Accessible.name: qsTr("Discord link") + } + + Label { + id: nomicProps + width: parent.width + textFormat: Text.RichText + wrapMode: Text.WordWrap + text: qsTr("Thank you to Nomic AI and the community for contributing so much great data, code, ideas, and energy to the growing open source AI ecosystem!") + onLinkActivated: { Qt.openUrlExternally("https://home.nomic.ai") } + color: theme.textColor + linkColor: theme.linkColor + + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Thank you blurb") + Accessible.description: qsTr("Contains embedded link to https://home.nomic.ai") + } + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } +} diff --git a/gpt4all-chat/qml/ChatDrawer.qml b/gpt4all-chat/qml/ChatDrawer.qml new file mode 100644 index 00000000..d3298f1a --- /dev/null +++ b/gpt4all-chat/qml/ChatDrawer.qml @@ -0,0 +1,353 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import llm +import download +import network + +Drawer { + id: chatDrawer + modal: false + opacity: 0.9 + + Theme { + id: theme + } + + signal downloadClicked + signal aboutClicked + + background: Rectangle { + height: parent.height + color: theme.backgroundDarkest + } + + Item { + anchors.fill: parent + anchors.margins: 10 + + Accessible.role: Accessible.Pane + Accessible.name: qsTr("Drawer on the left of the application") + Accessible.description: qsTr("Drawer that is revealed by pressing the hamburger button") + + Button { + id: newChat + anchors.left: parent.left + anchors.right: parent.right + padding: 15 + font.pixelSize: theme.fontSizeLarger + background: Rectangle { + color: theme.backgroundDarkest + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + } + contentItem: Text { + text: qsTr("New chat") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch an external application that will check for updates to the installer") + } + onClicked: { + LLM.chatListModel.addChat(); + Network.sendNewChat(LLM.chatListModel.count) + } + } + + ScrollView { + anchors.left: parent.left + anchors.right: parent.right + anchors.rightMargin: -10 + anchors.topMargin: 10 + anchors.top: newChat.bottom + anchors.bottom: checkForUpdatesButton.top + anchors.bottomMargin: 10 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + + ListView { + id: conversationList + anchors.fill: parent + anchors.rightMargin: 10 + + model: LLM.chatListModel + + delegate: Rectangle { + id: chatRectangle + width: conversationList.width + height: chatName.height + opacity: 0.9 + property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index) + property bool trashQuestionDisplayed: false + z: isCurrent ? 199 : 1 + color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter + border.width: isCurrent + border.color: chatName.readOnly ? theme.assistantColor : theme.userColor + TextField { + id: chatName + anchors.left: parent.left + anchors.right: buttons.left + color: theme.textColor + padding: 15 + focus: false + readOnly: true + wrapMode: Text.NoWrap + hoverEnabled: false // Disable hover events on the TextArea + selectByMouse: false // Disable text selection in the TextArea + font.pixelSize: theme.fontSizeLarger + text: readOnly ? metrics.elidedText : name + horizontalAlignment: TextInput.AlignLeft + opacity: trashQuestionDisplayed ? 0.5 : 1.0 + TextMetrics { + id: metrics + font: chatName.font + text: name + elide: Text.ElideRight + elideWidth: chatName.width - 25 + } + background: Rectangle { + color: "transparent" + } + onEditingFinished: { + // Work around a bug in qml where we're losing focus when the whole window + // goes out of focus even though this textfield should be marked as not + // having focus + if (chatName.readOnly) + return; + changeName(); + Network.sendRenameChat() + } + function changeName() { + LLM.chatListModel.get(index).name = chatName.text + chatName.focus = false + chatName.readOnly = true + chatName.selectByMouse = false + } + TapHandler { + onTapped: { + if (isCurrent) + return; + LLM.chatListModel.currentChat = LLM.chatListModel.get(index); + } + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Select the current chat") + Accessible.description: qsTr("Provides a button to select the current chat or edit the chat when in edit mode") + } + Row { + id: buttons + anchors.verticalCenter: chatName.verticalCenter + anchors.right: chatRectangle.right + anchors.rightMargin: 10 + spacing: 10 + Button { + id: editButton + width: 30 + height: 30 + visible: isCurrent + opacity: trashQuestionDisplayed ? 0.5 : 1.0 + background: Image { + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/edit.svg" + } + onClicked: { + chatName.focus = true + chatName.readOnly = false + chatName.selectByMouse = true + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Edit the chat name") + Accessible.description: qsTr("Provides a button to edit the chat name") + } + Button { + id: c + width: 30 + height: 30 + visible: isCurrent + background: Image { + width: 30 + height: 30 + source: "qrc:/gpt4all/icons/trash.svg" + } + onClicked: { + trashQuestionDisplayed = true + timer.start() + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Delete of the chat") + Accessible.description: qsTr("Provides a button to delete the chat") + } + } + Rectangle { + id: trashSureQuestion + anchors.top: buttons.bottom + anchors.topMargin: 10 + anchors.right: buttons.right + width: childrenRect.width + height: childrenRect.height + color: chatRectangle.color + visible: isCurrent && trashQuestionDisplayed + opacity: 1.0 + radius: 10 + z: 200 + Row { + spacing: 10 + Button { + id: checkMark + width: 30 + height: 30 + contentItem: Text { + color: theme.textErrorColor + text: "\u2713" + font.pixelSize: theme.fontSizeLarger + horizontalAlignment: Text.AlignHCenter + verticalAlignment: Text.AlignVCenter + } + background: Rectangle { + width: 30 + height: 30 + color: "transparent" + } + onClicked: { + LLM.chatListModel.removeChat(LLM.chatListModel.get(index)) + Network.sendRemoveChat() + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Confirm delete of the chat") + Accessible.description: qsTr("Provides a button to confirm delete of the chat") + } + Button { + id: cancel + width: 30 + height: 30 + contentItem: Text { + color: theme.textColor + text: "\u2715" + font.pixelSize: theme.fontSizeLarger + horizontalAlignment: Text.AlignHCenter + verticalAlignment: Text.AlignVCenter + } + background: Rectangle { + width: 30 + height: 30 + color: "transparent" + } + onClicked: { + trashQuestionDisplayed = false + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Cancel the delete of the chat") + Accessible.description: qsTr("Provides a button to cancel delete of the chat") + } + } + } + Timer { + id: timer + interval: 3000; running: false; repeat: false + onTriggered: trashQuestionDisplayed = false + } + } + + Accessible.role: Accessible.List + Accessible.name: qsTr("List of chats") + Accessible.description: qsTr("List of chats in the drawer dialog") + } + } + + Button { + id: checkForUpdatesButton + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: downloadButton.top + anchors.bottomMargin: 10 + padding: 15 + contentItem: Text { + text: qsTr("Updates") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch an external application that will check for updates to the installer") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + + onClicked: { + if (!LLM.checkForUpdates()) + checkForUpdatesError.open() + } + } + + Button { + id: downloadButton + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: aboutButton.top + anchors.bottomMargin: 10 + padding: 15 + contentItem: Text { + text: qsTr("Downloads") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch a dialog to download new models") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + + onClicked: { + downloadClicked() + } + } + + Button { + id: aboutButton + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: parent.bottom + padding: 15 + contentItem: Text { + text: qsTr("About") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch a dialog to show the about page") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + + onClicked: { + aboutClicked() + } + } + } +} \ No newline at end of file diff --git a/gpt4all-chat/qml/ModelDownloaderDialog.qml b/gpt4all-chat/qml/ModelDownloaderDialog.qml new file mode 100644 index 00000000..0c2a58a5 --- /dev/null +++ b/gpt4all-chat/qml/ModelDownloaderDialog.qml @@ -0,0 +1,383 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Dialogs +import QtQuick.Layouts +import download +import llm +import network + +Dialog { + id: modelDownloaderDialog + modal: true + opacity: 0.9 + closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + background: Rectangle { + anchors.fill: parent + anchors.margins: -20 + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + onOpened: { + Network.sendModelDownloaderDialog(); + } + + property string defaultModelPath: Download.defaultLocalModelsPath() + property alias modelPath: settings.modelPath + Settings { + id: settings + property string modelPath: modelDownloaderDialog.defaultModelPath + } + + Component.onCompleted: { + Download.downloadLocalModelsPath = settings.modelPath + } + + Component.onDestruction: { + settings.sync() + } + + ColumnLayout { + anchors.fill: parent + anchors.margins: 20 + spacing: 30 + + Label { + id: listLabel + text: "Available Models:" + Layout.alignment: Qt.AlignLeft + Layout.fillWidth: true + color: theme.textColor + } + + ScrollView { + id: scrollView + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + Layout.fillWidth: true + Layout.fillHeight: true + clip: true + + ListView { + id: modelList + model: Download.modelList + boundsBehavior: Flickable.StopAtBounds + + delegate: Item { + id: delegateItem + width: modelList.width + height: modelName.height + modelName.padding + + description.height + description.padding + objectName: "delegateItem" + property bool downloading: false + Rectangle { + anchors.fill: parent + color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter + } + + Text { + id: modelName + objectName: "modelName" + property string filename: modelData.filename + text: filename.slice(5, filename.length - 4) + padding: 20 + anchors.top: parent.top + anchors.left: parent.left + font.bold: modelData.isDefault || modelData.bestGPTJ || modelData.bestLlama || modelData.bestMPT + color: theme.assistantColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Model file") + Accessible.description: qsTr("Model file to be downloaded") + } + + Text { + id: description + text: " - " + modelData.description + leftPadding: 20 + rightPadding: 20 + anchors.top: modelName.bottom + anchors.left: modelName.left + anchors.right: parent.right + wrapMode: Text.WordWrap + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Description") + Accessible.description: qsTr("The description of the file") + } + + Text { + id: isDefault + text: qsTr("(default)") + visible: modelData.isDefault + anchors.top: modelName.top + anchors.left: modelName.right + padding: 20 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Default file") + Accessible.description: qsTr("Whether the file is the default model") + } + + Text { + text: modelData.filesize + anchors.top: modelName.top + anchors.left: isDefault.visible ? isDefault.right : modelName.right + padding: 20 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("File size") + Accessible.description: qsTr("The size of the file") + } + + Label { + id: speedLabel + anchors.top: modelName.top + anchors.right: itemProgressBar.left + padding: 20 + objectName: "speedLabel" + color: theme.textColor + text: "" + visible: downloading + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Download speed") + Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second") + } + + ProgressBar { + id: itemProgressBar + objectName: "itemProgressBar" + anchors.top: modelName.top + anchors.right: downloadButton.left + anchors.topMargin: 20 + anchors.rightMargin: 20 + width: 100 + visible: downloading + background: Rectangle { + implicitWidth: 200 + implicitHeight: 30 + color: theme.backgroundDarkest + radius: 3 + } + + contentItem: Item { + implicitWidth: 200 + implicitHeight: 25 + + Rectangle { + width: itemProgressBar.visualPosition * parent.width + height: parent.height + radius: 2 + color: theme.assistantColor + } + } + Accessible.role: Accessible.ProgressBar + Accessible.name: qsTr("Download progressBar") + Accessible.description: qsTr("Shows the progress made in the download") + } + + Item { + visible: modelData.calcHash + anchors.top: modelName.top + anchors.right: parent.right + + Label { + id: calcHashLabel + anchors.right: busyCalcHash.left + padding: 20 + objectName: "calcHashLabel" + color: theme.textColor + text: qsTr("Calculating MD5...") + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file hash is being calculated") + } + + BusyIndicator { + id: busyCalcHash + anchors.right: parent.right + padding: 20 + running: modelData.calcHash + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the file hash is being calculated") + } + } + + Label { + id: installedLabel + anchors.top: modelName.top + anchors.right: parent.right + padding: 20 + objectName: "installedLabel" + color: theme.textColor + text: qsTr("Already installed") + visible: modelData.installed + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: qsTr("Whether the file is already installed on your system") + } + + Button { + id: downloadButton + contentItem: Text { + color: theme.textColor + text: downloading ? "Cancel" : "Download" + } + anchors.top: modelName.top + anchors.right: parent.right + anchors.topMargin: 15 + anchors.rightMargin: 20 + visible: !modelData.installed && !modelData.calcHash + onClicked: { + if (!downloading) { + downloading = true; + Download.downloadModel(modelData.filename); + } else { + downloading = false; + Download.cancelDownload(modelData.filename); + } + } + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Cancel/Download button to stop/start the download") + + } + } + + Component.onCompleted: { + Download.downloadProgress.connect(updateProgress); + Download.downloadFinished.connect(resetProgress); + } + + property var lastUpdate: ({}) + + function updateProgress(bytesReceived, bytesTotal, modelName) { + let currentTime = new Date().getTime(); + + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = bytesReceived / bytesTotal; + + // Calculate the download speed + if (lastUpdate[modelName] && lastUpdate[modelName].timestamp) { + let timeDifference = currentTime - lastUpdate[modelName].timestamp; + let bytesDifference = bytesReceived - lastUpdate[modelName].bytesReceived; + let speed = (bytesDifference / timeDifference) * 1000; // bytes per second + delegateItem.downloading = true + + // Update the speed label + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + if (speed < 1024) { + speedLabel.text = speed.toFixed(2) + " B/s"; + } else if (speed < 1024 * 1024) { + speedLabel.text = (speed / 1024).toFixed(2) + " KB/s"; + } else { + speedLabel.text = (speed / (1024 * 1024)).toFixed(2) + " MB/s"; + } + } + + // Update the lastUpdate object for the current model + lastUpdate[modelName] = {"timestamp": currentTime, "bytesReceived": bytesReceived}; + break; + } + } + } + } + + function resetProgress(modelName) { + for (let i = 0; i < modelList.contentItem.children.length; i++) { + let delegateItem = modelList.contentItem.children[i]; + if (delegateItem.objectName === "delegateItem") { + let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").filename; + if (modelNameText === modelName) { + let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar"); + progressBar.value = 0; + delegateItem.downloading = false; + + // Remove speed label text + let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel"); + speedLabel.text = ""; + + // Remove the lastUpdate object for the canceled model + delete lastUpdate[modelName]; + break; + } + } + } + } + } + } + + RowLayout { + Layout.alignment: Qt.AlignCenter + Layout.fillWidth: true + spacing: 20 + FolderDialog { + id: modelPathDialog + title: "Please choose a directory" + currentFolder: Download.downloadLocalModelsPath + onAccepted: { + Download.downloadLocalModelsPath = selectedFolder + settings.modelPath = Download.downloadLocalModelsPath + settings.sync() + } + } + Label { + id: modelPathLabel + text: qsTr("Download path:") + color: theme.textColor + Layout.row: 1 + Layout.column: 0 + } + TextField { + id: modelPathDisplayLabel + text: Download.downloadLocalModelsPath + readOnly: true + color: theme.textColor + Layout.fillWidth: true + ToolTip.text: qsTr("Path where model files will be downloaded to") + ToolTip.visible: hovered + Accessible.role: Accessible.ToolTip + Accessible.name: modelPathDisplayLabel.text + Accessible.description: ToolTip.text + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + } + Button { + text: qsTr("Browse") + contentItem: Text { + text: qsTr("Browse") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Opens a folder picker dialog to choose where to save model files") + } + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + onClicked: modelPathDialog.open() + } + } + } +} diff --git a/gpt4all-chat/qml/NetworkDialog.qml b/gpt4all-chat/qml/NetworkDialog.qml new file mode 100644 index 00000000..10b97bf9 --- /dev/null +++ b/gpt4all-chat/qml/NetworkDialog.qml @@ -0,0 +1,174 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: networkDialog + anchors.centerIn: parent + modal: true + opacity: 0.9 + padding: 20 + + Theme { + id: theme + } + + Settings { + id: settings + category: "network" + property string attribution: "" + } + + Component.onDestruction: { + settings.sync() + } + + Column { + id: column + spacing: 20 + Item { + width: childrenRect.width + height: childrenRect.height + Image { + id: img + anchors.top: parent.top + anchors.left: parent.left + width: 60 + height: 60 + source: "qrc:/gpt4all/icons/logo.svg" + } + Text { + anchors.left: img.right + anchors.leftMargin: 30 + anchors.verticalCenter: img.verticalCenter + text: qsTr("Contribute data to the GPT4All Opensource Datalake.") + color: theme.textColor + } + } + + ScrollView { + clip: true + height: 300 + width: 1024 - 40 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + ScrollBar.horizontal.policy: ScrollBar.AlwaysOff + + TextArea { + id: textOptIn + wrapMode: Text.Wrap + width: 1024 - 40 + padding: 20 + text: qsTr("By enabling this feature, you will be able to participate in the democratic process of training a large language model by contributing data for future model improvements. + +When a GPT4All model responds to you and you have opted-in, your conversation will be sent to the GPT4All Open Source Datalake. Additionally, you can like/dislike its response. If you dislike a response, you can suggest an alternative response. This data will be collected and aggregated in the GPT4All Datalake. + +NOTE: By turning on this feature, you will be sending your data to the GPT4All Open Source Datalake. You should have no expectation of chat privacy when this feature is enabled. You should; however, have an expectation of an optional attribution if you wish. Your chat data will be openly available for anyone to download and will be used by Nomic AI to improve future GPT4All models. Nomic AI will retain all attribution information attached to your data and you will be credited as a contributor to any GPT4All model release that uses your data!") + color: theme.textColor + focus: false + readOnly: true + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Terms for opt-in") + Accessible.description: qsTr("Describes what will happen when you opt-in") + background: Rectangle { + color: theme.backgroundLight + radius: 10 + } + } + } + + TextField { + id: attribution + color: theme.textColor + padding: 20 + width: parent.width + text: settings.attribution + font.pixelSize: theme.fontSizeLarge + placeholderText: qsTr("Please provide a name for attribution (optional)") + placeholderTextColor: theme.backgroundLightest + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + Accessible.role: Accessible.EditableText + Accessible.name: qsTr("Attribution (optional)") + Accessible.description: qsTr("Textfield for providing attribution") + onEditingFinished: { + settings.attribution = attribution.text; + settings.sync(); + } + } + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + footer: DialogButtonBox { + id: dialogBox + padding: 20 + alignment: Qt.AlignRight + spacing: 10 + Button { + contentItem: Text { + color: theme.textColor + text: qsTr("Enable") + } + background: Rectangle { + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Enable opt-in button") + + padding: 15 + DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole + } + Button { + contentItem: Text { + color: theme.textColor + text: qsTr("Cancel") + } + background: Rectangle { + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Cancel opt-in button") + + padding: 15 + DialogButtonBox.buttonRole: DialogButtonBox.RejectRole + } + background: Rectangle { + color: "transparent" + } + } + + onAccepted: { + if (Network.isActive) + return + Network.isActive = true; + Network.sendNetworkToggled(true); + } + + onRejected: { + if (!Network.isActive) + return + Network.isActive = false; + Network.sendNetworkToggled(false); + } +} diff --git a/gpt4all-chat/qml/NewVersionDialog.qml b/gpt4all-chat/qml/NewVersionDialog.qml new file mode 100644 index 00000000..8da15f31 --- /dev/null +++ b/gpt4all-chat/qml/NewVersionDialog.qml @@ -0,0 +1,76 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: newVerionDialog + anchors.centerIn: parent + modal: true + opacity: 0.9 + width: contentItem.width + height: contentItem.height + padding: 20 + + Theme { + id: theme + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + Item { + id: contentItem + width: childrenRect.width + 40 + height: childrenRect.height + 40 + + Label { + id: label + anchors.top: parent.top + anchors.left: parent.left + topPadding: 20 + bottomPadding: 20 + text: qsTr("New version is available:") + color: theme.textColor + } + + Button { + id: button + anchors.left: label.right + anchors.leftMargin: 10 + anchors.verticalCenter: label.verticalCenter + padding: 20 + contentItem: Text { + text: qsTr("Update") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Use this to launch an external application that will check for updates to the installer") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + + onClicked: { + if (!LLM.checkForUpdates()) + checkForUpdatesError.open() + } + } + } +} diff --git a/gpt4all-chat/qml/PopupDialog.qml b/gpt4all-chat/qml/PopupDialog.qml new file mode 100644 index 00000000..dfd80d54 --- /dev/null +++ b/gpt4all-chat/qml/PopupDialog.qml @@ -0,0 +1,71 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +Dialog { + id: popupDialog + anchors.centerIn: parent + opacity: 0.9 + padding: 20 + property alias text: textField.text + property bool shouldTimeOut: true + property bool shouldShowBusy: false + modal: shouldShowBusy + closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + + Theme { + id: theme + } + + Row { + anchors.centerIn: parent + width: childrenRect.width + height: childrenRect.height + spacing: 20 + + Text { + id: textField + anchors.verticalCenter: busyIndicator.verticalCenter + horizontalAlignment: Text.AlignJustify + color: theme.textColor + Accessible.role: Accessible.HelpBalloon + Accessible.name: text + Accessible.description: qsTr("Reveals a shortlived help balloon") + } + + BusyIndicator { + id: busyIndicator + visible: shouldShowBusy + running: shouldShowBusy + + Accessible.role: Accessible.Animation + Accessible.name: qsTr("Busy indicator") + Accessible.description: qsTr("Displayed when the popup is showing busy") + } + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + exit: Transition { + NumberAnimation { duration: 500; property: "opacity"; from: 1.0; to: 0.0 } + } + + onOpened: { + if (shouldTimeOut) + timer.start() + } + + Timer { + id: timer + interval: 500; running: false; repeat: false + onTriggered: popupDialog.close() + } +} \ No newline at end of file diff --git a/gpt4all-chat/qml/SettingsDialog.qml b/gpt4all-chat/qml/SettingsDialog.qml new file mode 100644 index 00000000..c9f3557f --- /dev/null +++ b/gpt4all-chat/qml/SettingsDialog.qml @@ -0,0 +1,828 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Dialogs +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: settingsDialog + modal: true + opacity: 0.9 + background: Rectangle { + anchors.fill: parent + anchors.margins: -20 + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + onOpened: { + Network.sendSettingsDialog(); + } + + property var currentChat: LLM.chatListModel.currentChat + + Theme { + id: theme + } + + property real defaultTemperature: 0.28 + property real defaultTopP: 0.95 + property int defaultTopK: 40 + property int defaultMaxLength: 4096 + property int defaultPromptBatchSize: 9 + property real defaultRepeatPenalty: 1.10 + property int defaultRepeatPenaltyTokens: 64 + property int defaultThreadCount: 0 + property bool defaultSaveChats: false + property string defaultPromptTemplate: "### Human: +%1 +### Assistant:\n" + property string defaultModelPath: Download.defaultLocalModelsPath() + property string defaultUserDefaultModel: "Application default" + + property alias temperature: settings.temperature + property alias topP: settings.topP + property alias topK: settings.topK + property alias maxLength: settings.maxLength + property alias promptBatchSize: settings.promptBatchSize + property alias promptTemplate: settings.promptTemplate + property alias repeatPenalty: settings.repeatPenalty + property alias repeatPenaltyTokens: settings.repeatPenaltyTokens + property alias threadCount: settings.threadCount + property alias saveChats: settings.saveChats + property alias modelPath: settings.modelPath + property alias userDefaultModel: settings.userDefaultModel + + Settings { + id: settings + property real temperature: settingsDialog.defaultTemperature + property real topP: settingsDialog.defaultTopP + property int topK: settingsDialog.defaultTopK + property int maxLength: settingsDialog.defaultMaxLength + property int promptBatchSize: settingsDialog.defaultPromptBatchSize + property int threadCount: settingsDialog.defaultThreadCount + property bool saveChats: settingsDialog.defaultSaveChats + property real repeatPenalty: settingsDialog.defaultRepeatPenalty + property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens + property string promptTemplate: settingsDialog.defaultPromptTemplate + property string modelPath: settingsDialog.defaultModelPath + property string userDefaultModel: settingsDialog.defaultUserDefaultModel + } + + function restoreGenerationDefaults() { + settings.temperature = defaultTemperature + settings.topP = defaultTopP + settings.topK = defaultTopK + settings.maxLength = defaultMaxLength + settings.promptBatchSize = defaultPromptBatchSize + settings.promptTemplate = defaultPromptTemplate + settings.repeatPenalty = defaultRepeatPenalty + settings.repeatPenaltyTokens = defaultRepeatPenaltyTokens + settings.sync() + } + + function restoreApplicationDefaults() { + settings.modelPath = settingsDialog.defaultModelPath + settings.threadCount = defaultThreadCount + settings.saveChats = defaultSaveChats + settings.userDefaultModel = defaultUserDefaultModel + Download.downloadLocalModelsPath = settings.modelPath + LLM.threadCount = settings.threadCount + LLM.chatListModel.shouldSaveChats = settings.saveChats + settings.sync() + } + + Component.onCompleted: { + LLM.threadCount = settings.threadCount + LLM.chatListModel.shouldSaveChats = settings.saveChats + Download.downloadLocalModelsPath = settings.modelPath + } + + Connections { + target: settingsDialog + function onClosed() { + settings.sync() + } + } + + Item { + Accessible.role: Accessible.Dialog + Accessible.name: qsTr("Settings dialog") + Accessible.description: qsTr("Dialog containing various application settings") + } + TabBar { + id: settingsTabBar + width: parent.width / 1.5 + + TabButton { + id: genSettingsButton + contentItem: IconLabel { + color: theme.textColor + font.bold: genSettingsButton.checked + font.pixelSize: genSettingsButton.checked ? theme.fontSizeLarger : theme.fontSizeLarge + text: qsTr("Generation") + } + background: Rectangle { + color: genSettingsButton.checked ? theme.backgroundDarkest : theme.backgroundLight + border.color: theme.tabBorder + border.width: 1 ? genSettingsButton.checked : 0 + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Generation settings") + Accessible.description: qsTr("Settings related to how the model generates text") + } + + TabButton { + id: appSettingsButton + contentItem: IconLabel { + color: theme.textColor + font.bold: appSettingsButton.checked + font.pixelSize: appSettingsButton.checked ? theme.fontSizeLarger : theme.fontSizeLarge + text: qsTr("Application") + } + background: Rectangle { + color: appSettingsButton.checked ? theme.backgroundDarkest : theme.backgroundLight + border.color: theme.tabBorder + border.width: 1 ? appSettingsButton.checked : 0 + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Application settings") + Accessible.description: qsTr("Settings related to general behavior of the application") + } + } + + StackLayout { + anchors.top: settingsTabBar.bottom + width: parent.width + height: availableHeight + currentIndex: settingsTabBar.currentIndex + + Item { + id: generationSettingsTab + ScrollView { + background: Rectangle { + color: 'transparent' + border.color: theme.tabBorder + border.width: 1 + radius: 2 + } + padding: 10 + width: parent.width + height: parent.height - 30 + contentWidth: availableWidth - 20 + contentHeight: generationSettingsTabInner.implicitHeight + 40 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + + GridLayout { + id: generationSettingsTabInner + anchors.margins: 10 + columns: 2 + rowSpacing: 10 + columnSpacing: 10 + anchors.fill: parent + + Label { + id: tempLabel + text: qsTr("Temperature:") + color: theme.textColor + Layout.row: 0 + Layout.column: 0 + } + TextField { + text: settings.temperature.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens - higher temperature gives more creative but less predictable outputs") + ToolTip.visible: hovered + Layout.row: 0 + Layout.column: 1 + validator: DoubleValidator { + locale: "C" + } + onEditingFinished: { + var val = parseFloat(text) + if (!isNaN(val)) { + settings.temperature = val + settings.sync() + focus = false + } else { + text = settings.temperature.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: tempLabel.text + Accessible.description: ToolTip.text + } + Label { + id: topPLabel + text: qsTr("Top P:") + color: theme.textColor + Layout.row: 1 + Layout.column: 0 + } + TextField { + text: settings.topP.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen, prevents choosing highly unlikely tokens, aka Nucleus Sampling") + ToolTip.visible: hovered + Layout.row: 1 + Layout.column: 1 + validator: DoubleValidator { + locale: "C" + } + onEditingFinished: { + var val = parseFloat(text) + if (!isNaN(val)) { + settings.topP = val + settings.sync() + focus = false + } else { + text = settings.topP.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: topPLabel.text + Accessible.description: ToolTip.text + } + Label { + id: topKLabel + text: qsTr("Top K:") + color: theme.textColor + Layout.row: 2 + Layout.column: 0 + } + TextField { + text: settings.topK.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from") + ToolTip.visible: hovered + Layout.row: 2 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.topK = val + settings.sync() + focus = false + } else { + text = settings.topK.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: topKLabel.text + Accessible.description: ToolTip.text + } + Label { + id: maxLengthLabel + text: qsTr("Max Length:") + color: theme.textColor + Layout.row: 3 + Layout.column: 0 + } + TextField { + text: settings.maxLength.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Maximum length of response in tokens") + ToolTip.visible: hovered + Layout.row: 3 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.maxLength = val + settings.sync() + focus = false + } else { + text = settings.maxLength.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: maxLengthLabel.text + Accessible.description: ToolTip.text + } + + Label { + id: batchSizeLabel + text: qsTr("Prompt Batch Size:") + color: theme.textColor + Layout.row: 4 + Layout.column: 0 + } + TextField { + text: settings.promptBatchSize.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Amount of prompt tokens to process at once, higher values can speed up reading prompts but will use more RAM") + ToolTip.visible: hovered + Layout.row: 4 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.promptBatchSize = val + settings.sync() + focus = false + } else { + text = settings.promptBatchSize.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: batchSizeLabel.text + Accessible.description: ToolTip.text + } + Label { + id: repeatPenaltyLabel + text: qsTr("Repeat Penalty:") + color: theme.textColor + Layout.row: 5 + Layout.column: 0 + } + TextField { + text: settings.repeatPenalty.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Amount to penalize repetitiveness of the output") + ToolTip.visible: hovered + Layout.row: 5 + Layout.column: 1 + validator: DoubleValidator { + locale: "C" + } + onEditingFinished: { + var val = parseFloat(text) + if (!isNaN(val)) { + settings.repeatPenalty = val + settings.sync() + focus = false + } else { + text = settings.repeatPenalty.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: repeatPenaltyLabel.text + Accessible.description: ToolTip.text + } + Label { + id: repeatPenaltyTokensLabel + text: qsTr("Repeat Penalty Tokens:") + color: theme.textColor + Layout.row: 6 + Layout.column: 0 + } + TextField { + text: settings.repeatPenaltyTokens.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("How far back in output to apply repeat penalty") + ToolTip.visible: hovered + Layout.row: 6 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.repeatPenaltyTokens = val + settings.sync() + focus = false + } else { + text = settings.repeatPenaltyTokens.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: repeatPenaltyTokensLabel.text + Accessible.description: ToolTip.text + } + + Label { + id: promptTemplateLabel + text: qsTr("Prompt Template:") + color: theme.textColor + Layout.row: 7 + Layout.column: 0 + } + Rectangle { + Layout.row: 7 + Layout.column: 1 + Layout.fillWidth: true + height: 200 + color: "transparent" + clip: true + Label { + id: promptTemplateLabelHelp + visible: settings.promptTemplate.indexOf( + "%1") === -1 + font.bold: true + color: theme.textErrorColor + text: qsTr("Prompt template must contain %1 to be replaced with the user's input.") + anchors.fill: templateScrollView + z: 200 + padding: 10 + wrapMode: TextArea.Wrap + Accessible.role: Accessible.EditableText + Accessible.name: text + } + ScrollView { + id: templateScrollView + anchors.fill: parent + TextArea { + text: settings.promptTemplate + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + wrapMode: TextArea.Wrap + onTextChanged: { + settings.promptTemplate = text + settings.sync() + } + bottomPadding: 10 + Accessible.role: Accessible.EditableText + Accessible.name: promptTemplateLabel.text + Accessible.description: promptTemplateLabelHelp.text + } + } + } + Button { + Layout.row: 8 + Layout.column: 1 + Layout.fillWidth: true + padding: 10 + contentItem: Text { + text: qsTr("Restore Defaults") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Restores the settings dialog to a default state") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + onClicked: { + settingsDialog.restoreGenerationDefaults() + } + } + } + } + } + Item { + id: applicationSettingsTab + ScrollView { + background: Rectangle { + color: 'transparent' + border.color: theme.tabBorder + border.width: 1 + radius: 2 + } + padding: 10 + width: parent.width + height: parent.height - 30 + contentWidth: availableWidth - 20 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + + GridLayout { + anchors.margins: 10 + columns: 3 + rowSpacing: 10 + columnSpacing: 10 + anchors.fill: parent + Label { + id: defaultModelLabel + text: qsTr("Default model:") + color: theme.textColor + Layout.row: 1 + Layout.column: 0 + } + ComboBox { + id: comboBox + Layout.row: 1 + Layout.column: 1 + Layout.minimumWidth: 350 + font.pixelSize: theme.fontSizeLarge + spacing: 0 + padding: 10 + model: modelList + Accessible.role: Accessible.ComboBox + Accessible.name: qsTr("ComboBox for displaying/picking the default model") + Accessible.description: qsTr("Use this for picking the default model to use; the first item is the current default model") + function updateModel(newModelList) { + var newArray = Array.from(newModelList); + newArray.unshift('Application default'); + comboBox.model = newArray; + settings.sync(); + comboBox.currentIndex = comboBox.indexOfValue(settingsDialog.userDefaultModel); + + } + Component.onCompleted: { + comboBox.updateModel(currentChat.modelList) + } + Connections { + target: settings + function onUserDefaultModelChanged() { + comboBox.updateModel(currentChat.modelList) + } + } + Connections { + target: currentChat + function onModelListChanged() { + comboBox.updateModel(currentChat.modelList) + } + } + contentItem: Text { + anchors.horizontalCenter: parent.horizontalCenter + leftPadding: 10 + rightPadding: 10 + text: comboBox.displayText + font: comboBox.font + color: theme.textColor + verticalAlignment: Text.AlignVCenter + horizontalAlignment: Text.AlignHCenter + elide: Text.ElideRight + } + delegate: ItemDelegate { + width: comboBox.width + contentItem: Text { + text: modelData + color: theme.textColor + font: comboBox.font + elide: Text.ElideRight + verticalAlignment: Text.AlignVCenter + } + background: Rectangle { + color: highlighted ? theme.backgroundLight : theme.backgroundDark + } + highlighted: comboBox.highlightedIndex === index + } + popup: Popup { + y: comboBox.height - 1 + width: comboBox.width + implicitHeight: contentItem.implicitHeight + padding: 0 + + contentItem: ListView { + clip: true + implicitHeight: contentHeight + model: comboBox.popup.visible ? comboBox.delegateModel : null + currentIndex: comboBox.highlightedIndex + ScrollIndicator.vertical: ScrollIndicator { } + } + + background: Rectangle { + color: theme.backgroundDark + } + } + + background: Rectangle { + color: theme.backgroundDark + border.width: 1 + border.color: theme.backgroundLightest + radius: 10 + } + + onActivated: { + settingsDialog.userDefaultModel = comboBox.currentText + settings.sync() + } + } + FolderDialog { + id: modelPathDialog + title: "Please choose a directory" + currentFolder: Download.downloadLocalModelsPath + onAccepted: { + Download.downloadLocalModelsPath = selectedFolder + settings.modelPath = Download.downloadLocalModelsPath + settings.sync() + } + } + Label { + id: modelPathLabel + text: qsTr("Download path:") + color: theme.textColor + Layout.row: 2 + Layout.column: 0 + } + TextField { + id: modelPathDisplayLabel + text: Download.downloadLocalModelsPath + readOnly: true + color: theme.textColor + implicitWidth: 300 + padding: 10 + Layout.row: 2 + Layout.column: 1 + Layout.fillWidth: true + ToolTip.text: qsTr("Path where model files will be downloaded to") + ToolTip.visible: hovered + Accessible.role: Accessible.ToolTip + Accessible.name: modelPathDisplayLabel.text + Accessible.description: ToolTip.text + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + } + Button { + Layout.row: 2 + Layout.column: 2 + text: qsTr("Browse") + contentItem: Text { + text: qsTr("Browse") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Opens a folder picker dialog to choose where to save model files") + } + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + onClicked: modelPathDialog.open() + } + Label { + id: nThreadsLabel + text: qsTr("CPU Threads:") + color: theme.textColor + Layout.row: 3 + Layout.column: 0 + } + TextField { + text: settingsDialog.threadCount.toString() + color: theme.textColor + background: Rectangle { + implicitWidth: 150 + color: theme.backgroundLighter + radius: 10 + } + padding: 10 + ToolTip.text: qsTr("Amount of processing threads to use, a setting of 0 will use the lesser of 4 or your number of CPU threads") + ToolTip.visible: hovered + Layout.row: 3 + Layout.column: 1 + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settingsDialog.threadCount = val + LLM.threadCount = val + settings.sync() + focus = false + } else { + text = settingsDialog.threadCount.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: nThreadsLabel.text + Accessible.description: ToolTip.text + } + Label { + id: saveChatsLabel + text: qsTr("Save chats to disk:") + color: theme.textColor + Layout.row: 4 + Layout.column: 0 + } + CheckBox { + id: saveChatsBox + Layout.row: 4 + Layout.column: 1 + checked: settingsDialog.saveChats + onClicked: { + Network.sendSaveChatsToggled(saveChatsBox.checked); + settingsDialog.saveChats = saveChatsBox.checked + LLM.chatListModel.shouldSaveChats = saveChatsBox.checked + settings.sync() + } + + ToolTip.text: qsTr("WARNING: Saving chats to disk can be ~2GB per chat") + ToolTip.visible: hovered + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: saveChatsBox.leftPadding + y: parent.height / 2 - height / 2 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + color: theme.textColor + visible: saveChatsBox.checked + } + } + + contentItem: Text { + text: saveChatsBox.text + font: saveChatsBox.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: saveChatsBox.indicator.width + saveChatsBox.spacing + } + } + Button { + Layout.row: 5 + Layout.column: 1 + Layout.fillWidth: true + padding: 10 + contentItem: Text { + text: qsTr("Restore Defaults") + horizontalAlignment: Text.AlignHCenter + color: theme.textColor + Accessible.role: Accessible.Button + Accessible.name: text + Accessible.description: qsTr("Restores the settings dialog to a default state") + } + + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + onClicked: { + settingsDialog.restoreApplicationDefaults() + } + } + } + } + } + } +} diff --git a/gpt4all-chat/qml/StartupDialog.qml b/gpt4all-chat/qml/StartupDialog.qml new file mode 100644 index 00000000..fabc02ef --- /dev/null +++ b/gpt4all-chat/qml/StartupDialog.qml @@ -0,0 +1,357 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: startupDialog + anchors.centerIn: parent + modal: true + opacity: 0.9 + padding: 20 + width: 1024 + height: column.height + 40 + closePolicy: !optInStatisticsRadio.choiceMade || !optInNetworkRadio.choiceMade ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) + + Theme { + id: theme + } + + Column { + id: column + spacing: 20 + Item { + width: childrenRect.width + height: childrenRect.height + Image { + id: img + anchors.top: parent.top + anchors.left: parent.left + width: 60 + height: 60 + source: "qrc:/gpt4all/icons/logo.svg" + } + Text { + anchors.left: img.right + anchors.leftMargin: 30 + anchors.verticalCenter: img.verticalCenter + text: qsTr("Welcome!") + color: theme.textColor + } + } + + ScrollView { + clip: true + height: 200 + width: 1024 - 40 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + ScrollBar.horizontal.policy: ScrollBar.AlwaysOff + + TextArea { + id: welcome + wrapMode: Text.Wrap + width: 1024 - 40 + padding: 20 + textFormat: TextEdit.MarkdownText + text: qsTr("### Release notes\n") + + Download.releaseInfo.notes + + qsTr("### Contributors\n") + + Download.releaseInfo.contributors + color: theme.textColor + focus: false + readOnly: true + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Release notes") + Accessible.description: qsTr("Release notes for this version") + background: Rectangle { + color: theme.backgroundLight + radius: 10 + } + } + } + + ScrollView { + clip: true + height: 150 + width: 1024 - 40 + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + ScrollBar.horizontal.policy: ScrollBar.AlwaysOff + + TextArea { + id: optInTerms + wrapMode: Text.Wrap + width: 1024 - 40 + padding: 20 + textFormat: TextEdit.MarkdownText + text: qsTr( +"### Opt-ins for anonymous usage analytics and datalake +By enabling these features, you will be able to participate in the democratic process of training a +large language model by contributing data for future model improvements. + +When a GPT4All model responds to you and you have opted-in, your conversation will be sent to the GPT4All +Open Source Datalake. Additionally, you can like/dislike its response. If you dislike a response, you +can suggest an alternative response. This data will be collected and aggregated in the GPT4All Datalake. + +NOTE: By turning on this feature, you will be sending your data to the GPT4All Open Source Datalake. +You should have no expectation of chat privacy when this feature is enabled. You should; however, have +an expectation of an optional attribution if you wish. Your chat data will be openly available for anyone +to download and will be used by Nomic AI to improve future GPT4All models. Nomic AI will retain all +attribution information attached to your data and you will be credited as a contributor to any GPT4All +model release that uses your data!") + + color: theme.textColor + focus: false + readOnly: true + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Terms for opt-in") + Accessible.description: qsTr("Describes what will happen when you opt-in") + background: Rectangle { + color: theme.backgroundLight + radius: 10 + } + } + } + + GridLayout { + columns: 2 + rowSpacing: 10 + columnSpacing: 10 + anchors.right: parent.right + Label { + id: optInStatistics + text: "Opt-in to anonymous usage analytics used to improve GPT4All" + Layout.row: 0 + Layout.column: 0 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Opt-in for anonymous usage statistics") + Accessible.description: qsTr("Label for opt-in") + } + + ButtonGroup { + buttons: optInStatisticsRadio.children + onClicked: { + Network.usageStatsActive = optInStatisticsRadio.checked + if (optInNetworkRadio.choiceMade && optInStatisticsRadio.choiceMade) + startupDialog.close(); + } + } + + RowLayout { + id: optInStatisticsRadio + Layout.alignment: Qt.AlignVCenter + Layout.row: 0 + Layout.column: 1 + property bool defaultChecked: Network.usageStatsActive + property alias checked: optInStatisticsRadioYes.checked + property bool choiceMade: optInStatisticsRadioYes.checked || optInStatisticsRadioNo.checked + + RadioButton { + id: optInStatisticsRadioYes + checked: optInStatisticsRadio.defaultChecked + text: qsTr("Yes") + Accessible.role: Accessible.RadioButton + Accessible.name: qsTr("Opt-in for anonymous usage statistics") + Accessible.description: qsTr("Radio button to allow opt-in for anonymous usage statistics") + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: optInStatisticsRadioYes.leftPadding + y: parent.height / 2 - height / 2 + radius: 13 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + radius: 7 + color: theme.textColor + visible: optInStatisticsRadioYes.checked + } + } + + contentItem: Text { + text: optInStatisticsRadioYes.text + font: optInStatisticsRadioYes.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: optInStatisticsRadioYes.indicator.width + optInStatisticsRadioYes.spacing + } + } + RadioButton { + id: optInStatisticsRadioNo + text: qsTr("No") + Accessible.role: Accessible.RadioButton + Accessible.name: qsTr("Opt-out for anonymous usage statistics") + Accessible.description: qsTr("Radio button to allow opt-out for anonymous usage statistics") + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: optInStatisticsRadioNo.leftPadding + y: parent.height / 2 - height / 2 + radius: 13 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + radius: 7 + color: theme.textColor + visible: optInStatisticsRadioNo.checked + } + } + + contentItem: Text { + text: optInStatisticsRadioNo.text + font: optInStatisticsRadioNo.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: optInStatisticsRadioNo.indicator.width + optInStatisticsRadioNo.spacing + } + } + } + + Label { + id: optInNetwork + text: "Opt-in to anonymous sharing of chats to the GPT4All Datalake" + Layout.row: 1 + Layout.column: 0 + color: theme.textColor + Accessible.role: Accessible.Paragraph + Accessible.name: qsTr("Opt-in for network") + Accessible.description: qsTr("Checkbox to allow opt-in for network") + } + + ButtonGroup { + buttons: optInNetworkRadio.children + onClicked: { + Network.isActive = optInNetworkRadio.checked + if (optInNetworkRadio.choiceMade && optInStatisticsRadio.choiceMade) + startupDialog.close(); + } + } + + RowLayout { + id: optInNetworkRadio + Layout.alignment: Qt.AlignVCenter + Layout.row: 1 + Layout.column: 1 + property bool defaultChecked: Network.isActive + property alias checked: optInNetworkRadioYes.checked + property bool choiceMade: optInNetworkRadioYes.checked || optInNetworkRadioNo.checked + + RadioButton { + id: optInNetworkRadioYes + checked: optInNetworkRadio.defaultChecked + text: qsTr("Yes") + Accessible.role: Accessible.RadioButton + Accessible.name: qsTr("Opt-in for network") + Accessible.description: qsTr("Radio button to allow opt-in anonymous sharing of chats to the GPT4All Datalake") + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: optInNetworkRadioYes.leftPadding + y: parent.height / 2 - height / 2 + radius: 13 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + radius: 7 + color: theme.textColor + visible: optInNetworkRadioYes.checked + } + } + + contentItem: Text { + text: optInNetworkRadioYes.text + font: optInNetworkRadioYes.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: optInNetworkRadioYes.indicator.width + optInNetworkRadioYes.spacing + } + } + RadioButton { + id: optInNetworkRadioNo + text: qsTr("No") + Accessible.role: Accessible.RadioButton + Accessible.name: qsTr("Opt-out for network") + Accessible.description: qsTr("Radio button to allow opt-out anonymous sharing of chats to the GPT4All Datalake") + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: optInNetworkRadioNo.leftPadding + y: parent.height / 2 - height / 2 + radius: 13 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + radius: 7 + color: theme.textColor + visible: optInNetworkRadioNo.checked + } + } + + contentItem: Text { + text: optInNetworkRadioNo.text + font: optInNetworkRadioNo.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: optInNetworkRadioNo.indicator.width + optInNetworkRadioNo.spacing + } + } + } + } + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } +} diff --git a/gpt4all-chat/qml/Theme.qml b/gpt4all-chat/qml/Theme.qml new file mode 100644 index 00000000..6417550d --- /dev/null +++ b/gpt4all-chat/qml/Theme.qml @@ -0,0 +1,20 @@ +import QtCore +import QtQuick +import QtQuick.Controls.Basic + +QtObject { + property color textColor: "#d1d5db" + property color textErrorColor: "red" + property color backgroundDarkest: "#202123" + property color backgroundDark: "#242528" + property color backgroundLight: "#343541" + property color backgroundLighter: "#444654" + property color backgroundLightest: "#7d7d8e" + property color dialogBorder: "#d1d5db" + property color userColor: "#ec86bf" + property color assistantColor: "#10a37f" + property color linkColor: "white" + property color tabBorder: "#aaa" + property real fontSizeLarge: Qt.application.font.pixelSize + property real fontSizeLarger: Qt.application.font.pixelSize + 2 +} diff --git a/gpt4all-chat/qml/ThumbsDownDialog.qml b/gpt4all-chat/qml/ThumbsDownDialog.qml new file mode 100644 index 00000000..8cb1d115 --- /dev/null +++ b/gpt4all-chat/qml/ThumbsDownDialog.qml @@ -0,0 +1,112 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts +import download +import network +import llm + +Dialog { + id: thumbsDownDialog + modal: true + opacity: 0.9 + padding: 20 + + Theme { + id: theme + } + + property alias response: thumbsDownNewResponse.text + + Column { + anchors.fill: parent + spacing: 20 + Item { + width: childrenRect.width + height: childrenRect.height + Image { + id: img + anchors.top: parent.top + anchors.left: parent.left + width: 60 + height: 60 + source: "qrc:/gpt4all/icons/thumbs_down.svg" + } + Text { + anchors.left: img.right + anchors.leftMargin: 30 + anchors.verticalCenter: img.verticalCenter + text: qsTr("Please edit the text below to provide a better response. (optional)") + color: theme.textColor + } + } + + ScrollView { + clip: true + height: 300 + width: parent.width + ScrollBar.vertical.policy: ScrollBar.AlwaysOn + ScrollBar.horizontal.policy: ScrollBar.AlwaysOff + + TextArea { + id: thumbsDownNewResponse + color: theme.textColor + padding: 20 + wrapMode: Text.Wrap + font.pixelSize: theme.fontSizeLarge + placeholderText: qsTr("Please provide a better response...") + placeholderTextColor: theme.backgroundLightest + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + } + } + } + + background: Rectangle { + anchors.fill: parent + color: theme.backgroundDarkest + border.width: 1 + border.color: theme.dialogBorder + radius: 10 + } + + footer: DialogButtonBox { + padding: 20 + alignment: Qt.AlignRight + spacing: 10 + Button { + contentItem: Text { + color: theme.textColor + text: qsTr("Submit") + } + background: Rectangle { + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + padding: 15 + DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole + } + Button { + contentItem: Text { + color: theme.textColor + text: qsTr("Cancel") + } + background: Rectangle { + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + padding: 15 + DialogButtonBox.buttonRole: DialogButtonBox.RejectRole + } + background: Rectangle { + color: "transparent" + } + } +} \ No newline at end of file diff --git a/gpt4all-chat/sysinfo.h b/gpt4all-chat/sysinfo.h new file mode 100644 index 00000000..4a02826f --- /dev/null +++ b/gpt4all-chat/sysinfo.h @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +#if defined(Q_OS_MAC) +#include +#include +#endif + +#if defined(Q_OS_WIN) +#include +#endif + +QString getSystemTotalRAM() +{ + qint64 totalRAM = 0; + +#if defined(Q_OS_LINUX) + QFile file("/proc/meminfo"); + if (file.open(QIODevice::ReadOnly | QIODevice::Text)) { + QTextStream in(&file); + QString line = in.readLine(); + while (!line.isNull()) { + if (line.startsWith("MemTotal")) { + QStringList parts = line.split(QRegularExpression("\\s+")); + totalRAM = parts[1].toLongLong() * 1024; // Convert from KB to bytes + break; + } + line = in.readLine(); + } + file.close(); + } +#elif defined(Q_OS_MAC) + int mib[2] = {CTL_HW, HW_MEMSIZE}; + size_t length = sizeof(totalRAM); + sysctl(mib, 2, &totalRAM, &length, NULL, 0); +#elif defined(Q_OS_WIN) + MEMORYSTATUSEX memoryStatus; + memoryStatus.dwLength = sizeof(memoryStatus); + GlobalMemoryStatusEx(&memoryStatus); + totalRAM = memoryStatus.ullTotalPhys; +#endif + + double totalRAM_GB = static_cast(totalRAM) / (1024 * 1024 * 1024); + return QString::number(totalRAM_GB, 'f', 2) + " GB"; +} diff --git a/gpt4all-chat/test_hw.cpp b/gpt4all-chat/test_hw.cpp new file mode 100644 index 00000000..eef10129 --- /dev/null +++ b/gpt4all-chat/test_hw.cpp @@ -0,0 +1,29 @@ +#include +#include + +int main(int argc, char *argv[]) +{ + static bool avx = __builtin_cpu_supports("avx"); + static bool avx2 = __builtin_cpu_supports("avx2"); + static bool fma = __builtin_cpu_supports("fma"); + static bool sse3 = __builtin_cpu_supports("sse3"); + static std::string s; + s = "gpt4all hardware test results:\n"; + s += " AVX = " + std::to_string(avx) + "\n"; + s += " AVX2 = " + std::to_string(avx2) + "\n"; + s += " FMA = " + std::to_string(fma) + "\n"; + s += " SSE3 = " + std::to_string(sse3) + "\n"; + fprintf(stderr, "%s", s.c_str()); + fprintf(stderr, "your hardware supports the \""); + fflush(stderr); + if (avx2) + printf("full"); + else if (avx && fma) + printf("avx_only"); + else + printf("bare_minimum"); + fflush(stdout); + fprintf(stderr, "\" version of gpt4all.\n"); + fflush(stderr); + return 0; +}