Explicitly clear the kv cache each time we eval tokens to match n_past. (#1808)

This commit is contained in:
AT 2024-01-03 13:06:08 -06:00 committed by GitHub
parent 2d566710e5
commit 96cee4f9ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -298,6 +298,8 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
batch.n_tokens = tokens.size();