mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Add new C++ version of the chat model. Getting ready for chat history.
This commit is contained in:
parent
83609bf8a5
commit
bbffa7364b
@ -58,6 +58,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
qt_add_executable(chat
|
||||
main.cpp
|
||||
chat.h chat.cpp chatmodel.h
|
||||
download.h download.cpp
|
||||
network.h network.cpp
|
||||
llm.h llm.cpp
|
||||
|
42
chat.h
Normal file
42
chat.h
Normal file
@ -0,0 +1,42 @@
|
||||
#ifndef CHAT_H
|
||||
#define CHAT_H
|
||||
|
||||
#include <QObject>
|
||||
#include <QtQml>
|
||||
|
||||
#include "chatmodel.h"
|
||||
#include "network.h"
|
||||
|
||||
class Chat : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(QString id READ id NOTIFY idChanged)
|
||||
Q_PROPERTY(QString name READ name NOTIFY nameChanged)
|
||||
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
|
||||
QML_ELEMENT
|
||||
QML_UNCREATABLE("Only creatable from c++!")
|
||||
|
||||
public:
|
||||
explicit Chat(QObject *parent = nullptr) : QObject(parent)
|
||||
{
|
||||
m_id = Network::globalInstance()->generateUniqueId();
|
||||
m_name = tr("New Chat");
|
||||
m_chatModel = new ChatModel(this);
|
||||
}
|
||||
|
||||
QString id() const { return m_id; }
|
||||
QString name() const { return m_name; }
|
||||
ChatModel *chatModel() { return m_chatModel; }
|
||||
|
||||
Q_SIGNALS:
|
||||
void idChanged();
|
||||
void nameChanged();
|
||||
void chatModelChanged();
|
||||
|
||||
private:
|
||||
QString m_id;
|
||||
QString m_name;
|
||||
ChatModel *m_chatModel;
|
||||
};
|
||||
|
||||
#endif // CHAT_H
|
210
chatmodel.h
Normal file
210
chatmodel.h
Normal file
@ -0,0 +1,210 @@
|
||||
#ifndef CHATMODEL_H
|
||||
#define CHATMODEL_H
|
||||
|
||||
#include <QAbstractListModel>
|
||||
#include <QtQml>
|
||||
|
||||
struct ChatItem
|
||||
{
|
||||
Q_GADGET
|
||||
Q_PROPERTY(int id MEMBER id)
|
||||
Q_PROPERTY(QString name MEMBER name)
|
||||
Q_PROPERTY(QString value MEMBER value)
|
||||
Q_PROPERTY(QString prompt MEMBER prompt)
|
||||
Q_PROPERTY(QString newResponse MEMBER newResponse)
|
||||
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
|
||||
Q_PROPERTY(bool stopped MEMBER stopped)
|
||||
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
|
||||
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
|
||||
|
||||
public:
|
||||
int id = 0;
|
||||
QString name;
|
||||
QString value;
|
||||
QString prompt;
|
||||
QString newResponse;
|
||||
bool currentResponse = false;
|
||||
bool stopped = false;
|
||||
bool thumbsUpState = false;
|
||||
bool thumbsDownState = false;
|
||||
};
|
||||
Q_DECLARE_METATYPE(ChatItem)
|
||||
|
||||
class ChatModel : public QAbstractListModel
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
||||
|
||||
public:
|
||||
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
|
||||
|
||||
enum Roles {
|
||||
IdRole = Qt::UserRole + 1,
|
||||
NameRole,
|
||||
ValueRole,
|
||||
PromptRole,
|
||||
NewResponseRole,
|
||||
CurrentResponseRole,
|
||||
StoppedRole,
|
||||
ThumbsUpStateRole,
|
||||
ThumbsDownStateRole
|
||||
};
|
||||
|
||||
int rowCount(const QModelIndex &parent = QModelIndex()) const override
|
||||
{
|
||||
Q_UNUSED(parent)
|
||||
return m_chatItems.size();
|
||||
}
|
||||
|
||||
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
|
||||
{
|
||||
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
|
||||
return QVariant();
|
||||
|
||||
const ChatItem &item = m_chatItems.at(index.row());
|
||||
switch (role) {
|
||||
case IdRole:
|
||||
return item.id;
|
||||
case NameRole:
|
||||
return item.name;
|
||||
case ValueRole:
|
||||
return item.value;
|
||||
case PromptRole:
|
||||
return item.prompt;
|
||||
case NewResponseRole:
|
||||
return item.newResponse;
|
||||
case CurrentResponseRole:
|
||||
return item.currentResponse;
|
||||
case StoppedRole:
|
||||
return item.stopped;
|
||||
case ThumbsUpStateRole:
|
||||
return item.thumbsUpState;
|
||||
case ThumbsDownStateRole:
|
||||
return item.thumbsDownState;
|
||||
}
|
||||
|
||||
return QVariant();
|
||||
}
|
||||
|
||||
QHash<int, QByteArray> roleNames() const override
|
||||
{
|
||||
QHash<int, QByteArray> roles;
|
||||
roles[IdRole] = "id";
|
||||
roles[NameRole] = "name";
|
||||
roles[ValueRole] = "value";
|
||||
roles[PromptRole] = "prompt";
|
||||
roles[NewResponseRole] = "newResponse";
|
||||
roles[CurrentResponseRole] = "currentResponse";
|
||||
roles[StoppedRole] = "stopped";
|
||||
roles[ThumbsUpStateRole] = "thumbsUpState";
|
||||
roles[ThumbsDownStateRole] = "thumbsDownState";
|
||||
return roles;
|
||||
}
|
||||
|
||||
Q_INVOKABLE void appendPrompt(const QString &name, const QString &value)
|
||||
{
|
||||
ChatItem item;
|
||||
item.name = name;
|
||||
item.value = value;
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(item);
|
||||
endInsertRows();
|
||||
emit countChanged();
|
||||
}
|
||||
|
||||
Q_INVOKABLE void appendResponse(const QString &name, const QString &prompt)
|
||||
{
|
||||
ChatItem item;
|
||||
item.id = m_chatItems.count(); // This is only relevant for responses
|
||||
item.name = name;
|
||||
item.prompt = prompt;
|
||||
item.currentResponse = true;
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(item);
|
||||
endInsertRows();
|
||||
emit countChanged();
|
||||
}
|
||||
|
||||
Q_INVOKABLE ChatItem get(int index)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
|
||||
return m_chatItems.at(index);
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.currentResponse != b) {
|
||||
item.currentResponse = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
|
||||
}
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateStopped(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.stopped != b) {
|
||||
item.stopped = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
|
||||
}
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateValue(int index, const QString &value)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.value != value) {
|
||||
item.value = value;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
|
||||
}
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsUpState != b) {
|
||||
item.thumbsUpState = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
|
||||
}
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsDownState != b) {
|
||||
item.thumbsDownState = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
|
||||
}
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.newResponse != newResponse) {
|
||||
item.newResponse = newResponse;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
||||
}
|
||||
}
|
||||
|
||||
int count() const { return m_chatItems.size(); }
|
||||
|
||||
Q_SIGNALS:
|
||||
void countChanged();
|
||||
|
||||
private:
|
||||
|
||||
QList<ChatItem> m_chatItems;
|
||||
};
|
||||
|
||||
#endif // CHATMODEL_H
|
3
llm.cpp
3
llm.cpp
@ -1,6 +1,8 @@
|
||||
#include "llm.h"
|
||||
#include "download.h"
|
||||
#include "network.h"
|
||||
#include "llmodel/gptj.h"
|
||||
#include "llmodel/llamamodel.h"
|
||||
|
||||
#include <QCoreApplication>
|
||||
#include <QDir>
|
||||
@ -345,6 +347,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
|
||||
LLM::LLM()
|
||||
: QObject{nullptr}
|
||||
, m_currentChat(new Chat)
|
||||
, m_llmodel(new LLMObject)
|
||||
, m_responseInProgress(false)
|
||||
{
|
||||
|
11
llm.h
11
llm.h
@ -3,8 +3,9 @@
|
||||
|
||||
#include <QObject>
|
||||
#include <QThread>
|
||||
#include "llmodel/gptj.h"
|
||||
#include "llmodel/llamamodel.h"
|
||||
|
||||
#include "chat.h"
|
||||
#include "llmodel/llmodel.h"
|
||||
|
||||
class LLMObject : public QObject
|
||||
{
|
||||
@ -24,6 +25,7 @@ public:
|
||||
void regenerateResponse();
|
||||
void resetResponse();
|
||||
void resetContext();
|
||||
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
void setThreadCount(int32_t n_threads);
|
||||
int32_t threadCount();
|
||||
@ -83,6 +85,7 @@ class LLM : public QObject
|
||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged)
|
||||
|
||||
public:
|
||||
|
||||
@ -111,6 +114,8 @@ public:
|
||||
|
||||
bool isRecalc() const;
|
||||
|
||||
Chat *currentChat() const { return m_currentChat; }
|
||||
|
||||
Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
void responseChanged();
|
||||
@ -126,12 +131,14 @@ Q_SIGNALS:
|
||||
void threadCountChanged();
|
||||
void setThreadCountRequested(int32_t threadCount);
|
||||
void recalcChanged();
|
||||
void currentChatChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void responseStarted();
|
||||
void responseStopped();
|
||||
|
||||
private:
|
||||
Chat *m_currentChat;
|
||||
LLMObject *m_llmodel;
|
||||
int32_t m_desiredThreadCount;
|
||||
bool m_responseInProgress;
|
||||
|
49
main.qml
49
main.qml
@ -19,6 +19,7 @@ Window {
|
||||
}
|
||||
|
||||
property string chatId: Network.generateUniqueId()
|
||||
property var chatModel: LLM.currentChat.chatModel
|
||||
|
||||
color: theme.textColor
|
||||
|
||||
@ -666,10 +667,6 @@ Window {
|
||||
anchors.bottomMargin: 30
|
||||
ScrollBar.vertical.policy: ScrollBar.AlwaysOn
|
||||
|
||||
ListModel {
|
||||
id: chatModel
|
||||
}
|
||||
|
||||
Rectangle {
|
||||
anchors.fill: parent
|
||||
color: theme.backgroundLighter
|
||||
@ -750,9 +747,9 @@ Window {
|
||||
if (thumbsDownState && !thumbsUpState && !responseHasChanged)
|
||||
return
|
||||
|
||||
newResponse = response
|
||||
thumbsDownState = true
|
||||
thumbsUpState = false
|
||||
chatModel.updateNewResponse(index, response)
|
||||
chatModel.updateThumbsUpState(index, false)
|
||||
chatModel.updateThumbsDownState(index, true)
|
||||
Network.sendConversation(chatId, getConversationJson());
|
||||
}
|
||||
}
|
||||
@ -782,9 +779,9 @@ Window {
|
||||
if (thumbsUpState && !thumbsDownState)
|
||||
return
|
||||
|
||||
newResponse = ""
|
||||
thumbsUpState = true
|
||||
thumbsDownState = false
|
||||
chatModel.updateNewResponse(index, "")
|
||||
chatModel.updateThumbsUpState(index, true)
|
||||
chatModel.updateThumbsDownState(index, false)
|
||||
Network.sendConversation(chatId, getConversationJson());
|
||||
}
|
||||
}
|
||||
@ -862,8 +859,8 @@ Window {
|
||||
}
|
||||
leftPadding: 50
|
||||
onClicked: {
|
||||
if (chatModel.count)
|
||||
var listElement = chatModel.get(chatModel.count - 1)
|
||||
var index = Math.max(0, chatModel.count - 1);
|
||||
var listElement = chatModel.get(index);
|
||||
|
||||
if (LLM.responseInProgress) {
|
||||
listElement.stopped = true
|
||||
@ -872,12 +869,12 @@ Window {
|
||||
LLM.regenerateResponse()
|
||||
if (chatModel.count) {
|
||||
if (listElement.name === qsTr("Response: ")) {
|
||||
listElement.currentResponse = true
|
||||
listElement.stopped = false
|
||||
listElement.value = LLM.response
|
||||
listElement.thumbsUpState = false
|
||||
listElement.thumbsDownState = false
|
||||
listElement.newResponse = ""
|
||||
chatModel.updateCurrentResponse(index, true);
|
||||
chatModel.updateStopped(index, false);
|
||||
chatModel.updateValue(index, LLM.response);
|
||||
chatModel.updateThumbsUpState(index, false);
|
||||
chatModel.updateThumbsDownState(index, false);
|
||||
chatModel.updateNewResponse(index, "");
|
||||
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
|
||||
settingsDialog.maxLength,
|
||||
settingsDialog.topK, settingsDialog.topP,
|
||||
@ -949,18 +946,14 @@ Window {
|
||||
LLM.stopGenerating()
|
||||
|
||||
if (chatModel.count) {
|
||||
var listElement = chatModel.get(chatModel.count - 1)
|
||||
listElement.currentResponse = false
|
||||
listElement.value = LLM.response
|
||||
var index = Math.max(0, chatModel.count - 1);
|
||||
var listElement = chatModel.get(index);
|
||||
chatModel.updateCurrentResponse(index, false);
|
||||
chatModel.updateValue(index, LLM.response);
|
||||
}
|
||||
var prompt = textInput.text + "\n"
|
||||
chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false,
|
||||
"value": textInput.text})
|
||||
chatModel.append({"id": chatModel.count, "name": qsTr("Response: "),
|
||||
"currentResponse": true, "value": "", "stopped": false,
|
||||
"thumbsUpState": false, "thumbsDownState": false,
|
||||
"newResponse": "",
|
||||
"prompt": prompt})
|
||||
chatModel.appendPrompt(qsTr("Prompt: "), textInput.text);
|
||||
chatModel.appendResponse(qsTr("Response: "), prompt);
|
||||
LLM.resetResponse()
|
||||
LLM.prompt(prompt, settingsDialog.promptTemplate,
|
||||
settingsDialog.maxLength,
|
||||
|
Loading…
Reference in New Issue
Block a user