diff --git a/README.md b/README.md index 380fc51d..1c85b7f4 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ Find the most up-to-date information on the [GPT4All Website](https://gpt4all.io * :computer: Official Typescript Bindings * :computer: Official GoLang Bindings * :computer: Official C# Bindings +* :computer: Official Java Bindings ## Contributing diff --git a/gpt4all-bindings/java/.gitignore b/gpt4all-bindings/java/.gitignore new file mode 100644 index 00000000..8c3a43d3 --- /dev/null +++ b/gpt4all-bindings/java/.gitignore @@ -0,0 +1,2 @@ +# Make sure native directory never gets commited to git for the project. +/src/main/resources/native \ No newline at end of file diff --git a/gpt4all-bindings/java/README.md b/gpt4all-bindings/java/README.md new file mode 100644 index 00000000..fb9f523d --- /dev/null +++ b/gpt4all-bindings/java/README.md @@ -0,0 +1,105 @@ +# Java bindings + +Java bindings let you load a gpt4all library into your Java application and execute text +generation using an intuitive and easy to use API. No GPU is required because gpt4all system executes on the cpu. +The gpt4all models are quantized to easily fit into system RAM and use about 4 to 7GB of system RAM. + +## Getting Started +You can add Java bindings into your Java project by adding dependency to your project: + +**Maven** +``` + + com.hexadevlabs + gpt4all-java-binding + 1.1.2 + +``` +**Gradle** +``` +implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2' +``` + +To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/). + +To download a model binary weights file use an url such as https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin. + +For information about other models available see [Model file list](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-chat#manual-download-of-models). + +### Sample code +```java +public class Example { + public static void main(String[] args) { + + String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:"; + + // Replace the hardcoded path with the actual path where your model file resides + String modelFilePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin"; + + try (LLModel model = new LLModel(Path.of(modelFilePath))) { + + // May generate up to 4096 tokens but generally stops early + LLModel.GenerationConfig config = LLModel.config() + .withNPredict(4096).build(); + + // Will also stream to Standard out + String fullGeneration = model.generate(prompt, config, true); + + } catch (Exception e) { + // Exception generally may happen if model file fails to load + // for a number of reasons such as file not found. + // It is possible that Java may not be able to dynamically load the native shared library or + // the llmodel shared library may not be able to dynamically load the backend + // implementation for the model file you provided. + // + // Once the LLModel class is successfully loaded into memory the text generation calls + // generally should not throw exceptions. + e.printStackTrace(); // Printing here but in production system you may want to take some action. + } + } + +} +``` + +For a maven based sample project that uses this library see [Sample project](https://github.com/felix-zaslavskiy/gpt4all-java-bindings-sample) + +### Additional considerations +#### Logger warnings +The Java bindings library may produce a warning: +``` +SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". +SLF4J: Defaulting to no-operation (NOP) logger implementation +SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. +``` +If you don't have a SLF4J binding included in your project. Java bindings only use logging for informational +purposes, so logger is not essential to correctly use the library. You can ignore this warning if you don't have SLF4J bindings +in your project. + +To add a simple logger using maven dependency you may use: +``` + + org.slf4j + slf4j-simple + 1.7.36 + +``` + +#### Loading your native libraries +1. Java bindings package jar comes bundled with native library files for Windows, macOS and Linux. These library files are +copied to a temporary directory and loaded at runtime. For advanced users who may want to package shared libraries into Docker containers +or want to use a custom build of the shared libraries and ignore the once bundled with the java package they have option +to load libraries from your local directory by setting a static property to the location of library files. +There are no guarantees of compatibility if used in such a way so be careful if you really want to do it. + +For example: +```java +class Example { + public static void main(String[] args) { + // gpt4all native shared libraries location + LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + // ... use the library normally + } +} +``` +2. Not every avx only shared library is bundled with the jar right now to reduce size. Only the libgptj-avx is included. +If you are running into issues please let us know using the gpt4all project issue tracker https://github.com/nomic-ai/gpt4all/issues. diff --git a/gpt4all-bindings/java/pom.xml b/gpt4all-bindings/java/pom.xml new file mode 100644 index 00000000..647b6876 --- /dev/null +++ b/gpt4all-bindings/java/pom.xml @@ -0,0 +1,202 @@ + + + 4.0.0 + + com.hexadevlabs + gpt4all-java-binding + 1.1.2 + jar + + + 11 + 11 + UTF-8 + + + ${project.groupId}:${project.artifactId} + Java bindings for GPT4ALL LLM + https://github.com/nomic-ai/gpt4all + + + The Apache License, Version 2.0 + https://github.com/nomic-ai/gpt4all/blob/main/LICENSE.txt + + + + + Felix Zaslavskiy + felixz@hexadevlabs.com + https://github.com/felix-zaslavskiy/ + + + + scm:git:git://github.com/nomic-ai/gpt4all.git + scm:git:ssh://github.com/nomic-ai/gpt4all.git + https://github.com/nomic-ai/gpt4all/tree/main + + + + + com.github.jnr + jnr-ffi + 2.2.13 + + + + org.slf4j + slf4j-api + 1.7.36 + + + + org.junit.jupiter + junit-jupiter-api + 5.9.2 + test + + + + + + ossrh + https://s01.oss.sonatype.org/content/repositories/snapshots + + + ossrh + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + + src/main/resources + + + ${project.build.directory}/generated-resources + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.0.0 + + 0 + + + + org.apache.maven.plugins + maven-resources-plugin + 3.3.1 + + + copy-resources + + validate + + copy-resources + + + ${project.build.directory}/generated-resources + + + C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023 + + + + + + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.6.13 + true + + ossrh + https://s01.oss.sonatype.org/ + true + + + + org.apache.maven.plugins + maven-source-plugin + 2.2.1 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.5.0 + + + attach-javadocs + + jar + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.5 + + + sign-artifacts + verify + + sign + + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.6.0 + + + jar-with-dependencies + + + + com.hexadevlabs.gpt4allsample.Example4 + + + + + + make-assembly + package + + single + + + + + + + + + + + \ No newline at end of file diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java new file mode 100644 index 00000000..d571b8d2 --- /dev/null +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -0,0 +1,349 @@ +package com.hexadevlabs.gpt4all; + +import jnr.ffi.Pointer; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class LLModel implements AutoCloseable { + + /** + * Config used for how to decode LLM outputs. + * High temperature closer to 1 gives more creative outputs + * while low temperature closer to 0 produce more precise outputs. + *

+ * Use builder to set settings you want. + */ + public static class GenerationConfig extends LLModelLibrary.LLModelPromptContext { + + private GenerationConfig() { + super(jnr.ffi.Runtime.getSystemRuntime()); + logits_size.set(0); + tokens_size.set(0); + n_past.set(0); + n_ctx.set(1024); + n_predict.set(128); + top_k.set(40); + top_p.set(0.95); + temp.set(0.28); + n_batch.set(8); + repeat_penalty.set(1.1); + repeat_last_n.set(10); + context_erase.set(0.55); + } + + public static class Builder { + private final GenerationConfig configToBuild; + + public Builder() { + configToBuild = new GenerationConfig(); + } + + public Builder withNPast(int n_past) { + configToBuild.n_past.set(n_past); + return this; + } + + public Builder withNCtx(int n_ctx) { + configToBuild.n_ctx.set(n_ctx); + return this; + } + + public Builder withNPredict(int n_predict) { + configToBuild.n_predict.set(n_predict); + return this; + } + + public Builder withTopK(int top_k) { + configToBuild.top_k.set(top_k); + return this; + } + + public Builder withTopP(float top_p) { + configToBuild.top_p.set(top_p); + return this; + } + + public Builder withTemp(float temp) { + configToBuild.temp.set(temp); + return this; + } + + public Builder withNBatch(int n_batch) { + configToBuild.n_batch.set(n_batch); + return this; + } + + public Builder withRepeatPenalty(float repeat_penalty) { + configToBuild.repeat_penalty.set(repeat_penalty); + return this; + } + + public Builder withRepeatLastN(int repeat_last_n) { + configToBuild.repeat_last_n.set(repeat_last_n); + return this; + } + + public Builder withContextErase(float context_erase) { + configToBuild.context_erase.set(context_erase); + return this; + } + + public GenerationConfig build() { + return configToBuild; + } + } + } + + /** + * Shortcut for making GenerativeConfig builder. + */ + public static GenerationConfig.Builder config(){ + return new GenerationConfig.Builder(); + } + + /** + * This may be set before any Model instance classe are instantiated to + * set where the model may be found. This may be needed if setting + * library search path by standard means is not available. + */ + public static String LIBRARY_SEARCH_PATH; + + + /** + * Generally for debugging purposes only. Will print + * the numerical tokens as they are generated instead of the string representations. + * Will also print out the processed input tokens as numbers to standard out. + */ + public static boolean OUTPUT_DEBUG = false; + + protected static LLModelLibrary library; + + protected Pointer model; + + protected String modelName; + + public LLModel(Path modelPath) { + + if(library==null) { + + if (LIBRARY_SEARCH_PATH != null){ + library = Util.loadSharedLibrary(LIBRARY_SEARCH_PATH); + library.llmodel_set_implementation_search_path(LIBRARY_SEARCH_PATH); + } else { + // Copy system libraries to Temp folder + Path tempLibraryDirectory = Util.copySharedLibraries(); + library = Util.loadSharedLibrary(tempLibraryDirectory.toString()); + + library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString() ); + } + + } + + // modelType = type; + modelName = modelPath.getFileName().toString(); + String modelPathAbs = modelPath.toAbsolutePath().toString(); + + LLModelLibrary.LLModelError error = new LLModelLibrary.LLModelError(jnr.ffi.Runtime.getSystemRuntime()); + + // Check if model file exists + if(!Files.exists(modelPath)){ + throw new IllegalStateException("Model file does not exist: " + modelPathAbs); + } + + // Create Model Struct. Will load dynamically the correct backend based on model type + model = library.llmodel_model_create2(modelPathAbs, "auto", error); + + if(model == null) { + throw new IllegalStateException("Could not load gpt4all backend :" + error.message); + } + library.llmodel_loadModel(model, modelPathAbs); + + if(!library.llmodel_isModelLoaded(model)){ + throw new IllegalStateException("The model " + modelName + " could not be loaded"); + } + + } + + public void setThreadCount(int nThreads) { + library.llmodel_setThreadCount(this.model, nThreads); + } + + public int threadCount() { + return library.llmodel_threadCount(this.model); + } + + /** + * Generate text after the prompt + * + * @param prompt The text prompt to complete + * @param generationConfig What generation settings to use while generating text + * @return String The complete generated text + */ + public String generate(String prompt, GenerationConfig generationConfig) { + return generate(prompt, generationConfig, false); + } + + /** + * Generate text after the prompt + * + * @param prompt The text prompt to complete + * @param generationConfig What generation settings to use while generating text + * @param streamToStdOut Should the generation be streamed to standard output. Useful for troubleshooting. + * @return String The complete generated text + */ + public String generate(String prompt, GenerationConfig generationConfig, boolean streamToStdOut) { + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> { + + if(LLModel.OUTPUT_DEBUG) + System.out.print("Response token " + tokenID + " " ); + + long len = 0; + byte nextByte; + do{ + nextByte= response.getByte(len); + len++; + if(nextByte!=0) { + bufferingForWholeGeneration.write(nextByte); + if(streamToStdOut){ + bufferingForStdOutStream.write(nextByte); + // Test if Buffer is UTF8 valid string. + byte[] currentBytes = bufferingForStdOutStream.toByteArray(); + String validString = Util.getValidUtf8(currentBytes); + if(validString!=null){ // is valid string + System.out.print(validString); + // reset the buffer for next utf8 sequence to buffer + bufferingForStdOutStream.reset(); + } + } + } + } while(nextByte != 0); + + return true; // continue generating + }; + + library.llmodel_prompt(this.model, + prompt, + (int tokenID) -> { + if(LLModel.OUTPUT_DEBUG) + System.out.println("token " + tokenID); + return true; // continue processing + }, + responseCallback, + (boolean isRecalculating) -> { + if(LLModel.OUTPUT_DEBUG) + System.out.println("recalculating"); + return isRecalculating; // continue generating + }, + generationConfig); + + return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + } + + + + public static class ChatCompletionResponse { + public String model; + public Usage usage; + public List> choices; + + // Getters and setters + } + + public static class Usage { + public int promptTokens; + public int completionTokens; + public int totalTokens; + + // Getters and setters + } + + public ChatCompletionResponse chatCompletion(List> messages, + GenerationConfig generationConfig) { + return chatCompletion(messages, generationConfig, false, false); + } + + /** + * chatCompletion formats the existing chat conversation into a template to be + * easier to process for chat UIs. It is not absolutely necessary as generate method + * may be directly used to make generations with gpt models. + * + * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..." + * @param generationConfig How to decode/process the generation. + * @param streamToStdOut Send tokens as they are calculated Standard output. + * @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output. + * @return ChatCompletionResponse contains stats and generated Text. + */ + public ChatCompletionResponse chatCompletion(List> messages, + GenerationConfig generationConfig, boolean streamToStdOut, + boolean outputFullPromptToStdOut) { + String fullPrompt = buildPrompt(messages); + + if(outputFullPromptToStdOut) + System.out.print(fullPrompt); + + String generatedText = generate(fullPrompt, generationConfig, streamToStdOut); + + ChatCompletionResponse response = new ChatCompletionResponse(); + response.model = this.modelName; + + Usage usage = new Usage(); + usage.promptTokens = fullPrompt.length(); + usage.completionTokens = generatedText.length(); + usage.totalTokens = fullPrompt.length() + generatedText.length(); + response.usage = usage; + + Map message = new HashMap<>(); + message.put("role", "assistant"); + message.put("content", generatedText); + + response.choices = List.of(message); + + return response; + } + + protected static String buildPrompt(List> messages) { + StringBuilder fullPrompt = new StringBuilder(); + + for (Map message : messages) { + if ("system".equals(message.get("role"))) { + String systemMessage = message.get("content") + "\n"; + fullPrompt.append(systemMessage); + } + } + + fullPrompt.append("### Instruction: \n" + + "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" + + "### Prompt: "); + + for (Map message : messages) { + if ("user".equals(message.get("role"))) { + String userMessage = "\n" + message.get("content"); + fullPrompt.append(userMessage); + } + if ("assistant".equals(message.get("role"))) { + String assistantMessage = "\n### Response: " + message.get("content"); + fullPrompt.append(assistantMessage); + } + } + + fullPrompt.append("\n### Response:"); + + return fullPrompt.toString(); + } + + @Override + public void close() throws Exception { + library.llmodel_model_destroy(model); + } + +} \ No newline at end of file diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java new file mode 100644 index 00000000..eca19f8c --- /dev/null +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModelLibrary.java @@ -0,0 +1,79 @@ +package com.hexadevlabs.gpt4all; + +import jnr.ffi.Pointer; +import jnr.ffi.Struct; +import jnr.ffi.annotations.Delegate; +import jnr.ffi.annotations.Encoding; +import jnr.ffi.annotations.In; +import jnr.ffi.annotations.Out; +import jnr.ffi.types.u_int64_t; + + +/** + * The basic Native library interface the provides all the LLM functions. + */ +public interface LLModelLibrary { + + interface PromptCallback { + @Delegate + boolean invoke(int token_id); + } + + interface ResponseCallback { + @Delegate + boolean invoke(int token_id, Pointer response); + } + + interface RecalculateCallback { + @Delegate + boolean invoke(boolean is_recalculating); + } + + class LLModelError extends Struct { + public final Struct.AsciiStringRef message = new Struct.AsciiStringRef(); + public final int32_t status = new int32_t(); + public LLModelError(jnr.ffi.Runtime runtime) { + super(runtime); + } + } + + class LLModelPromptContext extends Struct { + public final Pointer logits = new Pointer(); + public final ssize_t logits_size = new ssize_t(); + public final Pointer tokens = new Pointer(); + public final ssize_t tokens_size = new ssize_t(); + public final int32_t n_past = new int32_t(); + public final int32_t n_ctx = new int32_t(); + public final int32_t n_predict = new int32_t(); + public final int32_t top_k = new int32_t(); + public final Float top_p = new Float(); + public final Float temp = new Float(); + public final int32_t n_batch = new int32_t(); + public final Float repeat_penalty = new Float(); + public final int32_t repeat_last_n = new int32_t(); + public final Float context_erase = new Float(); + + public LLModelPromptContext(jnr.ffi.Runtime runtime) { + super(runtime); + } + } + + Pointer llmodel_model_create2(String model_path, String build_variant, @Out LLModelError llmodel_error); + void llmodel_model_destroy(Pointer model); + boolean llmodel_loadModel(Pointer model, String model_path); + boolean llmodel_isModelLoaded(Pointer model); + @u_int64_t long llmodel_get_state_size(Pointer model); + @u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest); + @u_int64_t long llmodel_restore_state_data(Pointer model, Pointer src); + + void llmodel_set_implementation_search_path(String path); + + // ctx was an @Out ... without @Out crash + void llmodel_prompt(Pointer model, @Encoding("UTF-8") String prompt, + PromptCallback prompt_callback, + ResponseCallback response_callback, + RecalculateCallback recalculate_callback, + @In LLModelPromptContext ctx); + void llmodel_setThreadCount(Pointer model, int n_threads); + int llmodel_threadCount(Pointer model); +} diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java new file mode 100644 index 00000000..ce3c9619 --- /dev/null +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java @@ -0,0 +1,147 @@ +package com.hexadevlabs.gpt4all; + +import jnr.ffi.LibraryLoader; +import jnr.ffi.LibraryOption; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Util { + + private static final Logger logger = LoggerFactory.getLogger(Util.class); + private static final CharsetDecoder cs = StandardCharsets.UTF_8.newDecoder(); + + public static LLModelLibrary loadSharedLibrary(String librarySearchPath){ + String libraryName = "llmodel"; + Map libraryOptions = new HashMap<>(); + libraryOptions.put(LibraryOption.LoadNow, true); // load immediately instead of lazily (ie on first use) + libraryOptions.put(LibraryOption.IgnoreError, false); // calls shouldn't save last errno after call + + if(librarySearchPath!=null) { + Map> searchPaths = new HashMap<>(); + searchPaths.put(libraryName, List.of(librarySearchPath)); + + return LibraryLoader.loadLibrary(LLModelLibrary.class, + libraryOptions, + searchPaths, + libraryName + ); + }else { + + return LibraryLoader.loadLibrary(LLModelLibrary.class, + libraryOptions, + libraryName + ); + } + + } + + /** + * Copy over shared library files from resource package to + * target Temp directory. + * Returns Path to the temp directory holding the shared libraries + */ + public static Path copySharedLibraries() { + try { + // Identify the OS and architecture + String osName = System.getProperty("os.name").toLowerCase(); + boolean isWindows = osName.startsWith("windows"); + boolean isMac = osName.startsWith("mac os x"); + boolean isLinux = osName.startsWith("linux"); + if(isWindows) osName = "windows"; + if(isMac) osName = "macos"; + if(isLinux) osName = "linux"; + + //String osArch = System.getProperty("os.arch"); + + // Create a temporary directory + Path tempDirectory = Files.createTempDirectory("nativeLibraries"); + tempDirectory.toFile().deleteOnExit(); + + String[] libraryNames = { + "gptj-default", + "gptj-avxonly", + "llmodel", + "mpt-default", + "llamamodel-230511-default", + "llamamodel-230519-default", + "llamamodel-mainline-default" + }; + + for (String libraryName : libraryNames) { + + if(isWindows){ + libraryName = libraryName + ".dll"; + } else if(isMac){ + libraryName = "lib" + libraryName + ".dylib"; + } else if(isLinux) { + libraryName = "lib"+ libraryName + ".so"; + } + + // Construct the resource path based on the OS and architecture + String nativeLibraryPath = "/native/" + osName + "/" + libraryName; + + // Get the library resource as a stream + InputStream in = Util.class.getResourceAsStream(nativeLibraryPath); + if (in == null) { + throw new RuntimeException("Unable to find native library: " + nativeLibraryPath); + } + + // Create a file in the temporary directory with the original library name + Path tempLibraryPath = tempDirectory.resolve(libraryName); + + // Use Files.copy to copy the library to the temporary file + Files.copy(in, tempLibraryPath, StandardCopyOption.REPLACE_EXISTING); + + // Close the input stream + in.close(); + } + + // Add shutdown hook to delete tempDir on JVM exit + // On Windows deleting dll files that are loaded into memory is not possible. + if(!isWindows) { + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + Files.walk(tempDirectory) + .sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .forEach(file -> { + try { + Files.delete(file.toPath()); + } catch (IOException e) { + logger.error("Deleting temp library file", e); + } + }); + } catch (IOException e) { + logger.error("Deleting temp directory for libraries", e); + } + })); + } + + return tempDirectory; + } catch (IOException e) { + throw new RuntimeException("Failed to load native libraries", e); + } + } + + public static String getValidUtf8(byte[] bytes) { + try { + return cs.decode(ByteBuffer.wrap(bytes)).toString(); + } catch (CharacterCodingException e) { + return null; + } + } +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example1.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example1.java new file mode 100644 index 00000000..e8925a1f --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example1.java @@ -0,0 +1,30 @@ +package com.hexadevlabs.gpt4all; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +/** + * GPTJ chat completion, multiple messages + */ +public class Example1 { + public static void main(String[] args) { + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + + try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){ + + LLModel.GenerationConfig config = LLModel.config() + .withNPredict(4096).build(); + + gptjModel.chatCompletion( + List.of(Map.of("role", "user", "content", "Add 2+2"), + Map.of("role", "assistant", "content", "4"), + Map.of("role", "user", "content", "Multiply 4 * 5")), config, true, true); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example2.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example2.java new file mode 100644 index 00000000..35719d15 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example2.java @@ -0,0 +1,31 @@ +package com.hexadevlabs.gpt4all; + +import java.nio.file.Path; + +/** + * Generation with MPT model + */ +public class Example2 { + public static void main(String[] args) { + + String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:"; + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + + try (LLModel mptModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-mpt-7b-instruct.bin"))) { + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(4096) + .withRepeatLastN(64) + .build(); + + mptModel.generate(prompt, config, true); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example3.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example3.java new file mode 100644 index 00000000..fd842cdc --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example3.java @@ -0,0 +1,33 @@ +package com.hexadevlabs.gpt4all; + +import jnr.ffi.LibraryLoader; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +/** + * GPTJ chat completion with system message + */ +public class Example3 { + public static void main(String[] args) { + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + + try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){ + + LLModel.GenerationConfig config = LLModel.config() + .withNPredict(4096).build(); + + // String result = gptjModel.generate(prompt, config, true); + gptjModel.chatCompletion( + List.of(Map.of("role", "system", "content", "You are a helpful assistant"), + Map.of("role", "user", "content", "Add 2+2")), config, true, true); + + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example4.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example4.java new file mode 100644 index 00000000..f8fc0029 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example4.java @@ -0,0 +1,43 @@ +package com.hexadevlabs.gpt4all; + +import java.nio.file.Path; + +public class Example4 { + + public static void main(String[] args) { + + String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:"; + // The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string. + // LLM should correctly identify it as poop emoji in the description + //String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:"; + //String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:"; + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + + String model = "ggml-vicuna-7b-1.1-q4_2.bin"; + //String model = "ggml-gpt4all-j-v1.3-groovy.bin"; + //String model = "ggml-mpt-7b-instruct.bin"; + String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\"; + //String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/"; + + try (LLModel mptModel = new LLModel(Path.of(basePath + model))) { + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(4096) + .withRepeatLastN(64) + .build(); + + + String result = mptModel.generate(prompt, config, true); + + System.out.println("Code points:"); + result.codePoints().forEach(System.out::println); + + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java new file mode 100644 index 00000000..c1b77860 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java @@ -0,0 +1,68 @@ +package com.hexadevlabs.gpt4all; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * For now disabled until can run with latest library on windows + */ +@Disabled +public class Tests { + + static LLModel model; + static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin"; + + @BeforeAll + public static void before(){ + model = new LLModel(Path.of(PATH_TO_MODEL)); + } + + @Test + public void simplePrompt(){ + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(20) + .build(); + String prompt = + "### Instruction: \n" + + "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" + + "### Prompt: \n" + + "Add 2+2\n" + + "### Response: 4\n" + + "Multiply 4 * 5\n" + + "### Response:"; + String generatedText = model.generate(prompt, config, true); + assertTrue(generatedText.contains("20")); + } + + @Test + public void chatCompletion(){ + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(20) + .build(); + + LLModel.ChatCompletionResponse response= model.chatCompletion( + List.of(Map.of("role", "system", "content", "You are a helpful assistant"), + Map.of("role", "user", "content", "Add 2+2")), config, true, true); + + assertTrue( response.choices.get(0).get("content").contains("4") ); + + } + + @AfterAll + public static void after() throws Exception { + if(model != null) + model.close(); + } + + +}