mirror of https://github.com/nomic-ai/gpt4all
bump version a few more doc fixes.
add macos metal files Add check for Prompt is too long. add logging statement for gpt4all version of the binding add version string, readme update Add unit tests for Java code of the java bindings.pull/1054/head
parent
0638b45b47
commit
4e274baee1
@ -0,0 +1,2 @@
|
||||
1. Better Chat completions function.
|
||||
2. Chat completion that returns result in OpenAI compatible format.
|
@ -0,0 +1,7 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
public class PromptIsTooLongException extends RuntimeException {
|
||||
public PromptIsTooLongException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
@ -0,0 +1,155 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
|
||||
import jnr.ffi.Memory;
|
||||
import jnr.ffi.Pointer;
|
||||
import jnr.ffi.Runtime;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
/**
|
||||
* These tests only test the Java implementation as the underlying backend can't be mocked.
|
||||
* These tests do serve the purpose of validating the java bits that do
|
||||
* not directly have to do with the function of the underlying gp4all library.
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class BasicTests {
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
|
||||
LLModel model = Mockito.spy(new LLModel());
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
// The generate method will return "4"
|
||||
doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));
|
||||
|
||||
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||
|
||||
// Verifies the prompt and response are certain length.
|
||||
assertEquals( 224 , response.usage.totalTokens );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResponseCallback(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// Allocate memory for the byte array. Has to be null terminated
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
|
||||
|
||||
// Optional: Converting the byte array back to a String to print the character
|
||||
String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8);
|
||||
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
responseCallback.invoke(1, pointer);
|
||||
|
||||
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
|
||||
assertEquals(decodedString, result);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResponseCallbackTwoTokens(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// Allocate memory for the byte array. Has to be null terminated
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination
|
||||
byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
|
||||
|
||||
// Optional: Converting the byte array back to a String to print the character
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
responseCallback.invoke(1, pointer);
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length);
|
||||
|
||||
responseCallback.invoke(2, pointer);
|
||||
|
||||
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
|
||||
assertEquals("\uD83D\uDCA9", result);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testResponseCallbackExpectError(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination
|
||||
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
|
||||
|
||||
assertEquals("Empty array or not null terminated", exception.getMessage());
|
||||
|
||||
// With empty array
|
||||
utf8ByteArray = new byte[0];
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
|
||||
|
||||
assertEquals("Empty array or not null terminated", exceptionN.getMessage());
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class Example5 {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
// String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
|
||||
// LLM should correctly identify it as poop emoji in the description
|
||||
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
|
||||
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
StringBuffer b = new StringBuffer();
|
||||
b.append("The ".repeat(2060));
|
||||
String prompt = b.toString();
|
||||
|
||||
|
||||
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
|
||||
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
//String model = "ggml-mpt-7b-instruct.bin";
|
||||
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
|
||||
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
|
||||
|
||||
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(4096)
|
||||
.withRepeatLastN(64)
|
||||
.build();
|
||||
|
||||
String result = mptModel.generate(prompt, config, true);
|
||||
|
||||
System.out.println("Code points:");
|
||||
result.codePoints().forEach(System.out::println);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
System.out.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* For now disabled until can run with latest library on windows
|
||||
*/
|
||||
@Disabled
|
||||
public class Tests {
|
||||
|
||||
static LLModel model;
|
||||
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
|
||||
@BeforeAll
|
||||
public static void before(){
|
||||
model = new LLModel(Path.of(PATH_TO_MODEL));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
String prompt =
|
||||
"### Instruction: \n" +
|
||||
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||
"### Prompt: \n" +
|
||||
"Add 2+2\n" +
|
||||
"### Response: 4\n" +
|
||||
"Multiply 4 * 5\n" +
|
||||
"### Response:";
|
||||
String generatedText = model.generate(prompt, config, true);
|
||||
assertTrue(generatedText.contains("20"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void chatCompletion(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void after() throws Exception {
|
||||
if(model != null)
|
||||
model.close();
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue