gpt4all/gpt4all-bindings/golang/example/main.go

83 lines
1.7 KiB
Go
Raw Normal View History

package main
import (
"bufio"
"flag"
"fmt"
"io"
"os"
"runtime"
"strings"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
)
var (
threads = 4
tokens = 128
)
func main() {
var model string
flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
flags.StringVar(&model, "m", "./models/7B/ggml-model-q4_0.bin", "path to q4_0.bin model file to load")
flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation")
flags.IntVar(&tokens, "n", 512, "number of tokens to predict")
err := flags.Parse(os.Args[1:])
if err != nil {
fmt.Printf("Parsing program arguments failed: %s", err)
os.Exit(1)
}
l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads))
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
fmt.Printf("Model loaded successfully.\n")
l.SetTokenCallback(func(token string) bool {
fmt.Print(token)
return true
})
reader := bufio.NewReader(os.Stdin)
for {
text := readMultiLineInput(reader)
_, err := l.Predict(text, gpt4all.SetTokens(tokens), gpt4all.SetTopK(90), gpt4all.SetTopP(0.86))
if err != nil {
panic(err)
}
fmt.Printf("\n\n")
}
}
// readMultiLineInput reads input until an empty line is entered.
func readMultiLineInput(reader *bufio.Reader) string {
var lines []string
fmt.Print(">>> ")
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
os.Exit(0)
}
fmt.Printf("Reading the prompt failed: %s", err)
os.Exit(1)
}
if len(strings.TrimSpace(line)) == 0 {
break
}
lines = append(lines, line)
}
text := strings.Join(lines, "")
return text
}