From e94177ee9a26b7c027d064711f1d0a0586e5c7a4 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 29 May 2024 10:51:00 -0400 Subject: [PATCH] llamamodel: fix embedding crash for >512 tokens after #2310 (#2383) Signed-off-by: Jared Van Bortel --- gpt4all-backend/llamamodel.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 35dd559d..e32aa582 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -386,7 +386,8 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model)); const int n_ctx_train = llama_n_ctx_train(d_ptr->model); if (isEmbedding) { - d_ptr->ctx_params.n_batch = n_ctx; + d_ptr->ctx_params.n_batch = n_ctx; + d_ptr->ctx_params.n_ubatch = n_ctx; } else { if (n_ctx > n_ctx_train) { std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("