From ed508abc9b91d0cc37fe6827173ddeddf15e5753 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 13 Apr 2023 22:15:40 -0400 Subject: [PATCH] Add an abstraction around gpt-j that will allow other arch models to be loaded in ui. --- CMakeLists.txt | 1 + gptj.h | 14 ++++------- llm.cpp | 64 +++++++++++++++++++++++++------------------------- llm.h | 8 +++---- llmodel.h | 24 +++++++++++++++++++ 5 files changed, 66 insertions(+), 45 deletions(-) create mode 100644 llmodel.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ae5a005a..5f2f1d5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ qt_add_executable(chat main.cpp gptj.h gptj.cpp llm.h llm.cpp + llmodel.h ) qt_add_qml_module(chat diff --git a/gptj.h b/gptj.h index b3a42406..dcac6231 100644 --- a/gptj.h +++ b/gptj.h @@ -2,25 +2,21 @@ #define GPTJ_H #include -#include #include #include +#include "llmodel.h" class GPTJPrivate; -class GPTJ { +class GPTJ : public LLModel { public: GPTJ(); ~GPTJ(); - bool loadModel(const std::string &modelPath, std::istream &fin); - bool isModelLoaded() const; - struct PromptContext { - std::vector logits; - int32_t n_past = 0; // number of tokens in past conversation - }; + bool loadModel(const std::string &modelPath, std::istream &fin) override; + bool isModelLoaded() const override; void prompt(const std::string &prompt, std::function response, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, - float temp = 0.9f, int32_t n_batch = 9); + float temp = 0.9f, int32_t n_batch = 9) override; private: GPTJPrivate *d_ptr; diff --git a/llm.cpp b/llm.cpp index 1efffcef..38ea62e3 100644 --- a/llm.cpp +++ b/llm.cpp @@ -14,19 +14,19 @@ LLM *LLM::globalInstance() return llmInstance(); } -static GPTJ::PromptContext s_ctx; +static LLModel::PromptContext s_ctx; -GPTJObject::GPTJObject() +LLMObject::LLMObject() : QObject{nullptr} - , m_gptj(new GPTJ) + , m_llmodel(new GPTJ) { moveToThread(&m_llmThread); - connect(&m_llmThread, &QThread::started, this, &GPTJObject::loadModel); + connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); m_llmThread.setObjectName("llm thread"); m_llmThread.start(); } -bool GPTJObject::loadModel() +bool LLMObject::loadModel() { if (isModelLoaded()) return true; @@ -45,45 +45,45 @@ bool GPTJObject::loadModel() if (info.exists()) { auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); - m_gptj->loadModel(modelName.toStdString(), fin); + m_llmodel->loadModel(modelName.toStdString(), fin); emit isModelLoadedChanged(); } - if (m_gptj) { + if (m_llmodel) { m_modelName = info.baseName().remove(0, 5); // remove the ggml- prefix emit modelNameChanged(); } - return m_gptj; + return m_llmodel; } -bool GPTJObject::isModelLoaded() const +bool LLMObject::isModelLoaded() const { - return m_gptj->isModelLoaded(); + return m_llmodel->isModelLoaded(); } -void GPTJObject::resetResponse() +void LLMObject::resetResponse() { m_response = std::string(); emit responseChanged(); } -void GPTJObject::resetContext() +void LLMObject::resetContext() { - s_ctx = GPTJ::PromptContext(); + s_ctx = LLModel::PromptContext(); } -QString GPTJObject::response() const +QString LLMObject::response() const { return QString::fromStdString(m_response); } -QString GPTJObject::modelName() const +QString LLMObject::modelName() const { return m_modelName; } -bool GPTJObject::handleResponse(const std::string &response) +bool LLMObject::handleResponse(const std::string &response) { #if 0 printf("%s", response.c_str()); @@ -96,38 +96,38 @@ bool GPTJObject::handleResponse(const std::string &response) return !m_stopGenerating; } -bool GPTJObject::prompt(const QString &prompt) +bool LLMObject::prompt(const QString &prompt) { if (!isModelLoaded()) return false; m_stopGenerating = false; - auto func = std::bind(&GPTJObject::handleResponse, this, std::placeholders::_1); + auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); emit responseStarted(); - m_gptj->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/); + m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/); emit responseStopped(); return true; } LLM::LLM() : QObject{nullptr} - , m_gptj(new GPTJObject) + , m_llmodel(new LLMObject) , m_responseInProgress(false) { - connect(m_gptj, &GPTJObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection); - connect(m_gptj, &GPTJObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); - connect(m_gptj, &GPTJObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); - connect(m_gptj, &GPTJObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); - connect(m_gptj, &GPTJObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); - connect(this, &LLM::promptRequested, m_gptj, &GPTJObject::prompt, Qt::QueuedConnection); - connect(this, &LLM::resetResponseRequested, m_gptj, &GPTJObject::resetResponse, Qt::BlockingQueuedConnection); - connect(this, &LLM::resetContextRequested, m_gptj, &GPTJObject::resetContext, Qt::BlockingQueuedConnection); + connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); + connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); + connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); } bool LLM::isModelLoaded() const { - return m_gptj->isModelLoaded(); + return m_llmodel->isModelLoaded(); } void LLM::prompt(const QString &prompt) @@ -147,12 +147,12 @@ void LLM::resetContext() void LLM::stopGenerating() { - m_gptj->stopGenerating(); + m_llmodel->stopGenerating(); } QString LLM::response() const { - return m_gptj->response(); + return m_llmodel->response(); } void LLM::responseStarted() @@ -169,7 +169,7 @@ void LLM::responseStopped() QString LLM::modelName() const { - return m_gptj->modelName(); + return m_llmodel->modelName(); } bool LLM::checkForUpdates() const diff --git a/llm.h b/llm.h index 808fa886..3740723d 100644 --- a/llm.h +++ b/llm.h @@ -5,7 +5,7 @@ #include #include "gptj.h" -class GPTJObject : public QObject +class LLMObject : public QObject { Q_OBJECT Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) @@ -14,7 +14,7 @@ class GPTJObject : public QObject public: - GPTJObject(); + LLMObject(); bool loadModel(); bool isModelLoaded() const; @@ -39,7 +39,7 @@ private: bool handleResponse(const std::string &response); private: - GPTJ *m_gptj; + LLModel *m_llmodel; std::string m_response; QString m_modelName; QThread m_llmThread; @@ -84,7 +84,7 @@ private Q_SLOTS: void responseStopped(); private: - GPTJObject *m_gptj; + LLMObject *m_llmodel; bool m_responseInProgress; private: diff --git a/llmodel.h b/llmodel.h new file mode 100644 index 00000000..da52d190 --- /dev/null +++ b/llmodel.h @@ -0,0 +1,24 @@ +#ifndef LLMODEL_H +#define LLMODEL_H + +#include +#include +#include + +class LLModel { +public: + explicit LLModel() {} + virtual ~LLModel() {} + + virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0; + virtual bool isModelLoaded() const = 0; + struct PromptContext { + std::vector logits; + int32_t n_past = 0; // number of tokens in past conversation + }; + virtual void prompt(const std::string &prompt, std::function response, + PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, + float temp = 0.9f, int32_t n_batch = 9) = 0; +}; + +#endif // LLMODEL_H \ No newline at end of file