Don't store db results in ChatLLM.

This commit is contained in:
Adam Treat 2023-06-19 18:23:54 -04:00 committed by AT
parent 0cfe225506
commit a3a6a20146
5 changed files with 17 additions and 12 deletions

View File

@ -58,6 +58,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
@ -177,11 +178,6 @@ void Chat::handleModelLoadedChanged()
deleteLater();
}
QList<ResultInfo> Chat::databaseResults() const
{
return m_llmodel->databaseResults();
}
void Chat::promptProcessing()
{
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
@ -348,6 +344,11 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
emit tokenSpeedChanged();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{
m_databaseResults = results;
}
bool Chat::serialize(QDataStream &stream, int version) const
{
stream << m_creationDate;

View File

@ -61,7 +61,7 @@ public:
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QList<ResultInfo> databaseResults() const;
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
@ -133,6 +133,7 @@ private Q_SLOTS:
void handleModelNameChanged();
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
private:
QString m_id;
@ -147,6 +148,7 @@ private:
ResponseState m_responseState;
qint64 m_creationDate;
ChatLLM *m_llmodel;
QList<ResultInfo> m_databaseResults;
bool m_isServer;
bool m_shouldDeleteLater;
};

View File

@ -413,15 +413,16 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
if (!isModelLoaded())
return false;
m_databaseResults.clear();
QList<ResultInfo> databaseResults;
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_databaseResults); // blocks
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
if (!m_databaseResults.isEmpty())
if (!databaseResults.isEmpty())
augmentedTemplate.append("### Context:");
for (const ResultInfo &info : m_databaseResults)
for (const ResultInfo &info : databaseResults)
augmentedTemplate.append(info.text);
augmentedTemplate.append(prompt_template);

View File

@ -81,7 +81,6 @@ public:
void regenerateResponse();
void resetResponse();
void resetContext();
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
void stopGenerating() { m_stopGenerating = true; }
@ -131,6 +130,7 @@ Q_SIGNALS:
void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
void databaseResultsChanged(const QList<ResultInfo>&);
protected:
bool handlePrompt(int32_t token);
@ -157,7 +157,6 @@ protected:
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
QList<ResultInfo> m_databaseResults;
bool m_isRecalc;
bool m_isServer;
bool m_isChatGPT;

View File

@ -22,10 +22,12 @@ Q_SIGNALS:
private Q_SLOTS:
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
private:
Chat *m_chat;
QHttpServer *m_server;
QList<ResultInfo> m_databaseResults;
};
#endif // SERVER_H