|
|
|
@ -635,6 +635,7 @@ struct GPTJPrivate {
|
|
|
|
|
gpt_vocab vocab;
|
|
|
|
|
gptj_model model;
|
|
|
|
|
int64_t n_threads = 0;
|
|
|
|
|
size_t mem_per_token = 0;
|
|
|
|
|
std::mt19937 rng;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
|
|
|
|
|
|
|
|
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
|
|
|
|
d_ptr->modelLoaded = true;
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const
|
|
|
|
|
|
|
|
|
|
void GPTJ::prompt(const std::string &prompt,
|
|
|
|
|
std::function<bool(int32_t, const std::string&)> response,
|
|
|
|
|
std::function<bool(bool)> recalculate,
|
|
|
|
|
PromptContext &promptCtx) {
|
|
|
|
|
|
|
|
|
|
if (!isModelLoaded()) {
|
|
|
|
@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt,
|
|
|
|
|
static bool initialized = false;
|
|
|
|
|
static std::vector<gpt_vocab::id> p_instruct;
|
|
|
|
|
static std::vector<gpt_vocab::id> r_instruct;
|
|
|
|
|
size_t mem_per_token = 0;
|
|
|
|
|
if (!initialized) {
|
|
|
|
|
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token);
|
|
|
|
|
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits,
|
|
|
|
|
d_ptr->mem_per_token);
|
|
|
|
|
initialized = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt,
|
|
|
|
|
|
|
|
|
|
// Check if the context has run out...
|
|
|
|
|
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
|
|
|
|
// FIXME: will produce gibberish after this
|
|
|
|
|
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
|
|
|
|
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
|
|
|
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
|
|
|
|
// Erase the first percentage of context from the tokens...
|
|
|
|
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
|
|
|
|
promptCtx.n_past = promptCtx.tokens.size();
|
|
|
|
|
recalculateContext(promptCtx, recalculate);
|
|
|
|
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) {
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
|
|
|
|
d_ptr->mem_per_token)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt,
|
|
|
|
|
|
|
|
|
|
// Check if the context has run out...
|
|
|
|
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
|
|
|
|
// FIXME: will produce gibberish after this
|
|
|
|
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
|
|
|
|
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
|
|
|
|
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
|
|
|
|
// Erase the first percentage of context from the tokens...
|
|
|
|
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
|
|
|
|
promptCtx.n_past = promptCtx.tokens.size();
|
|
|
|
|
recalculateContext(promptCtx, recalculate);
|
|
|
|
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int64_t t_start_predict_us = ggml_time_us();
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) {
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
|
|
|
|
d_ptr->mem_per_token)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -807,3 +820,29 @@ stop_generating:
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
|
|
|
|
{
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
promptCtx.n_past = 0;
|
|
|
|
|
while (i < promptCtx.tokens.size()) {
|
|
|
|
|
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
|
|
|
|
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
|
|
|
|
|
|
|
|
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
|
|
|
|
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
|
|
|
|
d_ptr->mem_per_token)) {
|
|
|
|
|
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
|
|
|
|
goto stop_generating;
|
|
|
|
|
}
|
|
|
|
|
promptCtx.n_past += batch.size();
|
|
|
|
|
if (!recalculate(true))
|
|
|
|
|
goto stop_generating;
|
|
|
|
|
i = batch_end;
|
|
|
|
|
}
|
|
|
|
|
assert(promptCtx.n_past == promptCtx.tokens.size());
|
|
|
|
|
|
|
|
|
|
stop_generating:
|
|
|
|
|
recalculate(false);
|
|
|
|
|
}
|
|
|
|
|