mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-18 03:25:46 +00:00
chat: report reason for fallback to CPU
This commit is contained in:
parent
906699e8e9
commit
2eb83b9f2a
@ -57,6 +57,7 @@ void Chat::connectLLM()
|
|||||||
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection);
|
||||||
|
connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
||||||
|
|
||||||
@ -352,6 +353,12 @@ void Chat::handleDeviceChanged(const QString &device)
|
|||||||
emit deviceChanged();
|
emit deviceChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Chat::handleFallbackReasonChanged(const QString &fallbackReason)
|
||||||
|
{
|
||||||
|
m_fallbackReason = fallbackReason;
|
||||||
|
emit fallbackReasonChanged();
|
||||||
|
}
|
||||||
|
|
||||||
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
||||||
{
|
{
|
||||||
m_databaseResults = results;
|
m_databaseResults = results;
|
||||||
|
@ -26,6 +26,7 @@ class Chat : public QObject
|
|||||||
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
|
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
|
||||||
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
|
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
|
||||||
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
|
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
|
||||||
|
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
|
||||||
QML_ELEMENT
|
QML_ELEMENT
|
||||||
QML_UNCREATABLE("Only creatable from c++!")
|
QML_UNCREATABLE("Only creatable from c++!")
|
||||||
|
|
||||||
@ -90,6 +91,7 @@ public:
|
|||||||
|
|
||||||
QString tokenSpeed() const { return m_tokenSpeed; }
|
QString tokenSpeed() const { return m_tokenSpeed; }
|
||||||
QString device() const { return m_device; }
|
QString device() const { return m_device; }
|
||||||
|
QString fallbackReason() const { return m_fallbackReason; }
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
void serverNewPromptResponsePair(const QString &prompt);
|
void serverNewPromptResponsePair(const QString &prompt);
|
||||||
@ -118,6 +120,7 @@ Q_SIGNALS:
|
|||||||
void collectionListChanged(const QList<QString> &collectionList);
|
void collectionListChanged(const QList<QString> &collectionList);
|
||||||
void tokenSpeedChanged();
|
void tokenSpeedChanged();
|
||||||
void deviceChanged();
|
void deviceChanged();
|
||||||
|
void fallbackReasonChanged();
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void handleResponseChanged(const QString &response);
|
void handleResponseChanged(const QString &response);
|
||||||
@ -129,6 +132,7 @@ private Q_SLOTS:
|
|||||||
void handleModelLoadingError(const QString &error);
|
void handleModelLoadingError(const QString &error);
|
||||||
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
||||||
void handleDeviceChanged(const QString &device);
|
void handleDeviceChanged(const QString &device);
|
||||||
|
void handleFallbackReasonChanged(const QString &device);
|
||||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
||||||
void handleModelInfoChanged(const ModelInfo &modelInfo);
|
void handleModelInfoChanged(const ModelInfo &modelInfo);
|
||||||
void handleModelInstalled();
|
void handleModelInstalled();
|
||||||
@ -142,6 +146,7 @@ private:
|
|||||||
QString m_modelLoadingError;
|
QString m_modelLoadingError;
|
||||||
QString m_tokenSpeed;
|
QString m_tokenSpeed;
|
||||||
QString m_device;
|
QString m_device;
|
||||||
|
QString m_fallbackReason;
|
||||||
QString m_response;
|
QString m_response;
|
||||||
QList<QString> m_collections;
|
QList<QString> m_collections;
|
||||||
ChatModel *m_chatModel;
|
ChatModel *m_chatModel;
|
||||||
|
@ -267,27 +267,46 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
if (requestedDevice != "CPU") {
|
if (requestedDevice != "CPU") {
|
||||||
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
|
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
|
||||||
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
||||||
|
LLModel::GPUDevice *device = nullptr;
|
||||||
|
|
||||||
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
|
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
|
||||||
m_llModelInfo.model->initializeGPUDevice(availableDevices.front());
|
device = &availableDevices.front();
|
||||||
actualDevice = QString::fromStdString(availableDevices.front().name);
|
|
||||||
} else {
|
} else {
|
||||||
for (LLModel::GPUDevice &d : availableDevices) {
|
for (LLModel::GPUDevice &d : availableDevices) {
|
||||||
if (QString::fromStdString(d.name) == requestedDevice) {
|
if (QString::fromStdString(d.name) == requestedDevice) {
|
||||||
m_llModelInfo.model->initializeGPUDevice(d);
|
device = &d;
|
||||||
actualDevice = QString::fromStdString(d.name);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!device) {
|
||||||
|
emit reportFallbackReason("<br>Using CPU: device not found");
|
||||||
|
} else if (!m_llModelInfo.model->initializeGPUDevice(*device)) {
|
||||||
|
emit reportFallbackReason("<br>Using CPU: failed to init device");
|
||||||
|
} else {
|
||||||
|
actualDevice = QString::fromStdString(device->name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Report which device we're actually using
|
// Report which device we're actually using
|
||||||
emit reportDevice(actualDevice);
|
emit reportDevice(actualDevice);
|
||||||
|
|
||||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
||||||
if (!success && actualDevice != "CPU") {
|
if (actualDevice == "CPU") {
|
||||||
|
// we asked llama.cpp to use the CPU
|
||||||
|
} else if (!success) {
|
||||||
|
// llama_init_from_file returned nullptr
|
||||||
|
// this may happen because ggml_metal_add_buffer failed
|
||||||
emit reportDevice("CPU");
|
emit reportDevice("CPU");
|
||||||
|
emit reportFallbackReason("<br>Using CPU: llama_init_from_file failed");
|
||||||
success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
||||||
|
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
||||||
|
// ggml_vk_init was not called in llama.cpp
|
||||||
|
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
||||||
|
// for instance if the quantization method is not supported on Vulkan yet
|
||||||
|
emit reportDevice("CPU");
|
||||||
|
emit reportFallbackReason("<br>Using CPU: unsupported quantization type");
|
||||||
}
|
}
|
||||||
|
|
||||||
MySettings::globalInstance()->setAttemptModelLoad(QString());
|
MySettings::globalInstance()->setAttemptModelLoad(QString());
|
||||||
@ -299,11 +318,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
m_llModelInfo = LLModelInfo();
|
m_llModelInfo = LLModelInfo();
|
||||||
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
|
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
|
||||||
} else {
|
} else {
|
||||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
|
||||||
// for instance if the quantization method is not supported on Vulkan yet
|
|
||||||
if (actualDevice != "CPU" && !m_llModelInfo.model->usingGPUDevice())
|
|
||||||
emit reportDevice("CPU");
|
|
||||||
|
|
||||||
switch (m_llModelInfo.model->implementation().modelType()[0]) {
|
switch (m_llModelInfo.model->implementation().modelType()[0]) {
|
||||||
case 'L': m_llModelType = LLModelType::LLAMA_; break;
|
case 'L': m_llModelType = LLModelType::LLAMA_; break;
|
||||||
case 'G': m_llModelType = LLModelType::GPTJ_; break;
|
case 'G': m_llModelType = LLModelType::GPTJ_; break;
|
||||||
|
@ -127,6 +127,7 @@ Q_SIGNALS:
|
|||||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||||
void reportSpeed(const QString &speed);
|
void reportSpeed(const QString &speed);
|
||||||
void reportDevice(const QString &device);
|
void reportDevice(const QString &device);
|
||||||
|
void reportFallbackReason(const QString &fallbackReason);
|
||||||
void databaseResultsChanged(const QList<ResultInfo>&);
|
void databaseResultsChanged(const QList<ResultInfo>&);
|
||||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||||
|
|
||||||
|
@ -1013,7 +1013,7 @@ Window {
|
|||||||
anchors.rightMargin: 30
|
anchors.rightMargin: 30
|
||||||
color: theme.mutedTextColor
|
color: theme.mutedTextColor
|
||||||
visible: currentChat.tokenSpeed !== ""
|
visible: currentChat.tokenSpeed !== ""
|
||||||
text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + currentChat.device
|
text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + currentChat.device + currentChat.fallbackReason
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user