diff --git a/gpt4all-bindings/java/README.md b/gpt4all-bindings/java/README.md index 63b580f9..cea23846 100644 --- a/gpt4all-bindings/java/README.md +++ b/gpt4all-bindings/java/README.md @@ -12,12 +12,12 @@ You can add Java bindings into your Java project by adding the following depende com.hexadevlabs gpt4all-java-binding - 1.1.2 + 1.1.3 ``` **Gradle** ``` -implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2' +implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.3' ``` To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/). @@ -103,3 +103,17 @@ class Example { ``` 2. Not every AVX-only shared library is bundled with the JAR right now to reduce size. Only libgptj-avx is included. If you are running into issues please let us know using the [gpt4all project issue tracker](https://github.com/nomic-ai/gpt4all/issues). + +3. For Windows the native library included in jar depends on specific Microsoft C and C++ (MSVC) runtime libraries which may not be installed on your system. +If this is the case you can easily download and install the latest x64 Microsoft Visual C++ Redistributable package from https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170 + +## Version history +1. Version **1.1.2**: + - Java bindings is compatible with gpt4ll version 2.4.6 + - Initial stable release with the initial feature set +2. Version **1.1.3**: + - Java bindings is compatible with gpt4all version 2.4.8 + - Add static GPT4ALL_VERSION to signify gpt4all version of the bindings + - Add PromptIsTooLongException for prompts that are longer than context size. + - Replit model support to include Metal Mac hardware support. + \ No newline at end of file diff --git a/gpt4all-bindings/java/TODO.md b/gpt4all-bindings/java/TODO.md new file mode 100644 index 00000000..3c85bf71 --- /dev/null +++ b/gpt4all-bindings/java/TODO.md @@ -0,0 +1,2 @@ +1. Better Chat completions function. +2. Chat completion that returns result in OpenAI compatible format. diff --git a/gpt4all-bindings/java/pom.xml b/gpt4all-bindings/java/pom.xml index 647b6876..454222fb 100644 --- a/gpt4all-bindings/java/pom.xml +++ b/gpt4all-bindings/java/pom.xml @@ -6,7 +6,7 @@ com.hexadevlabs gpt4all-java-binding - 1.1.2 + 1.1.3 jar @@ -56,6 +56,20 @@ 5.9.2 test + + + org.mockito + mockito-junit-jupiter + 5.4.0 + test + + + + org.mockito + mockito-core + 5.4.0 + test + @@ -103,7 +117,7 @@ ${project.build.directory}/generated-resources - C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023 + C:\Users\felix\dev\gpt4all_java_bins\release_1_1_3_Jun22_2023 diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java index bf18315c..f3d0d674 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -1,6 +1,8 @@ package com.hexadevlabs.gpt4all; import jnr.ffi.Pointer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; @@ -94,6 +96,10 @@ public class LLModel implements AutoCloseable { return this; } + /** + * + * @return GenerationConfig build instance of the config + */ public GenerationConfig build() { return configToBuild; } @@ -102,6 +108,8 @@ public class LLModel implements AutoCloseable { /** * Shortcut for making GenerativeConfig builder. + * + * @return GenerationConfig.Builder - builder that can be used to make a GenerationConfig */ public static GenerationConfig.Builder config(){ return new GenerationConfig.Builder(); @@ -122,14 +130,32 @@ public class LLModel implements AutoCloseable { */ public static boolean OUTPUT_DEBUG = false; + private static final Logger logger = LoggerFactory.getLogger(LLModel.class); + + /** + * Which version of GPT4ALL that this binding is built for. + * The binding is guaranteed to work with this version of + * GPT4ALL native libraries. The binding may work for older + * versions but that is not guaranteed. + */ + public static final String GPT4ALL_VERSION = "2.4.8"; + protected static LLModelLibrary library; protected Pointer model; protected String modelName; + /** + * Package private default constructor, for testing purposes. + */ + LLModel(){ + } + public LLModel(Path modelPath) { + logger.info("Java bindings for gpt4all version: " + GPT4ALL_VERSION); + if(library==null) { if (LIBRARY_SEARCH_PATH != null){ @@ -202,34 +228,7 @@ public class LLModel implements AutoCloseable { ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); - LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> { - - if(LLModel.OUTPUT_DEBUG) - System.out.print("Response token " + tokenID + " " ); - - long len = 0; - byte nextByte; - do{ - nextByte= response.getByte(len); - len++; - if(nextByte!=0) { - bufferingForWholeGeneration.write(nextByte); - if(streamToStdOut){ - bufferingForStdOutStream.write(nextByte); - // Test if Buffer is UTF8 valid string. - byte[] currentBytes = bufferingForStdOutStream.toByteArray(); - String validString = Util.getValidUtf8(currentBytes); - if(validString!=null){ // is valid string - System.out.print(validString); - // reset the buffer for next utf8 sequence to buffer - bufferingForStdOutStream.reset(); - } - } - } - } while(nextByte != 0); - - return true; // continue generating - }; + LLModelLibrary.ResponseCallback responseCallback = getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration); library.llmodel_prompt(this.model, prompt, @@ -249,6 +248,56 @@ public class LLModel implements AutoCloseable { return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); } + /** + * Callback method to be used by prompt method as text is generated. + * + * @param streamToStdOut Should send generated text to standard out. + * @param bufferingForStdOutStream Output stream used for buffering bytes for standard output. + * @param bufferingForWholeGeneration Output stream used for buffering a complete generation. + * @return LLModelLibrary.ResponseCallback lambda function that is invoked by response callback. + */ + static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) { + return (int tokenID, Pointer response) -> { + + if(LLModel.OUTPUT_DEBUG) + System.out.print("Response token " + tokenID + " " ); + + // For all models if input sequence in tokens is longer then model context length + // the error is generated. + if(tokenID==-1){ + throw new PromptIsTooLongException(response.getString(0, 1000, StandardCharsets.UTF_8)); + } + + long len = 0; + byte nextByte; + do{ + try { + nextByte = response.getByte(len); + } catch(IndexOutOfBoundsException e){ + // Not sure if this can ever happen but just in case + // the generation does not terminate in a Null (0) value. + throw new RuntimeException("Empty array or not null terminated"); + } + len++; + if(nextByte!=0) { + bufferingForWholeGeneration.write(nextByte); + if(streamToStdOut){ + bufferingForStdOutStream.write(nextByte); + // Test if Buffer is UTF8 valid string. + byte[] currentBytes = bufferingForStdOutStream.toByteArray(); + String validString = Util.getValidUtf8(currentBytes); + if(validString!=null){ // is valid string + System.out.print(validString); + // reset the buffer for next utf8 sequence to buffer + bufferingForStdOutStream.reset(); + } + } + } + } while(nextByte != 0); + + return true; // continue generating + }; + } public static class ChatCompletionResponse { @@ -277,7 +326,7 @@ public class LLModel implements AutoCloseable { * easier to process for chat UIs. It is not absolutely necessary as generate method * may be directly used to make generations with gpt models. * - * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..." + * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..." * @param generationConfig How to decode/process the generation. * @param streamToStdOut Send tokens as they are calculated Standard output. * @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output. diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java new file mode 100644 index 00000000..82301696 --- /dev/null +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java @@ -0,0 +1,7 @@ +package com.hexadevlabs.gpt4all; + +public class PromptIsTooLongException extends RuntimeException { + public PromptIsTooLongException(String message) { + super(message); + } +} diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java index ce3c9619..4b8a978c 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java @@ -52,7 +52,8 @@ public class Util { /** * Copy over shared library files from resource package to * target Temp directory. - * Returns Path to the temp directory holding the shared libraries + * + * @return Path path to the temp directory holding the shared libraries */ public static Path copySharedLibraries() { try { @@ -78,15 +79,26 @@ public class Util { "mpt-default", "llamamodel-230511-default", "llamamodel-230519-default", - "llamamodel-mainline-default" + "llamamodel-mainline-default", + "llamamodel-mainline-metal", + "replit-mainline-default", + "replit-mainline-metal", + "ggml-metal.metal" }; for (String libraryName : libraryNames) { + if(!isMac && ( + libraryName.equals("replit-mainline-metal") + || libraryName.equals("llamamodel-mainline-metal") + || libraryName.equals("ggml-metal.metal")) + ) continue; + if(isWindows){ libraryName = libraryName + ".dll"; } else if(isMac){ - libraryName = "lib" + libraryName + ".dylib"; + if(!libraryName.equals("ggml-metal.metal")) + libraryName = "lib" + libraryName + ".dylib"; } else if(isLinux) { libraryName = "lib"+ libraryName + ".so"; } diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java new file mode 100644 index 00000000..8bc7c914 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java @@ -0,0 +1,155 @@ +package com.hexadevlabs.gpt4all; + + +import jnr.ffi.Memory; +import jnr.ffi.Pointer; +import jnr.ffi.Runtime; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; + +import org.mockito.junit.jupiter.MockitoExtension; + + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +/** + * These tests only test the Java implementation as the underlying backend can't be mocked. + * These tests do serve the purpose of validating the java bits that do + * not directly have to do with the function of the underlying gp4all library. + */ +@ExtendWith(MockitoExtension.class) +public class BasicTests { + + @Test + public void simplePrompt(){ + + LLModel model = Mockito.spy(new LLModel()); + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(20) + .build(); + + // The generate method will return "4" + doReturn("4").when( model ).generate(anyString(), eq(config), eq(true)); + + LLModel.ChatCompletionResponse response= model.chatCompletion( + List.of(Map.of("role", "system", "content", "You are a helpful assistant"), + Map.of("role", "user", "content", "Add 2+2")), config, true, true); + + assertTrue( response.choices.get(0).get("content").contains("4") ); + + // Verifies the prompt and response are certain length. + assertEquals( 224 , response.usage.totalTokens ); + } + + @Test + public void testResponseCallback(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // Allocate memory for the byte array. Has to be null terminated + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination + + // Optional: Converting the byte array back to a String to print the character + String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8); + + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + responseCallback.invoke(1, pointer); + + String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + + assertEquals(decodedString, result); + + } + + @Test + public void testResponseCallbackTwoTokens(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // Allocate memory for the byte array. Has to be null terminated + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination + byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination + + // Optional: Converting the byte array back to a String to print the character + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + responseCallback.invoke(1, pointer); + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length); + + responseCallback.invoke(2, pointer); + + String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8); + + assertEquals("\uD83D\uDCA9", result); + + } + + + @Test + public void testResponseCallbackExpectError(){ + + ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream(); + ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream(); + + LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration); + + // Get the runtime instance + Runtime runtime = Runtime.getSystemRuntime(); + + // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9 + byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination + + Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length); + + // Copy the byte array to the allocated memory + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer)); + + assertEquals("Empty array or not null terminated", exception.getMessage()); + + // With empty array + utf8ByteArray = new byte[0]; + pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length); + + Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer)); + + assertEquals("Empty array or not null terminated", exceptionN.getMessage()); + + } + +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java new file mode 100644 index 00000000..a2045cd0 --- /dev/null +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java @@ -0,0 +1,47 @@ +package com.hexadevlabs.gpt4all; + +import java.nio.file.Path; + +public class Example5 { + + public static void main(String[] args) { + + // String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:"; + // The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string. + // LLM should correctly identify it as poop emoji in the description + //String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:"; + //String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:"; + + // Optionally in case override to location of shared libraries is necessary + //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\"; + StringBuffer b = new StringBuffer(); + b.append("The ".repeat(2060)); + String prompt = b.toString(); + + + String model = "ggml-vicuna-7b-1.1-q4_2.bin"; + //String model = "ggml-gpt4all-j-v1.3-groovy.bin"; + //String model = "ggml-mpt-7b-instruct.bin"; + String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\"; + //String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/"; + + try (LLModel mptModel = new LLModel(Path.of(basePath + model))) { + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(4096) + .withRepeatLastN(64) + .build(); + + String result = mptModel.generate(prompt, config, true); + + System.out.println("Code points:"); + result.codePoints().forEach(System.out::println); + + + } catch (Exception e) { + System.out.println(e.getMessage()); + throw new RuntimeException(e); + } + } +} diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java deleted file mode 100644 index c1b77860..00000000 --- a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java +++ /dev/null @@ -1,68 +0,0 @@ -package com.hexadevlabs.gpt4all; - -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.nio.file.Path; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -/** - * For now disabled until can run with latest library on windows - */ -@Disabled -public class Tests { - - static LLModel model; - static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin"; - - @BeforeAll - public static void before(){ - model = new LLModel(Path.of(PATH_TO_MODEL)); - } - - @Test - public void simplePrompt(){ - LLModel.GenerationConfig config = - LLModel.config() - .withNPredict(20) - .build(); - String prompt = - "### Instruction: \n" + - "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" + - "### Prompt: \n" + - "Add 2+2\n" + - "### Response: 4\n" + - "Multiply 4 * 5\n" + - "### Response:"; - String generatedText = model.generate(prompt, config, true); - assertTrue(generatedText.contains("20")); - } - - @Test - public void chatCompletion(){ - LLModel.GenerationConfig config = - LLModel.config() - .withNPredict(20) - .build(); - - LLModel.ChatCompletionResponse response= model.chatCompletion( - List.of(Map.of("role", "system", "content", "You are a helpful assistant"), - Map.of("role", "user", "content", "Add 2+2")), config, true, true); - - assertTrue( response.choices.get(0).get("content").contains("4") ); - - } - - @AfterAll - public static void after() throws Exception { - if(model != null) - model.close(); - } - - -}