diff --git a/gpt4all-bindings/java/README.md b/gpt4all-bindings/java/README.md
index 63b580f9..cea23846 100644
--- a/gpt4all-bindings/java/README.md
+++ b/gpt4all-bindings/java/README.md
@@ -12,12 +12,12 @@ You can add Java bindings into your Java project by adding the following depende
com.hexadevlabs
gpt4all-java-binding
- 1.1.2
+ 1.1.3
```
**Gradle**
```
-implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
+implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.3'
```
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
@@ -103,3 +103,17 @@ class Example {
```
2. Not every AVX-only shared library is bundled with the JAR right now to reduce size. Only libgptj-avx is included.
If you are running into issues please let us know using the [gpt4all project issue tracker](https://github.com/nomic-ai/gpt4all/issues).
+
+3. For Windows the native library included in jar depends on specific Microsoft C and C++ (MSVC) runtime libraries which may not be installed on your system.
+If this is the case you can easily download and install the latest x64 Microsoft Visual C++ Redistributable package from https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170
+
+## Version history
+1. Version **1.1.2**:
+ - Java bindings is compatible with gpt4ll version 2.4.6
+ - Initial stable release with the initial feature set
+2. Version **1.1.3**:
+ - Java bindings is compatible with gpt4all version 2.4.8
+ - Add static GPT4ALL_VERSION to signify gpt4all version of the bindings
+ - Add PromptIsTooLongException for prompts that are longer than context size.
+ - Replit model support to include Metal Mac hardware support.
+
\ No newline at end of file
diff --git a/gpt4all-bindings/java/TODO.md b/gpt4all-bindings/java/TODO.md
new file mode 100644
index 00000000..3c85bf71
--- /dev/null
+++ b/gpt4all-bindings/java/TODO.md
@@ -0,0 +1,2 @@
+1. Better Chat completions function.
+2. Chat completion that returns result in OpenAI compatible format.
diff --git a/gpt4all-bindings/java/pom.xml b/gpt4all-bindings/java/pom.xml
index 647b6876..454222fb 100644
--- a/gpt4all-bindings/java/pom.xml
+++ b/gpt4all-bindings/java/pom.xml
@@ -6,7 +6,7 @@
com.hexadevlabs
gpt4all-java-binding
- 1.1.2
+ 1.1.3
jar
@@ -56,6 +56,20 @@
5.9.2
test
+
+
+ org.mockito
+ mockito-junit-jupiter
+ 5.4.0
+ test
+
+
+
+ org.mockito
+ mockito-core
+ 5.4.0
+ test
+
@@ -103,7 +117,7 @@
${project.build.directory}/generated-resources
- C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023
+ C:\Users\felix\dev\gpt4all_java_bins\release_1_1_3_Jun22_2023
diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java
index bf18315c..f3d0d674 100644
--- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java
+++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java
@@ -1,6 +1,8 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.Pointer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
@@ -94,6 +96,10 @@ public class LLModel implements AutoCloseable {
return this;
}
+ /**
+ *
+ * @return GenerationConfig build instance of the config
+ */
public GenerationConfig build() {
return configToBuild;
}
@@ -102,6 +108,8 @@ public class LLModel implements AutoCloseable {
/**
* Shortcut for making GenerativeConfig builder.
+ *
+ * @return GenerationConfig.Builder - builder that can be used to make a GenerationConfig
*/
public static GenerationConfig.Builder config(){
return new GenerationConfig.Builder();
@@ -122,14 +130,32 @@ public class LLModel implements AutoCloseable {
*/
public static boolean OUTPUT_DEBUG = false;
+ private static final Logger logger = LoggerFactory.getLogger(LLModel.class);
+
+ /**
+ * Which version of GPT4ALL that this binding is built for.
+ * The binding is guaranteed to work with this version of
+ * GPT4ALL native libraries. The binding may work for older
+ * versions but that is not guaranteed.
+ */
+ public static final String GPT4ALL_VERSION = "2.4.8";
+
protected static LLModelLibrary library;
protected Pointer model;
protected String modelName;
+ /**
+ * Package private default constructor, for testing purposes.
+ */
+ LLModel(){
+ }
+
public LLModel(Path modelPath) {
+ logger.info("Java bindings for gpt4all version: " + GPT4ALL_VERSION);
+
if(library==null) {
if (LIBRARY_SEARCH_PATH != null){
@@ -202,34 +228,7 @@ public class LLModel implements AutoCloseable {
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
- LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
-
- if(LLModel.OUTPUT_DEBUG)
- System.out.print("Response token " + tokenID + " " );
-
- long len = 0;
- byte nextByte;
- do{
- nextByte= response.getByte(len);
- len++;
- if(nextByte!=0) {
- bufferingForWholeGeneration.write(nextByte);
- if(streamToStdOut){
- bufferingForStdOutStream.write(nextByte);
- // Test if Buffer is UTF8 valid string.
- byte[] currentBytes = bufferingForStdOutStream.toByteArray();
- String validString = Util.getValidUtf8(currentBytes);
- if(validString!=null){ // is valid string
- System.out.print(validString);
- // reset the buffer for next utf8 sequence to buffer
- bufferingForStdOutStream.reset();
- }
- }
- }
- } while(nextByte != 0);
-
- return true; // continue generating
- };
+ LLModelLibrary.ResponseCallback responseCallback = getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration);
library.llmodel_prompt(this.model,
prompt,
@@ -249,6 +248,56 @@ public class LLModel implements AutoCloseable {
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
}
+ /**
+ * Callback method to be used by prompt method as text is generated.
+ *
+ * @param streamToStdOut Should send generated text to standard out.
+ * @param bufferingForStdOutStream Output stream used for buffering bytes for standard output.
+ * @param bufferingForWholeGeneration Output stream used for buffering a complete generation.
+ * @return LLModelLibrary.ResponseCallback lambda function that is invoked by response callback.
+ */
+ static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) {
+ return (int tokenID, Pointer response) -> {
+
+ if(LLModel.OUTPUT_DEBUG)
+ System.out.print("Response token " + tokenID + " " );
+
+ // For all models if input sequence in tokens is longer then model context length
+ // the error is generated.
+ if(tokenID==-1){
+ throw new PromptIsTooLongException(response.getString(0, 1000, StandardCharsets.UTF_8));
+ }
+
+ long len = 0;
+ byte nextByte;
+ do{
+ try {
+ nextByte = response.getByte(len);
+ } catch(IndexOutOfBoundsException e){
+ // Not sure if this can ever happen but just in case
+ // the generation does not terminate in a Null (0) value.
+ throw new RuntimeException("Empty array or not null terminated");
+ }
+ len++;
+ if(nextByte!=0) {
+ bufferingForWholeGeneration.write(nextByte);
+ if(streamToStdOut){
+ bufferingForStdOutStream.write(nextByte);
+ // Test if Buffer is UTF8 valid string.
+ byte[] currentBytes = bufferingForStdOutStream.toByteArray();
+ String validString = Util.getValidUtf8(currentBytes);
+ if(validString!=null){ // is valid string
+ System.out.print(validString);
+ // reset the buffer for next utf8 sequence to buffer
+ bufferingForStdOutStream.reset();
+ }
+ }
+ }
+ } while(nextByte != 0);
+
+ return true; // continue generating
+ };
+ }
public static class ChatCompletionResponse {
@@ -277,7 +326,7 @@ public class LLModel implements AutoCloseable {
* easier to process for chat UIs. It is not absolutely necessary as generate method
* may be directly used to make generations with gpt models.
*
- * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
+ * @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
* @param generationConfig How to decode/process the generation.
* @param streamToStdOut Send tokens as they are calculated Standard output.
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java
new file mode 100644
index 00000000..82301696
--- /dev/null
+++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/PromptIsTooLongException.java
@@ -0,0 +1,7 @@
+package com.hexadevlabs.gpt4all;
+
+public class PromptIsTooLongException extends RuntimeException {
+ public PromptIsTooLongException(String message) {
+ super(message);
+ }
+}
diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java
index ce3c9619..4b8a978c 100644
--- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java
+++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/Util.java
@@ -52,7 +52,8 @@ public class Util {
/**
* Copy over shared library files from resource package to
* target Temp directory.
- * Returns Path to the temp directory holding the shared libraries
+ *
+ * @return Path path to the temp directory holding the shared libraries
*/
public static Path copySharedLibraries() {
try {
@@ -78,15 +79,26 @@ public class Util {
"mpt-default",
"llamamodel-230511-default",
"llamamodel-230519-default",
- "llamamodel-mainline-default"
+ "llamamodel-mainline-default",
+ "llamamodel-mainline-metal",
+ "replit-mainline-default",
+ "replit-mainline-metal",
+ "ggml-metal.metal"
};
for (String libraryName : libraryNames) {
+ if(!isMac && (
+ libraryName.equals("replit-mainline-metal")
+ || libraryName.equals("llamamodel-mainline-metal")
+ || libraryName.equals("ggml-metal.metal"))
+ ) continue;
+
if(isWindows){
libraryName = libraryName + ".dll";
} else if(isMac){
- libraryName = "lib" + libraryName + ".dylib";
+ if(!libraryName.equals("ggml-metal.metal"))
+ libraryName = "lib" + libraryName + ".dylib";
} else if(isLinux) {
libraryName = "lib"+ libraryName + ".so";
}
diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java
new file mode 100644
index 00000000..8bc7c914
--- /dev/null
+++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java
@@ -0,0 +1,155 @@
+package com.hexadevlabs.gpt4all;
+
+
+import jnr.ffi.Memory;
+import jnr.ffi.Pointer;
+import jnr.ffi.Runtime;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mockito;
+
+import org.mockito.junit.jupiter.MockitoExtension;
+
+
+import java.io.ByteArrayOutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.*;
+
+/**
+ * These tests only test the Java implementation as the underlying backend can't be mocked.
+ * These tests do serve the purpose of validating the java bits that do
+ * not directly have to do with the function of the underlying gp4all library.
+ */
+@ExtendWith(MockitoExtension.class)
+public class BasicTests {
+
+ @Test
+ public void simplePrompt(){
+
+ LLModel model = Mockito.spy(new LLModel());
+
+ LLModel.GenerationConfig config =
+ LLModel.config()
+ .withNPredict(20)
+ .build();
+
+ // The generate method will return "4"
+ doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));
+
+ LLModel.ChatCompletionResponse response= model.chatCompletion(
+ List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
+ Map.of("role", "user", "content", "Add 2+2")), config, true, true);
+
+ assertTrue( response.choices.get(0).get("content").contains("4") );
+
+ // Verifies the prompt and response are certain length.
+ assertEquals( 224 , response.usage.totalTokens );
+ }
+
+ @Test
+ public void testResponseCallback(){
+
+ ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
+ ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
+
+ LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
+
+ // Get the runtime instance
+ Runtime runtime = Runtime.getSystemRuntime();
+
+ // Allocate memory for the byte array. Has to be null terminated
+
+ // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
+ byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
+
+ // Optional: Converting the byte array back to a String to print the character
+ String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8);
+
+ Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
+
+ // Copy the byte array to the allocated memory
+ pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
+
+ responseCallback.invoke(1, pointer);
+
+ String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
+
+ assertEquals(decodedString, result);
+
+ }
+
+ @Test
+ public void testResponseCallbackTwoTokens(){
+
+ ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
+ ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
+
+ LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
+
+ // Get the runtime instance
+ Runtime runtime = Runtime.getSystemRuntime();
+
+ // Allocate memory for the byte array. Has to be null terminated
+
+ // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
+ byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination
+ byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
+
+ // Optional: Converting the byte array back to a String to print the character
+ Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
+
+ // Copy the byte array to the allocated memory
+ pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
+
+ responseCallback.invoke(1, pointer);
+ // Copy the byte array to the allocated memory
+ pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length);
+
+ responseCallback.invoke(2, pointer);
+
+ String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
+
+ assertEquals("\uD83D\uDCA9", result);
+
+ }
+
+
+ @Test
+ public void testResponseCallbackExpectError(){
+
+ ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
+ ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
+
+ LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
+
+ // Get the runtime instance
+ Runtime runtime = Runtime.getSystemRuntime();
+
+ // UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
+ byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination
+
+ Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
+
+ // Copy the byte array to the allocated memory
+ pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
+
+ Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
+
+ assertEquals("Empty array or not null terminated", exception.getMessage());
+
+ // With empty array
+ utf8ByteArray = new byte[0];
+ pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
+
+ Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
+
+ assertEquals("Empty array or not null terminated", exceptionN.getMessage());
+
+ }
+
+}
diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java
new file mode 100644
index 00000000..a2045cd0
--- /dev/null
+++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Example5.java
@@ -0,0 +1,47 @@
+package com.hexadevlabs.gpt4all;
+
+import java.nio.file.Path;
+
+public class Example5 {
+
+ public static void main(String[] args) {
+
+ // String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
+ // The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
+ // LLM should correctly identify it as poop emoji in the description
+ //String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
+ //String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
+
+ // Optionally in case override to location of shared libraries is necessary
+ //LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
+ StringBuffer b = new StringBuffer();
+ b.append("The ".repeat(2060));
+ String prompt = b.toString();
+
+
+ String model = "ggml-vicuna-7b-1.1-q4_2.bin";
+ //String model = "ggml-gpt4all-j-v1.3-groovy.bin";
+ //String model = "ggml-mpt-7b-instruct.bin";
+ String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
+ //String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
+
+ try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
+
+ LLModel.GenerationConfig config =
+ LLModel.config()
+ .withNPredict(4096)
+ .withRepeatLastN(64)
+ .build();
+
+ String result = mptModel.generate(prompt, config, true);
+
+ System.out.println("Code points:");
+ result.codePoints().forEach(System.out::println);
+
+
+ } catch (Exception e) {
+ System.out.println(e.getMessage());
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java
deleted file mode 100644
index c1b77860..00000000
--- a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/Tests.java
+++ /dev/null
@@ -1,68 +0,0 @@
-package com.hexadevlabs.gpt4all;
-
-import org.junit.jupiter.api.AfterAll;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Disabled;
-import org.junit.jupiter.api.Test;
-
-import java.nio.file.Path;
-import java.util.List;
-import java.util.Map;
-
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-/**
- * For now disabled until can run with latest library on windows
- */
-@Disabled
-public class Tests {
-
- static LLModel model;
- static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
-
- @BeforeAll
- public static void before(){
- model = new LLModel(Path.of(PATH_TO_MODEL));
- }
-
- @Test
- public void simplePrompt(){
- LLModel.GenerationConfig config =
- LLModel.config()
- .withNPredict(20)
- .build();
- String prompt =
- "### Instruction: \n" +
- "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
- "### Prompt: \n" +
- "Add 2+2\n" +
- "### Response: 4\n" +
- "Multiply 4 * 5\n" +
- "### Response:";
- String generatedText = model.generate(prompt, config, true);
- assertTrue(generatedText.contains("20"));
- }
-
- @Test
- public void chatCompletion(){
- LLModel.GenerationConfig config =
- LLModel.config()
- .withNPredict(20)
- .build();
-
- LLModel.ChatCompletionResponse response= model.chatCompletion(
- List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
- Map.of("role", "user", "content", "Add 2+2")), config, true, true);
-
- assertTrue( response.choices.get(0).get("content").contains("4") );
-
- }
-
- @AfterAll
- public static void after() throws Exception {
- if(model != null)
- model.close();
- }
-
-
-}