Programmatically get the model name from the LLM. The LLM now searches

for applicable models in the directory of the executable given a pattern
match and then loads the first one it finds.

Also, add a busy indicator for model loading.
This commit is contained in:
Adam Treat 2023-04-11 08:29:55 -04:00
parent 95cd59b405
commit 0ea31487e3
3 changed files with 73 additions and 23 deletions

34
llm.cpp
View File

@ -31,16 +31,29 @@ bool GPTJObject::loadModel()
if (isModelLoaded()) if (isModelLoaded())
return true; return true;
QString modelName("ggml-model-q4_0.bin"); QDir dir(QCoreApplication::applicationDirPath());
QString fileName = QCoreApplication::applicationDirPath() + QDir::separator() + modelName; dir.setNameFilters(QStringList() << "ggml-*.bin");
QFile file(fileName); QStringList fileNames = dir.entryList();
if (file.exists()) { if (fileNames.isEmpty()) {
qDebug() << "ERROR: Could not find any applicable models in directory"
<< QCoreApplication::applicationDirPath();
}
auto fin = std::ifstream(fileName.toStdString(), std::ios::binary); QString modelName = fileNames.first();
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelName;
QFileInfo info(filePath);
if (info.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
m_gptj->loadModel(modelName.toStdString(), fin); m_gptj->loadModel(modelName.toStdString(), fin);
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
if (m_gptj) {
m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix
emit modelNameChanged();
}
return m_gptj; return m_gptj;
} }
@ -64,6 +77,11 @@ QString GPTJObject::response() const
return QString::fromStdString(m_response); return QString::fromStdString(m_response);
} }
QString GPTJObject::modelName() const
{
return m_modelName;
}
bool GPTJObject::handleResponse(const std::string &response) bool GPTJObject::handleResponse(const std::string &response)
{ {
#if 0 #if 0
@ -97,6 +115,7 @@ LLM::LLM()
connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection);
@ -145,6 +164,11 @@ void LLM::responseStopped()
emit responseInProgressChanged(); emit responseInProgressChanged();
} }
QString LLM::modelName() const
{
return m_gptj->modelName();
}
bool LLM::checkForUpdates() const bool LLM::checkForUpdates() const
{ {
#if defined(Q_OS_LINUX) #if defined(Q_OS_LINUX)

8
llm.h
View File

@ -10,6 +10,7 @@ class GPTJObject : public QObject
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
public: public:
@ -22,6 +23,7 @@ public:
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
QString response() const; QString response() const;
QString modelName() const;
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt); bool prompt(const QString &prompt);
@ -31,6 +33,7 @@ Q_SIGNALS:
void responseChanged(); void responseChanged();
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void modelNameChanged();
private: private:
bool handleResponse(const std::string &response); bool handleResponse(const std::string &response);
@ -38,6 +41,7 @@ private:
private: private:
GPTJ *m_gptj; GPTJ *m_gptj;
std::string m_response; std::string m_response;
QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
}; };
@ -47,6 +51,7 @@ class LLM : public QObject
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
public: public:
@ -61,6 +66,8 @@ public:
QString response() const; QString response() const;
bool responseInProgress() const { return m_responseInProgress; } bool responseInProgress() const { return m_responseInProgress; }
QString modelName() const;
Q_INVOKABLE bool checkForUpdates() const; Q_INVOKABLE bool checkForUpdates() const;
Q_SIGNALS: Q_SIGNALS:
@ -70,6 +77,7 @@ Q_SIGNALS:
void promptRequested(const QString &prompt); void promptRequested(const QString &prompt);
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
void modelNameChanged();
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();

View File

@ -11,16 +11,26 @@ Window {
title: qsTr("GPT4All Chat") title: qsTr("GPT4All Chat")
color: "#d1d5db" color: "#d1d5db"
TextField { Rectangle {
id: header id: header
anchors.left: parent.left anchors.left: parent.left
anchors.right: parent.right anchors.right: parent.right
anchors.top: parent.top anchors.top: parent.top
height: 100 height: 100
color: "#202123"
Item {
anchors.centerIn: parent
width: childrenRect.width
height: childrenRect.height
visible: LLM.isModelLoaded
TextField {
id: modelNameField
color: "#d1d5db" color: "#d1d5db"
padding: 20 padding: 20
font.pixelSize: 24 font.pixelSize: 24
text: "GPT4ALL Model: gpt4all-j" text: "GPT4ALL Model: " + LLM.modelName
background: Rectangle { background: Rectangle {
color: "#202123" color: "#202123"
} }
@ -29,13 +39,21 @@ Window {
} }
Image { Image {
anchors.verticalCenter: header.baseline anchors.left: modelNameField.right
x: parent.width / 2 + 163 anchors.verticalCenter: modelNameField.baseline
width: 50 width: 50
height: 65 height: 65
source: "qrc:/gpt4all-chat/icons/logo.svg" source: "qrc:/gpt4all-chat/icons/logo.svg"
z: 300 z: 300
} }
}
BusyIndicator {
anchors.centerIn: parent
visible: !LLM.isModelLoaded
running: !LLM.isModelLoaded
}
}
Button { Button {
id: drawerButton id: drawerButton