llmodel: change tokenToString to not use string_view (#968)

fixes a definite use-after-free and likely avoids some other
potential ones - std::string will convert to a std::string_view
automatically but as soon as the std::string in question goes out of
scope it is already freed and the string_view is pointing at freed
memory - this is *mostly* fine if its returning a reference to the
tokenizer's internal vocab table but it's, imo, too easy to return a
reference to a dynamically constructed string with this as replit is
doing (and unfortunately needs to do to convert the internal whitespace
replacement symbol back to a space)
pull/981/head
Aaron Miller 1 year ago committed by GitHub
parent 726dcbd43d
commit 88616fde7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -907,7 +907,7 @@ LLModel::Token GPTJ::sampleToken(PromptContext &promptCtx) const
d_ptr->rng);
}
std::string_view GPTJ::tokenToString(Token id) const
std::string GPTJ::tokenToString(Token id) const
{
return d_ptr->vocab.id_to_token[id];
}

@ -29,7 +29,7 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string_view tokenToString(Token) const override;
std::string tokenToString(Token) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override;

@ -177,7 +177,7 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres;
}
std::string_view LLamaModel::tokenToString(Token id) const
std::string LLamaModel::tokenToString(Token id) const
{
return llama_token_to_str(d_ptr->ctx, id);
}

@ -28,7 +28,7 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override;
std::string tokenToString(Token) const override;
Token sampleToken(PromptContext& ctx) const override;
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;

@ -86,7 +86,7 @@ protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
virtual std::string_view tokenToString(Token) const = 0;
virtual std::string tokenToString(Token) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
virtual int32_t contextLength() const = 0;

@ -121,7 +121,7 @@ void LLModel::prompt(const std::string &prompt,
if (id == token) return;
}
const std::string_view str = tokenToString(id);
const std::string str = tokenToString(id);
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;

@ -820,7 +820,7 @@ std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &st
return ::gpt_tokenize(d_ptr->vocab, str);
}
std::string_view MPT::tokenToString(Token id) const
std::string MPT::tokenToString(Token id) const
{
return d_ptr->vocab.id_to_token[id];
}

@ -28,7 +28,7 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override;
std::string tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;

@ -910,7 +910,7 @@ std::vector<LLModel::Token> Replit::tokenize(PromptContext &, const std::string
return replit_tokenizer_tokenize(d_ptr->vocab, str);
}
std::string_view Replit::tokenToString(LLModel::Token id) const
std::string Replit::tokenToString(LLModel::Token id) const
{
return replit_tokenizer_detokenize(d_ptr->vocab, {id});
}

@ -30,7 +30,7 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override;
std::string tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;

@ -39,7 +39,7 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
std::string_view tokenToString(Token) const override { return std::string_view(); }
std::string tokenToString(Token) const override { return std::string(); }
Token sampleToken(PromptContext &ctx) const override { return -1; }
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
int32_t contextLength() const override { return -1; }

Loading…
Cancel
Save