nodejs bindings (#602)

* chore: boilerplate, refactor in future

* chore: boilerplate

* feat: can compile succesfully

* document .gyp file

* add src, test and fix gyp

* progress on prompting and some helper methods

* add destructor and basic prompting work, prepare download function

* download function done

* download function edits and adding documentation

* fix bindings memory issue and add tests and specs

* add more documentation and readme

* add npmignore

* Update README.md

Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>

* Update package.json - redundant scripts

Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>

---------

Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>
This commit is contained in:
Jacob Nguyen 2023-05-22 14:55:22 -05:00 committed by GitHub
parent 22fdccbdc0
commit 394d9b4076
14 changed files with 757 additions and 1 deletions

View File

@ -0,0 +1,2 @@
node_modules/
build/

View File

@ -0,0 +1,3 @@
test/
spec/

View File

@ -1,4 +1,56 @@
### Javascript Bindings
The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date.
The GPT4All community is looking for help in implemented a new set of bindings based on the C backend found [here](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-backend).
- created by [jacoobes](https://github.com/jacoobes) and [nomic ai](https://home.nomic.ai) :D, for all to use.
- will maintain this repository when possible, new feature requests will be handled through nomic
### Build Instructions
- As of 05/21/2023, Tested on windows (MSVC) only. (somehow got it to work on MSVC 🤯)
- binding.gyp is compile config
### Requirements
- git
- [node.js >= 18.0.0](https://nodejs.org/en)
- [yarn](https://yarnpkg.com/)
- [node-gyp](https://github.com/nodejs/node-gyp)
- all of its requirements.
### Build
```sh
git clone https://github.com/nomic-ai/gpt4all.git
cd gpt4all-bindings/typescript
```
- The below shell commands assume the current working directory is `typescript`.
- To Build and Rebuild:
```sh
yarn
```
- llama.cpp git submodule for gpt4all can be possibly absent. If this is the case, make sure to run in llama.cpp parent directory
```sh
git submodule update --init --depth 1 --recursive`
```
### Test
```sh
yarn test
```
### Source Overview
#### src/
- Extra functions to help aid devex
- Typings for the native node addon
- the javascript interface
#### test/
- simple unit testings for some functions exported.
- more advanced ai testing is not handled
#### spec/
- Average look and feel of the api
- Should work assuming a model is installed locally in working directory
#### index.cc
- The bridge between nodejs and c. Where the bindings are.

View File

@ -0,0 +1,46 @@
{
"targets": [
{
"target_name": "gpt4allts", # gpt4all-ts will cause compile error
"cflags!": [ "-fno-exceptions" ],
"cflags_cc!": [ "-fno-exceptions" ],
"include_dirs": [
"<!@(node -p \"require('node-addon-api').include\")",
"../../gpt4all-backend/llama.cpp/", # need to include llama.cpp because the include paths for examples/common.h include llama.h relatively
"../../gpt4all-backend",
],
"sources": [ # is there a better way to do this
"../../gpt4all-backend/llama.cpp/examples/common.cpp",
"../../gpt4all-backend/llama.cpp/ggml.c",
"../../gpt4all-backend/llama.cpp/llama.cpp",
"../../gpt4all-backend/utils.cpp",
"../../gpt4all-backend/llmodel_c.cpp",
"../../gpt4all-backend/gptj.cpp",
"../../gpt4all-backend/llamamodel.cpp",
"../../gpt4all-backend/mpt.cpp",
"stdcapture.cc",
"index.cc",
],
"conditions": [
['OS=="mac"', {
'defines': [
'NAPI_CPP_EXCEPTIONS'
],
}],
['OS=="win"', {
'defines': [
'NAPI_CPP_EXCEPTIONS',
"__AVX2__" # allows SIMD: https://discord.com/channels/1076964370942267462/1092290790388150272/1107564673957630023
],
"msvs_settings": {
"VCCLCompilerTool": {
"AdditionalOptions": [
"/std:c++20",
"/EHsc"
],
},
},
}]
]
}]
}

View File

@ -0,0 +1,227 @@
#include <napi.h>
#include <iostream>
#include "llmodel_c.h"
#include "llmodel.h"
#include "gptj.h"
#include "llamamodel.h"
#include "mpt.h"
#include "stdcapture.h"
class NodeModelWrapper : public Napi::ObjectWrap<NodeModelWrapper> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports) {
Napi::Function func = DefineClass(env, "LLModel", {
InstanceMethod("type", &NodeModelWrapper::getType),
InstanceMethod("name", &NodeModelWrapper::getName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
});
Napi::FunctionReference* constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);
env.SetInstanceData(constructor);
exports.Set("LLModel", func);
return exports;
}
Napi::Value getType(const Napi::CallbackInfo& info)
{
return Napi::String::New(info.Env(), type);
}
NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap<NodeModelWrapper>(info)
{
auto env = info.Env();
std::string weights_path = info[0].As<Napi::String>().Utf8Value();
const char *c_weights_path = weights_path.c_str();
inference_ = create_model_set_type(c_weights_path);
auto success = llmodel_loadModel(inference_, c_weights_path);
if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return;
}
name = weights_path.substr(weights_path.find_last_of("/\\") + 1);
};
~NodeModelWrapper() {
// destroying the model manually causes exit code 3221226505, why?
// However, bindings seem to operate fine without destructing pointer
//llmodel_model_destroy(inference_);
}
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info) {
return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(inference_));
}
Napi::Value StateSize(const Napi::CallbackInfo& info) {
// Implement the binding for the stateSize method
return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(inference_)));
}
/**
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/
Napi::Value Prompt(const Napi::CallbackInfo& info) {
auto env = info.Env();
std::string question;
if(info[0].IsString()) {
question = info[0].As<Napi::String>().Utf8Value();
} else {
Napi::Error::New(env, "invalid string argument").ThrowAsJavaScriptException();
return env.Undefined();
}
//defaults copied from python bindings
llmodel_prompt_context promptContext = {
.logits = nullptr,
.tokens = nullptr,
.n_past = 0,
.n_ctx = 1024,
.n_predict = 128,
.top_k = 40,
.top_p = 0.9f,
.temp = 0.72f,
.n_batch = 8,
.repeat_penalty = 1.0f,
.repeat_last_n = 10,
.context_erase = 0.5
};
if(info[1].IsObject())
{
auto inputObject = info[1].As<Napi::Object>();
// Extract and assign the properties
if (inputObject.Has("logits") || inputObject.Has("tokens")) {
Napi::Error::New(env, "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException();
return env.Undefined();
}
// Assign the remaining properties
if(inputObject.Has("n_past")) {
promptContext.n_past = inputObject.Get("n_past").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("n_ctx")) {
promptContext.n_ctx = inputObject.Get("n_ctx").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("n_predict")) {
promptContext.n_predict = inputObject.Get("n_predict").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("top_k")) {
promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("top_p")) {
promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
}
if(inputObject.Has("temp")) {
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
}
if(inputObject.Has("n_batch")) {
promptContext.n_batch = inputObject.Get("n_batch").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("repeat_penalty")) {
promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As<Napi::Number>().FloatValue();
}
if(inputObject.Has("repeat_last_n")) {
promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As<Napi::Number>().Int32Value();
}
if(inputObject.Has("context_erase")) {
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
}
}
// custom callbacks are weird with the gpt4all c bindings: I need to turn Napi::Functions into raw c function pointers,
// but it doesn't seem like its possible? (TODO, is it possible?)
// if(info[1].IsFunction()) {
// Napi::Callback cb = *info[1].As<Napi::Function>();
// }
// For now, simple capture of stdout
// possible TODO: put this on a libuv async thread. (AsyncWorker)
CoutRedirect cr;
llmodel_prompt(inference_, question.c_str(), &prompt_callback, &response_callback, &recalculate_callback, &promptContext);
return Napi::String::New(env, cr.getString());
}
void SetThreadCount(const Napi::CallbackInfo& info) {
if(info[0].IsNumber()) {
llmodel_setThreadCount(inference_, info[0].As<Napi::Number>().Int64Value());
} else {
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException();
return;
}
}
Napi::Value getName(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(), name);
}
Napi::Value ThreadCount(const Napi::CallbackInfo& info) {
return Napi::Number::New(info.Env(), llmodel_threadCount(inference_));
}
private:
llmodel_model inference_;
std::string type;
std::string name;
//wrapper cb to capture output into stdout.then, CoutRedirect captures this
// and writes it to a file
static bool response_callback(int32_t tid, const char* resp)
{
if(tid != -1) {
std::cout<<std::string(resp);
return true;
}
return false;
}
static bool prompt_callback(int32_t tid) { return true; }
static bool recalculate_callback(bool isrecalculating) { return isrecalculating; }
// Had to use this instead of the c library in order
// set the type of the model loaded.
// causes side effect: type is mutated;
llmodel_model create_model_set_type(const char* c_weights_path)
{
uint32_t magic;
llmodel_model model;
FILE *f = fopen(c_weights_path, "rb");
fread(&magic, sizeof(magic), 1, f);
if (magic == 0x67676d6c) {
model = llmodel_gptj_create();
type = "gptj";
}
else if (magic == 0x67676a74) {
model = llmodel_llama_create();
type = "llama";
}
else if (magic == 0x67676d6d) {
model = llmodel_mpt_create();
type = "mpt";
}
else {fprintf(stderr, "Invalid model file\n");}
fclose(f);
return model;
}
};
//Exports Bindings
Napi::Object Init(Napi::Env env, Napi::Object exports) {
return NodeModelWrapper::Init(env, exports);
}
NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init)

View File

@ -0,0 +1,19 @@
{
"name": "gpt4all-ts",
"packageManager": "yarn@3.5.1",
"gypfile": true,
"scripts": {
"test": "node ./test/index.mjs"
},
"dependencies": {
"bindings": "^1.5.0",
"node-addon-api": "^6.1.0"
},
"devDependencies": {
"@types/node": "^20.1.5"
},
"engines": {
"node": ">= 18.x.x"
}
}

View File

@ -0,0 +1,32 @@
import { LLModel, prompt, createCompletion } from '../src/gpt4all.js'
const ll = new LLModel("./ggml-vicuna-7b-1.1-q4_2.bin");
try {
class Extended extends LLModel {
}
} catch(e) {
console.log("Extending from native class gone wrong " + e)
}
console.log("state size " + ll.stateSize())
console.log("thread count " + ll.threadCount());
ll.setThreadCount(5);
console.log("thread count " + ll.threadCount());
ll.setThreadCount(4);
console.log("thread count " + ll.threadCount());
console.log(createCompletion(
ll,
prompt`${"header"} ${"prompt"}`, {
verbose: true,
prompt: 'hello! Say something thought provoking.'
}
));

View File

@ -0,0 +1,162 @@
/// <reference types="node" />
declare module 'gpt4all-ts';
interface LLModelPromptContext {
// Size of the raw logits vector
logits_size: number;
// Size of the raw tokens vector
tokens_size: number;
// Number of tokens in past conversation
n_past: number;
// Number of tokens possible in context window
n_ctx: number;
// Number of tokens to predict
n_predict: number;
// Top k logits to sample from
top_k: number;
// Nucleus sampling probability threshold
top_p: number;
// Temperature to adjust model's output distribution
temp: number;
// Number of predictions to generate in parallel
n_batch: number;
// Penalty factor for repeated tokens
repeat_penalty: number;
// Last n tokens to penalize
repeat_last_n: number;
// Percent of context to erase if we exceed the context window
context_erase: number;
}
/**
* LLModel class representing a language model.
* This is a base class that provides common functionality for different types of language models.
*/
declare class LLModel {
//either 'gpt', mpt', or 'llama'
type() : ModelType;
//The name of the model
name(): ModelFile;
constructor(path: string);
/**
* Get the size of the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @return the size in bytes of the internal state of the model
*/
stateSize(): number;
/**
* Get the number of threads used for model inference.
* The default is the number of physical cores your computer has.
* @returns The number of threads used for model inference.
*/
threadCount() : number;
/**
* Set the number of threads used for model inference.
* @param newNumber The new number of threads.
*/
setThreadCount(newNumber: number): void;
/**
* Prompt the model with a given input and optional parameters.
* This is the raw output from std out.
* Use the prompt function exported for a value
* @param q The prompt input.
* @param params Optional parameters for the prompt context.
* @returns The result of the model prompt.
*/
raw_prompt(q: string, params?: Partial<LLModelPromptContext>) : unknown; //todo work on return type
}
interface DownloadController {
//Cancel the request to download from gpt4all website if this is called.
cancel: () => void;
//Convert the downloader into a promise, allowing people to await and manage its lifetime
promise: () => Promise<void>
}
export interface DownloadConfig {
/**
* location to download the model.
* Default is process.cwd(), or the current working directory
*/
location: string;
/**
* Debug mode -- check how long it took to download in seconds
*/
debug: boolean;
/**
* Default link = https://gpt4all.io/models`
* This property overrides the default.
*/
link?: string
}
/**
* Initiates the download of a model file of a specific model type.
* By default this downloads without waiting. use the controller returned to alter this behavior.
* @param {ModelFile[ModelType]} m - The model file to be downloaded.
* @param {Record<string, unknown>} op - options to pass into the downloader. Default is { location: (cwd), debug: false }.
* @returns {DownloadController} A DownloadController object that allows controlling the download process.
*/
declare function download(m: ModelFile[ModelType], op: { location: string, debug: boolean, link?:string }): DownloadController
type ModelType = 'gptj' | 'llama' | 'mpt';
/*
* A nice interface for intellisense of all possibly models.
*/
interface ModelFile {
'gptj': | "ggml-gpt4all-j-v1.3-groovy.bin"
| "ggml-gpt4all-j-v1.2-jazzy.bin"
| "ggml-gpt4all-j-v1.1-breezy.bin"
| "ggml-gpt4all-j.bin";
'llama':| "ggml-gpt4all-l13b-snoozy.bin"
| "ggml-vicuna-7b-1.1-q4_2.bin"
| "ggml-vicuna-13b-1.1-q4_2.bin"
| "ggml-wizardLM-7B.q4_2.bin"
| "ggml-stable-vicuna-13B.q4_2.bin"
| "ggml-nous-gpt4-vicuna-13b.bin"
'mpt': | "ggml-mpt-7b-base.bin"
| "ggml-mpt-7b-chat.bin"
| "ggml-mpt-7b-instruct.bin"
}
interface ExtendedOptions {
verbose?: boolean;
system?: string;
header?: string;
prompt: string;
promptEntries?: Record<string, unknown>
}
type PromptTemplate = (...args: string[]) => string;
declare function createCompletion(
model: LLModel,
pt: PromptTemplate,
options: LLModelPromptContext&ExtendedOptions
) : string
function prompt(
strings: TemplateStringsArray
): PromptTemplate
export { LLModel, LLModelPromptContext, ModelType, download, DownloadController, prompt, ExtendedOptions, createCompletion }

View File

@ -0,0 +1,112 @@
/// This file implements the gpt4all.d.ts file endings.
/// Written in commonjs to support both ESM and CJS projects.
const { LLModel } = require('bindings')('../build/Release/gpt4allts');
const { createWriteStream, existsSync } = require('fs');
const { join } = require('path');
const { performance } = require('node:perf_hooks');
// readChunks() reads from the provided reader and yields the results into an async iterable
// https://css-tricks.com/web-streams-everywhere-and-fetch-for-node-js/
function readChunks(reader) {
return {
async* [Symbol.asyncIterator]() {
let readResult = await reader.read();
while (!readResult.done) {
yield readResult.value;
readResult = await reader.read();
}
},
};
}
exports.LLModel = LLModel;
exports.download = function (
name,
options = { debug: false, location: process.cwd(), link: undefined }
) {
const abortController = new AbortController();
const signal = abortController.signal;
const pathToModel = join(options.location, name);
if(existsSync(pathToModel)) {
throw Error("Path to model already exists");
}
//wrapper function to get the readable stream from request
const fetcher = (name) => fetch(options.link ?? `https://gpt4all.io/models/${name}`, {
signal,
})
.then(res => {
if(!res.ok) {
throw Error("Could not find "+ name + " from " + `https://gpt4all.io/models/` )
}
return res.body.getReader()
})
//a promise that executes and writes to a stream. Resolves when done writing.
const res = new Promise((resolve, reject) => {
fetcher(name)
//Resolves an array of a reader and writestream.
.then(reader => [reader, createWriteStream(pathToModel)])
.then(
async ([readable, wstream]) => {
console.log('(CLI might hang) Downloading @ ', pathToModel);
let perf;
if(options.debug) {
perf = performance.now();
}
for await (const chunk of readChunks(readable)) {
wstream.write(chunk);
}
if(options.debug) {
console.log("Time taken: ", (performance.now()-perf).toFixed(2), " ms");
}
resolve();
}
).catch(reject);
});
return {
cancel : () => abortController.abort(),
promise: () => res
}
}
//https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Template_literals#tagged_templates
exports.prompt = function prompt(strings, ...keys) {
return (...values) => {
const dict = values[values.length - 1] || {};
const result = [strings[0]];
keys.forEach((key, i) => {
const value = Number.isInteger(key) ? values[key] : dict[key];
result.push(value, strings[i + 1]);
});
return result.join("");
};
}
exports.createCompletion = function (llmodel, promptMaker, options) {
//creating the keys to insert into promptMaker.
const entries = {
system: options.system ?? '',
header: options.header ?? "### Instruction: The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n### Prompt: ",
prompt: options.prompt,
...(options.promptEntries ?? {})
};
const fullPrompt = promptMaker(entries)+'\n### Response:';
if(options.verbose) {
console.log("sending prompt: " + `"${fullPrompt}"`)
}
return llmodel.raw_prompt(fullPrompt, options);
}

View File

@ -0,0 +1,14 @@
#include "stdcapture.h"
CoutRedirect::CoutRedirect() {
old = std::cout.rdbuf(buffer.rdbuf()); // redirect cout to buffer stream
}
std::string CoutRedirect::getString() {
return buffer.str(); // get string
}
CoutRedirect::~CoutRedirect() {
std::cout.rdbuf(old); // reverse redirect
}

View File

@ -0,0 +1,21 @@
//https://stackoverflow.com/questions/5419356/redirect-stdout-stderr-to-a-string
#ifndef COUTREDIRECT_H
#define COUTREDIRECT_H
#include <iostream>
#include <streambuf>
#include <string>
#include <sstream>
class CoutRedirect {
public:
CoutRedirect();
std::string getString();
~CoutRedirect();
private:
std::stringstream buffer;
std::streambuf* old;
};
#endif // COUTREDIRECT_H

View File

@ -0,0 +1,41 @@
import * as assert from 'node:assert'
import { prompt, download } from '../src/gpt4all.js'
{
const somePrompt = prompt`${"header"} Hello joe, my name is Ron. ${"prompt"}`;
assert.equal(
somePrompt({ header: 'oompa', prompt: 'holy moly' }),
'oompa Hello joe, my name is Ron. holy moly'
);
}
{
const indexedPrompt = prompt`${0}, ${1} ${0}`;
assert.equal(
indexedPrompt('hello', 'world'),
'hello, world hello'
);
assert.notEqual(
indexedPrompt(['hello', 'world']),
'hello, world hello'
);
}
{
assert.equal(
(prompt`${"header"} ${"prompt"}`)({ header: 'hello', prompt: 'poo' }), 'hello poo',
"Template prompt not equal"
);
}
assert.rejects(async () => download('poo.bin').promise());
console.log('OK')

View File

@ -0,0 +1,25 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1
"@types/node@^20.1.5":
version "20.1.5"
resolved "https://registry.yarnpkg.com/@types/node/-/node-20.1.5.tgz#e94b604c67fc408f215fcbf3bd84d4743bf7f710"
integrity sha512-IvGD1CD/nego63ySR7vrAKEX3AJTcmrAN2kn+/sDNLi1Ff5kBzDeEdqWDplK+0HAEoLYej137Sk0cUU8OLOlMg==
bindings@^1.5.0:
version "1.5.0"
resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.5.0.tgz#10353c9e945334bc0511a6d90b38fbc7c9c504df"
integrity sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==
dependencies:
file-uri-to-path "1.0.0"
file-uri-to-path@1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz#553a7b8446ff6f684359c445f1e37a05dacc33dd"
integrity sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw==
node-addon-api@^6.1.0:
version "6.1.0"
resolved "https://registry.yarnpkg.com/node-addon-api/-/node-addon-api-6.1.0.tgz#ac8470034e58e67d0c6f1204a18ae6995d9c0d76"
integrity sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==