llmodel: change tokenToString to not use string_view (#968)

fixes a definite use-after-free and likely avoids some other
potential ones - std::string will convert to a std::string_view
automatically but as soon as the std::string in question goes out of
scope it is already freed and the string_view is pointing at freed
memory - this is *mostly* fine if its returning a reference to the
tokenizer's internal vocab table but it's, imo, too easy to return a
reference to a dynamically constructed string with this as replit is
doing (and unfortunately needs to do to convert the internal whitespace
replacement symbol back to a space)
pull/981/head
Aaron Miller 1 year ago committed by GitHub
parent 726dcbd43d
commit 88616fde7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -907,7 +907,7 @@ LLModel::Token GPTJ::sampleToken(PromptContext &promptCtx) const
d_ptr->rng); d_ptr->rng);
} }
std::string_view GPTJ::tokenToString(Token id) const std::string GPTJ::tokenToString(Token id) const
{ {
return d_ptr->vocab.id_to_token[id]; return d_ptr->vocab.id_to_token[id];
} }

@ -29,7 +29,7 @@ private:
protected: protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
Token sampleToken(PromptContext &ctx) const override; Token sampleToken(PromptContext &ctx) const override;
std::string_view tokenToString(Token) const override; std::string tokenToString(Token) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override; int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override; const std::vector<Token>& endTokens() const override;

@ -177,7 +177,7 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres; return fres;
} }
std::string_view LLamaModel::tokenToString(Token id) const std::string LLamaModel::tokenToString(Token id) const
{ {
return llama_token_to_str(d_ptr->ctx, id); return llama_token_to_str(d_ptr->ctx, id);
} }

@ -28,7 +28,7 @@ private:
protected: protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override; std::string tokenToString(Token) const override;
Token sampleToken(PromptContext& ctx) const override; Token sampleToken(PromptContext& ctx) const override;
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override; int32_t contextLength() const override;

@ -86,7 +86,7 @@ protected:
// These are pure virtual because subclasses need to implement as the default implementation of // These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions // 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0; virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
virtual std::string_view tokenToString(Token) const = 0; virtual std::string tokenToString(Token) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0; virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
virtual int32_t contextLength() const = 0; virtual int32_t contextLength() const = 0;

@ -121,7 +121,7 @@ void LLModel::prompt(const std::string &prompt,
if (id == token) return; if (id == token) return;
} }
const std::string_view str = tokenToString(id); const std::string str = tokenToString(id);
// Check if the provided str is part of our reverse prompts // Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false; bool foundPartialReversePrompt = false;

@ -820,7 +820,7 @@ std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &st
return ::gpt_tokenize(d_ptr->vocab, str); return ::gpt_tokenize(d_ptr->vocab, str);
} }
std::string_view MPT::tokenToString(Token id) const std::string MPT::tokenToString(Token id) const
{ {
return d_ptr->vocab.id_to_token[id]; return d_ptr->vocab.id_to_token[id];
} }

@ -28,7 +28,7 @@ private:
protected: protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override; std::string tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override; Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override; int32_t contextLength() const override;

@ -910,7 +910,7 @@ std::vector<LLModel::Token> Replit::tokenize(PromptContext &, const std::string
return replit_tokenizer_tokenize(d_ptr->vocab, str); return replit_tokenizer_tokenize(d_ptr->vocab, str);
} }
std::string_view Replit::tokenToString(LLModel::Token id) const std::string Replit::tokenToString(LLModel::Token id) const
{ {
return replit_tokenizer_detokenize(d_ptr->vocab, {id}); return replit_tokenizer_detokenize(d_ptr->vocab, {id});
} }

@ -30,7 +30,7 @@ private:
protected: protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override; std::string tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override; Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override; int32_t contextLength() const override;

@ -39,7 +39,7 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and // them as they are only called from the default implementation of 'prompt' which we override and
// completely replace // completely replace
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); } std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
std::string_view tokenToString(Token) const override { return std::string_view(); } std::string tokenToString(Token) const override { return std::string(); }
Token sampleToken(PromptContext &ctx) const override { return -1; } Token sampleToken(PromptContext &ctx) const override { return -1; }
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; } bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
int32_t contextLength() const override { return -1; } int32_t contextLength() const override { return -1; }

Loading…
Cancel
Save