Move the saving of the tokens to the impl and not the callbacks responsibility.

This commit is contained in:
Adam Treat 2023-04-27 11:16:51 -04:00
parent 9a65f73392
commit a3253c4ab1
3 changed files with 25 additions and 13 deletions

View File

@ -265,10 +265,6 @@ QList<QString> LLMObject::modelList() const
bool LLMObject::handlePrompt(int32_t token)
{
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
@ -289,11 +285,6 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
return false;
}
// Save the token to our prompt ctxt
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;

View File

@ -753,9 +753,13 @@ void GPTJ::prompt(const std::string &prompt,
}
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
for (size_t t = 0; t < tokens; ++t) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(batch.at(t));
if (!promptCallback(batch.at(t)))
return;
}
promptCtx.n_past += batch.size();
i = batch_end;
}
@ -806,7 +810,14 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
if (id == 50256 /*end of text*/)
goto stop_generating;
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(id);
if (!responseCallback(id, d_ptr->vocab.id_to_token[id]))
goto stop_generating;
}

View File

@ -139,9 +139,13 @@ void LLamaModel::prompt(const std::string &prompt,
}
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
for (size_t t = 0; t < tokens; ++t) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(batch.at(t));
if (!promptCallback(batch.at(t)))
return;
}
promptCtx.n_past += batch.size();
i = batch_end;
}
@ -174,7 +178,13 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
if (id == llama_token_eos())
return;
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(id);
if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
return;
}
}