diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 16a8e88f..87769219 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -890,7 +890,7 @@ size_t GPTJ::restoreState(const uint8_t *src) return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src); } -std::vector GPTJ::tokenize(const std::string &str) const +std::vector GPTJ::tokenize(PromptContext &, const std::string &str) const { return ::gpt_tokenize(d_ptr->vocab, str); } diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index 270d65bb..3e82a79f 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -27,7 +27,7 @@ private: GPTJPrivate *d_ptr; protected: - std::vector tokenize(const std::string&) const override; + std::vector tokenize(PromptContext &, const std::string&) const override; Token sampleToken(PromptContext &ctx) const override; std::string_view tokenToString(Token) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index e9822d33..66aacac4 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -90,7 +90,6 @@ struct LLamaPrivate { llama_context *ctx = nullptr; llama_context_params params; int64_t n_threads = 0; - bool empty = true; }; LLamaModel::LLamaModel() @@ -163,10 +162,11 @@ size_t LLamaModel::restoreState(const uint8_t *src) return llama_set_state_data(d_ptr->ctx, const_cast(src)); } -std::vector LLamaModel::tokenize(const std::string &str) const +std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const { + const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos()); std::vector fres(str.size()+4); - auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty); + auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), useBOS); fres.resize(fres_len); return fres; } @@ -187,7 +187,6 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const { - d_ptr->empty = false; return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; } diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index a39f6ffb..c1cc1bd6 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -27,7 +27,7 @@ private: LLamaPrivate *d_ptr; protected: - std::vector tokenize(const std::string&) const override; + std::vector tokenize(PromptContext &, const std::string&) const override; std::string_view tokenToString(Token) const override; Token sampleToken(PromptContext& ctx) const override; bool evalTokens(PromptContext& ctx, const std::vector &tokens) const override; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 634626e2..406f78c6 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -89,7 +89,7 @@ public: protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions - virtual std::vector tokenize(const std::string&) const = 0; + virtual std::vector tokenize(PromptContext &, const std::string&) const = 0; virtual std::string_view tokenToString(Token) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0; virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const = 0; diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index c84adfe5..cb50c734 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -38,7 +38,7 @@ void LLModel::prompt(const std::string &prompt, } // tokenize the prompt - std::vector embd_inp = tokenize(prompt); + std::vector embd_inp = tokenize(promptCtx, prompt); // save the context size promptCtx.n_ctx = contextLength(); diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index 018189a5..e0afdbae 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -815,7 +815,7 @@ size_t MPT::restoreState(const uint8_t *src) return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src); } -std::vector MPT::tokenize(const std::string &str) const +std::vector MPT::tokenize(PromptContext &, const std::string &str) const { return ::gpt_tokenize(d_ptr->vocab, str); } diff --git a/gpt4all-backend/mpt_impl.h b/gpt4all-backend/mpt_impl.h index ec39c92c..ff03995c 100644 --- a/gpt4all-backend/mpt_impl.h +++ b/gpt4all-backend/mpt_impl.h @@ -27,7 +27,7 @@ private: MPTPrivate *d_ptr; protected: - std::vector tokenize(const std::string&) const override; + std::vector tokenize(PromptContext &, const std::string&) const override; std::string_view tokenToString(Token) const override; Token sampleToken(PromptContext &ctx) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 87abb01d..934bbbfb 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -38,7 +38,7 @@ protected: // We have to implement these as they are pure virtual in base class, but we don't actually use // them as they are only called from the default implementation of 'prompt' which we override and // completely replace - std::vector tokenize(const std::string&) const override { return std::vector(); } + std::vector tokenize(PromptContext &, const std::string&) const override { return std::vector(); } std::string_view tokenToString(Token) const override { return std::string_view(); } Token sampleToken(PromptContext &ctx) const override { return -1; } bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const override { return false; }