From f13f4f4700f8b65c263f677b2a18f5460e26f47f Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 2 May 2023 11:19:17 -0400 Subject: [PATCH] Generate names via llm. --- chat.cpp | 14 +++++++++++++ chat.h | 2 ++ chatlistmodel.h | 18 +++++++++++++++- chatllm.cpp | 52 ++++++++++++++++++++++++++++++++++++++++++++++ chatllm.h | 9 ++++++++ qml/ChatDrawer.qml | 13 ++++++++++-- 6 files changed, 105 insertions(+), 3 deletions(-) diff --git a/chat.cpp b/chat.cpp index 697f74a4..2daca6e5 100644 --- a/chat.cpp +++ b/chat.cpp @@ -18,11 +18,13 @@ Chat::Chat(QObject *parent) connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection); connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection); + connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); // The following are blocking operations and will block the gui thread, therefore must be fast // to respond to @@ -77,6 +79,8 @@ void Chat::responseStopped() { m_responseInProgress = false; emit responseInProgressChanged(); + if (m_llmodel->generatedName().isEmpty()) + emit generateNameRequested(); } QString Chat::modelName() const @@ -128,3 +132,13 @@ void Chat::reload() { emit reloadRequested(); } + +void Chat::generatedNameChanged() +{ + // Only use the first three words maximum and remove newlines and extra spaces + QString gen = m_llmodel->generatedName().simplified(); + QStringList words = gen.split(' ', Qt::SkipEmptyParts); + int wordCount = qMin(3, words.size()); + m_name = words.mid(0, wordCount).join(' '); + emit nameChanged(); +} diff --git a/chat.h b/chat.h index 79c61aa9..8d7169c9 100644 --- a/chat.h +++ b/chat.h @@ -73,10 +73,12 @@ Q_SIGNALS: void recalcChanged(); void unloadRequested(); void reloadRequested(); + void generateNameRequested(); private Q_SLOTS: void responseStarted(); void responseStopped(); + void generatedNameChanged(); private: ChatLLM *m_llmodel; diff --git a/chatlistmodel.h b/chatlistmodel.h index d5ad10ce..0c7b85c4 100644 --- a/chatlistmodel.h +++ b/chatlistmodel.h @@ -63,6 +63,8 @@ public: m_newChat = new Chat(this); connect(m_newChat->chatModel(), &ChatModel::countChanged, this, &ChatListModel::newChatCountChanged); + connect(m_newChat, &Chat::nameChanged, + this, &ChatListModel::nameChanged); beginInsertRows(QModelIndex(), 0, 0); m_chats.prepend(m_newChat); @@ -147,10 +149,24 @@ private Q_SLOTS: void newChatCountChanged() { Q_ASSERT(m_newChat && m_newChat->chatModel()->count()); - m_newChat->disconnect(this); + m_newChat->chatModel()->disconnect(this); m_newChat = nullptr; } + void nameChanged() + { + Chat *chat = qobject_cast(sender()); + if (!chat) + return; + + int row = m_chats.indexOf(chat); + if (row < 0 || row >= m_chats.size()) + return; + + QModelIndex index = createIndex(row, 0); + emit dataChanged(index, index, {NameRole}); + } + private: Chat* m_newChat; Chat* m_currentChat; diff --git a/chatllm.cpp b/chatllm.cpp index 2612042e..247410eb 100644 --- a/chatllm.cpp +++ b/chatllm.cpp @@ -300,3 +300,55 @@ void ChatLLM::reload() { loadModel(); } + +void ChatLLM::generateName() +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded()) + return; + + QString instructPrompt("### Instruction:\n" + "Describe response above in three words.\n" + "### Response:\n"); + auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1); + LLModel::PromptContext ctx = m_ctx; +#if defined(DEBUG) + printf("%s", qPrintable(instructPrompt)); + fflush(stdout); +#endif + m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); +#if defined(DEBUG) + printf("\n"); + fflush(stdout); +#endif + std::string trimmed = trim_whitespace(m_nameResponse); + if (trimmed != m_nameResponse) { + m_nameResponse = trimmed; + emit generatedNameChanged(); + } +} + +bool ChatLLM::handleNamePrompt(int32_t token) +{ + Q_UNUSED(token); + qt_noop(); + return true; +} + +bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) +{ + Q_UNUSED(token); + m_nameResponse.append(response); + emit generatedNameChanged(); + return true; +} + +bool ChatLLM::handleNameRecalculate(bool isRecalc) +{ + Q_UNUSED(isRecalc); + Q_UNREACHABLE(); + return true; +} diff --git a/chatllm.h b/chatllm.h index ae888f0f..e5c0acef 100644 --- a/chatllm.h +++ b/chatllm.h @@ -14,6 +14,7 @@ class ChatLLM : public QObject Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) public: ChatLLM(); @@ -34,6 +35,8 @@ public: bool isRecalc() const { return m_isRecalc; } + QString generatedName() const { return QString::fromStdString(m_nameResponse); } + public Q_SLOTS: bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); @@ -41,6 +44,7 @@ public Q_SLOTS: void modelNameChangeRequested(const QString &modelName); void unload(); void reload(); + void generateName(); Q_SIGNALS: void isModelLoadedChanged(); @@ -53,6 +57,7 @@ Q_SIGNALS: void sendStartup(); void sendModelLoaded(); void sendResetContext(); + void generatedNameChanged(); private: void resetContextPrivate(); @@ -60,11 +65,15 @@ private: bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleRecalculate(bool isRecalc); + bool handleNamePrompt(int32_t token); + bool handleNameResponse(int32_t token, const std::string &response); + bool handleNameRecalculate(bool isRecalc); private: LLModel::PromptContext m_ctx; LLModel *m_llmodel; std::string m_response; + std::string m_nameResponse; quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; diff --git a/qml/ChatDrawer.qml b/qml/ChatDrawer.qml index a1883634..882fcf09 100644 --- a/qml/ChatDrawer.qml +++ b/qml/ChatDrawer.qml @@ -84,7 +84,7 @@ Drawer { color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter border.width: isCurrent border.color: theme.backgroundLightest - TextArea { + TextField { id: chatName anchors.left: parent.left anchors.right: buttons.left @@ -96,8 +96,15 @@ Drawer { hoverEnabled: false // Disable hover events on the TextArea selectByMouse: false // Disable text selection in the TextArea font.pixelSize: theme.fontSizeLarger - text: name + text: readOnly ? metrics.elidedText : name horizontalAlignment: TextInput.AlignLeft + TextMetrics { + id: metrics + font: chatName.font + text: name + elide: Text.ElideRight + elideWidth: chatName.width - 25 + } background: Rectangle { color: "transparent" } @@ -111,6 +118,7 @@ Drawer { LLM.chatListModel.get(index).name = chatName.text chatName.focus = false chatName.readOnly = true + chatName.selectByMouse = false } TapHandler { onTapped: { @@ -139,6 +147,7 @@ Drawer { onClicked: { chatName.focus = true chatName.readOnly = false + chatName.selectByMouse = true } } Button {