diff --git a/gpt4all-chat/chatgpt.cpp b/gpt4all-chat/chatgpt.cpp index 0575ee8e..264b3849 100644 --- a/gpt4all-chat/chatgpt.cpp +++ b/gpt4all-chat/chatgpt.cpp @@ -86,13 +86,33 @@ void ChatGPT::prompt(const std::string &prompt, Q_UNUSED(promptCallback); Q_UNUSED(recalculateCallback); Q_UNUSED(special); - Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT if (!isModelLoaded()) { std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n"; return; } + if (!promptCtx.n_past) { m_queuedPrompts.clear(); } + Q_ASSERT(promptCtx.n_past <= m_context.size()); + m_context.resize(promptCtx.n_past); + + // FIXME(cebtenzzre): We're assuming people don't try to use %2 with ChatGPT. What would that even mean? + m_queuedPrompts << QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)); + + if (!promptCtx.n_predict && !fakeReply) { + return; // response explicitly suppressed, queue prompt for later + } + + QString formattedPrompt = m_queuedPrompts.join(""); + m_queuedPrompts.clear(); + + if (fakeReply) { + promptCtx.n_past += 1; + m_context.append(formattedPrompt); + m_context.append(QString::fromStdString(*fakeReply)); + return; + } + // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering // an error we need to be able to count the tokens in our prompt. The only way to do this is to use // the OpenAI tiktokken library or to implement our own tokenization function that matches precisely @@ -104,8 +124,9 @@ void ChatGPT::prompt(const std::string &prompt, root.insert("temperature", promptCtx.temp); root.insert("top_p", promptCtx.top_p); + // conversation history QJsonArray messages; - for (int i = 0; i < m_context.count() && i < promptCtx.n_past; ++i) { + for (int i = 0; i < m_context.count(); ++i) { QJsonObject message; message.insert("role", i % 2 == 0 ? "assistant" : "user"); message.insert("content", m_context.at(i)); @@ -114,7 +135,7 @@ void ChatGPT::prompt(const std::string &prompt, QJsonObject promptObject; promptObject.insert("role", "user"); - promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt))); + promptObject.insert("content", formattedPrompt); messages.append(promptObject); root.insert("messages", messages); @@ -138,7 +159,7 @@ void ChatGPT::prompt(const std::string &prompt, workerThread.wait(); promptCtx.n_past += 1; - m_context.append(QString::fromStdString(prompt)); + m_context.append(formattedPrompt); m_context.append(worker.currentResponse()); m_responseCallback = nullptr; diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 2656c6f7..07ceba58 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -3,11 +3,14 @@ #include -#include +#include #include #include -#include +#include +#include +#include #include + #include "../gpt4all-backend/llmodel.h" class ChatGPT; @@ -126,6 +129,7 @@ private: QString m_modelName; QString m_apiKey; QList m_context; + QStringList m_queuedPrompts; }; #endif // CHATGPT_H