gpt4all/gpt4all-backend/starcoder.cpp

1024 lines
34 KiB
C++
Raw Normal View History

2023-07-06 20:34:04 +00:00
#define STARCODER_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#include "starcoder_impl.h"
#include "llama.h"
#include "llama-util.h"
#include "utils.h"
#include "llmodel_shared.h"
#include <cassert>
#include <cinttypes>
#include <iostream>
#include <sstream>
namespace {
const char *modelType_ = "Starcoder";
}
#define STARCODER_MAGIC 0x67676d6c
// default hparams (GPT-2 117M)
// https://huggingface.co/bigcode/gpt_bigcode-santacoder/blob/main/config.json
struct starcoder_hparams {
int32_t n_vocab = 49280;
int32_t n_ctx = 2048;
int32_t n_embd = 2048;
int32_t n_head = 16;
int32_t n_layer = 24;
int32_t ftype = 1;
};
struct starcoder_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
struct starcoder_model {
starcoder_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<starcoder_layer> layers;
// key + value memory
llm_kv_cache kv_self;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
llm_buffer eval_buf;
llm_buffer scr0_buf;
llm_buffer scr1_buf;
2023-07-27 23:56:47 +00:00
llm_buffer work_buf;
2023-07-06 20:34:04 +00:00
};
static bool kv_cache_init(
const struct starcoder_hparams & hparams,
struct llm_kv_cache & cache,
ggml_type wtype,
int n_ctx) {
const int n_embd = hparams.n_embd;
#if 0 // FIXME: This code can be made in future to support much smaller KV cache
const int dim_head = n_embd / hparams.n_head;
const int dim_kv = dim_head;
#endif
const int n_layer = hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd * n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
params.mem_buffer = cache.buf.addr;
params.no_alloc = false;
cache.ctx = ggml_init(params);
if (!cache.ctx) {
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
return true;
}
// load the model's weights from a file
bool starcoder_model_load(const std::string & fname, starcoder_model & model, gpt_vocab & vocab, size_t *mem_req) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
if (mem_req) {
*mem_req = 0;
}
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
printf("%s: qntvr = %d\n", __func__, qntvr);
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
std::string word;
std::vector<char> buf(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
buf.resize(len);
fin.read((char *) buf.data(), len);
word.assign(buf.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
// if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
}
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int head_dim = n_embd / hparams.n_head;
const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
const int kv_dim = kv_heads * head_dim;
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*((n_embd + 2*kv_dim)*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w // TODO:
ctx_size += n_layer*( (n_embd + 2*kv_dim)*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
// ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
// ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += (6 + 12*n_layer)*512; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
if (mem_req) {
const int n_embd = model.hparams.n_embd;
const int dim_head = n_embd / model.hparams.n_head;
const int dim_kv = dim_head;
const int n_layer = model.hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
const int64_t n_elements = dim_kv * n_mem;
size_t kv_cache_size = 2u*n_elements*ggml_type_size(wtype) + 2_MiB;
*mem_req = ctx_size + kv_cache_size;
return false;
}
// create the ggml context
{
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int head_dim = n_embd / hparams.n_head;
const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
const int kv_dim = kv_heads * head_dim;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd + 2*kv_dim);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd + 2*kv_dim);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); //TODO: 4*n_embd = config.n_inner
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int64_t n_mem = n_layer*n_ctx;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx);
return false;
}
const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v);
printf("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file. got %d, expected %d\n",
__func__, name.data(), (int) ggml_nelements(tensor), nelements);
return false;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
model.eval_buf.resize(1280u * 1024 * 1024);
model.scr0_buf.resize(256u * 1024 * 1024);
model.scr1_buf.resize(256u * 1024 * 1024);
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
bool starcoder_eval(
2023-07-27 23:56:47 +00:00
starcoder_model & model,
2023-07-06 20:34:04 +00:00
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const size_t head_dim = n_embd / n_head;
struct ggml_init_params eval_ctx_params = {
.mem_size = model.eval_buf.size,
.mem_buffer = model.eval_buf.addr,
.no_alloc = false,
};
struct ggml_context * ctx0 = ggml_init(eval_ctx_params);
struct ggml_cgraph gf = {};
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3); //TODO: need to be tiled
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale_inplace(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
ggml_set_scratch(ctx0, {0, model.scr1_buf.size, model.scr1_buf.addr, });
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
// inpL = WTE * inpL
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
//inpL = ggml_soft_max_inplace(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
2023-07-27 23:56:47 +00:00
ggml_graph_compute_g4a(model.work_buf, &gf, n_threads);
2023-07-06 20:34:04 +00:00
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu MB\n", ggml_used_mem(ctx0)/(1024*1024));
ggml_free(ctx0);
return true;
}
#define MAX_RNG_STATE 64*1024
size_t starcoder_get_state_size(const starcoder_model &model) {
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = MAX_RNG_STATE;
const size_t s_kv_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int);
const size_t s_kv = model.kv_self.buf.size;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_kv_size
+ s_kv_ntok
+ s_kv
);
return s_total;
}
size_t starcoder_copy_state_data(const starcoder_model &model, const std::mt19937 &rng, uint8_t *dest)
{
uint8_t * out = dest;
// copy rng
{
std::stringstream rng_ss;
rng_ss << rng;
const size_t rng_size = rng_ss.str().size();
char rng_buf[MAX_RNG_STATE];
memset(&rng_buf[0], 0, MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
memcpy(out, &rng_buf[0], MAX_RNG_STATE); out += MAX_RNG_STATE;
}
// copy kv cache
{
const size_t kv_size = model.kv_self.buf.size;
const int kv_ntok = model.kv_self.n;
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
if (kv_size) {
memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size;
}
}
const size_t written = out - dest;
assert(written == starcoder_get_state_size(model));
fflush(stdout);
return written;
}
size_t starcoder_set_state_data(starcoder_model *model, std::mt19937 *rng, const uint8_t *src)
{
const uint8_t * in = src;
// set rng
{
size_t rng_size;
char rng_buf[MAX_RNG_STATE];
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
memcpy(&rng_buf[0], in, MAX_RNG_STATE); in += MAX_RNG_STATE;
std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
rng_ss >> *rng;
assert(rng_ss.fail() == false);
}
// set kv cache
{
size_t kv_size;
int kv_ntok;
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
if (kv_size) {
assert(model->kv_self.buf.size == kv_size);
void * k_data = model->kv_self.k->data; // remember data pointers
void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size;
model->kv_self.k->data = k_data; // restore correct data pointers
model->kv_self.v->data = v_data;
}
model->kv_self.n = kv_ntok;
}
const size_t nread = in - src;
assert(nread == starcoder_get_state_size(*model));
fflush(stdout);
return nread;
}
struct StarcoderPrivate {
const std::string modelPath;
bool modelLoaded;
gpt_vocab vocab;
starcoder_model *model = nullptr;
int64_t n_threads = 0;
size_t mem_per_token = 0;
std::mt19937 rng;
};
Starcoder::Starcoder() : d_ptr(new StarcoderPrivate) {
d_ptr->model = new starcoder_model;
d_ptr->model->ctx = nullptr;
d_ptr->modelLoaded = false;
}
Starcoder::~Starcoder() {
if(d_ptr->model->ctx) {
ggml_free(d_ptr->model->ctx);
d_ptr->model->ctx = nullptr;
}
delete d_ptr->model;
}
bool Starcoder::loadModel(const std::string &modelPath)
{
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;
// load the model
if (!starcoder_model_load(modelPath, *d_ptr->model, d_ptr->vocab, nullptr)) {
std::cerr << "STARCODER ERROR: failed to load model from " << modelPath;
return false;
}
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true;
fflush(stdout);
return true;
}
bool Starcoder::isModelLoaded() const
{
return d_ptr->modelLoaded;
}
size_t Starcoder::requiredMem(const std::string &modelPath)
{
starcoder_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
auto fin = std::ifstream(modelPath, std::ios::binary);
starcoder_model_load(modelPath, dummy_model, dummy_vocab, &mem_req);
return mem_req;
}
size_t Starcoder::stateSize() const
{
return starcoder_get_state_size(*d_ptr->model);
}
size_t Starcoder::saveState(uint8_t *dest) const
{
return starcoder_copy_state_data(*d_ptr->model, d_ptr->rng, dest);
}
size_t Starcoder::restoreState(const uint8_t *src)
{
return starcoder_set_state_data(d_ptr->model, &d_ptr->rng, src);
}
void Starcoder::setThreadCount(int32_t n_threads)
{
d_ptr->n_threads = n_threads;
}
int32_t Starcoder::threadCount() const
{
return d_ptr->n_threads;
}
std::vector<LLModel::Token> Starcoder::tokenize(PromptContext &, const std::string &str) const
{
return ::gpt_tokenize(d_ptr->vocab, str);
}
LLModel::Token Starcoder::sampleToken(PromptContext &promptCtx) const
{
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
return gpt_sample_top_k_top_p(d_ptr->model->hparams.n_vocab,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks,
promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty,
d_ptr->rng);
}
std::string Starcoder::tokenToString(Token id) const
{
return d_ptr->vocab.id_to_token[id];
}
bool Starcoder::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
{
// determine the required inference memory per token:
static bool initialized = false;
if (!initialized) {
starcoder_eval(*d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits,
d_ptr->mem_per_token);
initialized = true;
}
return starcoder_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token);
}
int32_t Starcoder::contextLength() const
{
return d_ptr->model->hparams.n_ctx;
}
const std::vector<LLModel::Token> &Starcoder::endTokens() const
{
static const std::vector<LLModel::Token> out = { 0 };
return out;
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT __attribute__ ((visibility ("default")))
#endif
extern "C" {
DLL_EXPORT bool is_g4a_backend_model_implementation() {
return true;
}
DLL_EXPORT const char *get_model_type() {
return modelType_;
}
DLL_EXPORT const char *get_build_variant() {
return GGML_BUILD_VARIANT;
}
DLL_EXPORT bool magic_match(std::istream& f) {
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != STARCODER_MAGIC) {
return false;
}
starcoder_hparams hparams;
f.read(reinterpret_cast<char*>(&hparams), sizeof(hparams));
if (!(hparams.n_vocab >= 49100 && hparams.n_vocab <= 49290)) {
return false;
}
return true;
}
DLL_EXPORT LLModel *construct() {
return new Starcoder;
}
}