Add an abstraction around gpt-j that will allow other arch models to be loaded in ui.

This commit is contained in:
Adam Treat 2023-04-13 22:15:40 -04:00
parent 4e98e71eaf
commit 9de185488c
5 changed files with 66 additions and 45 deletions

View File

@ -15,6 +15,7 @@ qt_add_executable(chat
main.cpp main.cpp
gptj.h gptj.cpp gptj.h gptj.cpp
llm.h llm.cpp llm.h llm.cpp
llmodel.h
) )
qt_add_qml_module(chat qt_add_qml_module(chat

14
gptj.h
View File

@ -2,25 +2,21 @@
#define GPTJ_H #define GPTJ_H
#include <string> #include <string>
#include <sstream>
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "llmodel.h"
class GPTJPrivate; class GPTJPrivate;
class GPTJ { class GPTJ : public LLModel {
public: public:
GPTJ(); GPTJ();
~GPTJ(); ~GPTJ();
bool loadModel(const std::string &modelPath, std::istream &fin); bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const; bool isModelLoaded() const override;
struct PromptContext {
std::vector<float> logits;
int32_t n_past = 0; // number of tokens in past conversation
};
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response, void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
float temp = 0.9f, int32_t n_batch = 9); float temp = 0.9f, int32_t n_batch = 9) override;
private: private:
GPTJPrivate *d_ptr; GPTJPrivate *d_ptr;

64
llm.cpp
View File

@ -14,19 +14,19 @@ LLM *LLM::globalInstance()
return llmInstance(); return llmInstance();
} }
static GPTJ::PromptContext s_ctx; static LLModel::PromptContext s_ctx;
GPTJObject::GPTJObject() LLMObject::LLMObject()
: QObject{nullptr} : QObject{nullptr}
, m_gptj(new GPTJ) , m_llmodel(new GPTJ)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &GPTJObject::loadModel); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
m_llmThread.setObjectName("llm thread"); m_llmThread.setObjectName("llm thread");
m_llmThread.start(); m_llmThread.start();
} }
bool GPTJObject::loadModel() bool LLMObject::loadModel()
{ {
if (isModelLoaded()) if (isModelLoaded())
return true; return true;
@ -45,45 +45,45 @@ bool GPTJObject::loadModel()
if (info.exists()) { if (info.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
m_gptj->loadModel(modelName.toStdString(), fin); m_llmodel->loadModel(modelName.toStdString(), fin);
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
if (m_gptj) { if (m_llmodel) {
m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix
emit modelNameChanged(); emit modelNameChanged();
} }
return m_gptj; return m_llmodel;
} }
bool GPTJObject::isModelLoaded() const bool LLMObject::isModelLoaded() const
{ {
return m_gptj->isModelLoaded(); return m_llmodel->isModelLoaded();
} }
void GPTJObject::resetResponse() void LLMObject::resetResponse()
{ {
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged();
} }
void GPTJObject::resetContext() void LLMObject::resetContext()
{ {
s_ctx = GPTJ::PromptContext(); s_ctx = LLModel::PromptContext();
} }
QString GPTJObject::response() const QString LLMObject::response() const
{ {
return QString::fromStdString(m_response); return QString::fromStdString(m_response);
} }
QString GPTJObject::modelName() const QString LLMObject::modelName() const
{ {
return m_modelName; return m_modelName;
} }
bool GPTJObject::handleResponse(const std::string &response) bool LLMObject::handleResponse(const std::string &response)
{ {
#if 0 #if 0
printf("%s", response.c_str()); printf("%s", response.c_str());
@ -96,38 +96,38 @@ bool GPTJObject::handleResponse(const std::string &response)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool GPTJObject::prompt(const QString &prompt) bool LLMObject::prompt(const QString &prompt)
{ {
if (!isModelLoaded()) if (!isModelLoaded())
return false; return false;
m_stopGenerating = false; m_stopGenerating = false;
auto func = std::bind(&GPTJObject::handleResponse, this, std::placeholders::_1); auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
emit responseStarted(); emit responseStarted();
m_gptj->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/); m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/);
emit responseStopped(); emit responseStopped();
return true; return true;
} }
LLM::LLM() LLM::LLM()
: QObject{nullptr} : QObject{nullptr}
, m_gptj(new GPTJObject) , m_llmodel(new LLMObject)
, m_responseInProgress(false) , m_responseInProgress(false)
{ {
connect(m_gptj, &GPTJObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_gptj, &GPTJObject::resetContext, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
} }
bool LLM::isModelLoaded() const bool LLM::isModelLoaded() const
{ {
return m_gptj->isModelLoaded(); return m_llmodel->isModelLoaded();
} }
void LLM::prompt(const QString &prompt) void LLM::prompt(const QString &prompt)
@ -147,12 +147,12 @@ void LLM::resetContext()
void LLM::stopGenerating() void LLM::stopGenerating()
{ {
m_gptj->stopGenerating(); m_llmodel->stopGenerating();
} }
QString LLM::response() const QString LLM::response() const
{ {
return m_gptj->response(); return m_llmodel->response();
} }
void LLM::responseStarted() void LLM::responseStarted()
@ -169,7 +169,7 @@ void LLM::responseStopped()
QString LLM::modelName() const QString LLM::modelName() const
{ {
return m_gptj->modelName(); return m_llmodel->modelName();
} }
bool LLM::checkForUpdates() const bool LLM::checkForUpdates() const

8
llm.h
View File

@ -5,7 +5,7 @@
#include <QThread> #include <QThread>
#include "gptj.h" #include "gptj.h"
class GPTJObject : public QObject class LLMObject : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
@ -14,7 +14,7 @@ class GPTJObject : public QObject
public: public:
GPTJObject(); LLMObject();
bool loadModel(); bool loadModel();
bool isModelLoaded() const; bool isModelLoaded() const;
@ -39,7 +39,7 @@ private:
bool handleResponse(const std::string &response); bool handleResponse(const std::string &response);
private: private:
GPTJ *m_gptj; LLModel *m_llmodel;
std::string m_response; std::string m_response;
QString m_modelName; QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
@ -84,7 +84,7 @@ private Q_SLOTS:
void responseStopped(); void responseStopped();
private: private:
GPTJObject *m_gptj; LLMObject *m_llmodel;
bool m_responseInProgress; bool m_responseInProgress;
private: private:

24
llmodel.h Normal file
View File

@ -0,0 +1,24 @@
#ifndef LLMODEL_H
#define LLMODEL_H
#include <string>
#include <functional>
#include <vector>
class LLModel {
public:
explicit LLModel() {}
virtual ~LLModel() {}
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
virtual bool isModelLoaded() const = 0;
struct PromptContext {
std::vector<float> logits;
int32_t n_past = 0; // number of tokens in past conversation
};
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
float temp = 0.9f, int32_t n_batch = 9) = 0;
};
#endif // LLMODEL_H