Allow unloading/loading/changing of models.

This commit is contained in:
Adam Treat 2023-04-18 11:42:16 -04:00
parent 3a82a1d96c
commit 1eda8f030e
3 changed files with 126 additions and 25 deletions

91
llm.cpp
View File

@ -18,7 +18,7 @@ static LLModel::PromptContext s_ctx;
LLMObject::LLMObject()
: QObject{nullptr}
, m_llmodel(new GPTJ)
, m_llmodel(nullptr)
, m_responseTokens(0)
, m_responseLogits(0)
{
@ -30,19 +30,24 @@ LLMObject::LLMObject()
bool LLMObject::loadModel()
{
if (isModelLoaded())
return loadModelPrivate(modelList().first());
}
bool LLMObject::loadModelPrivate(const QString &modelName)
{
if (isModelLoaded() && m_modelName == modelName)
return true;
QDir dir(QCoreApplication::applicationDirPath());
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
if (fileNames.isEmpty()) {
qDebug() << "ERROR: Could not find any applicable models in directory"
<< QCoreApplication::applicationDirPath();
if (isModelLoaded()) {
delete m_llmodel;
m_llmodel = nullptr;
emit isModelLoadedChanged();
}
QString modelName = fileNames.first();
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelName;
m_llmodel = new GPTJ;
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() +
"ggml-" + modelName + ".bin";
QFileInfo info(filePath);
if (info.exists()) {
@ -51,17 +56,15 @@ bool LLMObject::loadModel()
emit isModelLoadedChanged();
}
if (m_llmodel) {
m_modelName = info.completeBaseName().remove(0, 5); // remove the ggml- prefix
emit modelNameChanged();
}
if (m_llmodel)
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
return m_llmodel;
}
bool LLMObject::isModelLoaded() const
{
return m_llmodel->isModelLoaded();
return m_llmodel && m_llmodel->isModelLoaded();
}
void LLMObject::regenerateResponse()
@ -119,6 +122,46 @@ QString LLMObject::modelName() const
return m_modelName;
}
void LLMObject::setModelName(const QString &modelName)
{
m_modelName = modelName;
emit modelNameChanged();
emit modelListChanged();
}
void LLMObject::modelNameChangeRequested(const QString &modelName)
{
if (!loadModelPrivate(modelName))
qWarning() << "ERROR: Could not load model" << modelName;
}
QList<QString> LLMObject::modelList() const
{
QDir dir(QCoreApplication::applicationDirPath());
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
if (fileNames.isEmpty()) {
qWarning() << "ERROR: Could not find any applicable models in directory"
<< QCoreApplication::applicationDirPath();
return QList<QString>();
}
QList<QString> list;
for (QString f : fileNames) {
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == m_modelName)
list.prepend(name);
else
list.append(name);
}
}
return list;
}
bool LLMObject::handleResponse(const std::string &response)
{
#if 0
@ -172,8 +215,12 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
@ -232,6 +279,18 @@ QString LLM::modelName() const
return m_llmodel->modelName();
}
void LLM::setModelName(const QString &modelName)
{
// doesn't block but will unload old model and load new one which the gui can see through changes
// to the isModelLoaded property
emit modelNameChangeRequested(modelName);
}
QList<QString> LLM::modelList() const
{
return m_llmodel->modelList();
}
bool LLM::checkForUpdates() const
{
#if defined(Q_OS_LINUX)

19
llm.h
View File

@ -8,15 +8,15 @@
class LLMObject : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
public:
LLMObject();
bool loadModel();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
@ -26,9 +26,14 @@ public:
QString response() const;
QString modelName() const;
QList<QString> modelList() const;
void setModelName(const QString &modelName);
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch);
bool loadModel();
void modelNameChangeRequested(const QString &modelName);
Q_SIGNALS:
void isModelLoadedChanged();
@ -36,8 +41,10 @@ Q_SIGNALS:
void responseStarted();
void responseStopped();
void modelNameChanged();
void modelListChanged();
private:
bool loadModelPrivate(const QString &modelName);
bool handleResponse(const std::string &response);
private:
@ -53,9 +60,10 @@ private:
class LLM : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
public:
@ -72,7 +80,10 @@ public:
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
QList<QString> modelList() const;
QString modelName() const;
void setModelName(const QString &modelName);
Q_INVOKABLE bool checkForUpdates() const;
@ -85,7 +96,9 @@ Q_SIGNALS:
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
void modelNameChangeRequested(const QString &modelName);
void modelNameChanged();
void modelListChanged();
private Q_SLOTS:
void responseStarted();

View File

@ -32,18 +32,47 @@ Window {
visible: LLM.isModelLoaded
Label {
id: modelNameField
id: modelLabel
color: "#d1d5db"
padding: 20
font.pixelSize: 24
text: "GPT4ALL Model: " + LLM.modelName
text: ""
background: Rectangle {
color: "#202123"
}
horizontalAlignment: TextInput.AlignHCenter
Accessible.role: Accessible.Heading
Accessible.name: text
Accessible.description: qsTr("Displays the model name that is currently loaded")
horizontalAlignment: TextInput.AlignRight
}
ComboBox {
id: comboBox
width: 400
anchors.top: modelLabel.top
anchors.bottom: modelLabel.bottom
anchors.horizontalCenter: parent.horizontalCenter
font.pixelSize: 24
spacing: 0
model: LLM.modelList
Accessible.role: Accessible.ComboBox
Accessible.name: qsTr("ComboBox for displaying/picking the current model")
Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model")
contentItem: Text {
anchors.horizontalCenter: parent.horizontalCenter
leftPadding: 10
rightPadding: 10
text: comboBox.displayText
font: comboBox.font
color: "#d1d5db"
verticalAlignment: Text.AlignVCenter
horizontalAlignment: Text.AlignHCenter
elide: Text.ElideRight
}
background: Rectangle {
color: "#242528"
}
onActivated: {
LLM.modelName = comboBox.currentText
}
}
}