Programmatically get the model name from the LLM. The LLM now searches

for applicable models in the directory of the executable given a pattern
match and then loads the first one it finds.

Also, add a busy indicator for model loading.
pull/520/head
Adam Treat 1 year ago
parent 95cd59b405
commit 0ea31487e3

@ -31,16 +31,29 @@ bool GPTJObject::loadModel()
if (isModelLoaded()) if (isModelLoaded())
return true; return true;
QString modelName("ggml-model-q4_0.bin"); QDir dir(QCoreApplication::applicationDirPath());
QString fileName = QCoreApplication::applicationDirPath() + QDir::separator() + modelName; dir.setNameFilters(QStringList() << "ggml-*.bin");
QFile file(fileName); QStringList fileNames = dir.entryList();
if (file.exists()) { if (fileNames.isEmpty()) {
qDebug() << "ERROR: Could not find any applicable models in directory"
<< QCoreApplication::applicationDirPath();
}
QString modelName = fileNames.first();
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelName;
QFileInfo info(filePath);
if (info.exists()) {
auto fin = std::ifstream(fileName.toStdString(), std::ios::binary); auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
m_gptj->loadModel(modelName.toStdString(), fin); m_gptj->loadModel(modelName.toStdString(), fin);
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
if (m_gptj) {
m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix
emit modelNameChanged();
}
return m_gptj; return m_gptj;
} }
@ -64,6 +77,11 @@ QString GPTJObject::response() const
return QString::fromStdString(m_response); return QString::fromStdString(m_response);
} }
QString GPTJObject::modelName() const
{
return m_modelName;
}
bool GPTJObject::handleResponse(const std::string &response) bool GPTJObject::handleResponse(const std::string &response)
{ {
#if 0 #if 0
@ -97,6 +115,7 @@ LLM::LLM()
connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection);
@ -145,6 +164,11 @@ void LLM::responseStopped()
emit responseInProgressChanged(); emit responseInProgressChanged();
} }
QString LLM::modelName() const
{
return m_gptj->modelName();
}
bool LLM::checkForUpdates() const bool LLM::checkForUpdates() const
{ {
#if defined(Q_OS_LINUX) #if defined(Q_OS_LINUX)

@ -10,6 +10,7 @@ class GPTJObject : public QObject
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
public: public:
@ -22,6 +23,7 @@ public:
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
QString response() const; QString response() const;
QString modelName() const;
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt); bool prompt(const QString &prompt);
@ -31,6 +33,7 @@ Q_SIGNALS:
void responseChanged(); void responseChanged();
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void modelNameChanged();
private: private:
bool handleResponse(const std::string &response); bool handleResponse(const std::string &response);
@ -38,6 +41,7 @@ private:
private: private:
GPTJ *m_gptj; GPTJ *m_gptj;
std::string m_response; std::string m_response;
QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
}; };
@ -47,6 +51,7 @@ class LLM : public QObject
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
public: public:
@ -61,6 +66,8 @@ public:
QString response() const; QString response() const;
bool responseInProgress() const { return m_responseInProgress; } bool responseInProgress() const { return m_responseInProgress; }
QString modelName() const;
Q_INVOKABLE bool checkForUpdates() const; Q_INVOKABLE bool checkForUpdates() const;
Q_SIGNALS: Q_SIGNALS:
@ -70,6 +77,7 @@ Q_SIGNALS:
void promptRequested(const QString &prompt); void promptRequested(const QString &prompt);
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
void modelNameChanged();
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();

@ -11,30 +11,48 @@ Window {
title: qsTr("GPT4All Chat") title: qsTr("GPT4All Chat")
color: "#d1d5db" color: "#d1d5db"
TextField { Rectangle {
id: header id: header
anchors.left: parent.left anchors.left: parent.left
anchors.right: parent.right anchors.right: parent.right
anchors.top: parent.top anchors.top: parent.top
height: 100 height: 100
color: "#d1d5db" color: "#202123"
padding: 20
font.pixelSize: 24 Item {
text: "GPT4ALL Model: gpt4all-j" anchors.centerIn: parent
background: Rectangle { width: childrenRect.width
color: "#202123" height: childrenRect.height
visible: LLM.isModelLoaded
TextField {
id: modelNameField
color: "#d1d5db"
padding: 20
font.pixelSize: 24
text: "GPT4ALL Model: " + LLM.modelName
background: Rectangle {
color: "#202123"
}
focus: false
horizontalAlignment: TextInput.AlignHCenter
}
Image {
anchors.left: modelNameField.right
anchors.verticalCenter: modelNameField.baseline
width: 50
height: 65
source: "qrc:/gpt4all-chat/icons/logo.svg"
z: 300
}
} }
focus: false
horizontalAlignment: TextInput.AlignHCenter
}
Image { BusyIndicator {
anchors.verticalCenter: header.baseline anchors.centerIn: parent
x: parent.width / 2 + 163 visible: !LLM.isModelLoaded
width: 50 running: !LLM.isModelLoaded
height: 65 }
source: "qrc:/gpt4all-chat/icons/logo.svg"
z: 300
} }
Button { Button {

Loading…
Cancel
Save