Fix up for newer models on reset context. This fixes the model from totally failing after a reset context.

This commit is contained in:
Adam Treat 2023-06-04 19:31:00 -04:00
parent bdba2e8de6
commit 301d2fdbea
9 changed files with 11 additions and 12 deletions

View File

@ -890,7 +890,7 @@ size_t GPTJ::restoreState(const uint8_t *src)
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
}
std::vector<LLModel::Token> GPTJ::tokenize(const std::string &str) const
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
{
return ::gpt_tokenize(d_ptr->vocab, str);
}

View File

@ -27,7 +27,7 @@ private:
GPTJPrivate *d_ptr;
protected:
std::vector<Token> tokenize(const std::string&) const override;
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string_view tokenToString(Token) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;

View File

@ -90,7 +90,6 @@ struct LLamaPrivate {
llama_context *ctx = nullptr;
llama_context_params params;
int64_t n_threads = 0;
bool empty = true;
};
LLamaModel::LLamaModel()
@ -163,10 +162,11 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
std::vector<LLModel::Token> LLamaModel::tokenize(const std::string &str) const
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
{
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos());
std::vector<LLModel::Token> fres(str.size()+4);
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty);
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), useBOS);
fres.resize(fres_len);
return fres;
}
@ -187,7 +187,6 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
{
d_ptr->empty = false;
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
}

View File

@ -27,7 +27,7 @@ private:
LLamaPrivate *d_ptr;
protected:
std::vector<Token> tokenize(const std::string&) const override;
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override;
Token sampleToken(PromptContext& ctx) const override;
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;

View File

@ -89,7 +89,7 @@ public:
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(const std::string&) const = 0;
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
virtual std::string_view tokenToString(Token) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;

View File

@ -38,7 +38,7 @@ void LLModel::prompt(const std::string &prompt,
}
// tokenize the prompt
std::vector<Token> embd_inp = tokenize(prompt);
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
// save the context size
promptCtx.n_ctx = contextLength();

View File

@ -815,7 +815,7 @@ size_t MPT::restoreState(const uint8_t *src)
return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src);
}
std::vector<LLModel::Token> MPT::tokenize(const std::string &str) const
std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &str) const
{
return ::gpt_tokenize(d_ptr->vocab, str);
}

View File

@ -27,7 +27,7 @@ private:
MPTPrivate *d_ptr;
protected:
std::vector<Token> tokenize(const std::string&) const override;
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;

View File

@ -38,7 +38,7 @@ protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(const std::string&) const override { return std::vector<Token>(); }
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
std::string_view tokenToString(Token) const override { return std::string_view(); }
Token sampleToken(PromptContext &ctx) const override { return -1; }
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }