diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 8e3e5ea2..ce7a6f57 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -9,6 +9,8 @@ #include #include +#define LLMODEL_MAX_PROMPT_BATCH 128 + class Dlhandle; class LLModel { diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index dfc07b76..cd4ace04 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -52,6 +52,7 @@ void LLModel::prompt(const std::string &prompt, promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); + promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); // process the prompt in batches size_t i = 0;