From f62e439a2dc26d199e5fe5fe2891c8a16c028c43 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 1 Jun 2023 14:13:12 -0400 Subject: [PATCH] Make localdocs work with server mode. --- gpt4all-chat/chat.cpp | 74 +++++++++++++------------------------- gpt4all-chat/chat.h | 17 ++------- gpt4all-chat/chatllm.cpp | 22 ++++++++++-- gpt4all-chat/chatllm.h | 5 +++ gpt4all-chat/database.cpp | 7 ++-- gpt4all-chat/database.h | 3 +- gpt4all-chat/localdocs.cpp | 9 ----- gpt4all-chat/localdocs.h | 5 ++- gpt4all-chat/server.cpp | 38 ++++++++++++++++---- 9 files changed, 90 insertions(+), 90 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 79f9d601..b61b1548 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -45,7 +45,6 @@ void Chat::connectLLM() // Should be in same thread connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); - connect(LocalDocs::globalInstance(), &LocalDocs::receivedResult, this, &Chat::handleLocalDocsRetrieved, Qt::DirectConnection); // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); @@ -101,52 +100,17 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { - Q_ASSERT(m_results.isEmpty()); - m_results.clear(); // just in case, but the assert above is important - m_responseInProgress = true; - m_responseState = Chat::LocalDocsRetrieval; - emit responseInProgressChanged(); - emit responseStateChanged(); - m_queuedPrompt.prompt = prompt; - m_queuedPrompt.prompt_template = prompt_template; - m_queuedPrompt.n_predict = n_predict; - m_queuedPrompt.top_k = top_k; - m_queuedPrompt.temp = temp; - m_queuedPrompt.n_batch = n_batch; - m_queuedPrompt.repeat_penalty = repeat_penalty; - m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens; - LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt); -} - -void Chat::handleLocalDocsRetrieved(const QString &uid, const QList &results) -{ - // If the uid doesn't match, then these are not our results - if (uid != m_id) - return; - - // Store our results locally - m_results = results; - - // Augment the prompt template with the results if any - QList augmentedTemplate; - if (!m_results.isEmpty()) - augmentedTemplate.append("### Context:"); - for (const ResultInfo &info : m_results) - augmentedTemplate.append(info.text); - - augmentedTemplate.append(m_queuedPrompt.prompt_template); emit promptRequested( - m_queuedPrompt.prompt, - augmentedTemplate.join("\n"), - m_queuedPrompt.n_predict, - m_queuedPrompt.top_k, - m_queuedPrompt.top_p, - m_queuedPrompt.temp, - m_queuedPrompt.n_batch, - m_queuedPrompt.repeat_penalty, - m_queuedPrompt.repeat_penalty_tokens, + prompt, + prompt_template, + n_predict, + top_k, + top_p, + temp, + n_batch, + repeat_penalty, + repeat_penalty_tokens, LLM::globalInstance()->threadCount()); - m_queuedPrompt = Prompt(); } void Chat::regenerateResponse() @@ -195,9 +159,14 @@ void Chat::handleModelLoadedChanged() deleteLater(); } +QList Chat::results() const +{ + return m_llmodel->results(); +} + void Chat::promptProcessing() { - m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; + m_responseState = !results().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; emit responseStateChanged(); } @@ -207,7 +176,7 @@ void Chat::responseStopped() QList references; QList referencesContext; int validReferenceNumber = 1; - for (const ResultInfo &info : m_results) { + for (const ResultInfo &info : results()) { if (info.file.isEmpty()) continue; if (validReferenceNumber == 1) @@ -241,7 +210,6 @@ void Chat::responseStopped() m_chatModel->updateReferences(index, references.join("\n"), referencesContext); emit responseChanged(); - m_results.clear(); m_responseInProgress = false; m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); @@ -266,6 +234,10 @@ void Chat::setModelName(const QString &modelName) void Chat::newPromptResponsePair(const QString &prompt) { + m_responseInProgress = true; + m_responseState = Chat::LocalDocsRetrieval; + emit responseInProgressChanged(); + emit responseStateChanged(); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt); @@ -274,7 +246,11 @@ void Chat::newPromptResponsePair(const QString &prompt) void Chat::serverNewPromptResponsePair(const QString &prompt) { - m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); + m_responseInProgress = true; + m_responseState = Chat::LocalDocsRetrieval; + emit responseInProgressChanged(); + emit responseStateChanged(); + m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt); } diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 0e9e6d47..984a4bda 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -60,6 +60,8 @@ public: Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void newPromptResponsePair(const QString &prompt); + QList results() const; + QString response() const; bool responseInProgress() const { return m_responseInProgress; } QString responseState() const; @@ -115,7 +117,6 @@ Q_SIGNALS: void collectionListChanged(); private Q_SLOTS: - void handleLocalDocsRetrieved(const QString &uid, const QList &results); void handleResponseChanged(); void handleModelLoadedChanged(); void promptProcessing(); @@ -125,24 +126,11 @@ private Q_SLOTS: void handleModelNameChanged(); private: - struct Prompt { - QString prompt; - QString prompt_template; - int32_t n_predict; - int32_t top_k; - float top_p; - float temp; - int32_t n_batch; - float repeat_penalty; - int32_t repeat_penalty_tokens; - }; - QString m_id; QString m_name; QString m_userName; QString m_savedModelName; QList m_collections; - QList m_results; ChatModel *m_chatModel; bool m_responseInProgress; ResponseState m_responseState; @@ -150,7 +138,6 @@ private: ChatLLM *m_llmodel; bool m_isServer; bool m_shouldDeleteLater; - Prompt m_queuedPrompt; }; #endif // CHAT_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 23312ced..6184815a 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -91,9 +91,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) moveToThread(&m_llmThread); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); - connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); + connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, + Qt::QueuedConnection); // explicitly queued connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); + + // The following are blocking operations and will block the llm thread + connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, + Qt::BlockingQueuedConnection); + m_llmThread.setObjectName(m_chat->id()); m_llmThread.start(); } @@ -386,7 +392,19 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 if (!isModelLoaded()) return false; - QString instructPrompt = prompt_template.arg(prompt); + m_results.clear(); + const int retrievalSize = LocalDocs::globalInstance()->retrievalSize(); + emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_results); // blocks + + // Augment the prompt template with the results if any + QList augmentedTemplate; + if (!m_results.isEmpty()) + augmentedTemplate.append("### Context:"); + for (const ResultInfo &info : m_results) + augmentedTemplate.append(info.text); + augmentedTemplate.append(prompt_template); + + QString instructPrompt = augmentedTemplate.join("\n").arg(prompt); m_stopGenerating = false; auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d2bd6a82..42de7d08 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -5,6 +5,7 @@ #include #include +#include "localdocs.h" #include "../gpt4all-backend/llmodel.h" enum LLModelType { @@ -39,6 +40,7 @@ public: void regenerateResponse(); void resetResponse(); void resetContext(); + QList results() const { return m_results; } void stopGenerating() { m_stopGenerating = true; } @@ -85,6 +87,8 @@ Q_SIGNALS: void stateChanged(); void threadStarted(); void shouldBeLoadedChanged(); + void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + protected: bool handlePrompt(int32_t token); @@ -111,6 +115,7 @@ protected: QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; + QList m_results; bool m_isRecalc; bool m_isServer; bool m_isChatGPT; diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index b5dd9b25..7ec2050a 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -892,7 +892,8 @@ bool Database::removeFolderFromWatch(const QString &path) return m_watcher->removePath(path); } -void Database::retrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize) +void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, + QList *results) { #if defined(DEBUG) qDebug() << "retrieveFromDB" << collections << text << retrievalSize; @@ -904,7 +905,6 @@ void Database::retrieveFromDB(const QString &uid, const QList &collecti return; } - QList results; while (q.next()) { const int rowid = q.value(0).toInt(); const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); @@ -924,13 +924,12 @@ void Database::retrieveFromDB(const QString &uid, const QList &collecti info.page = page; info.from = from; info.to = to; - results.append(info); + results->append(info); #if defined(DEBUG) qDebug() << "retrieve rowid:" << rowid << "chunk_text:" << chunk_text; #endif } - emit retrieveResult(uid, results); } void Database::cleanDB() diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index 2f25ff0d..59325ac4 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -43,13 +43,12 @@ public Q_SLOTS: void scanDocuments(int folder_id, const QString &folder_path); void addFolder(const QString &collection, const QString &path); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize); + void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void cleanDB(); void changeChunkSize(int chunkSize); Q_SIGNALS: void docsToScanChanged(); - void retrieveResult(const QString &uid, const QList &result); void collectionListUpdated(const QList &collectionList); private Q_SLOTS: diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index 6e62c0a2..7e4a910d 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -24,12 +24,8 @@ LocalDocs::LocalDocs() &Database::addFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestRemoveFolder, m_database, &Database::removeFolder, Qt::QueuedConnection); - connect(this, &LocalDocs::requestRetrieveFromDB, m_database, - &Database::retrieveFromDB, Qt::QueuedConnection); connect(this, &LocalDocs::requestChunkSizeChange, m_database, &Database::changeChunkSize, Qt::QueuedConnection); - connect(m_database, &Database::retrieveResult, this, - &LocalDocs::receivedResult, Qt::QueuedConnection); connect(m_database, &Database::collectionListUpdated, m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection); } @@ -49,11 +45,6 @@ void LocalDocs::removeFolder(const QString &collection, const QString &path) emit requestRemoveFolder(collection, path); } -void LocalDocs::requestRetrieve(const QString &uid, const QList &collections, const QString &text) -{ - emit requestRetrieveFromDB(uid, collections, text, m_retrievalSize); -} - int LocalDocs::chunkSize() const { return m_chunkSize; diff --git a/gpt4all-chat/localdocs.h b/gpt4all-chat/localdocs.h index 7011602c..deaf6a40 100644 --- a/gpt4all-chat/localdocs.h +++ b/gpt4all-chat/localdocs.h @@ -20,7 +20,8 @@ public: Q_INVOKABLE void addFolder(const QString &collection, const QString &path); Q_INVOKABLE void removeFolder(const QString &collection, const QString &path); - void requestRetrieve(const QString &uid, const QList &collections, const QString &text); + + Database *database() const { return m_database; } int chunkSize() const; void setChunkSize(int chunkSize); @@ -31,9 +32,7 @@ public: Q_SIGNALS: void requestAddFolder(const QString &collection, const QString &path); void requestRemoveFolder(const QString &collection, const QString &path); - void requestRetrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize); void requestChunkSizeChange(int chunkSize); - void receivedResult(const QString &uid, const QList &result); void localDocsModelChanged(); void chunkSizeChanged(); void retrievalSizeChanged(); diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index 8ba59e67..c7090670 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -51,6 +51,20 @@ static inline QJsonObject modelToJson(const ModelInfo &info) return model; } +static inline QJsonObject resultToJson(const ResultInfo &info) +{ + QJsonObject result; + result.insert("file", info.file); + result.insert("title", info.title); + result.insert("author", info.author); + result.insert("date", info.date); + result.insert("text", info.text); + result.insert("page", info.page); + result.insert("from", info.from); + result.insert("to", info.to); + return result; +} + Server::Server(Chat *chat) : ChatLLM(chat, true /*isServer*/) , m_chat(chat) @@ -298,7 +312,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re int promptTokens = 0; int responseTokens = 0; - QList responses; + QList>> responses; for (int i = 0; i < n; ++i) { if (!prompt(actualPrompt, promptTemplate, @@ -317,7 +331,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re QString echoedPrompt = actualPrompt; if (!echoedPrompt.endsWith("\n")) echoedPrompt += "\n"; - responses.append((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response()); + responses.append(qMakePair((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response(), m_results)); if (!promptTokens) promptTokens += m_promptTokens; responseTokens += m_promptResponseTokens - m_promptTokens; @@ -335,24 +349,36 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re if (isChat) { int index = 0; - for (QString r : responses) { + for (const auto &r : responses) { + QString result = r.first; + QList infos = r.second; QJsonObject choice; choice.insert("index", index++); choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); QJsonObject message; message.insert("role", "assistant"); - message.insert("content", r); + message.insert("content", result); choice.insert("message", message); + QJsonArray references; + for (const auto &ref : infos) + references.append(resultToJson(ref)); + choice.insert("references", references); choices.append(choice); } } else { int index = 0; - for (QString r : responses) { + for (const auto &r : responses) { + QString result = r.first; + QList infos = r.second; QJsonObject choice; - choice.insert("text", r); + choice.insert("text", result); choice.insert("index", index++); choice.insert("logprobs", QJsonValue::Null); // We don't support choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); + QJsonArray references; + for (const auto &ref : infos) + references.append(resultToJson(ref)); + choice.insert("references", references); choices.append(choice); } }