bump version a few more doc fixes.

add macos metal files
Add check for Prompt is too long.
add logging statement for gpt4all version of the binding
add version string, readme update
Add unit tests for Java code of the java bindings.
This commit is contained in:
felix 2023-06-18 00:50:12 -04:00 committed by AT
parent 0638b45b47
commit 4e274baee1
9 changed files with 336 additions and 104 deletions

View File

@ -12,12 +12,12 @@ You can add Java bindings into your Java project by adding the following depende
<dependency>
<groupId>com.hexadevlabs</groupId>
<artifactId>gpt4all-java-binding</artifactId>
<version>1.1.2</version>
<version>1.1.3</version>
</dependency>
```
**Gradle**
```
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.3'
```
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
@ -103,3 +103,17 @@ class Example {
```
2. Not every AVX-only shared library is bundled with the JAR right now to reduce size. Only libgptj-avx is included.
If you are running into issues please let us know using the [gpt4all project issue tracker](https://github.com/nomic-ai/gpt4all/issues).
3. For Windows the native library included in jar depends on specific Microsoft C and C++ (MSVC) runtime libraries which may not be installed on your system.
If this is the case you can easily download and install the latest x64 Microsoft Visual C++ Redistributable package from https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170
## Version history
1. Version **1.1.2**:
- Java bindings is compatible with gpt4ll version 2.4.6
- Initial stable release with the initial feature set
2. Version **1.1.3**:
- Java bindings is compatible with gpt4all version 2.4.8
- Add static GPT4ALL_VERSION to signify gpt4all version of the bindings
- Add PromptIsTooLongException for prompts that are longer than context size.
- Replit model support to include Metal Mac hardware support.

View File

@ -0,0 +1,2 @@
1. Better Chat completions function.
2. Chat completion that returns result in OpenAI compatible format.

View File

@ -6,7 +6,7 @@
<groupId>com.hexadevlabs</groupId>
<artifactId>gpt4all-java-binding</artifactId>
<version>1.1.2</version>
<version>1.1.3</version>
<packaging>jar</packaging>
<properties>
@ -56,6 +56,20 @@
<version>5.9.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<version>5.4.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.4.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<distributionManagement>
@ -103,7 +117,7 @@
<outputDirectory>${project.build.directory}/generated-resources</outputDirectory>
<resources>
<resource>
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023</directory>
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_3_Jun22_2023</directory>
</resource>
</resources>
</configuration>

View File

@ -1,6 +1,8 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.Pointer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
@ -94,6 +96,10 @@ public class LLModel implements AutoCloseable {
return this;
}
/**
*
* @return GenerationConfig build instance of the config
*/
public GenerationConfig build() {
return configToBuild;
}
@ -102,6 +108,8 @@ public class LLModel implements AutoCloseable {
/**
* Shortcut for making GenerativeConfig builder.
*
* @return GenerationConfig.Builder - builder that can be used to make a GenerationConfig
*/
public static GenerationConfig.Builder config(){
return new GenerationConfig.Builder();
@ -122,14 +130,32 @@ public class LLModel implements AutoCloseable {
*/
public static boolean OUTPUT_DEBUG = false;
private static final Logger logger = LoggerFactory.getLogger(LLModel.class);
/**
* Which version of GPT4ALL that this binding is built for.
* The binding is guaranteed to work with this version of
* GPT4ALL native libraries. The binding may work for older
* versions but that is not guaranteed.
*/
public static final String GPT4ALL_VERSION = "2.4.8";
protected static LLModelLibrary library;
protected Pointer model;
protected String modelName;
/**
* Package private default constructor, for testing purposes.
*/
LLModel(){
}
public LLModel(Path modelPath) {
logger.info("Java bindings for gpt4all version: " + GPT4ALL_VERSION);
if(library==null) {
if (LIBRARY_SEARCH_PATH != null){
@ -202,34 +228,7 @@ public class LLModel implements AutoCloseable {
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
if(LLModel.OUTPUT_DEBUG)
System.out.print("Response token " + tokenID + " " );
long len = 0;
byte nextByte;
do{
nextByte= response.getByte(len);
len++;
if(nextByte!=0) {
bufferingForWholeGeneration.write(nextByte);
if(streamToStdOut){
bufferingForStdOutStream.write(nextByte);
// Test if Buffer is UTF8 valid string.
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
String validString = Util.getValidUtf8(currentBytes);
if(validString!=null){ // is valid string
System.out.print(validString);
// reset the buffer for next utf8 sequence to buffer
bufferingForStdOutStream.reset();
}
}
}
} while(nextByte != 0);
return true; // continue generating
};
LLModelLibrary.ResponseCallback responseCallback = getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration);
library.llmodel_prompt(this.model,
prompt,
@ -249,6 +248,56 @@ public class LLModel implements AutoCloseable {
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
}
/**
* Callback method to be used by prompt method as text is generated.
*
* @param streamToStdOut Should send generated text to standard out.
* @param bufferingForStdOutStream Output stream used for buffering bytes for standard output.
* @param bufferingForWholeGeneration Output stream used for buffering a complete generation.
* @return LLModelLibrary.ResponseCallback lambda function that is invoked by response callback.
*/
static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) {
return (int tokenID, Pointer response) -> {
if(LLModel.OUTPUT_DEBUG)
System.out.print("Response token " + tokenID + " " );
// For all models if input sequence in tokens is longer then model context length
// the error is generated.
if(tokenID==-1){
throw new PromptIsTooLongException(response.getString(0, 1000, StandardCharsets.UTF_8));
}
long len = 0;
byte nextByte;
do{
try {
nextByte = response.getByte(len);
} catch(IndexOutOfBoundsException e){
// Not sure if this can ever happen but just in case
// the generation does not terminate in a Null (0) value.
throw new RuntimeException("Empty array or not null terminated");
}
len++;
if(nextByte!=0) {
bufferingForWholeGeneration.write(nextByte);
if(streamToStdOut){
bufferingForStdOutStream.write(nextByte);
// Test if Buffer is UTF8 valid string.
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
String validString = Util.getValidUtf8(currentBytes);
if(validString!=null){ // is valid string
System.out.print(validString);
// reset the buffer for next utf8 sequence to buffer
bufferingForStdOutStream.reset();
}
}
}
} while(nextByte != 0);
return true; // continue generating
};
}
public static class ChatCompletionResponse {
@ -277,7 +326,7 @@ public class LLModel implements AutoCloseable {
* easier to process for chat UIs. It is not absolutely necessary as generate method
* may be directly used to make generations with gpt models.
*
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
* @param messages List of Maps "role"-&gt;"user", "content"-&gt;"...", "role"-&gt; "assistant"-&gt;"..."
* @param generationConfig How to decode/process the generation.
* @param streamToStdOut Send tokens as they are calculated Standard output.
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.

View File

@ -0,0 +1,7 @@
package com.hexadevlabs.gpt4all;
public class PromptIsTooLongException extends RuntimeException {
public PromptIsTooLongException(String message) {
super(message);
}
}

View File

@ -52,7 +52,8 @@ public class Util {
/**
* Copy over shared library files from resource package to
* target Temp directory.
* Returns Path to the temp directory holding the shared libraries
*
* @return Path path to the temp directory holding the shared libraries
*/
public static Path copySharedLibraries() {
try {
@ -78,15 +79,26 @@ public class Util {
"mpt-default",
"llamamodel-230511-default",
"llamamodel-230519-default",
"llamamodel-mainline-default"
"llamamodel-mainline-default",
"llamamodel-mainline-metal",
"replit-mainline-default",
"replit-mainline-metal",
"ggml-metal.metal"
};
for (String libraryName : libraryNames) {
if(!isMac && (
libraryName.equals("replit-mainline-metal")
|| libraryName.equals("llamamodel-mainline-metal")
|| libraryName.equals("ggml-metal.metal"))
) continue;
if(isWindows){
libraryName = libraryName + ".dll";
} else if(isMac){
libraryName = "lib" + libraryName + ".dylib";
if(!libraryName.equals("ggml-metal.metal"))
libraryName = "lib" + libraryName + ".dylib";
} else if(isLinux) {
libraryName = "lib"+ libraryName + ".so";
}

View File

@ -0,0 +1,155 @@
package com.hexadevlabs.gpt4all;
import jnr.ffi.Memory;
import jnr.ffi.Pointer;
import jnr.ffi.Runtime;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
/**
* These tests only test the Java implementation as the underlying backend can't be mocked.
* These tests do serve the purpose of validating the java bits that do
* not directly have to do with the function of the underlying gp4all library.
*/
@ExtendWith(MockitoExtension.class)
public class BasicTests {
@Test
public void simplePrompt(){
LLModel model = Mockito.spy(new LLModel());
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();
// The generate method will return "4"
doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));
LLModel.ChatCompletionResponse response= model.chatCompletion(
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
assertTrue( response.choices.get(0).get("content").contains("4") );
// Verifies the prompt and response are certain length.
assertEquals( 224 , response.usage.totalTokens );
}
@Test
public void testResponseCallback(){
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
// Get the runtime instance
Runtime runtime = Runtime.getSystemRuntime();
// Allocate memory for the byte array. Has to be null terminated
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
// Optional: Converting the byte array back to a String to print the character
String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8);
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
// Copy the byte array to the allocated memory
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
responseCallback.invoke(1, pointer);
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
assertEquals(decodedString, result);
}
@Test
public void testResponseCallbackTwoTokens(){
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
// Get the runtime instance
Runtime runtime = Runtime.getSystemRuntime();
// Allocate memory for the byte array. Has to be null terminated
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination
byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
// Optional: Converting the byte array back to a String to print the character
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
// Copy the byte array to the allocated memory
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
responseCallback.invoke(1, pointer);
// Copy the byte array to the allocated memory
pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length);
responseCallback.invoke(2, pointer);
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
assertEquals("\uD83D\uDCA9", result);
}
@Test
public void testResponseCallbackExpectError(){
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
// Get the runtime instance
Runtime runtime = Runtime.getSystemRuntime();
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
// Copy the byte array to the allocated memory
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
assertEquals("Empty array or not null terminated", exception.getMessage());
// With empty array
utf8ByteArray = new byte[0];
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
assertEquals("Empty array or not null terminated", exceptionN.getMessage());
}
}

View File

@ -0,0 +1,47 @@
package com.hexadevlabs.gpt4all;
import java.nio.file.Path;
public class Example5 {
public static void main(String[] args) {
// String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
// LLM should correctly identify it as poop emoji in the description
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
// Optionally in case override to location of shared libraries is necessary
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
StringBuffer b = new StringBuffer();
b.append("The ".repeat(2060));
String prompt = b.toString();
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
//String model = "ggml-mpt-7b-instruct.bin";
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(4096)
.withRepeatLastN(64)
.build();
String result = mptModel.generate(prompt, config, true);
System.out.println("Code points:");
result.codePoints().forEach(System.out::println);
} catch (Exception e) {
System.out.println(e.getMessage());
throw new RuntimeException(e);
}
}
}

View File

@ -1,68 +0,0 @@
package com.hexadevlabs.gpt4all;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* For now disabled until can run with latest library on windows
*/
@Disabled
public class Tests {
static LLModel model;
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
@BeforeAll
public static void before(){
model = new LLModel(Path.of(PATH_TO_MODEL));
}
@Test
public void simplePrompt(){
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();
String prompt =
"### Instruction: \n" +
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
"### Prompt: \n" +
"Add 2+2\n" +
"### Response: 4\n" +
"Multiply 4 * 5\n" +
"### Response:";
String generatedText = model.generate(prompt, config, true);
assertTrue(generatedText.contains("20"));
}
@Test
public void chatCompletion(){
LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();
LLModel.ChatCompletionResponse response= model.chatCompletion(
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
assertTrue( response.choices.get(0).get("content").contains("4") );
}
@AfterAll
public static void after() throws Exception {
if(model != null)
model.close();
}
}