From 88616fde7fb577beed028d35d21e3f37c477dcb0 Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Tue, 13 Jun 2023 04:14:02 -0700 Subject: [PATCH] llmodel: change tokenToString to not use string_view (#968) fixes a definite use-after-free and likely avoids some other potential ones - std::string will convert to a std::string_view automatically but as soon as the std::string in question goes out of scope it is already freed and the string_view is pointing at freed memory - this is *mostly* fine if its returning a reference to the tokenizer's internal vocab table but it's, imo, too easy to return a reference to a dynamically constructed string with this as replit is doing (and unfortunately needs to do to convert the internal whitespace replacement symbol back to a space) --- gpt4all-backend/gptj.cpp | 2 +- gpt4all-backend/gptj_impl.h | 2 +- gpt4all-backend/llamamodel.cpp | 2 +- gpt4all-backend/llamamodel_impl.h | 2 +- gpt4all-backend/llmodel.h | 2 +- gpt4all-backend/llmodel_shared.cpp | 2 +- gpt4all-backend/mpt.cpp | 2 +- gpt4all-backend/mpt_impl.h | 2 +- gpt4all-backend/replit.cpp | 2 +- gpt4all-backend/replit_impl.h | 2 +- gpt4all-chat/chatgpt.h | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 87769219..457bce7e 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -907,7 +907,7 @@ LLModel::Token GPTJ::sampleToken(PromptContext &promptCtx) const d_ptr->rng); } -std::string_view GPTJ::tokenToString(Token id) const +std::string GPTJ::tokenToString(Token id) const { return d_ptr->vocab.id_to_token[id]; } diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index 3e82a79f..4dda3ad5 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -29,7 +29,7 @@ private: protected: std::vector tokenize(PromptContext &, const std::string&) const override; Token sampleToken(PromptContext &ctx) const override; - std::string_view tokenToString(Token) const override; + std::string tokenToString(Token) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; int32_t contextLength() const override; const std::vector& endTokens() const override; diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index a4d6b90b..4cdfd359 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -177,7 +177,7 @@ std::vector LLamaModel::tokenize(PromptContext &ctx, const std:: return fres; } -std::string_view LLamaModel::tokenToString(Token id) const +std::string LLamaModel::tokenToString(Token id) const { return llama_token_to_str(d_ptr->ctx, id); } diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index c1cc1bd6..10404576 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -28,7 +28,7 @@ private: protected: std::vector tokenize(PromptContext &, const std::string&) const override; - std::string_view tokenToString(Token) const override; + std::string tokenToString(Token) const override; Token sampleToken(PromptContext& ctx) const override; bool evalTokens(PromptContext& ctx, const std::vector &tokens) const override; int32_t contextLength() const override; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 7d06a901..ecd7d05b 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -86,7 +86,7 @@ protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions virtual std::vector tokenize(PromptContext &, const std::string&) const = 0; - virtual std::string_view tokenToString(Token) const = 0; + virtual std::string tokenToString(Token) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0; virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const = 0; virtual int32_t contextLength() const = 0; diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index cb50c734..dfc07b76 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -121,7 +121,7 @@ void LLModel::prompt(const std::string &prompt, if (id == token) return; } - const std::string_view str = tokenToString(id); + const std::string str = tokenToString(id); // Check if the provided str is part of our reverse prompts bool foundPartialReversePrompt = false; diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index 87bc23ef..ec33c20c 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -820,7 +820,7 @@ std::vector MPT::tokenize(PromptContext &, const std::string &st return ::gpt_tokenize(d_ptr->vocab, str); } -std::string_view MPT::tokenToString(Token id) const +std::string MPT::tokenToString(Token id) const { return d_ptr->vocab.id_to_token[id]; } diff --git a/gpt4all-backend/mpt_impl.h b/gpt4all-backend/mpt_impl.h index ff03995c..ee0998c7 100644 --- a/gpt4all-backend/mpt_impl.h +++ b/gpt4all-backend/mpt_impl.h @@ -28,7 +28,7 @@ private: protected: std::vector tokenize(PromptContext &, const std::string&) const override; - std::string_view tokenToString(Token) const override; + std::string tokenToString(Token) const override; Token sampleToken(PromptContext &ctx) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; int32_t contextLength() const override; diff --git a/gpt4all-backend/replit.cpp b/gpt4all-backend/replit.cpp index 79c8cd2c..b978aff3 100644 --- a/gpt4all-backend/replit.cpp +++ b/gpt4all-backend/replit.cpp @@ -910,7 +910,7 @@ std::vector Replit::tokenize(PromptContext &, const std::string return replit_tokenizer_tokenize(d_ptr->vocab, str); } -std::string_view Replit::tokenToString(LLModel::Token id) const +std::string Replit::tokenToString(LLModel::Token id) const { return replit_tokenizer_detokenize(d_ptr->vocab, {id}); } diff --git a/gpt4all-backend/replit_impl.h b/gpt4all-backend/replit_impl.h index 28b74933..0ff22aa4 100644 --- a/gpt4all-backend/replit_impl.h +++ b/gpt4all-backend/replit_impl.h @@ -30,7 +30,7 @@ private: protected: std::vector tokenize(PromptContext &, const std::string&) const override; - std::string_view tokenToString(Token) const override; + std::string tokenToString(Token) const override; Token sampleToken(PromptContext &ctx) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; int32_t contextLength() const override; diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 934bbbfb..4c8a123d 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -39,7 +39,7 @@ protected: // them as they are only called from the default implementation of 'prompt' which we override and // completely replace std::vector tokenize(PromptContext &, const std::string&) const override { return std::vector(); } - std::string_view tokenToString(Token) const override { return std::string_view(); } + std::string tokenToString(Token) const override { return std::string(); } Token sampleToken(PromptContext &ctx) const override { return -1; } bool evalTokens(PromptContext &/*ctx*/, const std::vector& /*tokens*/) const override { return false; } int32_t contextLength() const override { return -1; }