mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Add llama.cpp support for loading llama based models in the gui. We now
support loading both gptj derived models and llama derived models.
This commit is contained in:
parent
00cb5fe2a5
commit
71b308e914
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +1,6 @@
|
||||
[submodule "ggml"]
|
||||
path = ggml
|
||||
url = https://github.com/manyoso/ggml.git
|
||||
[submodule "llama.cpp"]
|
||||
path = llama.cpp
|
||||
url = https://github.com/manyoso/llama.cpp.git
|
||||
|
@ -28,15 +28,19 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
find_package(Qt6 6.2 COMPONENTS Quick Svg REQUIRED)
|
||||
|
||||
set(GGML_BUILD_EXAMPLES ON CACHE BOOL "ggml: build examples" FORCE)
|
||||
add_subdirectory(ggml)
|
||||
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON FORCE)
|
||||
add_subdirectory(llama.cpp)
|
||||
|
||||
qt_add_executable(chat
|
||||
main.cpp
|
||||
download.h download.cpp
|
||||
gptj.h gptj.cpp
|
||||
llamamodel.h llamamodel.cpp
|
||||
llama.cpp/examples/common.cpp
|
||||
llm.h llm.cpp
|
||||
llmodel.h
|
||||
utils.h utils.cpp
|
||||
)
|
||||
|
||||
qt_add_qml_module(chat
|
||||
@ -72,7 +76,7 @@ target_compile_definitions(chat
|
||||
target_link_libraries(chat
|
||||
PRIVATE Qt6::Quick Qt6::Svg)
|
||||
target_link_libraries(chat
|
||||
PRIVATE ggml ggml_utils)
|
||||
PRIVATE llama)
|
||||
|
||||
set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
|
||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)
|
||||
|
8
gptj.cpp
8
gptj.cpp
@ -1,5 +1,5 @@
|
||||
#include "gptj.h"
|
||||
#include "ggml/ggml.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
@ -644,6 +644,12 @@ GPTJ::GPTJ()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
bool GPTJ::loadModel(const std::string &modelPath)
|
||||
{
|
||||
std::cerr << "GPTJ ERROR: loading gpt model from file unsupported!\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
|
1
gptj.h
1
gptj.h
@ -12,6 +12,7 @@ public:
|
||||
GPTJ();
|
||||
~GPTJ();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
|
1
llama.cpp
Submodule
1
llama.cpp
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit c8c2c524827be8fd681a63f0e5a697b0bf4c587b
|
160
llamamodel.cpp
Normal file
160
llamamodel.cpp
Normal file
@ -0,0 +1,160 @@
|
||||
#include "llamamodel.h"
|
||||
|
||||
#include "llama.cpp/examples/common.h"
|
||||
#include "llama.cpp/llama.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
llama_context *ctx = nullptr;
|
||||
llama_context_params params;
|
||||
int64_t n_threads = 0;
|
||||
};
|
||||
|
||||
LLamaModel::LLamaModel()
|
||||
: d_ptr(new LLamaPrivate) {
|
||||
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath, std::istream &fin)
|
||||
{
|
||||
std::cerr << "LLAMA ERROR: loading llama model from stream unsupported!\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
{
|
||||
// load the model
|
||||
d_ptr->params = llama_context_default_params();
|
||||
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
|
||||
if (!d_ptr->ctx) {
|
||||
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
void LLamaModel::setThreadCount(int32_t n_threads) {
|
||||
d_ptr->n_threads = n_threads;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::threadCount() {
|
||||
return d_ptr->n_threads;
|
||||
}
|
||||
|
||||
LLamaModel::~LLamaModel()
|
||||
{
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelLoaded() const
|
||||
{
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
||||
return;
|
||||
}
|
||||
|
||||
gpt_params params;
|
||||
params.prompt = prompt;
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||
const int n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
|
||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||
std::cerr << "LLAMA ERROR: prompt is too long\n";
|
||||
return;
|
||||
}
|
||||
|
||||
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
params.n_keep = (int)embd_inp.size();
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
||||
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
if (promptCtx.n_past + batch.size() > n_ctx) {
|
||||
std::cerr << "eval n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
||||
std::cerr << "after n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||
return;
|
||||
}
|
||||
// We pass a null string for each token to see if the user has asked us to stop...
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
if (!response(""))
|
||||
return;
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
|
||||
std::vector<llama_token> cachedTokens;
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
// sample next token
|
||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f);
|
||||
|
||||
if (promptCtx.n_past + 1 > n_ctx) {
|
||||
std::cerr << "eval 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
||||
std::cerr << "after 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
cachedTokens.emplace_back(id);
|
||||
|
||||
for (int j = 0; j < cachedTokens.size(); ++j) {
|
||||
llama_token cachedToken = cachedTokens.at(j);
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, cachedToken)))
|
||||
goto stop_generating;
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
|
||||
stop_generating:
|
||||
return;
|
||||
}
|
28
llamamodel.h
Normal file
28
llamamodel.h
Normal file
@ -0,0 +1,28 @@
|
||||
#ifndef LLAMAMODEL_H
|
||||
#define LLAMAMODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class LLamaPrivate;
|
||||
class LLamaModel : public LLModel {
|
||||
public:
|
||||
LLamaModel();
|
||||
~LLamaModel();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
private:
|
||||
LLamaPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // LLAMAMODEL_H
|
18
llm.cpp
18
llm.cpp
@ -47,20 +47,32 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
||||
return true;
|
||||
|
||||
if (isModelLoaded()) {
|
||||
resetContext();
|
||||
delete m_llmodel;
|
||||
m_llmodel = nullptr;
|
||||
emit isModelLoadedChanged();
|
||||
}
|
||||
|
||||
m_llmodel = new GPTJ;
|
||||
|
||||
bool isGPTJ = false;
|
||||
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() +
|
||||
"ggml-" + modelName + ".bin";
|
||||
QFileInfo info(filePath);
|
||||
if (info.exists()) {
|
||||
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
fin.seekg(0);
|
||||
isGPTJ = magic == 0x67676d6c;
|
||||
if (isGPTJ) {
|
||||
m_llmodel = new GPTJ;
|
||||
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||
} else {
|
||||
m_llmodel = new LLamaModel;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
}
|
||||
|
||||
emit isModelLoadedChanged();
|
||||
emit threadCountChanged();
|
||||
}
|
||||
|
1
llm.h
1
llm.h
@ -4,6 +4,7 @@
|
||||
#include <QObject>
|
||||
#include <QThread>
|
||||
#include "gptj.h"
|
||||
#include "llamamodel.h"
|
||||
|
||||
class LLMObject : public QObject
|
||||
{
|
||||
|
@ -10,6 +10,7 @@ public:
|
||||
explicit LLModel() {}
|
||||
virtual ~LLModel() {}
|
||||
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
struct PromptContext {
|
||||
@ -19,8 +20,8 @@ public:
|
||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads);
|
||||
virtual int32_t threadCount();
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual int32_t threadCount() { return 1; }
|
||||
};
|
||||
|
||||
#endif // LLMODEL_H
|
||||
|
8
main.qml
8
main.qml
@ -70,7 +70,9 @@ Window {
|
||||
}
|
||||
|
||||
onActivated: {
|
||||
LLM.stopGenerating()
|
||||
LLM.modelName = comboBox.currentText
|
||||
chatModel.clear()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -775,7 +777,7 @@ Window {
|
||||
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
|
||||
|
||||
delegate: TextArea {
|
||||
text: currentResponse ? LLM.response : value
|
||||
text: currentResponse ? LLM.response : (value ? value : "")
|
||||
width: listView.width
|
||||
color: "#d1d5db"
|
||||
wrapMode: Text.WordWrap
|
||||
@ -800,8 +802,8 @@ Window {
|
||||
anchors.leftMargin: 90
|
||||
anchors.top: parent.top
|
||||
anchors.topMargin: 5
|
||||
visible: currentResponse && LLM.response === "" && LLM.responseInProgress
|
||||
running: currentResponse && LLM.response === "" && LLM.responseInProgress
|
||||
visible: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
|
||||
running: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
|
||||
|
||||
Accessible.role: Accessible.Animation
|
||||
Accessible.name: qsTr("Busy indicator")
|
||||
|
257
utils.cpp
Normal file
257
utils.cpp
Normal file
@ -0,0 +1,257 @@
|
||||
#include "utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str.find(needle, pos)) != std::string::npos) {
|
||||
str.replace(pos, needle.length(), replacement);
|
||||
pos += replacement.length();
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||
std::map<std::string, int32_t> result;
|
||||
|
||||
// read file into string
|
||||
std::string json;
|
||||
{
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
}
|
||||
|
||||
if (json[0] != '{') {
|
||||
return result;
|
||||
}
|
||||
|
||||
// parse json
|
||||
{
|
||||
bool has_key = false;
|
||||
bool in_token = false;
|
||||
|
||||
std::string str_key = "";
|
||||
std::string str_val = "";
|
||||
|
||||
int n = json.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (!in_token) {
|
||||
if (json[i] == ' ') continue;
|
||||
if (json[i] == '"') {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (json[i] == '\\' && i+1 < n) {
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
++i;
|
||||
} else if (json[i] == '"') {
|
||||
if (has_key == false) {
|
||||
has_key = true;
|
||||
++i;
|
||||
while (json[i] == ' ') ++i;
|
||||
++i; // :
|
||||
while (json[i] == ' ') ++i;
|
||||
if (json[i] != '\"') {
|
||||
while (json[i] != ',' && json[i] != '}') {
|
||||
str_val += json[i++];
|
||||
}
|
||||
has_key = false;
|
||||
} else {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
has_key = false;
|
||||
}
|
||||
|
||||
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||
::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||
|
||||
try {
|
||||
result[str_key] = std::stoi(str_val);
|
||||
} catch (...) {
|
||||
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||
|
||||
}
|
||||
str_key = "";
|
||||
str_val = "";
|
||||
in_token = false;
|
||||
continue;
|
||||
}
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const double scale = 1.0/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
double maxl = -INFINITY;
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top K tokens
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
double p = exp(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
top_k = i + 1;
|
||||
probs.resize(top_k);
|
||||
logits_id.resize(top_k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cumsum = 1.0/cumsum;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
probs[i] *= cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
83
utils.h
Normal file
83
utils.h
Normal file
@ -0,0 +1,83 @@
|
||||
// Various helper functions and utilities
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 200; // new tokens to predict
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
|
||||
int32_t n_batch = 8; // batch size for prompt processing
|
||||
|
||||
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
|
||||
std::string prompt;
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||
|
||||
// poor-man's JSON parsing
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||
|
||||
// load the tokens from encoder.json
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
|
||||
// sample next token given probabilities for each embedding
|
||||
//
|
||||
// - consider only the top K tokens
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
// TODO: not sure if this implementation is correct
|
||||
// TODO: temperature is not implemented
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng);
|
Loading…
Reference in New Issue
Block a user