|
|
|
@ -5,12 +5,10 @@ package gpt4all
|
|
|
|
|
// #cgo darwin LDFLAGS: -framework Accelerate
|
|
|
|
|
// #cgo darwin CXXFLAGS: -std=c++17
|
|
|
|
|
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++
|
|
|
|
|
// void* load_mpt_model(const char *fname, int n_threads);
|
|
|
|
|
// void* load_llama_model(const char *fname, int n_threads);
|
|
|
|
|
// void* load_gptj_model(const char *fname, int n_threads);
|
|
|
|
|
// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
|
|
|
|
// void* load_gpt4all_model(const char *fname, int n_threads);
|
|
|
|
|
// void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
|
|
|
|
// float top_p, float temp, int n_batch,float ctx_erase);
|
|
|
|
|
// void gptj_free_model(void *state_ptr);
|
|
|
|
|
// void gpt4all_free_model(void *state_ptr);
|
|
|
|
|
// extern unsigned char getTokenCallback(void *, char *);
|
|
|
|
|
import "C"
|
|
|
|
|
import (
|
|
|
|
@ -28,16 +26,8 @@ type Model struct {
|
|
|
|
|
|
|
|
|
|
func New(model string, opts ...ModelOption) (*Model, error) {
|
|
|
|
|
ops := NewModelOptions(opts...)
|
|
|
|
|
var state unsafe.Pointer
|
|
|
|
|
|
|
|
|
|
switch ops.ModelType {
|
|
|
|
|
case LLaMAType:
|
|
|
|
|
state = C.load_llama_model(C.CString(model), C.int(ops.Threads))
|
|
|
|
|
case GPTJType:
|
|
|
|
|
state = C.load_gptj_model(C.CString(model), C.int(ops.Threads))
|
|
|
|
|
case MPTType:
|
|
|
|
|
state = C.load_mpt_model(C.CString(model), C.int(ops.Threads))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
state := C.load_gpt4all_model(C.CString(model), C.int(ops.Threads))
|
|
|
|
|
|
|
|
|
|
if state == nil {
|
|
|
|
|
return nil, fmt.Errorf("failed loading model")
|
|
|
|
@ -62,7 +52,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
|
|
|
|
|
}
|
|
|
|
|
out := make([]byte, po.Tokens)
|
|
|
|
|
|
|
|
|
|
C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
|
|
|
|
|
C.gpt4all_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
|
|
|
|
|
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
|
|
|
|
|
|
|
|
|
|
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
|
|
|
|
@ -75,7 +65,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (l *Model) Free() {
|
|
|
|
|
C.gptj_free_model(l.state)
|
|
|
|
|
C.gpt4all_free_model(l.state)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (l *Model) SetTokenCallback(callback func(token string) bool) {
|
|
|
|
|