Show token generation speed in gui. (#1020)

This commit is contained in:
AT 2023-06-19 11:34:53 -07:00 committed by GitHub
parent fd419caa55
commit 2b6cc99a31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 2 deletions

View File

@ -57,6 +57,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
@ -102,6 +103,8 @@ void Chat::resetResponseState()
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
return;
m_tokenSpeed = QString();
emit tokenSpeedChanged();
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
@ -187,6 +190,9 @@ void Chat::promptProcessing()
void Chat::responseStopped()
{
m_tokenSpeed = QString();
emit tokenSpeedChanged();
const QString chatResponse = response();
QList<QString> references;
QList<QString> referencesContext;
@ -336,6 +342,12 @@ void Chat::handleModelLoadingError(const QString &error)
emit modelLoadingErrorChanged();
}
void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
{
m_tokenSpeed = tokenSpeed;
emit tokenSpeedChanged();
}
bool Chat::serialize(QDataStream &stream, int version) const
{
stream << m_creationDate;

View File

@ -25,6 +25,7 @@ class Chat : public QObject
Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
@ -91,6 +92,8 @@ public:
QString modelLoadingError() const { return m_modelLoadingError; }
QString tokenSpeed() const { return m_tokenSpeed; }
public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt);
@ -118,6 +121,7 @@ Q_SIGNALS:
void modelLoadingErrorChanged();
void isServerChanged();
void collectionListChanged();
void tokenSpeedChanged();
private Q_SLOTS:
void handleResponseChanged();
@ -128,6 +132,7 @@ private Q_SLOTS:
void handleRecalculating();
void handleModelNameChanged();
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
private:
QString m_id;
@ -135,6 +140,7 @@ private:
QString m_userName;
QString m_savedModelName;
QString m_modelLoadingError;
QString m_tokenSpeed;
QList<QString> m_collections;
ChatModel *m_chatModel;
bool m_responseInProgress;

View File

@ -94,6 +94,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_responseLogits(0)
, m_isRecalc(false)
, m_chat(parent)
, m_timer(nullptr)
, m_isServer(isServer)
, m_isChatGPT(false)
{
@ -103,7 +104,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
@ -126,6 +127,13 @@ ChatLLM::~ChatLLM()
}
}
void ChatLLM::handleThreadStarted()
{
m_timer = new TokenTimer(this);
connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed);
emit threadStarted();
}
bool ChatLLM::loadDefaultModel()
{
const QList<QString> models = m_chat->modelList();
@ -367,6 +375,7 @@ bool ChatLLM::handlePrompt(int32_t token)
#endif
++m_promptTokens;
++m_promptResponseTokens;
m_timer->inc();
return !m_stopGenerating;
}
@ -387,6 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
m_timer->inc();
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged();
@ -441,11 +451,13 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_timer->start();
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
m_timer->stop();
m_responseLogits += m_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {

View File

@ -23,6 +23,46 @@ struct LLModelInfo {
// must be able to serialize the information even if it is in the unloaded state
};
class TokenTimer : public QObject {
Q_OBJECT
public:
explicit TokenTimer(QObject *parent)
: QObject(parent)
, m_elapsed(0) {}
static int rollingAverage(int oldAvg, int newNumber, int n)
{
// i.e. to calculate the new average after then nth number,
// you multiply the old average by n1, add the new number, and divide the total by n.
return qRound(((float(oldAvg) * (n - 1)) + newNumber) / float(n));
}
void start() { m_tokens = 0; m_elapsed = 0; m_time.invalidate(); }
void stop() { handleTimeout(); }
void inc() {
if (!m_time.isValid())
m_time.start();
++m_tokens;
if (m_time.elapsed() > 999)
handleTimeout();
}
Q_SIGNALS:
void report(const QString &speed);
private Q_SLOTS:
void handleTimeout()
{
m_elapsed += m_time.restart();
emit report(QString("%1 tokens/sec").arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2));
}
private:
QElapsedTimer m_time;
qint64 m_elapsed;
quint32 m_tokens;
};
class Chat;
class ChatLLM : public QObject
{
@ -73,6 +113,7 @@ public Q_SLOTS:
void generateName();
void handleChatIdChanged();
void handleShouldBeLoadedChanged();
void handleThreadStarted();
Q_SIGNALS:
void isModelLoadedChanged();
@ -89,7 +130,7 @@ Q_SIGNALS:
void threadStarted();
void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
protected:
bool handlePrompt(int32_t token);
@ -112,6 +153,7 @@ protected:
quint32 m_responseLogits;
QString m_modelName;
Chat *m_chat;
TokenTimer *m_timer;
QByteArray m_state;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;

View File

@ -845,6 +845,16 @@ Window {
Accessible.description: qsTr("Controls generation of the response")
}
Text {
id: speed
anchors.bottom: textInputView.top
anchors.bottomMargin: 20
anchors.right: parent.right
anchors.rightMargin: 30
color: theme.mutedTextColor
text: currentChat.tokenSpeed
}
RectangularGlow {
id: effect
anchors.fill: textInputView