mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-02 09:40:42 +00:00
Add an abstraction around gpt-j that will allow other arch models to be loaded in ui.
This commit is contained in:
parent
4e98e71eaf
commit
305975451c
@ -15,6 +15,7 @@ qt_add_executable(chat
|
|||||||
main.cpp
|
main.cpp
|
||||||
gptj.h gptj.cpp
|
gptj.h gptj.cpp
|
||||||
llm.h llm.cpp
|
llm.h llm.cpp
|
||||||
|
llmodel.h
|
||||||
)
|
)
|
||||||
|
|
||||||
qt_add_qml_module(chat
|
qt_add_qml_module(chat
|
||||||
|
14
gptj.h
14
gptj.h
@ -2,25 +2,21 @@
|
|||||||
#define GPTJ_H
|
#define GPTJ_H
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <sstream>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "llmodel.h"
|
||||||
|
|
||||||
class GPTJPrivate;
|
class GPTJPrivate;
|
||||||
class GPTJ {
|
class GPTJ : public LLModel {
|
||||||
public:
|
public:
|
||||||
GPTJ();
|
GPTJ();
|
||||||
~GPTJ();
|
~GPTJ();
|
||||||
|
|
||||||
bool loadModel(const std::string &modelPath, std::istream &fin);
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const;
|
bool isModelLoaded() const override;
|
||||||
struct PromptContext {
|
|
||||||
std::vector<float> logits;
|
|
||||||
int32_t n_past = 0; // number of tokens in past conversation
|
|
||||||
};
|
|
||||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||||
float temp = 0.9f, int32_t n_batch = 9);
|
float temp = 0.9f, int32_t n_batch = 9) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
|
64
llm.cpp
64
llm.cpp
@ -14,19 +14,19 @@ LLM *LLM::globalInstance()
|
|||||||
return llmInstance();
|
return llmInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
static GPTJ::PromptContext s_ctx;
|
static LLModel::PromptContext s_ctx;
|
||||||
|
|
||||||
GPTJObject::GPTJObject()
|
LLMObject::LLMObject()
|
||||||
: QObject{nullptr}
|
: QObject{nullptr}
|
||||||
, m_gptj(new GPTJ)
|
, m_llmodel(new GPTJ)
|
||||||
{
|
{
|
||||||
moveToThread(&m_llmThread);
|
moveToThread(&m_llmThread);
|
||||||
connect(&m_llmThread, &QThread::started, this, &GPTJObject::loadModel);
|
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
|
||||||
m_llmThread.setObjectName("llm thread");
|
m_llmThread.setObjectName("llm thread");
|
||||||
m_llmThread.start();
|
m_llmThread.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPTJObject::loadModel()
|
bool LLMObject::loadModel()
|
||||||
{
|
{
|
||||||
if (isModelLoaded())
|
if (isModelLoaded())
|
||||||
return true;
|
return true;
|
||||||
@ -45,45 +45,45 @@ bool GPTJObject::loadModel()
|
|||||||
if (info.exists()) {
|
if (info.exists()) {
|
||||||
|
|
||||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||||
m_gptj->loadModel(modelName.toStdString(), fin);
|
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_gptj) {
|
if (m_llmodel) {
|
||||||
m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix
|
m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix
|
||||||
emit modelNameChanged();
|
emit modelNameChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
return m_gptj;
|
return m_llmodel;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPTJObject::isModelLoaded() const
|
bool LLMObject::isModelLoaded() const
|
||||||
{
|
{
|
||||||
return m_gptj->isModelLoaded();
|
return m_llmodel->isModelLoaded();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPTJObject::resetResponse()
|
void LLMObject::resetResponse()
|
||||||
{
|
{
|
||||||
m_response = std::string();
|
m_response = std::string();
|
||||||
emit responseChanged();
|
emit responseChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPTJObject::resetContext()
|
void LLMObject::resetContext()
|
||||||
{
|
{
|
||||||
s_ctx = GPTJ::PromptContext();
|
s_ctx = LLModel::PromptContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
QString GPTJObject::response() const
|
QString LLMObject::response() const
|
||||||
{
|
{
|
||||||
return QString::fromStdString(m_response);
|
return QString::fromStdString(m_response);
|
||||||
}
|
}
|
||||||
|
|
||||||
QString GPTJObject::modelName() const
|
QString LLMObject::modelName() const
|
||||||
{
|
{
|
||||||
return m_modelName;
|
return m_modelName;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPTJObject::handleResponse(const std::string &response)
|
bool LLMObject::handleResponse(const std::string &response)
|
||||||
{
|
{
|
||||||
#if 0
|
#if 0
|
||||||
printf("%s", response.c_str());
|
printf("%s", response.c_str());
|
||||||
@ -96,38 +96,38 @@ bool GPTJObject::handleResponse(const std::string &response)
|
|||||||
return !m_stopGenerating;
|
return !m_stopGenerating;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPTJObject::prompt(const QString &prompt)
|
bool LLMObject::prompt(const QString &prompt)
|
||||||
{
|
{
|
||||||
if (!isModelLoaded())
|
if (!isModelLoaded())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
m_stopGenerating = false;
|
m_stopGenerating = false;
|
||||||
auto func = std::bind(&GPTJObject::handleResponse, this, std::placeholders::_1);
|
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
|
||||||
emit responseStarted();
|
emit responseStarted();
|
||||||
m_gptj->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/);
|
m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/);
|
||||||
emit responseStopped();
|
emit responseStopped();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLM::LLM()
|
LLM::LLM()
|
||||||
: QObject{nullptr}
|
: QObject{nullptr}
|
||||||
, m_gptj(new GPTJObject)
|
, m_llmodel(new LLMObject)
|
||||||
, m_responseInProgress(false)
|
, m_responseInProgress(false)
|
||||||
{
|
{
|
||||||
connect(m_gptj, &GPTJObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection);
|
||||||
connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
|
||||||
connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
|
||||||
connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
|
||||||
connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
|
||||||
|
|
||||||
connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection);
|
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
||||||
connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection);
|
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
|
||||||
connect(this, &LLM::resetContextRequested, m_gptj, &GPTJObject::resetContext, Qt::BlockingQueuedConnection);
|
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLM::isModelLoaded() const
|
bool LLM::isModelLoaded() const
|
||||||
{
|
{
|
||||||
return m_gptj->isModelLoaded();
|
return m_llmodel->isModelLoaded();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLM::prompt(const QString &prompt)
|
void LLM::prompt(const QString &prompt)
|
||||||
@ -147,12 +147,12 @@ void LLM::resetContext()
|
|||||||
|
|
||||||
void LLM::stopGenerating()
|
void LLM::stopGenerating()
|
||||||
{
|
{
|
||||||
m_gptj->stopGenerating();
|
m_llmodel->stopGenerating();
|
||||||
}
|
}
|
||||||
|
|
||||||
QString LLM::response() const
|
QString LLM::response() const
|
||||||
{
|
{
|
||||||
return m_gptj->response();
|
return m_llmodel->response();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLM::responseStarted()
|
void LLM::responseStarted()
|
||||||
@ -169,7 +169,7 @@ void LLM::responseStopped()
|
|||||||
|
|
||||||
QString LLM::modelName() const
|
QString LLM::modelName() const
|
||||||
{
|
{
|
||||||
return m_gptj->modelName();
|
return m_llmodel->modelName();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLM::checkForUpdates() const
|
bool LLM::checkForUpdates() const
|
||||||
|
8
llm.h
8
llm.h
@ -5,7 +5,7 @@
|
|||||||
#include <QThread>
|
#include <QThread>
|
||||||
#include "gptj.h"
|
#include "gptj.h"
|
||||||
|
|
||||||
class GPTJObject : public QObject
|
class LLMObject : public QObject
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||||
@ -14,7 +14,7 @@ class GPTJObject : public QObject
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
GPTJObject();
|
LLMObject();
|
||||||
|
|
||||||
bool loadModel();
|
bool loadModel();
|
||||||
bool isModelLoaded() const;
|
bool isModelLoaded() const;
|
||||||
@ -39,7 +39,7 @@ private:
|
|||||||
bool handleResponse(const std::string &response);
|
bool handleResponse(const std::string &response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJ *m_gptj;
|
LLModel *m_llmodel;
|
||||||
std::string m_response;
|
std::string m_response;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
@ -84,7 +84,7 @@ private Q_SLOTS:
|
|||||||
void responseStopped();
|
void responseStopped();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJObject *m_gptj;
|
LLMObject *m_llmodel;
|
||||||
bool m_responseInProgress;
|
bool m_responseInProgress;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
24
llmodel.h
Normal file
24
llmodel.h
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#ifndef LLMODEL_H
|
||||||
|
#define LLMODEL_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class LLModel {
|
||||||
|
public:
|
||||||
|
explicit LLModel() {}
|
||||||
|
virtual ~LLModel() {}
|
||||||
|
|
||||||
|
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
||||||
|
virtual bool isModelLoaded() const = 0;
|
||||||
|
struct PromptContext {
|
||||||
|
std::vector<float> logits;
|
||||||
|
int32_t n_past = 0; // number of tokens in past conversation
|
||||||
|
};
|
||||||
|
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||||
|
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // LLMODEL_H
|
Loading…
Reference in New Issue
Block a user