From a3a6a201464e8e0b72766097f82b9229af46d523 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 19 Jun 2023 18:23:54 -0400 Subject: [PATCH] Don't store db results in ChatLLM. --- gpt4all-chat/chat.cpp | 11 ++++++----- gpt4all-chat/chat.h | 4 +++- gpt4all-chat/chatllm.cpp | 9 +++++---- gpt4all-chat/chatllm.h | 3 +-- gpt4all-chat/server.h | 2 ++ 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 9c54afb3..9cbeccca 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -58,6 +58,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); @@ -177,11 +178,6 @@ void Chat::handleModelLoadedChanged() deleteLater(); } -QList Chat::databaseResults() const -{ - return m_llmodel->databaseResults(); -} - void Chat::promptProcessing() { m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; @@ -348,6 +344,11 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) emit tokenSpeedChanged(); } +void Chat::handleDatabaseResultsChanged(const QList &results) +{ + m_databaseResults = results; +} + bool Chat::serialize(QDataStream &stream, int version) const { stream << m_creationDate; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 71c2f761..dc514274 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -61,7 +61,7 @@ public: Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void newPromptResponsePair(const QString &prompt); - QList databaseResults() const; + QList databaseResults() const { return m_databaseResults; } QString response() const; bool responseInProgress() const { return m_responseInProgress; } @@ -133,6 +133,7 @@ private Q_SLOTS: void handleModelNameChanged(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); + void handleDatabaseResultsChanged(const QList &results); private: QString m_id; @@ -147,6 +148,7 @@ private: ResponseState m_responseState; qint64 m_creationDate; ChatLLM *m_llmodel; + QList m_databaseResults; bool m_isServer; bool m_shouldDeleteLater; }; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index f28c954d..613ced0a 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -413,15 +413,16 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 if (!isModelLoaded()) return false; - m_databaseResults.clear(); + QList databaseResults; const int retrievalSize = LocalDocs::globalInstance()->retrievalSize(); - emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_databaseResults); // blocks + emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks + emit databaseResultsChanged(databaseResults); // Augment the prompt template with the results if any QList augmentedTemplate; - if (!m_databaseResults.isEmpty()) + if (!databaseResults.isEmpty()) augmentedTemplate.append("### Context:"); - for (const ResultInfo &info : m_databaseResults) + for (const ResultInfo &info : databaseResults) augmentedTemplate.append(info.text); augmentedTemplate.append(prompt_template); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 3cb0b7e2..67226d4e 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -81,7 +81,6 @@ public: void regenerateResponse(); void resetResponse(); void resetContext(); - QList databaseResults() const { return m_databaseResults; } void stopGenerating() { m_stopGenerating = true; } @@ -131,6 +130,7 @@ Q_SIGNALS: void shouldBeLoadedChanged(); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); + void databaseResultsChanged(const QList&); protected: bool handlePrompt(int32_t token); @@ -157,7 +157,6 @@ protected: QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; - QList m_databaseResults; bool m_isRecalc; bool m_isServer; bool m_isChatGPT; diff --git a/gpt4all-chat/server.h b/gpt4all-chat/server.h index 90a89cfb..ac6f1f75 100644 --- a/gpt4all-chat/server.h +++ b/gpt4all-chat/server.h @@ -22,10 +22,12 @@ Q_SIGNALS: private Q_SLOTS: QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); + void handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; } private: Chat *m_chat; QHttpServer *m_server; + QList m_databaseResults; }; #endif // SERVER_H