Add prompt processing and localdocs to the busy indicator in UI.

This commit is contained in:
Adam Treat 2023-05-20 20:04:36 -04:00 committed by AT
parent 837ece220f
commit c33bf0e895
5 changed files with 71 additions and 14 deletions

View File

@ -10,10 +10,12 @@ Chat::Chat(QObject *parent)
, m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
, m_isServer(false)
, m_shouldDeleteLater(false)
, m_contextContainsLocalDocs(false)
{
connectLLM();
}
@ -24,10 +26,12 @@ Chat::Chat(bool isServer, QObject *parent)
, m_name(tr("Server Chat"))
, m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new Server(this))
, m_isServer(true)
, m_shouldDeleteLater(false)
, m_contextContainsLocalDocs(false)
{
connectLLM();
}
@ -49,7 +53,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::modelLoadingError, Qt::QueuedConnection);
@ -99,6 +103,11 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
m_contextContainsLocalDocs = false;
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_queuedPrompt.prompt = prompt;
m_queuedPrompt.prompt_template = prompt_template;
m_queuedPrompt.n_predict = n_predict;
@ -116,8 +125,9 @@ void Chat::handleLocalDocsRetrieved()
QList<QString> results = LocalDocs::globalInstance()->result();
if (!results.isEmpty()) {
augmentedTemplate.append("### Context:");
augmentedTemplate.append(results);
augmentedTemplate.append(results.join("\n\n"));
}
m_contextContainsLocalDocs = !results.isEmpty();
augmentedTemplate.append(m_queuedPrompt.prompt_template);
emit promptRequested(
m_queuedPrompt.prompt,
@ -148,8 +158,26 @@ QString Chat::response() const
return m_llmodel->response();
}
QString Chat::responseState() const
{
switch (m_responseState) {
case ResponseStopped: return QStringLiteral("response stopped");
case LocalDocsRetrieval: return QStringLiteral("retrieving localdocs");
case LocalDocsProcessing: return QStringLiteral("processing localdocs");
case PromptProcessing: return QStringLiteral("processing prompt");
case ResponseGeneration: return QStringLiteral("generating response");
};
Q_UNREACHABLE();
return QString();
}
void Chat::handleResponseChanged()
{
if (m_responseState != Chat::ResponseGeneration) {
m_responseState = Chat::ResponseGeneration;
emit responseStateChanged();
}
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, response());
emit responseChanged();
@ -161,16 +189,19 @@ void Chat::handleModelLoadedChanged()
deleteLater();
}
void Chat::responseStarted()
void Chat::promptProcessing()
{
m_responseInProgress = true;
emit responseInProgressChanged();
m_responseState = m_contextContainsLocalDocs ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
}
void Chat::responseStopped()
{
m_contextContainsLocalDocs = false;
m_responseInProgress = false;
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
emit responseStateChanged();
if (m_llmodel->generatedName().isEmpty())
emit generateNameRequested();
if (chatModel()->count() < 3)

View File

@ -22,10 +22,20 @@ class Chat : public QObject
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
public:
enum ResponseState {
ResponseStopped,
LocalDocsRetrieval,
LocalDocsProcessing,
PromptProcessing,
ResponseGeneration
};
Q_ENUM(ResponseState)
explicit Chat(QObject *parent = nullptr);
explicit Chat(bool isServer, QObject *parent = nullptr);
virtual ~Chat();
@ -50,6 +60,7 @@ public:
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
QString responseState() const;
QString modelName() const;
void setModelName(const QString &modelName);
bool isRecalc() const;
@ -77,6 +88,7 @@ Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
void responseInProgressChanged();
void responseStateChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
@ -97,7 +109,7 @@ private Q_SLOTS:
void handleLocalDocsRetrieved();
void handleResponseChanged();
void handleModelLoadedChanged();
void responseStarted();
void promptProcessing();
void responseStopped();
void generatedNameChanged();
void handleRecalculating();
@ -122,10 +134,12 @@ private:
QString m_savedModelName;
ChatModel *m_chatModel;
bool m_responseInProgress;
ResponseState m_responseState;
qint64 m_creationDate;
ChatLLM *m_llmodel;
bool m_isServer;
bool m_shouldDeleteLater;
bool m_contextContainsLocalDocs;
Prompt m_queuedPrompt;
};

View File

@ -411,7 +411,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit responseStarted();
emit promptProcessing();
qint32 logitsBefore = m_ctx.logits.size();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;

View File

@ -75,7 +75,7 @@ Q_SIGNALS:
void isModelLoadedChanged();
void modelLoadingError(const QString &error);
void responseChanged();
void responseStarted();
void promptProcessing();
void responseStopped();
void modelNameChanged();
void recalcChanged();

View File

@ -577,17 +577,29 @@ Window {
leftPadding: 100
rightPadding: 100
BusyIndicator {
Item {
anchors.left: parent.left
anchors.leftMargin: 90
anchors.top: parent.top
anchors.topMargin: 5
visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
Accessible.description: qsTr("Displayed when the model is thinking")
width: childrenRect.width
height: childrenRect.height
Row {
spacing: 5
BusyIndicator {
anchors.verticalCenter: parent.verticalCenter
running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
Accessible.description: qsTr("Displayed when the model is thinking")
}
Label {
anchors.verticalCenter: parent.verticalCenter
text: currentChat.responseState + "..."
color: theme.mutedTextColor
}
}
}
Rectangle {