mirror of https://github.com/nomic-ai/gpt4all
Initial 1.0.0 Java-Bindings PR/release (#805)
* Initial 1.0.0 Java-Bindings PR/release * Initial 1.1.0 Java-Bindings PR/release * Add debug ability * 1.1.2 release --------- Co-authored-by: felix <felix@zaslavskiy.net>rguo123/pypi-ver-bump
parent
5cfb1bda89
commit
44bf91855d
@ -0,0 +1,2 @@
|
||||
# Make sure native directory never gets commited to git for the project.
|
||||
/src/main/resources/native
|
@ -0,0 +1,105 @@
|
||||
# Java bindings
|
||||
|
||||
Java bindings let you load a gpt4all library into your Java application and execute text
|
||||
generation using an intuitive and easy to use API. No GPU is required because gpt4all system executes on the cpu.
|
||||
The gpt4all models are quantized to easily fit into system RAM and use about 4 to 7GB of system RAM.
|
||||
|
||||
## Getting Started
|
||||
You can add Java bindings into your Java project by adding dependency to your project:
|
||||
|
||||
**Maven**
|
||||
```
|
||||
<dependency>
|
||||
<groupId>com.hexadevlabs</groupId>
|
||||
<artifactId>gpt4all-java-binding</artifactId>
|
||||
<version>1.1.2</version>
|
||||
</dependency>
|
||||
```
|
||||
**Gradle**
|
||||
```
|
||||
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
|
||||
```
|
||||
|
||||
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
|
||||
|
||||
To download a model binary weights file use an url such as https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin.
|
||||
|
||||
For information about other models available see [Model file list](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-chat#manual-download-of-models).
|
||||
|
||||
### Sample code
|
||||
```java
|
||||
public class Example {
|
||||
public static void main(String[] args) {
|
||||
|
||||
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||
|
||||
// Replace the hardcoded path with the actual path where your model file resides
|
||||
String modelFilePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
|
||||
try (LLModel model = new LLModel(Path.of(modelFilePath))) {
|
||||
|
||||
// May generate up to 4096 tokens but generally stops early
|
||||
LLModel.GenerationConfig config = LLModel.config()
|
||||
.withNPredict(4096).build();
|
||||
|
||||
// Will also stream to Standard out
|
||||
String fullGeneration = model.generate(prompt, config, true);
|
||||
|
||||
} catch (Exception e) {
|
||||
// Exception generally may happen if model file fails to load
|
||||
// for a number of reasons such as file not found.
|
||||
// It is possible that Java may not be able to dynamically load the native shared library or
|
||||
// the llmodel shared library may not be able to dynamically load the backend
|
||||
// implementation for the model file you provided.
|
||||
//
|
||||
// Once the LLModel class is successfully loaded into memory the text generation calls
|
||||
// generally should not throw exceptions.
|
||||
e.printStackTrace(); // Printing here but in production system you may want to take some action.
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
For a maven based sample project that uses this library see [Sample project](https://github.com/felix-zaslavskiy/gpt4all-java-bindings-sample)
|
||||
|
||||
### Additional considerations
|
||||
#### Logger warnings
|
||||
The Java bindings library may produce a warning:
|
||||
```
|
||||
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
|
||||
SLF4J: Defaulting to no-operation (NOP) logger implementation
|
||||
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
|
||||
```
|
||||
If you don't have a SLF4J binding included in your project. Java bindings only use logging for informational
|
||||
purposes, so logger is not essential to correctly use the library. You can ignore this warning if you don't have SLF4J bindings
|
||||
in your project.
|
||||
|
||||
To add a simple logger using maven dependency you may use:
|
||||
```
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
<version>1.7.36</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
#### Loading your native libraries
|
||||
1. Java bindings package jar comes bundled with native library files for Windows, macOS and Linux. These library files are
|
||||
copied to a temporary directory and loaded at runtime. For advanced users who may want to package shared libraries into Docker containers
|
||||
or want to use a custom build of the shared libraries and ignore the once bundled with the java package they have option
|
||||
to load libraries from your local directory by setting a static property to the location of library files.
|
||||
There are no guarantees of compatibility if used in such a way so be careful if you really want to do it.
|
||||
|
||||
For example:
|
||||
```java
|
||||
class Example {
|
||||
public static void main(String[] args) {
|
||||
// gpt4all native shared libraries location
|
||||
LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
// ... use the library normally
|
||||
}
|
||||
}
|
||||
```
|
||||
2. Not every avx only shared library is bundled with the jar right now to reduce size. Only the libgptj-avx is included.
|
||||
If you are running into issues please let us know using the gpt4all project issue tracker https://github.com/nomic-ai/gpt4all/issues.
|
@ -0,0 +1,202 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.hexadevlabs</groupId>
|
||||
<artifactId>gpt4all-java-binding</artifactId>
|
||||
<version>1.1.2</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<name>${project.groupId}:${project.artifactId}</name>
|
||||
<description>Java bindings for GPT4ALL LLM</description>
|
||||
<url>https://github.com/nomic-ai/gpt4all</url>
|
||||
<licenses>
|
||||
<license>
|
||||
<name>The Apache License, Version 2.0</name>
|
||||
<url>https://github.com/nomic-ai/gpt4all/blob/main/LICENSE.txt</url>
|
||||
</license>
|
||||
</licenses>
|
||||
<developers>
|
||||
<developer>
|
||||
<name>Felix Zaslavskiy</name>
|
||||
<email>felixz@hexadevlabs.com</email>
|
||||
<organizationUrl>https://github.com/felix-zaslavskiy/</organizationUrl>
|
||||
</developer>
|
||||
</developers>
|
||||
<scm>
|
||||
<connection>scm:git:git://github.com/nomic-ai/gpt4all.git</connection>
|
||||
<developerConnection>scm:git:ssh://github.com/nomic-ai/gpt4all.git</developerConnection>
|
||||
<url>https://github.com/nomic-ai/gpt4all/tree/main</url>
|
||||
</scm>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.github.jnr</groupId>
|
||||
<artifactId>jnr-ffi</artifactId>
|
||||
<version>2.2.13</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.36</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-api</artifactId>
|
||||
<version>5.9.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<distributionManagement>
|
||||
<snapshotRepository>
|
||||
<id>ossrh</id>
|
||||
<url>https://s01.oss.sonatype.org/content/repositories/snapshots</url>
|
||||
</snapshotRepository>
|
||||
<repository>
|
||||
<id>ossrh</id>
|
||||
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
|
||||
</repository>
|
||||
</distributionManagement>
|
||||
|
||||
<build>
|
||||
<resources>
|
||||
<resource>
|
||||
<directory>src/main/resources</directory>
|
||||
</resource>
|
||||
<resource>
|
||||
<directory>${project.build.directory}/generated-resources</directory>
|
||||
</resource>
|
||||
</resources>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>3.0.0</version>
|
||||
<configuration>
|
||||
<forkCount>0</forkCount>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-resources-plugin</artifactId>
|
||||
<version>3.3.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>copy-resources</id>
|
||||
<!-- Here the phase you need -->
|
||||
<phase>validate</phase>
|
||||
<goals>
|
||||
<goal>copy-resources</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<outputDirectory>${project.build.directory}/generated-resources</outputDirectory>
|
||||
<resources>
|
||||
<resource>
|
||||
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023</directory>
|
||||
</resource>
|
||||
</resources>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
|
||||
<plugin>
|
||||
<groupId>org.sonatype.plugins</groupId>
|
||||
<artifactId>nexus-staging-maven-plugin</artifactId>
|
||||
<version>1.6.13</version>
|
||||
<extensions>true</extensions>
|
||||
<configuration>
|
||||
<serverId>ossrh</serverId>
|
||||
<nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
|
||||
<autoReleaseAfterClose>true</autoReleaseAfterClose>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-source-plugin</artifactId>
|
||||
<version>2.2.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>attach-sources</id>
|
||||
<goals>
|
||||
<goal>jar-no-fork</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-javadoc-plugin</artifactId>
|
||||
<version>3.5.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>attach-javadocs</id>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-gpg-plugin</artifactId>
|
||||
<version>1.5</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>sign-artifacts</id>
|
||||
<phase>verify</phase>
|
||||
<goals>
|
||||
<goal>sign</goal>
|
||||
</goals>
|
||||
<!--
|
||||
<configuration>
|
||||
<keyname>${gpg.keyname}</keyname>
|
||||
<passphraseServerId>${gpg.keyname}</passphraseServerId>
|
||||
</configuration>
|
||||
-->
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<version>3.6.0</version>
|
||||
<configuration>
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
<archive>
|
||||
<manifest>
|
||||
<mainClass>com.hexadevlabs.gpt4allsample.Example4</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-assembly</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>single</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
||||
|
||||
</build>
|
||||
|
||||
|
||||
</project>
|
@ -0,0 +1,349 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import jnr.ffi.Pointer;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class LLModel implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Config used for how to decode LLM outputs.
|
||||
* High temperature closer to 1 gives more creative outputs
|
||||
* while low temperature closer to 0 produce more precise outputs.
|
||||
* <p>
|
||||
* Use builder to set settings you want.
|
||||
*/
|
||||
public static class GenerationConfig extends LLModelLibrary.LLModelPromptContext {
|
||||
|
||||
private GenerationConfig() {
|
||||
super(jnr.ffi.Runtime.getSystemRuntime());
|
||||
logits_size.set(0);
|
||||
tokens_size.set(0);
|
||||
n_past.set(0);
|
||||
n_ctx.set(1024);
|
||||
n_predict.set(128);
|
||||
top_k.set(40);
|
||||
top_p.set(0.95);
|
||||
temp.set(0.28);
|
||||
n_batch.set(8);
|
||||
repeat_penalty.set(1.1);
|
||||
repeat_last_n.set(10);
|
||||
context_erase.set(0.55);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private final GenerationConfig configToBuild;
|
||||
|
||||
public Builder() {
|
||||
configToBuild = new GenerationConfig();
|
||||
}
|
||||
|
||||
public Builder withNPast(int n_past) {
|
||||
configToBuild.n_past.set(n_past);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withNCtx(int n_ctx) {
|
||||
configToBuild.n_ctx.set(n_ctx);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withNPredict(int n_predict) {
|
||||
configToBuild.n_predict.set(n_predict);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTopK(int top_k) {
|
||||
configToBuild.top_k.set(top_k);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTopP(float top_p) {
|
||||
configToBuild.top_p.set(top_p);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTemp(float temp) {
|
||||
configToBuild.temp.set(temp);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withNBatch(int n_batch) {
|
||||
configToBuild.n_batch.set(n_batch);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withRepeatPenalty(float repeat_penalty) {
|
||||
configToBuild.repeat_penalty.set(repeat_penalty);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withRepeatLastN(int repeat_last_n) {
|
||||
configToBuild.repeat_last_n.set(repeat_last_n);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withContextErase(float context_erase) {
|
||||
configToBuild.context_erase.set(context_erase);
|
||||
return this;
|
||||
}
|
||||
|
||||
public GenerationConfig build() {
|
||||
return configToBuild;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut for making GenerativeConfig builder.
|
||||
*/
|
||||
public static GenerationConfig.Builder config(){
|
||||
return new GenerationConfig.Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* This may be set before any Model instance classe are instantiated to
|
||||
* set where the model may be found. This may be needed if setting
|
||||
* library search path by standard means is not available.
|
||||
*/
|
||||
public static String LIBRARY_SEARCH_PATH;
|
||||
|
||||
|
||||
/**
|
||||
* Generally for debugging purposes only. Will print
|
||||
* the numerical tokens as they are generated instead of the string representations.
|
||||
* Will also print out the processed input tokens as numbers to standard out.
|
||||
*/
|
||||
public static boolean OUTPUT_DEBUG = false;
|
||||
|
||||
protected static LLModelLibrary library;
|
||||
|
||||
protected Pointer model;
|
||||
|
||||
protected String modelName;
|
||||
|
||||
public LLModel(Path modelPath) {
|
||||
|
||||
if(library==null) {
|
||||
|
||||
if (LIBRARY_SEARCH_PATH != null){
|
||||
library = Util.loadSharedLibrary(LIBRARY_SEARCH_PATH);
|
||||
library.llmodel_set_implementation_search_path(LIBRARY_SEARCH_PATH);
|
||||
} else {
|
||||
// Copy system libraries to Temp folder
|
||||
Path tempLibraryDirectory = Util.copySharedLibraries();
|
||||
library = Util.loadSharedLibrary(tempLibraryDirectory.toString());
|
||||
|
||||
library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString() );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// modelType = type;
|
||||
modelName = modelPath.getFileName().toString();
|
||||
String modelPathAbs = modelPath.toAbsolutePath().toString();
|
||||
|
||||
LLModelLibrary.LLModelError error = new LLModelLibrary.LLModelError(jnr.ffi.Runtime.getSystemRuntime());
|
||||
|
||||
// Check if model file exists
|
||||
if(!Files.exists(modelPath)){
|
||||
throw new IllegalStateException("Model file does not exist: " + modelPathAbs);
|
||||
}
|
||||
|
||||
// Create Model Struct. Will load dynamically the correct backend based on model type
|
||||
model = library.llmodel_model_create2(modelPathAbs, "auto", error);
|
||||
|
||||
if(model == null) {
|
||||
throw new IllegalStateException("Could not load gpt4all backend :" + error.message);
|
||||
}
|
||||
library.llmodel_loadModel(model, modelPathAbs);
|
||||
|
||||
if(!library.llmodel_isModelLoaded(model)){
|
||||
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void setThreadCount(int nThreads) {
|
||||
library.llmodel_setThreadCount(this.model, nThreads);
|
||||
}
|
||||
|
||||
public int threadCount() {
|
||||
return library.llmodel_threadCount(this.model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text after the prompt
|
||||
*
|
||||
* @param prompt The text prompt to complete
|
||||
* @param generationConfig What generation settings to use while generating text
|
||||
* @return String The complete generated text
|
||||
*/
|
||||
public String generate(String prompt, GenerationConfig generationConfig) {
|
||||
return generate(prompt, generationConfig, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text after the prompt
|
||||
*
|
||||
* @param prompt The text prompt to complete
|
||||
* @param generationConfig What generation settings to use while generating text
|
||||
* @param streamToStdOut Should the generation be streamed to standard output. Useful for troubleshooting.
|
||||
* @return String The complete generated text
|
||||
*/
|
||||
public String generate(String prompt, GenerationConfig generationConfig, boolean streamToStdOut) {
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
|
||||
|
||||
if(LLModel.OUTPUT_DEBUG)
|
||||
System.out.print("Response token " + tokenID + " " );
|
||||
|
||||
long len = 0;
|
||||
byte nextByte;
|
||||
do{
|
||||
nextByte= response.getByte(len);
|
||||
len++;
|
||||
if(nextByte!=0) {
|
||||
bufferingForWholeGeneration.write(nextByte);
|
||||
if(streamToStdOut){
|
||||
bufferingForStdOutStream.write(nextByte);
|
||||
// Test if Buffer is UTF8 valid string.
|
||||
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
|
||||
String validString = Util.getValidUtf8(currentBytes);
|
||||
if(validString!=null){ // is valid string
|
||||
System.out.print(validString);
|
||||
// reset the buffer for next utf8 sequence to buffer
|
||||
bufferingForStdOutStream.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
} while(nextByte != 0);
|
||||
|
||||
return true; // continue generating
|
||||
};
|
||||
|
||||
library.llmodel_prompt(this.model,
|
||||
prompt,
|
||||
(int tokenID) -> {
|
||||
if(LLModel.OUTPUT_DEBUG)
|
||||
System.out.println("token " + tokenID);
|
||||
return true; // continue processing
|
||||
},
|
||||
responseCallback,
|
||||
(boolean isRecalculating) -> {
|
||||
if(LLModel.OUTPUT_DEBUG)
|
||||
System.out.println("recalculating");
|
||||
return isRecalculating; // continue generating
|
||||
},
|
||||
generationConfig);
|
||||
|
||||
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static class ChatCompletionResponse {
|
||||
public String model;
|
||||
public Usage usage;
|
||||
public List<Map<String, String>> choices;
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
|
||||
public static class Usage {
|
||||
public int promptTokens;
|
||||
public int completionTokens;
|
||||
public int totalTokens;
|
||||
|
||||
// Getters and setters
|
||||
}
|
||||
|
||||
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
|
||||
GenerationConfig generationConfig) {
|
||||
return chatCompletion(messages, generationConfig, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* chatCompletion formats the existing chat conversation into a template to be
|
||||
* easier to process for chat UIs. It is not absolutely necessary as generate method
|
||||
* may be directly used to make generations with gpt models.
|
||||
*
|
||||
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
|
||||
* @param generationConfig How to decode/process the generation.
|
||||
* @param streamToStdOut Send tokens as they are calculated Standard output.
|
||||
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
|
||||
* @return ChatCompletionResponse contains stats and generated Text.
|
||||
*/
|
||||
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
|
||||
GenerationConfig generationConfig, boolean streamToStdOut,
|
||||
boolean outputFullPromptToStdOut) {
|
||||
String fullPrompt = buildPrompt(messages);
|
||||
|
||||
if(outputFullPromptToStdOut)
|
||||
System.out.print(fullPrompt);
|
||||
|
||||
String generatedText = generate(fullPrompt, generationConfig, streamToStdOut);
|
||||
|
||||
ChatCompletionResponse response = new ChatCompletionResponse();
|
||||
response.model = this.modelName;
|
||||
|
||||
Usage usage = new Usage();
|
||||
usage.promptTokens = fullPrompt.length();
|
||||
usage.completionTokens = generatedText.length();
|
||||
usage.totalTokens = fullPrompt.length() + generatedText.length();
|
||||
response.usage = usage;
|
||||
|
||||
Map<String, String> message = new HashMap<>();
|
||||
message.put("role", "assistant");
|
||||
message.put("content", generatedText);
|
||||
|
||||
response.choices = List.of(message);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
protected static String buildPrompt(List<Map<String, String>> messages) {
|
||||
StringBuilder fullPrompt = new StringBuilder();
|
||||
|
||||
for (Map<String, String> message : messages) {
|
||||
if ("system".equals(message.get("role"))) {
|
||||
String systemMessage = message.get("content") + "\n";
|
||||
fullPrompt.append(systemMessage);
|
||||
}
|
||||
}
|
||||
|
||||
fullPrompt.append("### Instruction: \n" +
|
||||
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||
"### Prompt: ");
|
||||
|
||||
for (Map<String, String> message : messages) {
|
||||
if ("user".equals(message.get("role"))) {
|
||||
String userMessage = "\n" + message.get("content");
|
||||
fullPrompt.append(userMessage);
|
||||
}
|
||||
if ("assistant".equals(message.get("role"))) {
|
||||
String assistantMessage = "\n### Response: " + message.get("content");
|
||||
fullPrompt.append(assistantMessage);
|
||||
}
|
||||
}
|
||||
|
||||
fullPrompt.append("\n### Response:");
|
||||
|
||||
return fullPrompt.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
library.llmodel_model_destroy(model);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,79 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import jnr.ffi.Pointer;
|
||||
import jnr.ffi.Struct;
|
||||
import jnr.ffi.annotations.Delegate;
|
||||
import jnr.ffi.annotations.Encoding;
|
||||
import jnr.ffi.annotations.In;
|
||||
import jnr.ffi.annotations.Out;
|
||||
import jnr.ffi.types.u_int64_t;
|
||||
|
||||
|
||||
/**
|
||||
* The basic Native library interface the provides all the LLM functions.
|
||||
*/
|
||||
public interface LLModelLibrary {
|
||||
|
||||
interface PromptCallback {
|
||||
@Delegate
|
||||
boolean invoke(int token_id);
|
||||
}
|
||||
|
||||
interface ResponseCallback {
|
||||
@Delegate
|
||||
boolean invoke(int token_id, Pointer response);
|
||||
}
|
||||
|
||||
interface RecalculateCallback {
|
||||
@Delegate
|
||||
boolean invoke(boolean is_recalculating);
|
||||
}
|
||||
|
||||
class LLModelError extends Struct {
|
||||
public final Struct.AsciiStringRef message = new Struct.AsciiStringRef();
|
||||
public final int32_t status = new int32_t();
|
||||
public LLModelError(jnr.ffi.Runtime runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
}
|
||||
|
||||
class LLModelPromptContext extends Struct {
|
||||
public final Pointer logits = new Pointer();
|
||||
public final ssize_t logits_size = new ssize_t();
|
||||
public final Pointer tokens = new Pointer();
|
||||
public final ssize_t tokens_size = new ssize_t();
|
||||
public final int32_t n_past = new int32_t();
|
||||
public final int32_t n_ctx = new int32_t();
|
||||
public final int32_t n_predict = new int32_t();
|
||||
public final int32_t top_k = new int32_t();
|
||||
public final Float top_p = new Float();
|
||||
public final Float temp = new Float();
|
||||
public final int32_t n_batch = new int32_t();
|
||||
public final Float repeat_penalty = new Float();
|
||||
public final int32_t repeat_last_n = new int32_t();
|
||||
public final Float context_erase = new Float();
|
||||
|
||||
public LLModelPromptContext(jnr.ffi.Runtime runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
}
|
||||
|
||||
Pointer llmodel_model_create2(String model_path, String build_variant, @Out LLModelError llmodel_error);
|
||||
void llmodel_model_destroy(Pointer model);
|
||||
boolean llmodel_loadModel(Pointer model, String model_path);
|
||||
boolean llmodel_isModelLoaded(Pointer model);
|
||||
@u_int64_t long llmodel_get_state_size(Pointer model);
|
||||
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
||||
@u_int64_t long llmodel_restore_state_data(Pointer model, Pointer src);
|
||||
|
||||
void llmodel_set_implementation_search_path(String path);
|
||||
|
||||
// ctx was an @Out ... without @Out crash
|
||||
void llmodel_prompt(Pointer model, @Encoding("UTF-8") String prompt,
|
||||
PromptCallback prompt_callback,
|
||||
ResponseCallback response_callback,
|
||||
RecalculateCallback recalculate_callback,
|
||||
@In LLModelPromptContext ctx);
|
||||
void llmodel_setThreadCount(Pointer model, int n_threads);
|
||||
int llmodel_threadCount(Pointer model);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import jnr.ffi.LibraryLoader;
|
||||
import jnr.ffi.LibraryOption;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.CharacterCodingException;
|
||||
import java.nio.charset.CharsetDecoder;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class Util {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(Util.class);
|
||||
private static final CharsetDecoder cs = StandardCharsets.UTF_8.newDecoder();
|
||||
|
||||
public static LLModelLibrary loadSharedLibrary(String librarySearchPath){
|
||||
String libraryName = "llmodel";
|
||||
Map<LibraryOption, Object> libraryOptions = new HashMap<>();
|
||||
libraryOptions.put(LibraryOption.LoadNow, true); // load immediately instead of lazily (ie on first use)
|
||||
libraryOptions.put(LibraryOption.IgnoreError, false); // calls shouldn't save last errno after call
|
||||
|
||||
if(librarySearchPath!=null) {
|
||||
Map<String, List<String>> searchPaths = new HashMap<>();
|
||||
searchPaths.put(libraryName, List.of(librarySearchPath));
|
||||
|
||||
return LibraryLoader.loadLibrary(LLModelLibrary.class,
|
||||
libraryOptions,
|
||||
searchPaths,
|
||||
libraryName
|
||||
);
|
||||
}else {
|
||||
|
||||
return LibraryLoader.loadLibrary(LLModelLibrary.class,
|
||||
libraryOptions,
|
||||
libraryName
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy over shared library files from resource package to
|
||||
* target Temp directory.
|
||||
* Returns Path to the temp directory holding the shared libraries
|
||||
*/
|
||||
public static Path copySharedLibraries() {
|
||||
try {
|
||||
// Identify the OS and architecture
|
||||
String osName = System.getProperty("os.name").toLowerCase();
|
||||
boolean isWindows = osName.startsWith("windows");
|
||||
boolean isMac = osName.startsWith("mac os x");
|
||||
boolean isLinux = osName.startsWith("linux");
|
||||
if(isWindows) osName = "windows";
|
||||
if(isMac) osName = "macos";
|
||||
if(isLinux) osName = "linux";
|
||||
|
||||
//String osArch = System.getProperty("os.arch");
|
||||
|
||||
// Create a temporary directory
|
||||
Path tempDirectory = Files.createTempDirectory("nativeLibraries");
|
||||
tempDirectory.toFile().deleteOnExit();
|
||||
|
||||
String[] libraryNames = {
|
||||
"gptj-default",
|
||||
"gptj-avxonly",
|
||||
"llmodel",
|
||||
"mpt-default",
|
||||
"llamamodel-230511-default",
|
||||
"llamamodel-230519-default",
|
||||
"llamamodel-mainline-default"
|
||||
};
|
||||
|
||||
for (String libraryName : libraryNames) {
|
||||
|
||||
if(isWindows){
|
||||
libraryName = libraryName + ".dll";
|
||||
} else if(isMac){
|
||||
libraryName = "lib" + libraryName + ".dylib";
|
||||
} else if(isLinux) {
|
||||
libraryName = "lib"+ libraryName + ".so";
|
||||
}
|
||||
|
||||
// Construct the resource path based on the OS and architecture
|
||||
String nativeLibraryPath = "/native/" + osName + "/" + libraryName;
|
||||
|
||||
// Get the library resource as a stream
|
||||
InputStream in = Util.class.getResourceAsStream(nativeLibraryPath);
|
||||
if (in == null) {
|
||||
throw new RuntimeException("Unable to find native library: " + nativeLibraryPath);
|
||||
}
|
||||
|
||||
// Create a file in the temporary directory with the original library name
|
||||
Path tempLibraryPath = tempDirectory.resolve(libraryName);
|
||||
|
||||
// Use Files.copy to copy the library to the temporary file
|
||||
Files.copy(in, tempLibraryPath, StandardCopyOption.REPLACE_EXISTING);
|
||||
|
||||
// Close the input stream
|
||||
in.close();
|
||||
}
|
||||
|
||||
// Add shutdown hook to delete tempDir on JVM exit
|
||||
// On Windows deleting dll files that are loaded into memory is not possible.
|
||||
if(!isWindows) {
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
Files.walk(tempDirectory)
|
||||
.sorted(Comparator.reverseOrder())
|
||||
.map(Path::toFile)
|
||||
.forEach(file -> {
|
||||
try {
|
||||
Files.delete(file.toPath());
|
||||
} catch (IOException e) {
|
||||
logger.error("Deleting temp library file", e);
|
||||
}
|
||||
});
|
||||
} catch (IOException e) {
|
||||
logger.error("Deleting temp directory for libraries", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
return tempDirectory;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load native libraries", e);
|
||||
}
|
||||
}
|
||||
|
||||
public static String getValidUtf8(byte[] bytes) {
|
||||
try {
|
||||
return cs.decode(ByteBuffer.wrap(bytes)).toString();
|
||||
} catch (CharacterCodingException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* GPTJ chat completion, multiple messages
|
||||
*/
|
||||
public class Example1 {
|
||||
public static void main(String[] args) {
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
|
||||
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
|
||||
|
||||
LLModel.GenerationConfig config = LLModel.config()
|
||||
.withNPredict(4096).build();
|
||||
|
||||
gptjModel.chatCompletion(
|
||||
List.of(Map.of("role", "user", "content", "Add 2+2"),
|
||||
Map.of("role", "assistant", "content", "4"),
|
||||
Map.of("role", "user", "content", "Multiply 4 * 5")), config, true, true);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
/**
|
||||
* Generation with MPT model
|
||||
*/
|
||||
public class Example2 {
|
||||
public static void main(String[] args) {
|
||||
|
||||
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
|
||||
try (LLModel mptModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-mpt-7b-instruct.bin"))) {
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(4096)
|
||||
.withRepeatLastN(64)
|
||||
.build();
|
||||
|
||||
mptModel.generate(prompt, config, true);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import jnr.ffi.LibraryLoader;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* GPTJ chat completion with system message
|
||||
*/
|
||||
public class Example3 {
|
||||
public static void main(String[] args) {
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
|
||||
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
|
||||
|
||||
LLModel.GenerationConfig config = LLModel.config()
|
||||
.withNPredict(4096).build();
|
||||
|
||||
// String result = gptjModel.generate(prompt, config, true);
|
||||
gptjModel.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class Example4 {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
|
||||
// LLM should correctly identify it as poop emoji in the description
|
||||
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
|
||||
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
|
||||
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
|
||||
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
//String model = "ggml-mpt-7b-instruct.bin";
|
||||
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
|
||||
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
|
||||
|
||||
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(4096)
|
||||
.withRepeatLastN(64)
|
||||
.build();
|
||||
|
||||
|
||||
String result = mptModel.generate(prompt, config, true);
|
||||
|
||||
System.out.println("Code points:");
|
||||
result.codePoints().forEach(System.out::println);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* For now disabled until can run with latest library on windows
|
||||
*/
|
||||
@Disabled
|
||||
public class Tests {
|
||||
|
||||
static LLModel model;
|
||||
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
|
||||
@BeforeAll
|
||||
public static void before(){
|
||||
model = new LLModel(Path.of(PATH_TO_MODEL));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
String prompt =
|
||||
"### Instruction: \n" +
|
||||
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||
"### Prompt: \n" +
|
||||
"Add 2+2\n" +
|
||||
"### Response: 4\n" +
|
||||
"Multiply 4 * 5\n" +
|
||||
"### Response:";
|
||||
String generatedText = model.generate(prompt, config, true);
|
||||
assertTrue(generatedText.contains("20"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void chatCompletion(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void after() throws Exception {
|
||||
if(model != null)
|
||||
model.close();
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue