mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-16 06:13:09 +00:00
Fix up for newer models on reset context. This fixes the model from totally failing after a reset context.
This commit is contained in:
parent
bdba2e8de6
commit
301d2fdbea
@ -890,7 +890,7 @@ size_t GPTJ::restoreState(const uint8_t *src)
|
||||
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(const std::string &str) const
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
|
||||
{
|
||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ private:
|
||||
GPTJPrivate *d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
std::string_view tokenToString(Token) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
|
@ -90,7 +90,6 @@ struct LLamaPrivate {
|
||||
llama_context *ctx = nullptr;
|
||||
llama_context_params params;
|
||||
int64_t n_threads = 0;
|
||||
bool empty = true;
|
||||
};
|
||||
|
||||
LLamaModel::LLamaModel()
|
||||
@ -163,10 +162,11 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(const std::string &str) const
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
|
||||
{
|
||||
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos());
|
||||
std::vector<LLModel::Token> fres(str.size()+4);
|
||||
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty);
|
||||
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), useBOS);
|
||||
fres.resize(fres_len);
|
||||
return fres;
|
||||
}
|
||||
@ -187,7 +187,6 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
||||
|
||||
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
||||
{
|
||||
d_ptr->empty = false;
|
||||
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ private:
|
||||
LLamaPrivate *d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::string_view tokenToString(Token) const override;
|
||||
Token sampleToken(PromptContext& ctx) const override;
|
||||
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
||||
|
@ -89,7 +89,7 @@ public:
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(const std::string&) const = 0;
|
||||
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
||||
virtual std::string_view tokenToString(Token) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
||||
|
@ -38,7 +38,7 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<Token> embd_inp = tokenize(prompt);
|
||||
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
|
||||
|
||||
// save the context size
|
||||
promptCtx.n_ctx = contextLength();
|
||||
|
@ -815,7 +815,7 @@ size_t MPT::restoreState(const uint8_t *src)
|
||||
return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> MPT::tokenize(const std::string &str) const
|
||||
std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &str) const
|
||||
{
|
||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ private:
|
||||
MPTPrivate *d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::string_view tokenToString(Token) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
|
@ -38,7 +38,7 @@ protected:
|
||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
std::vector<Token> tokenize(const std::string&) const override { return std::vector<Token>(); }
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
||||
std::string_view tokenToString(Token) const override { return std::string_view(); }
|
||||
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
||||
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
||||
|
Loading…
Reference in New Issue
Block a user