|
|
@ -73,6 +73,7 @@ struct starcoder_model {
|
|
|
|
llm_buffer eval_buf;
|
|
|
|
llm_buffer eval_buf;
|
|
|
|
llm_buffer scr0_buf;
|
|
|
|
llm_buffer scr0_buf;
|
|
|
|
llm_buffer scr1_buf;
|
|
|
|
llm_buffer scr1_buf;
|
|
|
|
|
|
|
|
llm_buffer work_buf;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
static bool kv_cache_init(
|
|
|
|
static bool kv_cache_init(
|
|
|
@ -452,7 +453,7 @@ bool starcoder_model_load(const std::string & fname, starcoder_model & model, gp
|
|
|
|
// - embd_w: the predicted logits for the next token
|
|
|
|
// - embd_w: the predicted logits for the next token
|
|
|
|
//
|
|
|
|
//
|
|
|
|
bool starcoder_eval(
|
|
|
|
bool starcoder_eval(
|
|
|
|
const starcoder_model & model,
|
|
|
|
starcoder_model & model,
|
|
|
|
const int n_threads,
|
|
|
|
const int n_threads,
|
|
|
|
const int n_past,
|
|
|
|
const int n_past,
|
|
|
|
const std::vector<gpt_vocab::id> & embd_inp,
|
|
|
|
const std::vector<gpt_vocab::id> & embd_inp,
|
|
|
@ -477,7 +478,6 @@ bool starcoder_eval(
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_context * ctx0 = ggml_init(eval_ctx_params);
|
|
|
|
struct ggml_context * ctx0 = ggml_init(eval_ctx_params);
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
|
struct ggml_cgraph gf = {};
|
|
|
|
gf.n_threads = n_threads;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
|
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
|
@ -730,7 +730,7 @@ bool starcoder_eval(
|
|
|
|
|
|
|
|
|
|
|
|
// run the computation
|
|
|
|
// run the computation
|
|
|
|
ggml_build_forward_expand(&gf, inpL);
|
|
|
|
ggml_build_forward_expand(&gf, inpL);
|
|
|
|
ggml_graph_compute (ctx0, &gf);
|
|
|
|
ggml_graph_compute_g4a(model.work_buf, &gf, n_threads);
|
|
|
|
|
|
|
|
|
|
|
|
//if (n_past%100 == 0) {
|
|
|
|
//if (n_past%100 == 0) {
|
|
|
|
// ggml_graph_print (&gf);
|
|
|
|
// ggml_graph_print (&gf);
|
|
|
|