diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f0b61df..b9182434 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) qt_add_executable(chat main.cpp + chat.h chat.cpp chatmodel.h download.h download.cpp network.h network.cpp llm.h llm.cpp diff --git a/chat.cpp b/chat.cpp new file mode 100644 index 00000000..fbc66a34 --- /dev/null +++ b/chat.cpp @@ -0,0 +1 @@ +#include "chat.h" diff --git a/chat.h b/chat.h new file mode 100644 index 00000000..238a8695 --- /dev/null +++ b/chat.h @@ -0,0 +1,42 @@ +#ifndef CHAT_H +#define CHAT_H + +#include +#include + +#include "chatmodel.h" +#include "network.h" + +class Chat : public QObject +{ + Q_OBJECT + Q_PROPERTY(QString id READ id NOTIFY idChanged) + Q_PROPERTY(QString name READ name NOTIFY nameChanged) + Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) + QML_ELEMENT + QML_UNCREATABLE("Only creatable from c++!") + +public: + explicit Chat(QObject *parent = nullptr) : QObject(parent) + { + m_id = Network::globalInstance()->generateUniqueId(); + m_name = tr("New Chat"); + m_chatModel = new ChatModel(this); + } + + QString id() const { return m_id; } + QString name() const { return m_name; } + ChatModel *chatModel() { return m_chatModel; } + +Q_SIGNALS: + void idChanged(); + void nameChanged(); + void chatModelChanged(); + +private: + QString m_id; + QString m_name; + ChatModel *m_chatModel; +}; + +#endif // CHAT_H diff --git a/chatmodel.h b/chatmodel.h new file mode 100644 index 00000000..1102b008 --- /dev/null +++ b/chatmodel.h @@ -0,0 +1,210 @@ +#ifndef CHATMODEL_H +#define CHATMODEL_H + +#include +#include + +struct ChatItem +{ + Q_GADGET + Q_PROPERTY(int id MEMBER id) + Q_PROPERTY(QString name MEMBER name) + Q_PROPERTY(QString value MEMBER value) + Q_PROPERTY(QString prompt MEMBER prompt) + Q_PROPERTY(QString newResponse MEMBER newResponse) + Q_PROPERTY(bool currentResponse MEMBER currentResponse) + Q_PROPERTY(bool stopped MEMBER stopped) + Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) + Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) + +public: + int id = 0; + QString name; + QString value; + QString prompt; + QString newResponse; + bool currentResponse = false; + bool stopped = false; + bool thumbsUpState = false; + bool thumbsDownState = false; +}; +Q_DECLARE_METATYPE(ChatItem) + +class ChatModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + +public: + explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {} + + enum Roles { + IdRole = Qt::UserRole + 1, + NameRole, + ValueRole, + PromptRole, + NewResponseRole, + CurrentResponseRole, + StoppedRole, + ThumbsUpStateRole, + ThumbsDownStateRole + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_chatItems.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) + return QVariant(); + + const ChatItem &item = m_chatItems.at(index.row()); + switch (role) { + case IdRole: + return item.id; + case NameRole: + return item.name; + case ValueRole: + return item.value; + case PromptRole: + return item.prompt; + case NewResponseRole: + return item.newResponse; + case CurrentResponseRole: + return item.currentResponse; + case StoppedRole: + return item.stopped; + case ThumbsUpStateRole: + return item.thumbsUpState; + case ThumbsDownStateRole: + return item.thumbsDownState; + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[IdRole] = "id"; + roles[NameRole] = "name"; + roles[ValueRole] = "value"; + roles[PromptRole] = "prompt"; + roles[NewResponseRole] = "newResponse"; + roles[CurrentResponseRole] = "currentResponse"; + roles[StoppedRole] = "stopped"; + roles[ThumbsUpStateRole] = "thumbsUpState"; + roles[ThumbsDownStateRole] = "thumbsDownState"; + return roles; + } + + Q_INVOKABLE void appendPrompt(const QString &name, const QString &value) + { + ChatItem item; + item.name = name; + item.value = value; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(item); + endInsertRows(); + emit countChanged(); + } + + Q_INVOKABLE void appendResponse(const QString &name, const QString &prompt) + { + ChatItem item; + item.id = m_chatItems.count(); // This is only relevant for responses + item.name = name; + item.prompt = prompt; + item.currentResponse = true; + beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); + m_chatItems.append(item); + endInsertRows(); + emit countChanged(); + } + + Q_INVOKABLE ChatItem get(int index) + { + if (index < 0 || index >= m_chatItems.size()) return ChatItem(); + return m_chatItems.at(index); + } + + Q_INVOKABLE void updateCurrentResponse(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.currentResponse != b) { + item.currentResponse = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); + } + } + + Q_INVOKABLE void updateStopped(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.stopped != b) { + item.stopped = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); + } + } + + Q_INVOKABLE void updateValue(int index, const QString &value) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.value != value) { + item.value = value; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); + } + } + + Q_INVOKABLE void updateThumbsUpState(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.thumbsUpState != b) { + item.thumbsUpState = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole}); + } + } + + Q_INVOKABLE void updateThumbsDownState(int index, bool b) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.thumbsDownState != b) { + item.thumbsDownState = b; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole}); + } + } + + Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse) + { + if (index < 0 || index >= m_chatItems.size()) return; + + ChatItem &item = m_chatItems[index]; + if (item.newResponse != newResponse) { + item.newResponse = newResponse; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); + } + } + + int count() const { return m_chatItems.size(); } + +Q_SIGNALS: + void countChanged(); + +private: + + QList m_chatItems; +}; + +#endif // CHATMODEL_H diff --git a/llm.cpp b/llm.cpp index b7336f47..8e3abe0d 100644 --- a/llm.cpp +++ b/llm.cpp @@ -1,6 +1,8 @@ #include "llm.h" #include "download.h" #include "network.h" +#include "llmodel/gptj.h" +#include "llmodel/llamamodel.h" #include #include @@ -345,6 +347,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in LLM::LLM() : QObject{nullptr} + , m_currentChat(new Chat) , m_llmodel(new LLMObject) , m_responseInProgress(false) { diff --git a/llm.h b/llm.h index 089a63a4..ddcd7d48 100644 --- a/llm.h +++ b/llm.h @@ -3,8 +3,9 @@ #include #include -#include "llmodel/gptj.h" -#include "llmodel/llamamodel.h" + +#include "chat.h" +#include "llmodel/llmodel.h" class LLMObject : public QObject { @@ -24,6 +25,7 @@ public: void regenerateResponse(); void resetResponse(); void resetContext(); + void stopGenerating() { m_stopGenerating = true; } void setThreadCount(int32_t n_threads); int32_t threadCount(); @@ -83,6 +85,7 @@ class LLM : public QObject Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged) public: @@ -111,6 +114,8 @@ public: bool isRecalc() const; + Chat *currentChat() const { return m_currentChat; } + Q_SIGNALS: void isModelLoadedChanged(); void responseChanged(); @@ -126,12 +131,14 @@ Q_SIGNALS: void threadCountChanged(); void setThreadCountRequested(int32_t threadCount); void recalcChanged(); + void currentChatChanged(); private Q_SLOTS: void responseStarted(); void responseStopped(); private: + Chat *m_currentChat; LLMObject *m_llmodel; int32_t m_desiredThreadCount; bool m_responseInProgress; diff --git a/main.qml b/main.qml index 2a775156..a0174310 100644 --- a/main.qml +++ b/main.qml @@ -19,6 +19,7 @@ Window { } property string chatId: Network.generateUniqueId() + property var chatModel: LLM.currentChat.chatModel color: theme.textColor @@ -666,10 +667,6 @@ Window { anchors.bottomMargin: 30 ScrollBar.vertical.policy: ScrollBar.AlwaysOn - ListModel { - id: chatModel - } - Rectangle { anchors.fill: parent color: theme.backgroundLighter @@ -750,9 +747,9 @@ Window { if (thumbsDownState && !thumbsUpState && !responseHasChanged) return - newResponse = response - thumbsDownState = true - thumbsUpState = false + chatModel.updateNewResponse(index, response) + chatModel.updateThumbsUpState(index, false) + chatModel.updateThumbsDownState(index, true) Network.sendConversation(chatId, getConversationJson()); } } @@ -782,9 +779,9 @@ Window { if (thumbsUpState && !thumbsDownState) return - newResponse = "" - thumbsUpState = true - thumbsDownState = false + chatModel.updateNewResponse(index, "") + chatModel.updateThumbsUpState(index, true) + chatModel.updateThumbsDownState(index, false) Network.sendConversation(chatId, getConversationJson()); } } @@ -862,8 +859,8 @@ Window { } leftPadding: 50 onClicked: { - if (chatModel.count) - var listElement = chatModel.get(chatModel.count - 1) + var index = Math.max(0, chatModel.count - 1); + var listElement = chatModel.get(index); if (LLM.responseInProgress) { listElement.stopped = true @@ -872,12 +869,12 @@ Window { LLM.regenerateResponse() if (chatModel.count) { if (listElement.name === qsTr("Response: ")) { - listElement.currentResponse = true - listElement.stopped = false - listElement.value = LLM.response - listElement.thumbsUpState = false - listElement.thumbsDownState = false - listElement.newResponse = "" + chatModel.updateCurrentResponse(index, true); + chatModel.updateStopped(index, false); + chatModel.updateValue(index, LLM.response); + chatModel.updateThumbsUpState(index, false); + chatModel.updateThumbsDownState(index, false); + chatModel.updateNewResponse(index, ""); LLM.prompt(listElement.prompt, settingsDialog.promptTemplate, settingsDialog.maxLength, settingsDialog.topK, settingsDialog.topP, @@ -949,18 +946,14 @@ Window { LLM.stopGenerating() if (chatModel.count) { - var listElement = chatModel.get(chatModel.count - 1) - listElement.currentResponse = false - listElement.value = LLM.response + var index = Math.max(0, chatModel.count - 1); + var listElement = chatModel.get(index); + chatModel.updateCurrentResponse(index, false); + chatModel.updateValue(index, LLM.response); } var prompt = textInput.text + "\n" - chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false, - "value": textInput.text}) - chatModel.append({"id": chatModel.count, "name": qsTr("Response: "), - "currentResponse": true, "value": "", "stopped": false, - "thumbsUpState": false, "thumbsDownState": false, - "newResponse": "", - "prompt": prompt}) + chatModel.appendPrompt(qsTr("Prompt: "), textInput.text); + chatModel.appendResponse(qsTr("Response: "), prompt); LLM.resetResponse() LLM.prompt(prompt, settingsDialog.promptTemplate, settingsDialog.maxLength,