mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-16 06:13:09 +00:00
Fix up for newer models on reset context. This fixes the model from totally failing after a reset context.
This commit is contained in:
parent
bdba2e8de6
commit
301d2fdbea
@ -890,7 +890,7 @@ size_t GPTJ::restoreState(const uint8_t *src)
|
|||||||
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> GPTJ::tokenize(const std::string &str) const
|
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
|
||||||
{
|
{
|
||||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ private:
|
|||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string_view tokenToString(Token) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
|
@ -90,7 +90,6 @@ struct LLamaPrivate {
|
|||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
llama_context_params params;
|
llama_context_params params;
|
||||||
int64_t n_threads = 0;
|
int64_t n_threads = 0;
|
||||||
bool empty = true;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
LLamaModel::LLamaModel()
|
LLamaModel::LLamaModel()
|
||||||
@ -163,10 +162,11 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
|||||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> LLamaModel::tokenize(const std::string &str) const
|
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
|
||||||
{
|
{
|
||||||
|
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos());
|
||||||
std::vector<LLModel::Token> fres(str.size()+4);
|
std::vector<LLModel::Token> fres(str.size()+4);
|
||||||
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), d_ptr->empty);
|
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), fres.data(), fres.size(), useBOS);
|
||||||
fres.resize(fres_len);
|
fres.resize(fres_len);
|
||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
@ -187,7 +187,6 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
|||||||
|
|
||||||
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
||||||
{
|
{
|
||||||
d_ptr->empty = false;
|
|
||||||
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
|
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ private:
|
|||||||
LLamaPrivate *d_ptr;
|
LLamaPrivate *d_ptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string_view tokenToString(Token) const override;
|
||||||
Token sampleToken(PromptContext& ctx) const override;
|
Token sampleToken(PromptContext& ctx) const override;
|
||||||
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
|
@ -89,7 +89,7 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||||
// 'prompt' above calls these functions
|
// 'prompt' above calls these functions
|
||||||
virtual std::vector<Token> tokenize(const std::string&) const = 0;
|
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
||||||
virtual std::string_view tokenToString(Token) const = 0;
|
virtual std::string_view tokenToString(Token) const = 0;
|
||||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||||
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
||||||
|
@ -38,7 +38,7 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<Token> embd_inp = tokenize(prompt);
|
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
|
||||||
|
|
||||||
// save the context size
|
// save the context size
|
||||||
promptCtx.n_ctx = contextLength();
|
promptCtx.n_ctx = contextLength();
|
||||||
|
@ -815,7 +815,7 @@ size_t MPT::restoreState(const uint8_t *src)
|
|||||||
return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
return mpt_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> MPT::tokenize(const std::string &str) const
|
std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &str) const
|
||||||
{
|
{
|
||||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ private:
|
|||||||
MPTPrivate *d_ptr;
|
MPTPrivate *d_ptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string_view tokenToString(Token) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
|
@ -38,7 +38,7 @@ protected:
|
|||||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||||
// completely replace
|
// completely replace
|
||||||
std::vector<Token> tokenize(const std::string&) const override { return std::vector<Token>(); }
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
||||||
std::string_view tokenToString(Token) const override { return std::string_view(); }
|
std::string_view tokenToString(Token) const override { return std::string_view(); }
|
||||||
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
||||||
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
||||||
|
Loading…
Reference in New Issue
Block a user