@ -6,7 +6,7 @@ package gpt4all
// #cgo darwin CXXFLAGS: -std=c++17
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl
// void* load_model(const char *fname, int n_threads);
// void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
// void model_prompt( const char *prompt, const char *prompt_template, int special, const char *fake_reply, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
// float top_p, float min_p, float temp, int n_batch,float ctx_erase);
// void free_model(void *state_ptr);
// extern unsigned char getTokenCallback(void *, char *);
@ -47,7 +47,7 @@ func New(model string, opts ...ModelOption) (*Model, error) {
return gpt , nil
}
func ( l * Model ) Predict ( text string , opts ... PredictOption ) ( string , error ) {
func ( l * Model ) Predict ( text , template , fakeReplyText string , opts ... PredictOption ) ( string , error ) {
po := NewPredictOptions ( opts ... )
@ -55,10 +55,14 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
if po . Tokens == 0 {
po . Tokens = 99999999
}
templateInput := C . CString ( template )
fakeReplyInput := C . CString ( fakeReplyText )
out := make ( [ ] byte , po . Tokens )
C . model_prompt ( input , l . state , ( * C . char ) ( unsafe . Pointer ( & out [ 0 ] ) ) , C . int ( po . RepeatLastN ) , C . float ( po . RepeatPenalty ) , C . int ( po . ContextSize ) ,
C . int ( po . Tokens ) , C . int ( po . TopK ) , C . float ( po . TopP ) , C . float ( po . MinP ) , C . float ( po . Temperature ) , C . int ( po . Batch ) , C . float ( po . ContextErase ) )
C . model_prompt ( input , templateInput , C . int ( po . Special ) , fakeReplyInput , l . state , ( * C . char ) ( unsafe . Pointer ( & out [ 0 ] ) ) ,
C . int ( po . RepeatLastN ) , C . float ( po . RepeatPenalty ) , C . int ( po . ContextSize ) , C . int ( po . Tokens ) ,
C . int ( po . TopK ) , C . float ( po . TopP ) , C . float ( po . MinP ) , C . float ( po . Temperature ) , C . int ( po . Batch ) ,
C . float ( po . ContextErase ) )
res := C . GoString ( ( * C . char ) ( unsafe . Pointer ( & out [ 0 ] ) ) )
res = strings . TrimPrefix ( res , " " )