mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-18 03:25:46 +00:00
Show token generation speed in gui. (#1020)
This commit is contained in:
parent
fd419caa55
commit
2b6cc99a31
@ -57,6 +57,7 @@ void Chat::connectLLM()
|
||||
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
|
||||
@ -102,6 +103,8 @@ void Chat::resetResponseState()
|
||||
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
|
||||
return;
|
||||
|
||||
m_tokenSpeed = QString();
|
||||
emit tokenSpeedChanged();
|
||||
m_responseInProgress = true;
|
||||
m_responseState = Chat::LocalDocsRetrieval;
|
||||
emit responseInProgressChanged();
|
||||
@ -187,6 +190,9 @@ void Chat::promptProcessing()
|
||||
|
||||
void Chat::responseStopped()
|
||||
{
|
||||
m_tokenSpeed = QString();
|
||||
emit tokenSpeedChanged();
|
||||
|
||||
const QString chatResponse = response();
|
||||
QList<QString> references;
|
||||
QList<QString> referencesContext;
|
||||
@ -336,6 +342,12 @@ void Chat::handleModelLoadingError(const QString &error)
|
||||
emit modelLoadingErrorChanged();
|
||||
}
|
||||
|
||||
void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
|
||||
{
|
||||
m_tokenSpeed = tokenSpeed;
|
||||
emit tokenSpeedChanged();
|
||||
}
|
||||
|
||||
bool Chat::serialize(QDataStream &stream, int version) const
|
||||
{
|
||||
stream << m_creationDate;
|
||||
|
@ -25,6 +25,7 @@ class Chat : public QObject
|
||||
Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged)
|
||||
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
||||
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
|
||||
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
|
||||
QML_ELEMENT
|
||||
QML_UNCREATABLE("Only creatable from c++!")
|
||||
|
||||
@ -91,6 +92,8 @@ public:
|
||||
|
||||
QString modelLoadingError() const { return m_modelLoadingError; }
|
||||
|
||||
QString tokenSpeed() const { return m_tokenSpeed; }
|
||||
|
||||
public Q_SLOTS:
|
||||
void serverNewPromptResponsePair(const QString &prompt);
|
||||
|
||||
@ -118,6 +121,7 @@ Q_SIGNALS:
|
||||
void modelLoadingErrorChanged();
|
||||
void isServerChanged();
|
||||
void collectionListChanged();
|
||||
void tokenSpeedChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleResponseChanged();
|
||||
@ -128,6 +132,7 @@ private Q_SLOTS:
|
||||
void handleRecalculating();
|
||||
void handleModelNameChanged();
|
||||
void handleModelLoadingError(const QString &error);
|
||||
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
||||
|
||||
private:
|
||||
QString m_id;
|
||||
@ -135,6 +140,7 @@ private:
|
||||
QString m_userName;
|
||||
QString m_savedModelName;
|
||||
QString m_modelLoadingError;
|
||||
QString m_tokenSpeed;
|
||||
QList<QString> m_collections;
|
||||
ChatModel *m_chatModel;
|
||||
bool m_responseInProgress;
|
||||
|
@ -94,6 +94,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_responseLogits(0)
|
||||
, m_isRecalc(false)
|
||||
, m_chat(parent)
|
||||
, m_timer(nullptr)
|
||||
, m_isServer(isServer)
|
||||
, m_isChatGPT(false)
|
||||
{
|
||||
@ -103,7 +104,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||
|
||||
// The following are blocking operations and will block the llm thread
|
||||
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
||||
@ -126,6 +127,13 @@ ChatLLM::~ChatLLM()
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::handleThreadStarted()
|
||||
{
|
||||
m_timer = new TokenTimer(this);
|
||||
connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed);
|
||||
emit threadStarted();
|
||||
}
|
||||
|
||||
bool ChatLLM::loadDefaultModel()
|
||||
{
|
||||
const QList<QString> models = m_chat->modelList();
|
||||
@ -367,6 +375,7 @@ bool ChatLLM::handlePrompt(int32_t token)
|
||||
#endif
|
||||
++m_promptTokens;
|
||||
++m_promptResponseTokens;
|
||||
m_timer->inc();
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
@ -387,6 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
m_timer->inc();
|
||||
Q_ASSERT(!response.empty());
|
||||
m_response.append(response);
|
||||
emit responseChanged();
|
||||
@ -441,11 +451,13 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->start();
|
||||
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->stop();
|
||||
m_responseLogits += m_ctx.logits.size() - logitsBefore;
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
|
@ -23,6 +23,46 @@ struct LLModelInfo {
|
||||
// must be able to serialize the information even if it is in the unloaded state
|
||||
};
|
||||
|
||||
class TokenTimer : public QObject {
|
||||
Q_OBJECT
|
||||
public:
|
||||
explicit TokenTimer(QObject *parent)
|
||||
: QObject(parent)
|
||||
, m_elapsed(0) {}
|
||||
|
||||
static int rollingAverage(int oldAvg, int newNumber, int n)
|
||||
{
|
||||
// i.e. to calculate the new average after then nth number,
|
||||
// you multiply the old average by n−1, add the new number, and divide the total by n.
|
||||
return qRound(((float(oldAvg) * (n - 1)) + newNumber) / float(n));
|
||||
}
|
||||
|
||||
void start() { m_tokens = 0; m_elapsed = 0; m_time.invalidate(); }
|
||||
void stop() { handleTimeout(); }
|
||||
void inc() {
|
||||
if (!m_time.isValid())
|
||||
m_time.start();
|
||||
++m_tokens;
|
||||
if (m_time.elapsed() > 999)
|
||||
handleTimeout();
|
||||
}
|
||||
|
||||
Q_SIGNALS:
|
||||
void report(const QString &speed);
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleTimeout()
|
||||
{
|
||||
m_elapsed += m_time.restart();
|
||||
emit report(QString("%1 tokens/sec").arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2));
|
||||
}
|
||||
|
||||
private:
|
||||
QElapsedTimer m_time;
|
||||
qint64 m_elapsed;
|
||||
quint32 m_tokens;
|
||||
};
|
||||
|
||||
class Chat;
|
||||
class ChatLLM : public QObject
|
||||
{
|
||||
@ -73,6 +113,7 @@ public Q_SLOTS:
|
||||
void generateName();
|
||||
void handleChatIdChanged();
|
||||
void handleShouldBeLoadedChanged();
|
||||
void handleThreadStarted();
|
||||
|
||||
Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
@ -89,7 +130,7 @@ Q_SIGNALS:
|
||||
void threadStarted();
|
||||
void shouldBeLoadedChanged();
|
||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
|
||||
void reportSpeed(const QString &speed);
|
||||
|
||||
protected:
|
||||
bool handlePrompt(int32_t token);
|
||||
@ -112,6 +153,7 @@ protected:
|
||||
quint32 m_responseLogits;
|
||||
QString m_modelName;
|
||||
Chat *m_chat;
|
||||
TokenTimer *m_timer;
|
||||
QByteArray m_state;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
|
@ -845,6 +845,16 @@ Window {
|
||||
Accessible.description: qsTr("Controls generation of the response")
|
||||
}
|
||||
|
||||
Text {
|
||||
id: speed
|
||||
anchors.bottom: textInputView.top
|
||||
anchors.bottomMargin: 20
|
||||
anchors.right: parent.right
|
||||
anchors.rightMargin: 30
|
||||
color: theme.mutedTextColor
|
||||
text: currentChat.tokenSpeed
|
||||
}
|
||||
|
||||
RectangularGlow {
|
||||
id: effect
|
||||
anchors.fill: textInputView
|
||||
|
Loading…
Reference in New Issue
Block a user