diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 298010d8..de9ee0a6 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -88,6 +88,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_isRecalc(false) , m_chat(parent) , m_isServer(isServer) + , m_isChatGPT(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); @@ -138,8 +139,8 @@ bool ChatLLM::loadModel(const QString &modelName) if (isModelLoaded() && m_modelName == modelName) return true; - const bool isChatGPT = modelName.startsWith("chatgpt-"); - QString filePath = modelFilePath(modelName, isChatGPT); + m_isChatGPT = modelName.startsWith("chatgpt-"); + QString filePath = modelFilePath(modelName, m_isChatGPT); QFileInfo fileInfo(filePath); // We have a live model, but it isn't the one we want @@ -198,7 +199,7 @@ bool ChatLLM::loadModel(const QString &modelName) m_modelInfo.fileInfo = fileInfo; if (fileInfo.exists()) { - if (isChatGPT) { + if (m_isChatGPT) { QString apiKey; QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix { @@ -260,7 +261,7 @@ bool ChatLLM::loadModel(const QString &modelName) if (m_modelInfo.model) { QString basename = fileInfo.completeBaseName(); - setModelName(isChatGPT ? basename : basename.remove(0, 5)); // remove the ggml- prefix + setModelName(m_isChatGPT ? basename : basename.remove(0, 5)); // remove the ggml- prefix } return m_modelInfo.model; @@ -273,7 +274,12 @@ bool ChatLLM::isModelLoaded() const void ChatLLM::regenerateResponse() { - m_ctx.n_past -= m_promptResponseTokens; + // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning + // of n_past is of the number of prompt/response pairs, rather than for total tokens. + if (m_isChatGPT) + m_ctx.n_past -= 1; + else + m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past = std::max(0, m_ctx.n_past); // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end()); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index fc99d6ce..d5fb1f3b 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -115,6 +115,7 @@ protected: std::atomic m_shouldBeLoaded; bool m_isRecalc; bool m_isServer; + bool m_isChatGPT; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 7e81a6c7..92a18305 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -579,6 +579,16 @@ Window { anchors.fill: parent color: currentChat.isServer ? theme.backgroundDark : theme.backgroundLighter + Image { + visible: currentChat.isServer || currentChat.modelName.startsWith("chatgpt-") + anchors.fill: parent + sourceSize.width: 1024 + sourceSize.height: 1024 + fillMode: Image.PreserveAspectFit + opacity: 0.15 + source: "qrc:/gpt4all/icons/network.svg" + } + ListView { id: listView anchors.fill: parent @@ -599,6 +609,7 @@ Window { cursorVisible: currentResponse ? currentChat.responseInProgress : false cursorPosition: text.length background: Rectangle { + opacity: 0.3 color: name === qsTr("Response: ") ? (currentChat.isServer ? theme.backgroundDarkest : theme.backgroundLighter) : (currentChat.isServer ? theme.backgroundDark : theme.backgroundLight)