From 98aedd21736b46169a3627d640d31743c1e4ae9b Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 8 May 2023 12:08:37 -0400 Subject: [PATCH] Match Helly's impl of kv cache. --- llmodel/mpt.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/llmodel/mpt.cpp b/llmodel/mpt.cpp index ffe3ebf0..a336b921 100644 --- a/llmodel/mpt.cpp +++ b/llmodel/mpt.cpp @@ -435,7 +435,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod return true; } - // load the model's weights from a file path bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) { @@ -523,10 +522,14 @@ bool mpt_eval( struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*ggml_element_size(cur)*n_embd)); struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*ggml_element_size(cur)*n_embd)); - // store key and value to memory - if (N >= 1) { + // TODO: qk_ln? (seems to be False in MPT-7B configs) + { + Vcur = ggml_transpose(ctx0, Vcur); + struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, model.kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(model.kv_self.v), + (il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));