|
|
|
@ -687,14 +687,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|
|
|
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
|
|
|
|
|
|
|
|
|
n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size());
|
|
|
|
|
ctx.n_past = std::min(ctx.n_past, 1024);
|
|
|
|
|
// n_batch = embd_inp.size();
|
|
|
|
|
|
|
|
|
|
std::cout << "The past was: " << ctx.n_past;
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
|
|
|
|
|
std::vector<gpt_vocab::id> embd;
|
|
|
|
|
std::vector<gpt_vocab::id> resp;
|
|
|
|
|
ctx.n_past = std::min(ctx.n_past, d_ptr->model.hparams.n_ctx);
|
|
|
|
|
|
|
|
|
|
// determine the required inference memory per token:
|
|
|
|
|
static bool initialized = false;
|
|
|
|
@ -704,69 +697,50 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|
|
|
|
initialized = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
|
|
|
|
|
// predict
|
|
|
|
|
if (embd.size() > 0) {
|
|
|
|
|
const int64_t t_start_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, embd, ctx.logits, mem_per_token)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: Failed to predict\n";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t_predict_us += ggml_time_us() - t_start_us;
|
|
|
|
|
// process the prompt in batches
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
const int64_t t_start_prompt_us = ggml_time_us();
|
|
|
|
|
while (i < embd_inp.size()) {
|
|
|
|
|
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
|
|
|
|
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, batch, ctx.logits, mem_per_token)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
ctx.n_past += batch.size();
|
|
|
|
|
i = batch_end;
|
|
|
|
|
}
|
|
|
|
|
t_prompt_us += ggml_time_us() - t_start_prompt_us;
|
|
|
|
|
|
|
|
|
|
ctx.n_past += embd.size();
|
|
|
|
|
embd.clear();
|
|
|
|
|
resp.clear();
|
|
|
|
|
|
|
|
|
|
if (i >= embd_inp.size()) {
|
|
|
|
|
t_prompt_us += ggml_time_us() - t_main_start_us;
|
|
|
|
|
|
|
|
|
|
// sample next token
|
|
|
|
|
|
|
|
|
|
const int n_vocab = d_ptr->model.hparams.n_vocab;
|
|
|
|
|
|
|
|
|
|
gpt_vocab::id id = 0;
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
|
|
|
|
|
id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng);
|
|
|
|
|
// predict next tokens
|
|
|
|
|
int32_t totalPredictions = 0;
|
|
|
|
|
for (int i = 0; i < n_predict; i++) {
|
|
|
|
|
|
|
|
|
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// add it to the context
|
|
|
|
|
embd.push_back(id);
|
|
|
|
|
if (id != 50256)
|
|
|
|
|
resp.push_back(id);
|
|
|
|
|
} else {
|
|
|
|
|
// if here, it means we are still processing the input prompt
|
|
|
|
|
for (int k = i; k < embd_inp.size(); k++) {
|
|
|
|
|
embd.push_back(embd_inp[k]);
|
|
|
|
|
if (embd.size() > n_batch) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
i += embd.size() - 1;
|
|
|
|
|
// sample next token
|
|
|
|
|
const int n_vocab = d_ptr->model.hparams.n_vocab;
|
|
|
|
|
gpt_vocab::id id = 0;
|
|
|
|
|
{
|
|
|
|
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
|
id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab),
|
|
|
|
|
top_k, top_p, temp, d_ptr->rng);
|
|
|
|
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// display text
|
|
|
|
|
for (auto id : resp) {
|
|
|
|
|
if (!response(d_ptr->vocab.id_to_token[id]))
|
|
|
|
|
goto stop_generating;
|
|
|
|
|
const int64_t t_start_predict_us = ggml_time_us();
|
|
|
|
|
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, { id }, ctx.logits, mem_per_token)) {
|
|
|
|
|
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
t_predict_us += ggml_time_us() - t_start_predict_us;
|
|
|
|
|
ctx.n_past += 1;
|
|
|
|
|
|
|
|
|
|
// end of text token
|
|
|
|
|
if (embd.back() == 50256) {
|
|
|
|
|
// display text
|
|
|
|
|
++totalPredictions;
|
|
|
|
|
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id]))
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stop_generating:
|
|
|
|
|
#if 0
|
|
|
|
|
#if 1
|
|
|
|
|
// report timing
|
|
|
|
|
{
|
|
|
|
|
const int64_t t_main_end_us = ggml_time_us();
|
|
|
|
@ -774,7 +748,7 @@ stop_generating:
|
|
|
|
|
std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n";
|
|
|
|
|
std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n";
|
|
|
|
|
std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n";
|
|
|
|
|
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/n_past << " ms per token\n";
|
|
|
|
|
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n";
|
|
|
|
|
std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n";
|
|
|
|
|
fflush(stdout);
|
|
|
|
|
}
|
|
|
|
|