Generate names via llm.

This commit is contained in:
Adam Treat 2023-05-02 11:19:17 -04:00
parent a62fafc308
commit f13f4f4700
6 changed files with 105 additions and 3 deletions

View File

@ -18,11 +18,13 @@ Chat::Chat(QObject *parent)
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection); connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection);
connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection); connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast // The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to // to respond to
@ -77,6 +79,8 @@ void Chat::responseStopped()
{ {
m_responseInProgress = false; m_responseInProgress = false;
emit responseInProgressChanged(); emit responseInProgressChanged();
if (m_llmodel->generatedName().isEmpty())
emit generateNameRequested();
} }
QString Chat::modelName() const QString Chat::modelName() const
@ -128,3 +132,13 @@ void Chat::reload()
{ {
emit reloadRequested(); emit reloadRequested();
} }
void Chat::generatedNameChanged()
{
// Only use the first three words maximum and remove newlines and extra spaces
QString gen = m_llmodel->generatedName().simplified();
QStringList words = gen.split(' ', Qt::SkipEmptyParts);
int wordCount = qMin(3, words.size());
m_name = words.mid(0, wordCount).join(' ');
emit nameChanged();
}

2
chat.h
View File

@ -73,10 +73,12 @@ Q_SIGNALS:
void recalcChanged(); void recalcChanged();
void unloadRequested(); void unloadRequested();
void reloadRequested(); void reloadRequested();
void generateNameRequested();
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void generatedNameChanged();
private: private:
ChatLLM *m_llmodel; ChatLLM *m_llmodel;

View File

@ -63,6 +63,8 @@ public:
m_newChat = new Chat(this); m_newChat = new Chat(this);
connect(m_newChat->chatModel(), &ChatModel::countChanged, connect(m_newChat->chatModel(), &ChatModel::countChanged,
this, &ChatListModel::newChatCountChanged); this, &ChatListModel::newChatCountChanged);
connect(m_newChat, &Chat::nameChanged,
this, &ChatListModel::nameChanged);
beginInsertRows(QModelIndex(), 0, 0); beginInsertRows(QModelIndex(), 0, 0);
m_chats.prepend(m_newChat); m_chats.prepend(m_newChat);
@ -147,10 +149,24 @@ private Q_SLOTS:
void newChatCountChanged() void newChatCountChanged()
{ {
Q_ASSERT(m_newChat && m_newChat->chatModel()->count()); Q_ASSERT(m_newChat && m_newChat->chatModel()->count());
m_newChat->disconnect(this); m_newChat->chatModel()->disconnect(this);
m_newChat = nullptr; m_newChat = nullptr;
} }
void nameChanged()
{
Chat *chat = qobject_cast<Chat *>(sender());
if (!chat)
return;
int row = m_chats.indexOf(chat);
if (row < 0 || row >= m_chats.size())
return;
QModelIndex index = createIndex(row, 0);
emit dataChanged(index, index, {NameRole});
}
private: private:
Chat* m_newChat; Chat* m_newChat;
Chat* m_currentChat; Chat* m_currentChat;

View File

@ -300,3 +300,55 @@ void ChatLLM::reload()
{ {
loadModel(); loadModel();
} }
void ChatLLM::generateName()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded())
return;
QString instructPrompt("### Instruction:\n"
"Describe response above in three words.\n"
"### Response:\n");
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
emit generatedNameChanged();
}
}
bool ChatLLM::handleNamePrompt(int32_t token)
{
Q_UNUSED(token);
qt_noop();
return true;
}
bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
{
Q_UNUSED(token);
m_nameResponse.append(response);
emit generatedNameChanged();
return true;
}
bool ChatLLM::handleNameRecalculate(bool isRecalc)
{
Q_UNUSED(isRecalc);
Q_UNREACHABLE();
return true;
}

View File

@ -14,6 +14,7 @@ class ChatLLM : public QObject
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public: public:
ChatLLM(); ChatLLM();
@ -34,6 +35,8 @@ public:
bool isRecalc() const { return m_isRecalc; } bool isRecalc() const { return m_isRecalc; }
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
@ -41,6 +44,7 @@ public Q_SLOTS:
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
void unload(); void unload();
void reload(); void reload();
void generateName();
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
@ -53,6 +57,7 @@ Q_SIGNALS:
void sendStartup(); void sendStartup();
void sendModelLoaded(); void sendModelLoaded();
void sendResetContext(); void sendResetContext();
void generatedNameChanged();
private: private:
void resetContextPrivate(); void resetContextPrivate();
@ -60,11 +65,15 @@ private:
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc); bool handleRecalculate(bool isRecalc);
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleNameRecalculate(bool isRecalc);
private: private:
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
LLModel *m_llmodel; LLModel *m_llmodel;
std::string m_response; std::string m_response;
std::string m_nameResponse;
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;

View File

@ -84,7 +84,7 @@ Drawer {
color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter
border.width: isCurrent border.width: isCurrent
border.color: theme.backgroundLightest border.color: theme.backgroundLightest
TextArea { TextField {
id: chatName id: chatName
anchors.left: parent.left anchors.left: parent.left
anchors.right: buttons.left anchors.right: buttons.left
@ -96,8 +96,15 @@ Drawer {
hoverEnabled: false // Disable hover events on the TextArea hoverEnabled: false // Disable hover events on the TextArea
selectByMouse: false // Disable text selection in the TextArea selectByMouse: false // Disable text selection in the TextArea
font.pixelSize: theme.fontSizeLarger font.pixelSize: theme.fontSizeLarger
text: name text: readOnly ? metrics.elidedText : name
horizontalAlignment: TextInput.AlignLeft horizontalAlignment: TextInput.AlignLeft
TextMetrics {
id: metrics
font: chatName.font
text: name
elide: Text.ElideRight
elideWidth: chatName.width - 25
}
background: Rectangle { background: Rectangle {
color: "transparent" color: "transparent"
} }
@ -111,6 +118,7 @@ Drawer {
LLM.chatListModel.get(index).name = chatName.text LLM.chatListModel.get(index).name = chatName.text
chatName.focus = false chatName.focus = false
chatName.readOnly = true chatName.readOnly = true
chatName.selectByMouse = false
} }
TapHandler { TapHandler {
onTapped: { onTapped: {
@ -139,6 +147,7 @@ Drawer {
onClicked: { onClicked: {
chatName.focus = true chatName.focus = true
chatName.readOnly = false chatName.readOnly = false
chatName.selectByMouse = true
} }
} }
Button { Button {