Remove duplicated state tracking for chatgpt.

This commit is contained in:
Adam Treat 2023-06-20 14:02:46 -04:00 committed by AT
parent 7d2ce06029
commit 84ec4311e9
2 changed files with 6 additions and 8 deletions

View File

@ -96,7 +96,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_timer(nullptr) , m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
, m_isChatGPT(false)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
@ -158,8 +157,8 @@ bool ChatLLM::loadModel(const QString &modelName)
if (isModelLoaded() && m_modelName == modelName) if (isModelLoaded() && m_modelName == modelName)
return true; return true;
m_isChatGPT = modelName.startsWith("chatgpt-"); bool isChatGPT = modelName.startsWith("chatgpt-");
QString filePath = modelFilePath(modelName, m_isChatGPT); QString filePath = modelFilePath(modelName, isChatGPT);
QFileInfo fileInfo(filePath); QFileInfo fileInfo(filePath);
// We have a live model, but it isn't the one we want // We have a live model, but it isn't the one we want
@ -218,7 +217,7 @@ bool ChatLLM::loadModel(const QString &modelName)
m_modelInfo.fileInfo = fileInfo; m_modelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) { if (fileInfo.exists()) {
if (m_isChatGPT) { if (isChatGPT) {
QString apiKey; QString apiKey;
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix
{ {
@ -308,7 +307,7 @@ void ChatLLM::regenerateResponse()
{ {
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens. // of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_isChatGPT) if (m_modelType == LLModelType::CHATGPT_)
m_ctx.n_past -= 1; m_ctx.n_past -= 1;
else else
m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past -= m_promptResponseTokens;
@ -672,7 +671,7 @@ void ChatLLM::saveState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
if (m_isChatGPT) { if (m_modelType == LLModelType::CHATGPT_) {
m_state.clear(); m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly); QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_5); stream.setVersion(QDataStream::Qt_6_5);
@ -694,7 +693,7 @@ void ChatLLM::restoreState()
if (!isModelLoaded() || m_state.isEmpty()) if (!isModelLoaded() || m_state.isEmpty())
return; return;
if (m_isChatGPT) { if (m_modelType == LLModelType::CHATGPT_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly); QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_5); stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model); ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);

View File

@ -158,7 +158,6 @@ private:
LLModelInfo m_modelInfo; LLModelInfo m_modelInfo;
LLModelType m_modelType; LLModelType m_modelType;
QString m_modelName; QString m_modelName;
bool m_isChatGPT;
// The following are only accessed by this thread // The following are only accessed by this thread
QString m_defaultModel; QString m_defaultModel;