Adapt code

pull/791/head
mudler 1 year ago committed by AT
parent fca2578a81
commit 79cef86bec

@ -2,14 +2,11 @@
#include "../../gpt4all-backend/llmodel.h" #include "../../gpt4all-backend/llmodel.h"
#include "../../gpt4all-backend/llama.cpp/llama.h" #include "../../gpt4all-backend/llama.cpp/llama.h"
#include "../../gpt4all-backend/llmodel_c.cpp" #include "../../gpt4all-backend/llmodel_c.cpp"
#include "../../gpt4all-backend/mpt.h"
#include "../../gpt4all-backend/mpt.cpp"
#include "../../gpt4all-backend/llamamodel.h"
#include "../../gpt4all-backend/gptj.h"
#include "binding.h" #include "binding.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstddef>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
@ -19,46 +16,24 @@
#include <iostream> #include <iostream>
#include <unistd.h> #include <unistd.h>
void* load_mpt_model(const char *fname, int n_threads) { void* load_gpt4all_model(const char *fname, int n_threads) {
// load the model // load the model
auto gptj = llmodel_mpt_create(); auto gptj4all = llmodel_model_create(fname);
if (gptj4all == NULL ){
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr; return nullptr;
} }
llmodel_setThreadCount(gptj4all, n_threads);
return gptj; if (!llmodel_loadModel(gptj4all, fname)) {
}
void* load_llama_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_llama_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr;
}
return gptj;
}
void* load_gptj_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_gptj_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr; return nullptr;
} }
return gptj; return gptj4all;
} }
std::string res = ""; std::string res = "";
void * mm; void * mm;
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase) float top_p, float temp, int n_batch,float ctx_erase)
{ {
llmodel_model* model = (llmodel_model*) m; llmodel_model* model = (llmodel_model*) m;
@ -120,8 +95,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la
free(prompt_context); free(prompt_context);
} }
void gptj_free_model(void *state_ptr) { void gpt4all_free_model(void *state_ptr) {
llmodel_model* ctx = (llmodel_model*) state_ptr; llmodel_model* ctx = (llmodel_model*) state_ptr;
llmodel_llama_destroy(ctx); llmodel_model_destroy(*ctx);
} }

@ -4,16 +4,12 @@ extern "C" {
#include <stdbool.h> #include <stdbool.h>
void* load_mpt_model(const char *fname, int n_threads); void* load_gpt4all_model(const char *fname, int n_threads);
void* load_llama_model(const char *fname, int n_threads); void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
void* load_gptj_model(const char *fname, int n_threads);
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase); float top_p, float temp, int n_batch,float ctx_erase);
void gptj_free_model(void *state_ptr); void gpt4all_free_model(void *state_ptr);
extern unsigned char getTokenCallback(void *, char *); extern unsigned char getTokenCallback(void *, char *);

@ -30,7 +30,7 @@ func main() {
fmt.Printf("Parsing program arguments failed: %s", err) fmt.Printf("Parsing program arguments failed: %s", err)
os.Exit(1) os.Exit(1)
} }
l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads)) l, err := gpt4all.New(model, gpt4all.SetThreads(threads))
if err != nil { if err != nil {
fmt.Println("Loading the model failed:", err.Error()) fmt.Println("Loading the model failed:", err.Error())
os.Exit(1) os.Exit(1)

@ -5,12 +5,10 @@ package gpt4all
// #cgo darwin LDFLAGS: -framework Accelerate // #cgo darwin LDFLAGS: -framework Accelerate
// #cgo darwin CXXFLAGS: -std=c++17 // #cgo darwin CXXFLAGS: -std=c++17
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++ // #cgo LDFLAGS: -lgpt4all -lm -lstdc++
// void* load_mpt_model(const char *fname, int n_threads); // void* load_gpt4all_model(const char *fname, int n_threads);
// void* load_llama_model(const char *fname, int n_threads); // void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
// void* load_gptj_model(const char *fname, int n_threads);
// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
// float top_p, float temp, int n_batch,float ctx_erase); // float top_p, float temp, int n_batch,float ctx_erase);
// void gptj_free_model(void *state_ptr); // void gpt4all_free_model(void *state_ptr);
// extern unsigned char getTokenCallback(void *, char *); // extern unsigned char getTokenCallback(void *, char *);
import "C" import "C"
import ( import (
@ -28,16 +26,8 @@ type Model struct {
func New(model string, opts ...ModelOption) (*Model, error) { func New(model string, opts ...ModelOption) (*Model, error) {
ops := NewModelOptions(opts...) ops := NewModelOptions(opts...)
var state unsafe.Pointer
state := C.load_gpt4all_model(C.CString(model), C.int(ops.Threads))
switch ops.ModelType {
case LLaMAType:
state = C.load_llama_model(C.CString(model), C.int(ops.Threads))
case GPTJType:
state = C.load_gptj_model(C.CString(model), C.int(ops.Threads))
case MPTType:
state = C.load_mpt_model(C.CString(model), C.int(ops.Threads))
}
if state == nil { if state == nil {
return nil, fmt.Errorf("failed loading model") return nil, fmt.Errorf("failed loading model")
@ -62,7 +52,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
} }
out := make([]byte, po.Tokens) out := make([]byte, po.Tokens)
C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), C.gpt4all_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
@ -75,7 +65,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
} }
func (l *Model) Free() { func (l *Model) Free() {
C.gptj_free_model(l.state) C.gpt4all_free_model(l.state)
} }
func (l *Model) SetTokenCallback(callback func(token string) bool) { func (l *Model) SetTokenCallback(callback func(token string) bool) {

@ -13,15 +13,5 @@ var _ = Describe("LLama binding", func() {
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(model).To(BeNil()) Expect(model).To(BeNil())
}) })
It("fails with no model", func() {
model, err := New("not-existing", SetModelType(MPTType))
Expect(err).To(HaveOccurred())
Expect(model).To(BeNil())
})
It("fails with no model", func() {
model, err := New("not-existing", SetModelType(LLaMAType))
Expect(err).To(HaveOccurred())
Expect(model).To(BeNil())
})
}) })
}) })

@ -21,23 +21,13 @@ var DefaultOptions PredictOptions = PredictOptions{
var DefaultModelOptions ModelOptions = ModelOptions{ var DefaultModelOptions ModelOptions = ModelOptions{
Threads: 4, Threads: 4,
ModelType: GPTJType,
} }
type ModelOptions struct { type ModelOptions struct {
Threads int Threads int
ModelType ModelType
} }
type ModelOption func(p *ModelOptions) type ModelOption func(p *ModelOptions)
type ModelType int
const (
LLaMAType ModelType = 0
GPTJType ModelType = iota
MPTType ModelType = iota
)
// SetTokens sets the number of tokens to generate. // SetTokens sets the number of tokens to generate.
func SetTokens(tokens int) PredictOption { func SetTokens(tokens int) PredictOption {
return func(p *PredictOptions) { return func(p *PredictOptions) {
@ -110,13 +100,6 @@ func SetThreads(c int) ModelOption {
} }
} }
// SetModelType sets the model type.
func SetModelType(c ModelType) ModelOption {
return func(p *ModelOptions) {
p.ModelType = c
}
}
// Create a new PredictOptions object with the given options. // Create a new PredictOptions object with the given options.
func NewModelOptions(opts ...ModelOption) ModelOptions { func NewModelOptions(opts ...ModelOption) ModelOptions {
p := DefaultModelOptions p := DefaultModelOptions

Loading…
Cancel
Save