From 70e6b451232c8b1a0b23430a111bc8041c19f217 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 26 Apr 2023 19:08:37 -0400 Subject: [PATCH] Don't crash when prompt is too large. --- llm.cpp | 7 +++++++ llmodel/gptj.cpp | 7 +++++++ llmodel/llamamodel.cpp | 4 +++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/llm.cpp b/llm.cpp index 228e29da..2e6c56b4 100644 --- a/llm.cpp +++ b/llm.cpp @@ -252,6 +252,13 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response) fflush(stdout); #endif + // check for error + if (token < 0) { + m_response.append(response); + emit responseChanged(); + return false; + } + // Save the token to our prompt ctxt if (s_ctx.tokens.size() == s_ctx.n_ctx) s_ctx.tokens.erase(s_ctx.tokens.begin()); diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index 6f353d75..eef6d03a 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -707,6 +707,13 @@ void GPTJ::prompt(const std::string &prompt, // save the context size promptCtx.n_ctx = d_ptr->model.hparams.n_ctx; + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); + std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + return; + } + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index 06e4aced..61318592 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -102,7 +102,9 @@ void LLamaModel::prompt(const std::string &prompt, promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - std::cerr << "LLAMA ERROR: prompt is too long\n"; + response(-1, "The prompt size exceeds the context window size and cannot be processed."); + std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() << + "tokens and the context window is" << promptCtx.n_ctx << "!\n"; return; }