Fix up for newer models on reset context. This fixes the model from totally failing after a reset context.

pull/913/head
Adam Treat 1 year ago
parent 825fa64c17
commit b36ea3dde5

@ -890,7 +890,7 @@ size_t GPTJ::restoreState(const uint8_t *src)
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src); return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
} }
std::vector<LLModel::Token> GPTJ::tokenize(const std::string &str) const std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
{ {
return ::gpt_tokenize(d_ptr->vocab, str); return ::gpt_tokenize(d_ptr->vocab, str);
} }

@ -27,7 +27,7 @@ private:
GPTJPrivate *d_ptr; GPTJPrivate *d_ptr;
protected: protected:
std::vector<Token> tokenize(const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
Token sampleToken(PromptContext &ctx) const override; Token sampleToken(PromptContext &ctx) const override;
std::string_view tokenToString(Token) const override; std::string_view tokenToString(Token) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;

@ -90,7 +90,6 @@ struct LLamaPrivate {
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
llama_context_params params; llama_context_params params;
int64_t n_threads = 0; int64_t n_threads = 0;
bool empty = true;
}; };
LLamaModel::LLamaModel() LLamaModel::LLamaModel()
@ -163,10 +162,11 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src)); return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
} }
std::vector<LLModel::Token> LLamaModel::tokenize(const std::string &str) const std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
{ {
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos());
std::vector<LLModel::Token> fres(str.size()+4); std::vector<LLModel::Token> fres(str.size()+4);
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty); auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), useBOS);
fres.resize(fres_len); fres.resize(fres_len);
return fres; return fres;
} }
@ -187,7 +187,6 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
{ {
d_ptr->empty = false;
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
} }

@ -27,7 +27,7 @@ private:
LLamaPrivate *d_ptr; LLamaPrivate *d_ptr;
protected: protected:
std::vector<Token> tokenize(const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override; std::string_view tokenToString(Token) const override;
Token sampleToken(PromptContext& ctx) const override; Token sampleToken(PromptContext& ctx) const override;
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;

@ -89,7 +89,7 @@ public:
protected: protected:
// These are pure virtual because subclasses need to implement as the default implementation of // These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions // 'prompt' above calls these functions
virtual std::vector<Token> tokenize(const std::string&) const = 0; virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
virtual std::string_view tokenToString(Token) const = 0; virtual std::string_view tokenToString(Token) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0; virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;

@ -38,7 +38,7 @@ void LLModel::prompt(const std::string &prompt,
} }
// tokenize the prompt // tokenize the prompt
std::vector<Token> embd_inp = tokenize(prompt); std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
// save the context size // save the context size
promptCtx.n_ctx = contextLength(); promptCtx.n_ctx = contextLength();

@ -815,7 +815,7 @@ size_t MPT::restoreState(const uint8_t *src)
return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src); return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src);
} }
std::vector<LLModel::Token> MPT::tokenize(const std::string &str) const std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &str) const
{ {
return ::gpt_tokenize(d_ptr->vocab, str); return ::gpt_tokenize(d_ptr->vocab, str);
} }

@ -27,7 +27,7 @@ private:
MPTPrivate *d_ptr; MPTPrivate *d_ptr;
protected: protected:
std::vector<Token> tokenize(const std::string&) const override; std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string_view tokenToString(Token) const override; std::string_view tokenToString(Token) const override;
Token sampleToken(PromptContext &ctx) const override; Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override; bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;

@ -38,7 +38,7 @@ protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use // We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and // them as they are only called from the default implementation of 'prompt' which we override and
// completely replace // completely replace
std::vector<Token> tokenize(const std::string&) const override { return std::vector<Token>(); } std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
std::string_view tokenToString(Token) const override { return std::string_view(); } std::string_view tokenToString(Token) const override { return std::string_view(); }
Token sampleToken(PromptContext &ctx) const override { return -1; } Token sampleToken(PromptContext &ctx) const override { return -1; }
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; } bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }

Loading…
Cancel
Save