mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-20 03:25:37 +00:00
bump version a few more doc fixes.
add macos metal files Add check for Prompt is too long. add logging statement for gpt4all version of the binding add version string, readme update Add unit tests for Java code of the java bindings.
This commit is contained in:
parent
0638b45b47
commit
4e274baee1
@ -12,12 +12,12 @@ You can add Java bindings into your Java project by adding the following depende
|
||||
<dependency>
|
||||
<groupId>com.hexadevlabs</groupId>
|
||||
<artifactId>gpt4all-java-binding</artifactId>
|
||||
<version>1.1.2</version>
|
||||
<version>1.1.3</version>
|
||||
</dependency>
|
||||
```
|
||||
**Gradle**
|
||||
```
|
||||
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
|
||||
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.3'
|
||||
```
|
||||
|
||||
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
|
||||
@ -103,3 +103,17 @@ class Example {
|
||||
```
|
||||
2. Not every AVX-only shared library is bundled with the JAR right now to reduce size. Only libgptj-avx is included.
|
||||
If you are running into issues please let us know using the [gpt4all project issue tracker](https://github.com/nomic-ai/gpt4all/issues).
|
||||
|
||||
3. For Windows the native library included in jar depends on specific Microsoft C and C++ (MSVC) runtime libraries which may not be installed on your system.
|
||||
If this is the case you can easily download and install the latest x64 Microsoft Visual C++ Redistributable package from https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
|
||||
## Version history
|
||||
1. Version **1.1.2**:
|
||||
- Java bindings is compatible with gpt4ll version 2.4.6
|
||||
- Initial stable release with the initial feature set
|
||||
2. Version **1.1.3**:
|
||||
- Java bindings is compatible with gpt4all version 2.4.8
|
||||
- Add static GPT4ALL_VERSION to signify gpt4all version of the bindings
|
||||
- Add PromptIsTooLongException for prompts that are longer than context size.
|
||||
- Replit model support to include Metal Mac hardware support.
|
||||
|
2
gpt4all-bindings/java/TODO.md
Normal file
2
gpt4all-bindings/java/TODO.md
Normal file
@ -0,0 +1,2 @@
|
||||
1. Better Chat completions function.
|
||||
2. Chat completion that returns result in OpenAI compatible format.
|
@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.hexadevlabs</groupId>
|
||||
<artifactId>gpt4all-java-binding</artifactId>
|
||||
<version>1.1.2</version>
|
||||
<version>1.1.3</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
@ -56,6 +56,20 @@
|
||||
<version>5.9.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<version>5.4.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.4.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<distributionManagement>
|
||||
@ -103,7 +117,7 @@
|
||||
<outputDirectory>${project.build.directory}/generated-resources</outputDirectory>
|
||||
<resources>
|
||||
<resource>
|
||||
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023</directory>
|
||||
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_3_Jun22_2023</directory>
|
||||
</resource>
|
||||
</resources>
|
||||
</configuration>
|
||||
|
@ -1,6 +1,8 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import jnr.ffi.Pointer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@ -94,6 +96,10 @@ public class LLModel implements AutoCloseable {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return GenerationConfig build instance of the config
|
||||
*/
|
||||
public GenerationConfig build() {
|
||||
return configToBuild;
|
||||
}
|
||||
@ -102,6 +108,8 @@ public class LLModel implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Shortcut for making GenerativeConfig builder.
|
||||
*
|
||||
* @return GenerationConfig.Builder - builder that can be used to make a GenerationConfig
|
||||
*/
|
||||
public static GenerationConfig.Builder config(){
|
||||
return new GenerationConfig.Builder();
|
||||
@ -122,14 +130,32 @@ public class LLModel implements AutoCloseable {
|
||||
*/
|
||||
public static boolean OUTPUT_DEBUG = false;
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(LLModel.class);
|
||||
|
||||
/**
|
||||
* Which version of GPT4ALL that this binding is built for.
|
||||
* The binding is guaranteed to work with this version of
|
||||
* GPT4ALL native libraries. The binding may work for older
|
||||
* versions but that is not guaranteed.
|
||||
*/
|
||||
public static final String GPT4ALL_VERSION = "2.4.8";
|
||||
|
||||
protected static LLModelLibrary library;
|
||||
|
||||
protected Pointer model;
|
||||
|
||||
protected String modelName;
|
||||
|
||||
/**
|
||||
* Package private default constructor, for testing purposes.
|
||||
*/
|
||||
LLModel(){
|
||||
}
|
||||
|
||||
public LLModel(Path modelPath) {
|
||||
|
||||
logger.info("Java bindings for gpt4all version: " + GPT4ALL_VERSION);
|
||||
|
||||
if(library==null) {
|
||||
|
||||
if (LIBRARY_SEARCH_PATH != null){
|
||||
@ -202,34 +228,7 @@ public class LLModel implements AutoCloseable {
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
|
||||
|
||||
if(LLModel.OUTPUT_DEBUG)
|
||||
System.out.print("Response token " + tokenID + " " );
|
||||
|
||||
long len = 0;
|
||||
byte nextByte;
|
||||
do{
|
||||
nextByte= response.getByte(len);
|
||||
len++;
|
||||
if(nextByte!=0) {
|
||||
bufferingForWholeGeneration.write(nextByte);
|
||||
if(streamToStdOut){
|
||||
bufferingForStdOutStream.write(nextByte);
|
||||
// Test if Buffer is UTF8 valid string.
|
||||
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
|
||||
String validString = Util.getValidUtf8(currentBytes);
|
||||
if(validString!=null){ // is valid string
|
||||
System.out.print(validString);
|
||||
// reset the buffer for next utf8 sequence to buffer
|
||||
bufferingForStdOutStream.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
} while(nextByte != 0);
|
||||
|
||||
return true; // continue generating
|
||||
};
|
||||
LLModelLibrary.ResponseCallback responseCallback = getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
library.llmodel_prompt(this.model,
|
||||
prompt,
|
||||
@ -249,6 +248,56 @@ public class LLModel implements AutoCloseable {
|
||||
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
/**
|
||||
* Callback method to be used by prompt method as text is generated.
|
||||
*
|
||||
* @param streamToStdOut Should send generated text to standard out.
|
||||
* @param bufferingForStdOutStream Output stream used for buffering bytes for standard output.
|
||||
* @param bufferingForWholeGeneration Output stream used for buffering a complete generation.
|
||||
* @return LLModelLibrary.ResponseCallback lambda function that is invoked by response callback.
|
||||
*/
|
||||
static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) {
|
||||
return (int tokenID, Pointer response) -> {
|
||||
|
||||
if(LLModel.OUTPUT_DEBUG)
|
||||
System.out.print("Response token " + tokenID + " " );
|
||||
|
||||
// For all models if input sequence in tokens is longer then model context length
|
||||
// the error is generated.
|
||||
if(tokenID==-1){
|
||||
throw new PromptIsTooLongException(response.getString(0, 1000, StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
long len = 0;
|
||||
byte nextByte;
|
||||
do{
|
||||
try {
|
||||
nextByte = response.getByte(len);
|
||||
} catch(IndexOutOfBoundsException e){
|
||||
// Not sure if this can ever happen but just in case
|
||||
// the generation does not terminate in a Null (0) value.
|
||||
throw new RuntimeException("Empty array or not null terminated");
|
||||
}
|
||||
len++;
|
||||
if(nextByte!=0) {
|
||||
bufferingForWholeGeneration.write(nextByte);
|
||||
if(streamToStdOut){
|
||||
bufferingForStdOutStream.write(nextByte);
|
||||
// Test if Buffer is UTF8 valid string.
|
||||
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
|
||||
String validString = Util.getValidUtf8(currentBytes);
|
||||
if(validString!=null){ // is valid string
|
||||
System.out.print(validString);
|
||||
// reset the buffer for next utf8 sequence to buffer
|
||||
bufferingForStdOutStream.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
} while(nextByte != 0);
|
||||
|
||||
return true; // continue generating
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
public static class ChatCompletionResponse {
|
||||
@ -277,7 +326,7 @@ public class LLModel implements AutoCloseable {
|
||||
* easier to process for chat UIs. It is not absolutely necessary as generate method
|
||||
* may be directly used to make generations with gpt models.
|
||||
*
|
||||
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
|
||||
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
|
||||
* @param generationConfig How to decode/process the generation.
|
||||
* @param streamToStdOut Send tokens as they are calculated Standard output.
|
||||
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
|
||||
|
@ -0,0 +1,7 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
public class PromptIsTooLongException extends RuntimeException {
|
||||
public PromptIsTooLongException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
@ -52,7 +52,8 @@ public class Util {
|
||||
/**
|
||||
* Copy over shared library files from resource package to
|
||||
* target Temp directory.
|
||||
* Returns Path to the temp directory holding the shared libraries
|
||||
*
|
||||
* @return Path path to the temp directory holding the shared libraries
|
||||
*/
|
||||
public static Path copySharedLibraries() {
|
||||
try {
|
||||
@ -78,15 +79,26 @@ public class Util {
|
||||
"mpt-default",
|
||||
"llamamodel-230511-default",
|
||||
"llamamodel-230519-default",
|
||||
"llamamodel-mainline-default"
|
||||
"llamamodel-mainline-default",
|
||||
"llamamodel-mainline-metal",
|
||||
"replit-mainline-default",
|
||||
"replit-mainline-metal",
|
||||
"ggml-metal.metal"
|
||||
};
|
||||
|
||||
for (String libraryName : libraryNames) {
|
||||
|
||||
if(!isMac && (
|
||||
libraryName.equals("replit-mainline-metal")
|
||||
|| libraryName.equals("llamamodel-mainline-metal")
|
||||
|| libraryName.equals("ggml-metal.metal"))
|
||||
) continue;
|
||||
|
||||
if(isWindows){
|
||||
libraryName = libraryName + ".dll";
|
||||
} else if(isMac){
|
||||
libraryName = "lib" + libraryName + ".dylib";
|
||||
if(!libraryName.equals("ggml-metal.metal"))
|
||||
libraryName = "lib" + libraryName + ".dylib";
|
||||
} else if(isLinux) {
|
||||
libraryName = "lib"+ libraryName + ".so";
|
||||
}
|
||||
|
@ -0,0 +1,155 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
|
||||
import jnr.ffi.Memory;
|
||||
import jnr.ffi.Pointer;
|
||||
import jnr.ffi.Runtime;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
/**
|
||||
* These tests only test the Java implementation as the underlying backend can't be mocked.
|
||||
* These tests do serve the purpose of validating the java bits that do
|
||||
* not directly have to do with the function of the underlying gp4all library.
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class BasicTests {
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
|
||||
LLModel model = Mockito.spy(new LLModel());
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
// The generate method will return "4"
|
||||
doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));
|
||||
|
||||
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||
|
||||
// Verifies the prompt and response are certain length.
|
||||
assertEquals( 224 , response.usage.totalTokens );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResponseCallback(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// Allocate memory for the byte array. Has to be null terminated
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
|
||||
|
||||
// Optional: Converting the byte array back to a String to print the character
|
||||
String decodedString = new String(utf8ByteArray, 0, utf8ByteArray.length - 1, java.nio.charset.StandardCharsets.UTF_8);
|
||||
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
responseCallback.invoke(1, pointer);
|
||||
|
||||
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
|
||||
assertEquals(decodedString, result);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResponseCallbackTwoTokens(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// Allocate memory for the byte array. Has to be null terminated
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = { (byte) 0xF0, (byte) 0x9F, 0x00}; // Adding null termination
|
||||
byte[] utf8ByteArray2 = { (byte) 0x92, (byte) 0xA9, 0x00}; // Adding null termination
|
||||
|
||||
// Optional: Converting the byte array back to a String to print the character
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
responseCallback.invoke(1, pointer);
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray2, 0, utf8ByteArray2.length);
|
||||
|
||||
responseCallback.invoke(2, pointer);
|
||||
|
||||
String result = bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||
|
||||
assertEquals("\uD83D\uDCA9", result);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testResponseCallbackExpectError(){
|
||||
|
||||
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||
|
||||
LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(false, bufferingForStdOutStream, bufferingForWholeGeneration);
|
||||
|
||||
// Get the runtime instance
|
||||
Runtime runtime = Runtime.getSystemRuntime();
|
||||
|
||||
// UTF-8 Encoding of the character: 0xF0 0x9F 0x92 0xA9
|
||||
byte[] utf8ByteArray = {(byte) 0xF0, (byte) 0x9F, (byte) 0x92, (byte) 0xA9}; // No null termination
|
||||
|
||||
Pointer pointer = Memory.allocateDirect(runtime, utf8ByteArray.length);
|
||||
|
||||
// Copy the byte array to the allocated memory
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
Exception exception = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
|
||||
|
||||
assertEquals("Empty array or not null terminated", exception.getMessage());
|
||||
|
||||
// With empty array
|
||||
utf8ByteArray = new byte[0];
|
||||
pointer.put(0, utf8ByteArray, 0, utf8ByteArray.length);
|
||||
|
||||
Exception exceptionN = assertThrows(RuntimeException.class, () -> responseCallback.invoke(1, pointer));
|
||||
|
||||
assertEquals("Empty array or not null terminated", exceptionN.getMessage());
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class Example5 {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
// String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
|
||||
// LLM should correctly identify it as poop emoji in the description
|
||||
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
|
||||
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
|
||||
|
||||
// Optionally in case override to location of shared libraries is necessary
|
||||
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||
StringBuffer b = new StringBuffer();
|
||||
b.append("The ".repeat(2060));
|
||||
String prompt = b.toString();
|
||||
|
||||
|
||||
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
|
||||
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
//String model = "ggml-mpt-7b-instruct.bin";
|
||||
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
|
||||
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
|
||||
|
||||
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(4096)
|
||||
.withRepeatLastN(64)
|
||||
.build();
|
||||
|
||||
String result = mptModel.generate(prompt, config, true);
|
||||
|
||||
System.out.println("Code points:");
|
||||
result.codePoints().forEach(System.out::println);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
System.out.println(e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
package com.hexadevlabs.gpt4all;
|
||||
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* For now disabled until can run with latest library on windows
|
||||
*/
|
||||
@Disabled
|
||||
public class Tests {
|
||||
|
||||
static LLModel model;
|
||||
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||
|
||||
@BeforeAll
|
||||
public static void before(){
|
||||
model = new LLModel(Path.of(PATH_TO_MODEL));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
String prompt =
|
||||
"### Instruction: \n" +
|
||||
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||
"### Prompt: \n" +
|
||||
"Add 2+2\n" +
|
||||
"### Response: 4\n" +
|
||||
"Multiply 4 * 5\n" +
|
||||
"### Response:";
|
||||
String generatedText = model.generate(prompt, config, true);
|
||||
assertTrue(generatedText.contains("20"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void chatCompletion(){
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||
|
||||
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void after() throws Exception {
|
||||
if(model != null)
|
||||
model.close();
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user