Initial 1.0.0 Java-Bindings PR/release (#805)

* Initial 1.0.0 Java-Bindings PR/release

* Initial 1.1.0 Java-Bindings PR/release

* Add debug ability

* 1.1.2  release

---------

Co-authored-by: felix <felix@zaslavskiy.net>
This commit is contained in:
Felix Zaslavskiy 2023-06-12 14:58:06 -04:00 committed by GitHub
parent 5cfb1bda89
commit 44bf91855d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1090 additions and 0 deletions

View File

@ -61,6 +61,7 @@ Find the most up-to-date information on the [GPT4All Website](https://gpt4all.io
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/typescript">:computer: Official Typescript Bindings</a> * <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/typescript">:computer: Official Typescript Bindings</a>
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/golang">:computer: Official GoLang Bindings</a> * <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/golang">:computer: Official GoLang Bindings</a>
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/csharp">:computer: Official C# Bindings</a> * <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/csharp">:computer: Official C# Bindings</a>
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/java">:computer: Official Java Bindings</a>
## Contributing ## Contributing

2
gpt4all-bindings/java/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Make sure native directory never gets commited to git for the project.
/src/main/resources/native

View File

@ -0,0 +1,105 @@
# Java bindings
Java bindings let you load a gpt4all library into your Java application and execute text
generation using an intuitive and easy to use API. No GPU is required because gpt4all system executes on the cpu.
The gpt4all models are quantized to easily fit into system RAM and use about 4 to 7GB of system RAM.
## Getting Started
You can add Java bindings into your Java project by adding dependency to your project:
**Maven**
```
<dependency>
<groupId>com.hexadevlabs</groupId>
<artifactId>gpt4all-java-binding</artifactId>
<version>1.1.2</version>
</dependency>
```
**Gradle**
```
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
```
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
To download a model binary weights file use an url such as https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin.
For information about other models available see [Model file list](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-chat#manual-download-of-models).
### Sample code
```java
public class Example {
public static void main(String[] args) {
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
// Replace the hardcoded path with the actual path where your model file resides
String modelFilePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
try (LLModel model = new LLModel(Path.of(modelFilePath))) {
// May generate up to 4096 tokens but generally stops early
LLModel.GenerationConfig config = LLModel.config()
.withNPredict(4096).build();
// Will also stream to Standard out
String fullGeneration = model.generate(prompt, config, true);
} catch (Exception e) {
// Exception generally may happen if model file fails to load
// for a number of reasons such as file not found.
// It is possible that Java may not be able to dynamically load the native shared library or
// the llmodel shared library may not be able to dynamically load the backend
// implementation for the model file you provided.
//
// Once the LLModel class is successfully loaded into memory the text generation calls
// generally should not throw exceptions.
e.printStackTrace(); // Printing here but in production system you may want to take some action.
}
}
}
```
For a maven based sample project that uses this library see [Sample project](https://github.com/felix-zaslavskiy/gpt4all-java-bindings-sample)
### Additional considerations
#### Logger warnings
The Java bindings library may produce a warning:
```
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
```
If you don't have a SLF4J binding included in your project. Java bindings only use logging for informational
purposes, so logger is not essential to correctly use the library. You can ignore this warning if you don't have SLF4J bindings
in your project.
To add a simple logger using maven dependency you may use:
```
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.36</version>
</dependency>
```
#### Loading your native libraries
1. Java bindings package jar comes bundled with native library files for Windows, macOS and Linux. These library files are
copied to a temporary directory and loaded at runtime. For advanced users who may want to package shared libraries into Docker containers
or want to use a custom build of the shared libraries and ignore the once bundled with the java package they have option
to load libraries from your local directory by setting a static property to the location of library files.
There are no guarantees of compatibility if used in such a way so be careful if you really want to do it.
For example:
```java
class Example {
public static void main(String[] args) {
// gpt4all native shared libraries location
LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
// ... use the library normally
}
}
```
2. Not every avx only shared library is bundled with the jar right now to reduce size. Only the libgptj-avx is included.
If you are running into issues please let us know using the gpt4all project issue tracker https://github.com/nomic-ai/gpt4all/issues.

View File

@ -0,0 +1,202 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.hexadevlabs</groupId>
<artifactId>gpt4all-java-binding</artifactId>
<version>1.1.2</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<name>${project.groupId}:${project.artifactId}</name>
<description>Java bindings for GPT4ALL LLM</description>
<url>https://github.com/nomic-ai/gpt4all</url>
<licenses>
<license>
<name>The Apache License, Version 2.0</name>
<url>https://github.com/nomic-ai/gpt4all/blob/main/LICENSE.txt</url>
</license>
</licenses>
<developers>
<developer>
<name>Felix Zaslavskiy</name>
<email>felixz@hexadevlabs.com</email>
<organizationUrl>https://github.com/felix-zaslavskiy/</organizationUrl>
</developer>
</developers>
<scm>
<connection>scm:git:git://github.com/nomic-ai/gpt4all.git</connection>
<developerConnection>scm:git:ssh://github.com/nomic-ai/gpt4all.git</developerConnection>
<url>https://github.com/nomic-ai/gpt4all/tree/main</url>
</scm>
<dependencies>
<dependency>
<groupId>com.github.jnr</groupId>
<artifactId>jnr-ffi</artifactId>
<version>2.2.13</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.36</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.9.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
<url>https://s01.oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
<repository>
<id>ossrh</id>
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
</repository>
</distributionManagement>
<build>
<resources>
<resource>
<directory>src/main/resources</directory>
</resource>
<resource>
<directory>${project.build.directory}/generated-resources</directory>
</resource>
</resources>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<forkCount>0</forkCount>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<version>3.3.1</version>
<executions>
<execution>
<id>copy-resources</id>
<!-- Here the phase you need -->
<phase>validate</phase>
<goals>
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${project.build.directory}/generated-resources</outputDirectory>
<resources>
<resource>
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023</directory>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<version>1.6.13</version>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.5</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
<!--
<configuration>
<keyname>${gpg.keyname}</keyname>
<passphraseServerId>${gpg.keyname}</passphraseServerId>
</configuration>
-->
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.6.0</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<archive>
<manifest>
<mainClass>com.hexadevlabs.gpt4allsample.Example4</mainClass>
</manifest>
</archive>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,349 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.Pointer;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class LLModel implements AutoCloseable {
/**
* Config used for how to decode LLM outputs.
* High temperature closer to 1 gives more creative outputs
* while low temperature closer to 0 produce more precise outputs.
* <p>
* Use builder to set settings you want.
*/
public static class GenerationConfig extends LLModelLibrary.LLModelPromptContext {
private GenerationConfig() {
super(jnr.ffi.Runtime.getSystemRuntime());
logits_size.set(0);
tokens_size.set(0);
n_past.set(0);
n_ctx.set(1024);
n_predict.set(128);
top_k.set(40);
top_p.set(0.95);
temp.set(0.28);
n_batch.set(8);
repeat_penalty.set(1.1);
repeat_last_n.set(10);
context_erase.set(0.55);
}
public static class Builder {
private final GenerationConfig configToBuild;
public Builder() {
configToBuild = new GenerationConfig();
}
public Builder withNPast(int n_past) {
configToBuild.n_past.set(n_past);
return this;
}
public Builder withNCtx(int n_ctx) {
configToBuild.n_ctx.set(n_ctx);
return this;
}
public Builder withNPredict(int n_predict) {
configToBuild.n_predict.set(n_predict);
return this;
}
public Builder withTopK(int top_k) {
configToBuild.top_k.set(top_k);
return this;
}
public Builder withTopP(float top_p) {
configToBuild.top_p.set(top_p);
return this;
}
public Builder withTemp(float temp) {
configToBuild.temp.set(temp);
return this;
}
public Builder withNBatch(int n_batch) {
configToBuild.n_batch.set(n_batch);
return this;
}
public Builder withRepeatPenalty(float repeat_penalty) {
configToBuild.repeat_penalty.set(repeat_penalty);
return this;
}
public Builder withRepeatLastN(int repeat_last_n) {
configToBuild.repeat_last_n.set(repeat_last_n);
return this;
}
public Builder withContextErase(float context_erase) {
configToBuild.context_erase.set(context_erase);
return this;
}
public GenerationConfig build() {
return configToBuild;
}
}
}
/**
* Shortcut for making GenerativeConfig builder.
*/
public static GenerationConfig.Builder config(){
return new GenerationConfig.Builder();
}
/**
* This may be set before any Model instance classe are instantiated to
* set where the model may be found. This may be needed if setting
* library search path by standard means is not available.
*/
public static String LIBRARY_SEARCH_PATH;
/**
* Generally for debugging purposes only. Will print
* the numerical tokens as they are generated instead of the string representations.
* Will also print out the processed input tokens as numbers to standard out.
*/
public static boolean OUTPUT_DEBUG = false;
protected static LLModelLibrary library;
protected Pointer model;
protected String modelName;
public LLModel(Path modelPath) {
if(library==null) {
if (LIBRARY_SEARCH_PATH != null){
library = Util.loadSharedLibrary(LIBRARY_SEARCH_PATH);
library.llmodel_set_implementation_search_path(LIBRARY_SEARCH_PATH);
} else {
// Copy system libraries to Temp folder
Path tempLibraryDirectory = Util.copySharedLibraries();
library = Util.loadSharedLibrary(tempLibraryDirectory.toString());
library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString() );
}
}
// modelType = type;
modelName = modelPath.getFileName().toString();
String modelPathAbs = modelPath.toAbsolutePath().toString();
LLModelLibrary.LLModelError error = new LLModelLibrary.LLModelError(jnr.ffi.Runtime.getSystemRuntime());
// Check if model file exists
if(!Files.exists(modelPath)){
throw new IllegalStateException("Model file does not exist: " + modelPathAbs);
}
// Create Model Struct. Will load dynamically the correct backend based on model type
model = library.llmodel_model_create2(modelPathAbs, "auto", error);
if(model == null) {
throw new IllegalStateException("Could not load gpt4all backend :" + error.message);
}
library.llmodel_loadModel(model, modelPathAbs);
if(!library.llmodel_isModelLoaded(model)){
throw new IllegalStateException("The model " + modelName + " could not be loaded");
}
}
public void setThreadCount(int nThreads) {
library.llmodel_setThreadCount(this.model, nThreads);
}
public int threadCount() {
return library.llmodel_threadCount(this.model);
}
/**
* Generate text after the prompt
*
* @param prompt The text prompt to complete
* @param generationConfig What generation settings to use while generating text
* @return String The complete generated text
*/
public String generate(String prompt, GenerationConfig generationConfig) {
return generate(prompt, generationConfig, false);
}
/**
* Generate text after the prompt
*
* @param prompt The text prompt to complete
* @param generationConfig What generation settings to use while generating text
* @param streamToStdOut Should the generation be streamed to standard output. Useful for troubleshooting.
* @return String The complete generated text
*/
public String generate(String prompt, GenerationConfig generationConfig, boolean streamToStdOut) {
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
if(LLModel.OUTPUT_DEBUG)
System.out.print("Response token " + tokenID + " " );
long len = 0;
byte nextByte;
do{
nextByte= response.getByte(len);
len++;
if(nextByte!=0) {
bufferingForWholeGeneration.write(nextByte);
if(streamToStdOut){
bufferingForStdOutStream.write(nextByte);
// Test if Buffer is UTF8 valid string.
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
String validString = Util.getValidUtf8(currentBytes);
if(validString!=null){ // is valid string
System.out.print(validString);
// reset the buffer for next utf8 sequence to buffer
bufferingForStdOutStream.reset();
}
}
}
} while(nextByte != 0);
return true; // continue generating
};
library.llmodel_prompt(this.model,
prompt,
(int tokenID) -> {
if(LLModel.OUTPUT_DEBUG)
System.out.println("token " + tokenID);
return true; // continue processing
},
responseCallback,
(boolean isRecalculating) -> {
if(LLModel.OUTPUT_DEBUG)
System.out.println("recalculating");
return isRecalculating; // continue generating
},
generationConfig);
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
}
public static class ChatCompletionResponse {
public String model;
public Usage usage;
public List<Map<String, String>> choices;
// Getters and setters
}
public static class Usage {
public int promptTokens;
public int completionTokens;
public int totalTokens;
// Getters and setters
}
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
GenerationConfig generationConfig) {
return chatCompletion(messages, generationConfig, false, false);
}
/**
* chatCompletion formats the existing chat conversation into a template to be
* easier to process for chat UIs. It is not absolutely necessary as generate method
* may be directly used to make generations with gpt models.
*
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
* @param generationConfig How to decode/process the generation.
* @param streamToStdOut Send tokens as they are calculated Standard output.
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
* @return ChatCompletionResponse contains stats and generated Text.
*/
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
GenerationConfig generationConfig, boolean streamToStdOut,
boolean outputFullPromptToStdOut) {
String fullPrompt = buildPrompt(messages);
if(outputFullPromptToStdOut)
System.out.print(fullPrompt);
String generatedText = generate(fullPrompt, generationConfig, streamToStdOut);
ChatCompletionResponse response = new ChatCompletionResponse();
response.model = this.modelName;
Usage usage = new Usage();
usage.promptTokens = fullPrompt.length();
usage.completionTokens = generatedText.length();
usage.totalTokens = fullPrompt.length() + generatedText.length();
response.usage = usage;
Map<String, String> message = new HashMap<>();
message.put("role", "assistant");
message.put("content", generatedText);
response.choices = List.of(message);
return response;
}
protected static String buildPrompt(List<Map<String, String>> messages) {
StringBuilder fullPrompt = new StringBuilder();
for (Map<String, String> message : messages) {
if ("system".equals(message.get("role"))) {
String systemMessage = message.get("content") + "\n";
fullPrompt.append(systemMessage);
}
}
fullPrompt.append("### Instruction: \n" +
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
"### Prompt: ");
for (Map<String, String> message : messages) {
if ("user".equals(message.get("role"))) {
String userMessage = "\n" + message.get("content");
fullPrompt.append(userMessage);
}
if ("assistant".equals(message.get("role"))) {
String assistantMessage = "\n### Response: " + message.get("content");
fullPrompt.append(assistantMessage);
}
}
fullPrompt.append("\n### Response:");
return fullPrompt.toString();
}
@Override
public void close() throws Exception {
library.llmodel_model_destroy(model);
}
}

View File

@ -0,0 +1,79 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.Pointer;
import jnr.ffi.Struct;
import jnr.ffi.annotations.Delegate;
import jnr.ffi.annotations.Encoding;
import jnr.ffi.annotations.In;
import jnr.ffi.annotations.Out;
import jnr.ffi.types.u_int64_t;
/**
* The basic Native library interface the provides all the LLM functions.
*/
public interface LLModelLibrary {
interface PromptCallback {
@Delegate
boolean invoke(int token_id);
}
interface ResponseCallback {
@Delegate
boolean invoke(int token_id, Pointer response);
}
interface RecalculateCallback {
@Delegate
boolean invoke(boolean is_recalculating);
}
class LLModelError extends Struct {
public final Struct.AsciiStringRef message = new Struct.AsciiStringRef();
public final int32_t status = new int32_t();
public LLModelError(jnr.ffi.Runtime runtime) {
super(runtime);
}
}
class LLModelPromptContext extends Struct {
public final Pointer logits = new Pointer();
public final ssize_t logits_size = new ssize_t();
public final Pointer tokens = new Pointer();
public final ssize_t tokens_size = new ssize_t();
public final int32_t n_past = new int32_t();
public final int32_t n_ctx = new int32_t();
public final int32_t n_predict = new int32_t();
public final int32_t top_k = new int32_t();
public final Float top_p = new Float();
public final Float temp = new Float();
public final int32_t n_batch = new int32_t();
public final Float repeat_penalty = new Float();
public final int32_t repeat_last_n = new int32_t();
public final Float context_erase = new Float();
public LLModelPromptContext(jnr.ffi.Runtime runtime) {
super(runtime);
}
}
Pointer llmodel_model_create2(String model_path, String build_variant, @Out LLModelError llmodel_error);
void llmodel_model_destroy(Pointer model);
boolean llmodel_loadModel(Pointer model, String model_path);
boolean llmodel_isModelLoaded(Pointer model);
@u_int64_t long llmodel_get_state_size(Pointer model);
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
@u_int64_t long llmodel_restore_state_data(Pointer model, Pointer src);
void llmodel_set_implementation_search_path(String path);
// ctx was an @Out ... without @Out crash
void llmodel_prompt(Pointer model, @Encoding("UTF-8") String prompt,
PromptCallback prompt_callback,
ResponseCallback response_callback,
RecalculateCallback recalculate_callback,
@In LLModelPromptContext ctx);
void llmodel_setThreadCount(Pointer model, int n_threads);
int llmodel_threadCount(Pointer model);
}

View File

@ -0,0 +1,147 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.LibraryLoader;
import jnr.ffi.LibraryOption;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Util {
private static final Logger logger = LoggerFactory.getLogger(Util.class);
private static final CharsetDecoder cs = StandardCharsets.UTF_8.newDecoder();
public static LLModelLibrary loadSharedLibrary(String librarySearchPath){
String libraryName = "llmodel";
Map<LibraryOption, Object> libraryOptions = new HashMap<>();
libraryOptions.put(LibraryOption.LoadNow, true); // load immediately instead of lazily (ie on first use)
libraryOptions.put(LibraryOption.IgnoreError, false); // calls shouldn't save last errno after call
if(librarySearchPath!=null) {
Map<String, List<String>> searchPaths = new HashMap<>();
searchPaths.put(libraryName, List.of(librarySearchPath));
return LibraryLoader.loadLibrary(LLModelLibrary.class,
libraryOptions,
searchPaths,
libraryName
);
}else {
return LibraryLoader.loadLibrary(LLModelLibrary.class,
libraryOptions,
libraryName
);
}
}
/**
* Copy over shared library files from resource package to
* target Temp directory.
* Returns Path to the temp directory holding the shared libraries
*/
public static Path copySharedLibraries() {
try {
// Identify the OS and architecture
String osName = System.getProperty("os.name").toLowerCase();
boolean isWindows = osName.startsWith("windows");
boolean isMac = osName.startsWith("mac os x");
boolean isLinux = osName.startsWith("linux");
if(isWindows) osName = "windows";
if(isMac) osName = "macos";
if(isLinux) osName = "linux";
//String osArch = System.getProperty("os.arch");
// Create a temporary directory
Path tempDirectory = Files.createTempDirectory("nativeLibraries");
tempDirectory.toFile().deleteOnExit();
String[] libraryNames = {
"gptj-default",
"gptj-avxonly",
"llmodel",
"mpt-default",
"llamamodel-230511-default",
"llamamodel-230519-default",
"llamamodel-mainline-default"
};
for (String libraryName : libraryNames) {
if(isWindows){
libraryName = libraryName + ".dll";
} else if(isMac){
libraryName = "lib" + libraryName + ".dylib";
} else if(isLinux) {
libraryName = "lib"+ libraryName + ".so";
}
// Construct the resource path based on the OS and architecture
String nativeLibraryPath = "/native/" + osName + "/" + libraryName;
// Get the library resource as a stream
InputStream in = Util.class.getResourceAsStream(nativeLibraryPath);
if (in == null) {
throw new RuntimeException("Unable to find native library: " + nativeLibraryPath);
}
// Create a file in the temporary directory with the original library name
Path tempLibraryPath = tempDirectory.resolve(libraryName);
// Use Files.copy to copy the library to the temporary file
Files.copy(in, tempLibraryPath, StandardCopyOption.REPLACE_EXISTING);
// Close the input stream
in.close();
}
// Add shutdown hook to delete tempDir on JVM exit
// On Windows deleting dll files that are loaded into memory is not possible.
if(!isWindows) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
Files.walk(tempDirectory)
.sorted(Comparator.reverseOrder())
.map(Path::toFile)
.forEach(file -> {
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Deleting temp library file", e);
}
});
} catch (IOException e) {
logger.error("Deleting temp directory for libraries", e);
}
}));
}
return tempDirectory;
} catch (IOException e) {
throw new RuntimeException("Failed to load native libraries", e);
}
}
public static String getValidUtf8(byte[] bytes) {
try {
return cs.decode(ByteBuffer.wrap(bytes)).toString();
} catch (CharacterCodingException e) {
return null;
}
}
}

View File

@ -0,0 +1,30 @@
package com.hexadevlabs.gpt4all;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
/**
* GPTJ chat completion, multiple messages
*/
public class Example1 {
public static void main(String[] args) {
// Optionally in case override to location of shared libraries is necessary
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
LLModel.GenerationConfig config = LLModel.config()
.withNPredict(4096).build();
gptjModel.chatCompletion(
List.of(Map.of("role", "user", "content", "Add 2+2"),
Map.of("role", "assistant", "content", "4"),
Map.of("role", "user", "content", "Multiply 4 * 5")), config, true, true);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,31 @@
package com.hexadevlabs.gpt4all;
import java.nio.file.Path;
/**
* Generation with MPT model
*/
public class Example2 {
public static void main(String[] args) {
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
// Optionally in case override to location of shared libraries is necessary
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
try (LLModel mptModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-mpt-7b-instruct.bin"))) {
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(4096)
.withRepeatLastN(64)
.build();
mptModel.generate(prompt, config, true);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,33 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.LibraryLoader;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
/**
* GPTJ chat completion with system message
*/
public class Example3 {
public static void main(String[] args) {
// Optionally in case override to location of shared libraries is necessary
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
LLModel.GenerationConfig config = LLModel.config()
.withNPredict(4096).build();
// String result = gptjModel.generate(prompt, config, true);
gptjModel.chatCompletion(
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,43 @@
package com.hexadevlabs.gpt4all;
import java.nio.file.Path;
public class Example4 {
public static void main(String[] args) {
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
// LLM should correctly identify it as poop emoji in the description
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
// Optionally in case override to location of shared libraries is necessary
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
//String model = "ggml-mpt-7b-instruct.bin";
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(4096)
.withRepeatLastN(64)
.build();
String result = mptModel.generate(prompt, config, true);
System.out.println("Code points:");
result.codePoints().forEach(System.out::println);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,68 @@
package com.hexadevlabs.gpt4all;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* For now disabled until can run with latest library on windows
*/
@Disabled
public class Tests {
static LLModel model;
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
@BeforeAll
public static void before(){
model = new LLModel(Path.of(PATH_TO_MODEL));
}
@Test
public void simplePrompt(){
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();
String prompt =
"### Instruction: \n" +
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
"### Prompt: \n" +
"Add 2+2\n" +
"### Response: 4\n" +
"Multiply 4 * 5\n" +
"### Response:";
String generatedText = model.generate(prompt, config, true);
assertTrue(generatedText.contains("20"));
}
@Test
public void chatCompletion(){
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();
LLModel.ChatCompletionResponse response= model.chatCompletion(
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
assertTrue( response.choices.get(0).get("content").contains("4") );
}
@AfterAll
public static void after() throws Exception {
if(model != null)
model.close();
}
}