From 4e274baee154eb65b02257733e927445442baf3b Mon Sep 17 00:00:00 2001 From: felix Date: Sun, 18 Jun 2023 00:50:12 -0400 Subject: [PATCH] bump version a few more doc fixes. add macos metal files Add check for Prompt is too long. add logging statement for gpt4all version of the binding add version string, readme update Add unit tests for Java code of the java bindings. --- gpt4all-bindings/java/README.md | 18 +- gpt4all-bindings/java/TODO.md | 2 + gpt4all-bindings/java/pom.xml | 18 +- .../java/com/hexadevlabs/gpt4all/LLModel.java | 91 +++++++--- .../gpt4all/PromptIsTooLongException.java | 7 + .../java/com/hexadevlabs/gpt4all/Util.java | 18 +- .../com/hexadevlabs/gpt4all/BasicTests.java | 155 ++++++++++++++++++ .../com/hexadevlabs/gpt4all/Example5.java | 47 ++++++ .../java/com/hexadevlabs/gpt4all/Tests.java | 68 -------- 9 files changed, 328 insertions(+), 96 deletions(-) create mode 100644 gpt4all-bindings/java/TODO.md create mode 100644 gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java create mode 100644 gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java create mode 100644 gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java delete mode 100644 gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java diff --git a/gpt4all-bindings/java/README.md b/gpt4all-bindings/java/README.md index 63b580f9..cea23846 100644 --- a/gpt4all-bindings/java/README.md +++ b/gpt4all-bindings/java/README.md @@ -12,12 +12,12 @@ You can add Java bindings into your Java project by adding the following depende com.hexadevlabs gpt4all-java-binding - 1.1.2 + 1.1.3 ``` **Gradle** ``` -implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2' +implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.3' ``` To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/). @@ -103,3 +103,17 @@ class Example { ``` 2. Not every AVX-only shared library is bundled with the JAR right now to reduce size. Only libgptj-avx is included. If you are running into issues please let us know using the [gpt4all project issue tracker](https://github.com/nomic-ai/gpt4all/issues). + +3. For Windows the native library included in jar depends on specific Microsoft C and C++ (MSVC) runtime libraries which may not be installed on your system. +If this is the case you can easily download and install the latest x64 Microsoft Visual C++ Redistributable package from https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170 + +## Version history +1. Version **1.1.2**: + - Java bindings is compatible with gpt4ll version 2.4.6 + - Initial stable release with the initial feature set +2. Version **1.1.3**: + - Java bindings is compatible with gpt4all version 2.4.8 + - Add static GPT4ALL_VERSION to signify gpt4all version of the bindings + - Add PromptIsTooLongException for prompts that are longer than context size. + - Replit model support to include Metal Mac hardware support. + \ No newline at end of file diff --git a/gpt4all-bindings/java/TODO.md b/gpt4all-bindings/java/TODO.md new file mode 100644 index 00000000..3c85bf71 --- /dev/null +++ b/gpt4all-bindings/java/TODO.md @@ -0,0 +1,2 @@ +1. Better Chat completions function. +2. Chat completion that returns result in OpenAI compatible format. diff --git a/gpt4all-bindings/java/pom.xml b/gpt4all-bindings/java/pom.xml index 647b6876..454222fb 100644 --- a/gpt4all-bindings/java/pom.xml +++ b/gpt4all-bindings/java/pom.xml @@ -6,7 +6,7 @@ com.hexadevlabs gpt4all-java-binding - 1.1.2 + 1.1.3 jar @@ -56,6 +56,20 @@ 5.9.2 test + + + org.mockito + mockito-junit-jupiter + 5.4.0 + test + + + + org.mockito + mockito-core + 5.4.0 + test + @@ -103,7 +117,7 @@ ${project.build.directory}/generated-resources - C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023 + C:\Users\felix\dev\gpt4all_java_bins\release_1_1_3_Jun22_2023 diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java index bf18315c..f3d0d674 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -1,6 +1,8 @@ package com.hexadevlabs.gpt4all; import jnr.ffi.Pointer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; @@ -94,6 +96,10 @@ public class LLModel implements AutoCloseable { return this; } + /** + * + * @return GenerationConfig build instance of the config + */ public GenerationConfig build() { return configToBuild; } @@ -102,6 +108,8 @@ public class LLModel implements AutoCloseable { /** * Shortcut for making GenerativeConfig builder. + * + * @return GenerationConfig.Builder - builder that can be used to make a GenerationConfig */ public static GenerationConfig.Builder config(){ return new GenerationConfig.Builder(); @@ -122,14 +130,32 @@ public class LLModel implements AutoCloseable { */ public static boolean OUTPUT_DEBUG = false; + private static final Logger logger = LoggerFactory.getLogger(LLModel.class); + + /** + * Which version of GPT4ALL that this binding is built for. + * The binding is guaranteed to work with this version of + * GPT4ALL native libraries. The binding may work for older + * versions but that is not guaranteed. + */ + public static final String GPT4ALL_VERSION = "2.4.8"; + protected static LLModelLibrary library; protected Pointer model; protected String modelName; + /** + * Package private default constructor, for testing purposes. + */ + LLModel(){ + } + public LLModel(Path modelPath) { + logger.info("Java bindings for gpt4all version: " + GPT4ALL_VERSION); + if(library==null) { if (LIBRARY_SEARCH_PATH != null){ @@ -202,15 +228,56 @@ public class LLModel implements AutoCloseable { ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); - LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> { + LLModelLibrary.ResponseCallback responseCallback = getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration); + + library.llmodel_prompt(this.model, + prompt, + (int tokenID) -> { + if(LLModel.OUTPUT_DEBUG) + System.out.println("token " + tokenID); + return true; // continue processing + }, + responseCallback, + (boolean isRecalculating) -> { + if(LLModel.OUTPUT_DEBUG) + System.out.println("recalculating"); + return isRecalculating; // continue generating + }, + generationConfig); + + return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + } + + /** + * Callback method to be used by prompt method as text is generated. + * + * @param streamToStdOut Should send generated text to standard out. + * @param bufferingForStdOutStream Output stream used for buffering bytes for standard output. + * @param bufferingForWholeGeneration Output stream used for buffering a complete generation. + * @return LLModelLibrary.ResponseCallback lambda function that is invoked by response callback. + */ + static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) { + return (int tokenID, Pointer response) -> { if(LLModel.OUTPUT_DEBUG) System.out.print("Response token " + tokenID + " " ); + // For all models if input sequence in tokens is longer then model context length + // the error is generated. + if(tokenID==-1){ + throw new PromptIsTooLongException(response.getString(0, 1000, StandardCharsets.UTF_8)); + } + long len = 0; byte nextByte; do{ - nextByte= response.getByte(len); + try { + nextByte = response.getByte(len); + } catch(IndexOutOfBoundsException e){ + // Not sure if this can ever happen but just in case + // the generation does not terminate in a Null (0) value. + throw new RuntimeException("Empty array or not null terminated"); + } len++; if(nextByte!=0) { bufferingForWholeGeneration.write(nextByte); @@ -230,27 +297,9 @@ public class LLModel implements AutoCloseable { return true; // continue generating }; - - library.llmodel_prompt(this.model, - prompt, - (int tokenID) -> { - if(LLModel.OUTPUT_DEBUG) - System.out.println("token " + tokenID); - return true; // continue processing - }, - responseCallback, - (boolean isRecalculating) -> { - if(LLModel.OUTPUT_DEBUG) - System.out.println("recalculating"); - return isRecalculating; // continue generating - }, - generationConfig); - - return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); } - public static class ChatCompletionResponse { public String model; public Usage usage; @@ -277,7 +326,7 @@ public class LLModel implements AutoCloseable { * easier to process for chat UIs. It is not absolutely necessary as generate method * may be directly used to make generations with gpt models. * - * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..." + * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..." * @param generationConfig How to decode/process the generation. * @param streamToStdOut Send tokens as they are calculated Standard output. * @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output. diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java new file mode 100644 index 00000000..82301696 --- /dev/null +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java @@ -0,0 +1,7 @@ +package com.hexadevlabs.gpt4all; + +public class PromptIsTooLongException extends RuntimeException { + public PromptIsTooLongException(String message) { + super(message); + } +} diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java index ce3c9619..4b8a978c 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java @@ -52,7 +52,8 @@ public class Util { /** * Copy over shared library files from resource package to * target Temp directory. - * Returns Path to the temp directory holding the shared libraries + * + * @return Path path to the temp directory holding the shared libraries */ public static Path copySharedLibraries() { try { @@ -78,15 +79,26 @@ public class Util { "mpt-default", "llamamodel-230511-default", "llamamodel-230519-default", - "llamamodel-mainline-default" + "llamamodel-mainline-default", + "llamamodel-mainline-metal", + "replit-mainline-default", + "replit-mainline-metal", + "ggml-metal.metal" }; for (String libraryName : libraryNames) { + if(!isMac && ( + libraryName.equals("replit-mainline-metal") + || libraryName.equals("llamamodel-mainline-metal") + || libraryName.equals("ggml-metal.metal")) + ) continue; + if(isWindows){ libraryName = libraryName + ".dll"; } else if(isMac){ - libraryName = "lib" + libraryName + ".dylib"; + if(!libraryName.equals("ggml-metal.metal")) + libraryName = "lib" + libraryName + ".dylib"; } else if(isLinux) { libraryName = "lib"+ libraryName + ".so"; } diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java new file mode 100644 index 00000000..8bc7c914 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java @@ -0,0 +1,155 @@ +package com.hexadevlabs.gpt4all; + + +import jnr.ffi.Memory; +import jnr.ffi.Pointer; +import jnr.ffi.Runtime; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; + +import org.mockito.junit.jupiter.MockitoExtension; + + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +/** + * These tests only test the Java implementation as the underlying backend can't be mocked. + * These tests do serve the purpose of validating the java bits that do + * not directly have to do with the function of the underlying gp4all library. + */ +@ExtendWith(MockitoExtension.class) +public class BasicTests { + + @Test + public void simplePrompt(){ + + LLModel model = Mockito.spy(new LLModel()); + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(20) + .build(); + + // The generate method will return "4" + doReturn("4").when( model ).generate(anyString(), eq(config), eq(true)); + + LLModel.ChatCompletionResponse response= model.chatCompletion( + List.of(Map.of("role", "system", "content", "You are a helpful assistant"), + Map.of("role", "user", "content", "Add 2+2")), config, true, true); + + assertTrue( response.choices.get(0).get("content").contains("4") ); + + // Verifies the prompt and response are certain length. + assertEquals( 224 , response.usage.totalTokens ); + } + + @Test + public void testResponseCallback(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // Allocate memory for the byte array. Has to be null terminated + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination + + // Optional: Converting the byte array back to a String to print the character + String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8); + + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + responseCallback.invoke(1, pointer); + + String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + + assertEquals(decodedString, result); + + } + + @Test + public void testResponseCallbackTwoTokens(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // Allocate memory for the byte array. Has to be null terminated + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination + byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination + + // Optional: Converting the byte array back to a String to print the character + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + responseCallback.invoke(1, pointer); + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length); + + responseCallback.invoke(2, pointer); + + String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + + assertEquals("\uD83D\uDCA9", result); + + } + + + @Test + public void testResponseCallbackExpectError(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination + + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer)); + + assertEquals("Empty array or not null terminated", exception.getMessage()); + + // With empty array + utf8ByteArray = new byte[0]; + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer)); + + assertEquals("Empty array or not null terminated", exceptionN.getMessage()); + + } + +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java new file mode 100644 index 00000000..a2045cd0 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java @@ -0,0 +1,47 @@ +package com.hexadevlabs.gpt4all; + +import java.nio.file.Path; + +public class Example5 { + + public static void main(String[] args) { + + // String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:"; + // The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string. + // LLM should correctly identify it as poop emoji in the description + //String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:"; + //String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:"; + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + StringBuffer b = new StringBuffer(); + b.append("The ".repeat(2060)); + String prompt = b.toString(); + + + String model = "ggml-vicuna-7b-1.1-q4_2.bin"; + //String model = "ggml-gpt4all-j-v1.3-groovy.bin"; + //String model = "ggml-mpt-7b-instruct.bin"; + String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\"; + //String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/"; + + try (LLModel mptModel = new LLModel(Path.of(basePath + model))) { + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(4096) + .withRepeatLastN(64) + .build(); + + String result = mptModel.generate(prompt, config, true); + + System.out.println("Code points:"); + result.codePoints().forEach(System.out::println); + + + } catch (Exception e) { + System.out.println(e.getMessage()); + throw new RuntimeException(e); + } + } +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java deleted file mode 100644 index c1b77860..00000000 --- a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java +++ /dev/null @@ -1,68 +0,0 @@ -package com.hexadevlabs.gpt4all; - -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.nio.file.Path; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -/** - * For now disabled until can run with latest library on windows - */ -@Disabled -public class Tests { - - static LLModel model; - static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin"; - - @BeforeAll - public static void before(){ - model = new LLModel(Path.of(PATH_TO_MODEL)); - } - - @Test - public void simplePrompt(){ - LLModel.GenerationConfig config = - LLModel.config() - .withNPredict(20) - .build(); - String prompt = - "### Instruction: \n" + - "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" + - "### Prompt: \n" + - "Add 2+2\n" + - "### Response: 4\n" + - "Multiply 4 * 5\n" + - "### Response:"; - String generatedText = model.generate(prompt, config, true); - assertTrue(generatedText.contains("20")); - } - - @Test - public void chatCompletion(){ - LLModel.GenerationConfig config = - LLModel.config() - .withNPredict(20) - .build(); - - LLModel.ChatCompletionResponse response= model.chatCompletion( - List.of(Map.of("role", "system", "content", "You are a helpful assistant"), - Map.of("role", "user", "content", "Add 2+2")), config, true, true); - - assertTrue( response.choices.get(0).get("content").contains("4") ); - - } - - @AfterAll - public static void after() throws Exception { - if(model != null) - model.close(); - } - - -}