mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-18 03:25:46 +00:00
4462d2d755
* chore: boilerplate, refactor in future * chore: boilerplate * feat: can compile succesfully * document .gyp file * add src, test and fix gyp * progress on prompting and some helper methods * add destructor and basic prompting work, prepare download function * download function done * download function edits and adding documentation * fix bindings memory issue and add tests and specs * add more documentation and readme * add npmignore * Update README.md Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com> * Update package.json - redundant scripts Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com> --------- Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>
228 lines
8.1 KiB
C++
228 lines
8.1 KiB
C++
#include <napi.h>
|
|
#include <iostream>
|
|
#include "llmodel_c.h"
|
|
#include "llmodel.h"
|
|
#include "gptj.h"
|
|
#include "llamamodel.h"
|
|
#include "mpt.h"
|
|
#include "stdcapture.h"
|
|
|
|
class NodeModelWrapper : public Napi::ObjectWrap<NodeModelWrapper> {
|
|
public:
|
|
static Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
Napi::Function func = DefineClass(env, "LLModel", {
|
|
InstanceMethod("type", &NodeModelWrapper::getType),
|
|
InstanceMethod("name", &NodeModelWrapper::getName),
|
|
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
|
|
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
|
|
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
|
|
InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
|
|
});
|
|
|
|
Napi::FunctionReference* constructor = new Napi::FunctionReference();
|
|
*constructor = Napi::Persistent(func);
|
|
env.SetInstanceData(constructor);
|
|
|
|
exports.Set("LLModel", func);
|
|
return exports;
|
|
}
|
|
|
|
Napi::Value getType(const Napi::CallbackInfo& info)
|
|
{
|
|
return Napi::String::New(info.Env(), type);
|
|
}
|
|
|
|
NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap<NodeModelWrapper>(info)
|
|
{
|
|
auto env = info.Env();
|
|
std::string weights_path = info[0].As<Napi::String>().Utf8Value();
|
|
|
|
const char *c_weights_path = weights_path.c_str();
|
|
|
|
inference_ = create_model_set_type(c_weights_path);
|
|
|
|
auto success = llmodel_loadModel(inference_, c_weights_path);
|
|
if(!success) {
|
|
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
|
return;
|
|
}
|
|
name = weights_path.substr(weights_path.find_last_of("/\\") + 1);
|
|
|
|
};
|
|
~NodeModelWrapper() {
|
|
// destroying the model manually causes exit code 3221226505, why?
|
|
// However, bindings seem to operate fine without destructing pointer
|
|
//llmodel_model_destroy(inference_);
|
|
}
|
|
|
|
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info) {
|
|
return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(inference_));
|
|
}
|
|
|
|
Napi::Value StateSize(const Napi::CallbackInfo& info) {
|
|
// Implement the binding for the stateSize method
|
|
return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(inference_)));
|
|
}
|
|
|
|
/**
|
|
* Generate a response using the model.
|
|
* @param model A pointer to the llmodel_model instance.
|
|
* @param prompt A string representing the input prompt.
|
|
* @param prompt_callback A callback function for handling the processing of prompt.
|
|
* @param response_callback A callback function for handling the generated response.
|
|
* @param recalculate_callback A callback function for handling recalculation requests.
|
|
* @param ctx A pointer to the llmodel_prompt_context structure.
|
|
*/
|
|
Napi::Value Prompt(const Napi::CallbackInfo& info) {
|
|
|
|
auto env = info.Env();
|
|
|
|
std::string question;
|
|
if(info[0].IsString()) {
|
|
question = info[0].As<Napi::String>().Utf8Value();
|
|
} else {
|
|
Napi::Error::New(env, "invalid string argument").ThrowAsJavaScriptException();
|
|
return env.Undefined();
|
|
}
|
|
//defaults copied from python bindings
|
|
llmodel_prompt_context promptContext = {
|
|
.logits = nullptr,
|
|
.tokens = nullptr,
|
|
.n_past = 0,
|
|
.n_ctx = 1024,
|
|
.n_predict = 128,
|
|
.top_k = 40,
|
|
.top_p = 0.9f,
|
|
.temp = 0.72f,
|
|
.n_batch = 8,
|
|
.repeat_penalty = 1.0f,
|
|
.repeat_last_n = 10,
|
|
.context_erase = 0.5
|
|
};
|
|
if(info[1].IsObject())
|
|
{
|
|
auto inputObject = info[1].As<Napi::Object>();
|
|
|
|
// Extract and assign the properties
|
|
if (inputObject.Has("logits") || inputObject.Has("tokens")) {
|
|
Napi::Error::New(env, "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException();
|
|
return env.Undefined();
|
|
}
|
|
// Assign the remaining properties
|
|
if(inputObject.Has("n_past")) {
|
|
promptContext.n_past = inputObject.Get("n_past").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("n_ctx")) {
|
|
promptContext.n_ctx = inputObject.Get("n_ctx").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("n_predict")) {
|
|
promptContext.n_predict = inputObject.Get("n_predict").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("top_k")) {
|
|
promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("top_p")) {
|
|
promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
|
|
}
|
|
if(inputObject.Has("temp")) {
|
|
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
|
|
}
|
|
if(inputObject.Has("n_batch")) {
|
|
promptContext.n_batch = inputObject.Get("n_batch").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("repeat_penalty")) {
|
|
promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As<Napi::Number>().FloatValue();
|
|
}
|
|
if(inputObject.Has("repeat_last_n")) {
|
|
promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As<Napi::Number>().Int32Value();
|
|
}
|
|
if(inputObject.Has("context_erase")) {
|
|
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
|
|
}
|
|
}
|
|
// custom callbacks are weird with the gpt4all c bindings: I need to turn Napi::Functions into raw c function pointers,
|
|
// but it doesn't seem like its possible? (TODO, is it possible?)
|
|
|
|
// if(info[1].IsFunction()) {
|
|
// Napi::Callback cb = *info[1].As<Napi::Function>();
|
|
// }
|
|
|
|
|
|
// For now, simple capture of stdout
|
|
// possible TODO: put this on a libuv async thread. (AsyncWorker)
|
|
CoutRedirect cr;
|
|
llmodel_prompt(inference_, question.c_str(), &prompt_callback, &response_callback, &recalculate_callback, &promptContext);
|
|
return Napi::String::New(env, cr.getString());
|
|
}
|
|
|
|
void SetThreadCount(const Napi::CallbackInfo& info) {
|
|
if(info[0].IsNumber()) {
|
|
llmodel_setThreadCount(inference_, info[0].As<Napi::Number>().Int64Value());
|
|
} else {
|
|
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException();
|
|
return;
|
|
}
|
|
}
|
|
Napi::Value getName(const Napi::CallbackInfo& info) {
|
|
return Napi::String::New(info.Env(), name);
|
|
}
|
|
Napi::Value ThreadCount(const Napi::CallbackInfo& info) {
|
|
return Napi::Number::New(info.Env(), llmodel_threadCount(inference_));
|
|
}
|
|
|
|
private:
|
|
llmodel_model inference_;
|
|
std::string type;
|
|
std::string name;
|
|
|
|
|
|
//wrapper cb to capture output into stdout.then, CoutRedirect captures this
|
|
// and writes it to a file
|
|
static bool response_callback(int32_t tid, const char* resp)
|
|
{
|
|
if(tid != -1) {
|
|
std::cout<<std::string(resp);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static bool prompt_callback(int32_t tid) { return true; }
|
|
static bool recalculate_callback(bool isrecalculating) { return isrecalculating; }
|
|
// Had to use this instead of the c library in order
|
|
// set the type of the model loaded.
|
|
// causes side effect: type is mutated;
|
|
llmodel_model create_model_set_type(const char* c_weights_path)
|
|
{
|
|
|
|
uint32_t magic;
|
|
llmodel_model model;
|
|
FILE *f = fopen(c_weights_path, "rb");
|
|
fread(&magic, sizeof(magic), 1, f);
|
|
|
|
if (magic == 0x67676d6c) {
|
|
model = llmodel_gptj_create();
|
|
type = "gptj";
|
|
}
|
|
else if (magic == 0x67676a74) {
|
|
model = llmodel_llama_create();
|
|
type = "llama";
|
|
}
|
|
else if (magic == 0x67676d6d) {
|
|
model = llmodel_mpt_create();
|
|
type = "mpt";
|
|
}
|
|
else {fprintf(stderr, "Invalid model file\n");}
|
|
fclose(f);
|
|
|
|
return model;
|
|
}
|
|
};
|
|
|
|
//Exports Bindings
|
|
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
return NodeModelWrapper::Init(env, exports);
|
|
}
|
|
|
|
NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init)
|