Report the actual device we're using.

pull/1420/head
Adam Treat 1 year ago
parent cf4eb530ce
commit 1fa67a585c

@ -56,6 +56,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
@ -345,6 +346,12 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
emit tokenSpeedChanged(); emit tokenSpeedChanged();
} }
void Chat::handleDeviceChanged(const QString &device)
{
m_device = device;
emit deviceChanged();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results) void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{ {
m_databaseResults = results; m_databaseResults = results;

@ -25,6 +25,7 @@ class Chat : public QObject
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -88,6 +89,7 @@ public:
QString modelLoadingError() const { return m_modelLoadingError; } QString modelLoadingError() const { return m_modelLoadingError; }
QString tokenSpeed() const { return m_tokenSpeed; } QString tokenSpeed() const { return m_tokenSpeed; }
QString device() const { return m_device; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt); void serverNewPromptResponsePair(const QString &prompt);
@ -115,6 +117,7 @@ Q_SIGNALS:
void isServerChanged(); void isServerChanged();
void collectionListChanged(const QList<QString> &collectionList); void collectionListChanged(const QList<QString> &collectionList);
void tokenSpeedChanged(); void tokenSpeedChanged();
void deviceChanged();
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
@ -125,6 +128,7 @@ private Q_SLOTS:
void handleRecalculating(); void handleRecalculating();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDeviceChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleModelInstalled(); void handleModelInstalled();
@ -137,6 +141,7 @@ private:
ModelInfo m_modelInfo; ModelInfo m_modelInfo;
QString m_modelLoadingError; QString m_modelLoadingError;
QString m_tokenSpeed; QString m_tokenSpeed;
QString m_device;
QString m_response; QString m_response;
QList<QString> m_collections; QList<QString> m_collections;
ChatModel *m_chatModel; ChatModel *m_chatModel;

@ -271,22 +271,28 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
MySettings::globalInstance()->setDeviceList(deviceList); MySettings::globalInstance()->setDeviceList(deviceList);
// Pick the best match for the device // Pick the best match for the device
QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU";
const QString requestedDevice = MySettings::globalInstance()->device(); const QString requestedDevice = MySettings::globalInstance()->device();
if (requestedDevice != "CPU") { if (requestedDevice != "CPU") {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) { if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
m_llModelInfo.model->initializeGPUDevice(availableDevices.front()); m_llModelInfo.model->initializeGPUDevice(availableDevices.front());
actualDevice = QString::fromStdString(availableDevices.front().name);
} else { } else {
for (LLModel::GPUDevice &d : availableDevices) { for (LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.name) == requestedDevice) { if (QString::fromStdString(d.name) == requestedDevice) {
m_llModelInfo.model->initializeGPUDevice(d); m_llModelInfo.model->initializeGPUDevice(d);
actualDevice = QString::fromStdString(d.name);
break; break;
} }
} }
} }
} }
// Report which device we're actually using
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
MySettings::globalInstance()->setAttemptModelLoad(QString()); MySettings::globalInstance()->setAttemptModelLoad(QString());
if (!success) { if (!success) {

@ -129,6 +129,7 @@ Q_SIGNALS:
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void databaseResultsChanged(const QList<ResultInfo>&); void databaseResultsChanged(const QList<ResultInfo>&);
void modelInfoChanged(const ModelInfo &modelInfo); void modelInfoChanged(const ModelInfo &modelInfo);

@ -1013,7 +1013,7 @@ Window {
anchors.rightMargin: 30 anchors.rightMargin: 30
color: theme.mutedTextColor color: theme.mutedTextColor
visible: currentChat.tokenSpeed !== "" visible: currentChat.tokenSpeed !== ""
text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + MySettings.device text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + currentChat.device
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
} }

Loading…
Cancel
Save