mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-02 09:40:42 +00:00
Move the saving of the tokens to the impl and not the callbacks responsibility.
This commit is contained in:
parent
9a65f73392
commit
a3253c4ab1
9
llm.cpp
9
llm.cpp
@ -265,10 +265,6 @@ QList<QString> LLMObject::modelList() const
|
||||
|
||||
bool LLMObject::handlePrompt(int32_t token)
|
||||
{
|
||||
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||
s_ctx.tokens.push_back(token);
|
||||
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
@ -289,11 +285,6 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Save the token to our prompt ctxt
|
||||
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||
s_ctx.tokens.push_back(token);
|
||||
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
|
@ -753,9 +753,13 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
return;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
@ -806,7 +810,14 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
||||
|
||||
if (id == 50256 /*end of text*/)
|
||||
goto stop_generating;
|
||||
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(id);
|
||||
if (!responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
||||
goto stop_generating;
|
||||
}
|
||||
|
||||
|
@ -139,9 +139,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
return;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
@ -174,7 +178,13 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||
if (id == llama_token_eos())
|
||||
return;
|
||||
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(id);
|
||||
if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user