mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-18 03:25:46 +00:00
Initial 1.0.0 Java-Bindings PR/release (#805)
* Initial 1.0.0 Java-Bindings PR/release * Initial 1.1.0 Java-Bindings PR/release * Add debug ability * 1.1.2 release --------- Co-authored-by: felix <felix@zaslavskiy.net>
This commit is contained in:
parent
5cfb1bda89
commit
44bf91855d
@ -61,6 +61,7 @@ Find the most up-to-date information on the [GPT4All Website](https://gpt4all.io
|
|||||||
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/typescript">:computer: Official Typescript Bindings</a>
|
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/typescript">:computer: Official Typescript Bindings</a>
|
||||||
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/golang">:computer: Official GoLang Bindings</a>
|
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/golang">:computer: Official GoLang Bindings</a>
|
||||||
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/csharp">:computer: Official C# Bindings</a>
|
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/csharp">:computer: Official C# Bindings</a>
|
||||||
|
* <a href="https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/java">:computer: Official Java Bindings</a>
|
||||||
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
2
gpt4all-bindings/java/.gitignore
vendored
Normal file
2
gpt4all-bindings/java/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# Make sure native directory never gets commited to git for the project.
|
||||||
|
/src/main/resources/native
|
105
gpt4all-bindings/java/README.md
Normal file
105
gpt4all-bindings/java/README.md
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# Java bindings
|
||||||
|
|
||||||
|
Java bindings let you load a gpt4all library into your Java application and execute text
|
||||||
|
generation using an intuitive and easy to use API. No GPU is required because gpt4all system executes on the cpu.
|
||||||
|
The gpt4all models are quantized to easily fit into system RAM and use about 4 to 7GB of system RAM.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
You can add Java bindings into your Java project by adding dependency to your project:
|
||||||
|
|
||||||
|
**Maven**
|
||||||
|
```
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.hexadevlabs</groupId>
|
||||||
|
<artifactId>gpt4all-java-binding</artifactId>
|
||||||
|
<version>1.1.2</version>
|
||||||
|
</dependency>
|
||||||
|
```
|
||||||
|
**Gradle**
|
||||||
|
```
|
||||||
|
implementation 'com.hexadevlabs:gpt4all-java-binding:1.1.2'
|
||||||
|
```
|
||||||
|
|
||||||
|
To add the library dependency for another build system see [Maven Central Java bindings](https://central.sonatype.com/artifact/com.hexadevlabs/gpt4all-java-binding/).
|
||||||
|
|
||||||
|
To download a model binary weights file use an url such as https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin.
|
||||||
|
|
||||||
|
For information about other models available see [Model file list](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-chat#manual-download-of-models).
|
||||||
|
|
||||||
|
### Sample code
|
||||||
|
```java
|
||||||
|
public class Example {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
|
||||||
|
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||||
|
|
||||||
|
// Replace the hardcoded path with the actual path where your model file resides
|
||||||
|
String modelFilePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||||
|
|
||||||
|
try (LLModel model = new LLModel(Path.of(modelFilePath))) {
|
||||||
|
|
||||||
|
// May generate up to 4096 tokens but generally stops early
|
||||||
|
LLModel.GenerationConfig config = LLModel.config()
|
||||||
|
.withNPredict(4096).build();
|
||||||
|
|
||||||
|
// Will also stream to Standard out
|
||||||
|
String fullGeneration = model.generate(prompt, config, true);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
// Exception generally may happen if model file fails to load
|
||||||
|
// for a number of reasons such as file not found.
|
||||||
|
// It is possible that Java may not be able to dynamically load the native shared library or
|
||||||
|
// the llmodel shared library may not be able to dynamically load the backend
|
||||||
|
// implementation for the model file you provided.
|
||||||
|
//
|
||||||
|
// Once the LLModel class is successfully loaded into memory the text generation calls
|
||||||
|
// generally should not throw exceptions.
|
||||||
|
e.printStackTrace(); // Printing here but in production system you may want to take some action.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For a maven based sample project that uses this library see [Sample project](https://github.com/felix-zaslavskiy/gpt4all-java-bindings-sample)
|
||||||
|
|
||||||
|
### Additional considerations
|
||||||
|
#### Logger warnings
|
||||||
|
The Java bindings library may produce a warning:
|
||||||
|
```
|
||||||
|
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
|
||||||
|
SLF4J: Defaulting to no-operation (NOP) logger implementation
|
||||||
|
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
|
||||||
|
```
|
||||||
|
If you don't have a SLF4J binding included in your project. Java bindings only use logging for informational
|
||||||
|
purposes, so logger is not essential to correctly use the library. You can ignore this warning if you don't have SLF4J bindings
|
||||||
|
in your project.
|
||||||
|
|
||||||
|
To add a simple logger using maven dependency you may use:
|
||||||
|
```
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-simple</artifactId>
|
||||||
|
<version>1.7.36</version>
|
||||||
|
</dependency>
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Loading your native libraries
|
||||||
|
1. Java bindings package jar comes bundled with native library files for Windows, macOS and Linux. These library files are
|
||||||
|
copied to a temporary directory and loaded at runtime. For advanced users who may want to package shared libraries into Docker containers
|
||||||
|
or want to use a custom build of the shared libraries and ignore the once bundled with the java package they have option
|
||||||
|
to load libraries from your local directory by setting a static property to the location of library files.
|
||||||
|
There are no guarantees of compatibility if used in such a way so be careful if you really want to do it.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```java
|
||||||
|
class Example {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// gpt4all native shared libraries location
|
||||||
|
LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||||
|
// ... use the library normally
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
2. Not every avx only shared library is bundled with the jar right now to reduce size. Only the libgptj-avx is included.
|
||||||
|
If you are running into issues please let us know using the gpt4all project issue tracker https://github.com/nomic-ai/gpt4all/issues.
|
202
gpt4all-bindings/java/pom.xml
Normal file
202
gpt4all-bindings/java/pom.xml
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>com.hexadevlabs</groupId>
|
||||||
|
<artifactId>gpt4all-java-binding</artifactId>
|
||||||
|
<version>1.1.2</version>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>11</maven.compiler.source>
|
||||||
|
<maven.compiler.target>11</maven.compiler.target>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<name>${project.groupId}:${project.artifactId}</name>
|
||||||
|
<description>Java bindings for GPT4ALL LLM</description>
|
||||||
|
<url>https://github.com/nomic-ai/gpt4all</url>
|
||||||
|
<licenses>
|
||||||
|
<license>
|
||||||
|
<name>The Apache License, Version 2.0</name>
|
||||||
|
<url>https://github.com/nomic-ai/gpt4all/blob/main/LICENSE.txt</url>
|
||||||
|
</license>
|
||||||
|
</licenses>
|
||||||
|
<developers>
|
||||||
|
<developer>
|
||||||
|
<name>Felix Zaslavskiy</name>
|
||||||
|
<email>felixz@hexadevlabs.com</email>
|
||||||
|
<organizationUrl>https://github.com/felix-zaslavskiy/</organizationUrl>
|
||||||
|
</developer>
|
||||||
|
</developers>
|
||||||
|
<scm>
|
||||||
|
<connection>scm:git:git://github.com/nomic-ai/gpt4all.git</connection>
|
||||||
|
<developerConnection>scm:git:ssh://github.com/nomic-ai/gpt4all.git</developerConnection>
|
||||||
|
<url>https://github.com/nomic-ai/gpt4all/tree/main</url>
|
||||||
|
</scm>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.jnr</groupId>
|
||||||
|
<artifactId>jnr-ffi</artifactId>
|
||||||
|
<version>2.2.13</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>1.7.36</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>5.9.2</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<distributionManagement>
|
||||||
|
<snapshotRepository>
|
||||||
|
<id>ossrh</id>
|
||||||
|
<url>https://s01.oss.sonatype.org/content/repositories/snapshots</url>
|
||||||
|
</snapshotRepository>
|
||||||
|
<repository>
|
||||||
|
<id>ossrh</id>
|
||||||
|
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
|
||||||
|
</repository>
|
||||||
|
</distributionManagement>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<resources>
|
||||||
|
<resource>
|
||||||
|
<directory>src/main/resources</directory>
|
||||||
|
</resource>
|
||||||
|
<resource>
|
||||||
|
<directory>${project.build.directory}/generated-resources</directory>
|
||||||
|
</resource>
|
||||||
|
</resources>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
<configuration>
|
||||||
|
<forkCount>0</forkCount>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-resources-plugin</artifactId>
|
||||||
|
<version>3.3.1</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>copy-resources</id>
|
||||||
|
<!-- Here the phase you need -->
|
||||||
|
<phase>validate</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>copy-resources</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<outputDirectory>${project.build.directory}/generated-resources</outputDirectory>
|
||||||
|
<resources>
|
||||||
|
<resource>
|
||||||
|
<directory>C:\Users\felix\dev\gpt4all_java_bins\release_1_1_1_Jun8_2023</directory>
|
||||||
|
</resource>
|
||||||
|
</resources>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.sonatype.plugins</groupId>
|
||||||
|
<artifactId>nexus-staging-maven-plugin</artifactId>
|
||||||
|
<version>1.6.13</version>
|
||||||
|
<extensions>true</extensions>
|
||||||
|
<configuration>
|
||||||
|
<serverId>ossrh</serverId>
|
||||||
|
<nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
|
||||||
|
<autoReleaseAfterClose>true</autoReleaseAfterClose>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-source-plugin</artifactId>
|
||||||
|
<version>2.2.1</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>attach-sources</id>
|
||||||
|
<goals>
|
||||||
|
<goal>jar-no-fork</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-javadoc-plugin</artifactId>
|
||||||
|
<version>3.5.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>attach-javadocs</id>
|
||||||
|
<goals>
|
||||||
|
<goal>jar</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-gpg-plugin</artifactId>
|
||||||
|
<version>1.5</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>sign-artifacts</id>
|
||||||
|
<phase>verify</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>sign</goal>
|
||||||
|
</goals>
|
||||||
|
<!--
|
||||||
|
<configuration>
|
||||||
|
<keyname>${gpg.keyname}</keyname>
|
||||||
|
<passphraseServerId>${gpg.keyname}</passphraseServerId>
|
||||||
|
</configuration>
|
||||||
|
-->
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
|
<version>3.6.0</version>
|
||||||
|
<configuration>
|
||||||
|
<descriptorRefs>
|
||||||
|
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||||
|
</descriptorRefs>
|
||||||
|
<archive>
|
||||||
|
<manifest>
|
||||||
|
<mainClass>com.hexadevlabs.gpt4allsample.Example4</mainClass>
|
||||||
|
</manifest>
|
||||||
|
</archive>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>make-assembly</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>single</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
|
||||||
|
|
||||||
|
</build>
|
||||||
|
|
||||||
|
|
||||||
|
</project>
|
@ -0,0 +1,349 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import jnr.ffi.Pointer;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class LLModel implements AutoCloseable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Config used for how to decode LLM outputs.
|
||||||
|
* High temperature closer to 1 gives more creative outputs
|
||||||
|
* while low temperature closer to 0 produce more precise outputs.
|
||||||
|
* <p>
|
||||||
|
* Use builder to set settings you want.
|
||||||
|
*/
|
||||||
|
public static class GenerationConfig extends LLModelLibrary.LLModelPromptContext {
|
||||||
|
|
||||||
|
private GenerationConfig() {
|
||||||
|
super(jnr.ffi.Runtime.getSystemRuntime());
|
||||||
|
logits_size.set(0);
|
||||||
|
tokens_size.set(0);
|
||||||
|
n_past.set(0);
|
||||||
|
n_ctx.set(1024);
|
||||||
|
n_predict.set(128);
|
||||||
|
top_k.set(40);
|
||||||
|
top_p.set(0.95);
|
||||||
|
temp.set(0.28);
|
||||||
|
n_batch.set(8);
|
||||||
|
repeat_penalty.set(1.1);
|
||||||
|
repeat_last_n.set(10);
|
||||||
|
context_erase.set(0.55);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private final GenerationConfig configToBuild;
|
||||||
|
|
||||||
|
public Builder() {
|
||||||
|
configToBuild = new GenerationConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withNPast(int n_past) {
|
||||||
|
configToBuild.n_past.set(n_past);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withNCtx(int n_ctx) {
|
||||||
|
configToBuild.n_ctx.set(n_ctx);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withNPredict(int n_predict) {
|
||||||
|
configToBuild.n_predict.set(n_predict);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withTopK(int top_k) {
|
||||||
|
configToBuild.top_k.set(top_k);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withTopP(float top_p) {
|
||||||
|
configToBuild.top_p.set(top_p);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withTemp(float temp) {
|
||||||
|
configToBuild.temp.set(temp);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withNBatch(int n_batch) {
|
||||||
|
configToBuild.n_batch.set(n_batch);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withRepeatPenalty(float repeat_penalty) {
|
||||||
|
configToBuild.repeat_penalty.set(repeat_penalty);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withRepeatLastN(int repeat_last_n) {
|
||||||
|
configToBuild.repeat_last_n.set(repeat_last_n);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder withContextErase(float context_erase) {
|
||||||
|
configToBuild.context_erase.set(context_erase);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GenerationConfig build() {
|
||||||
|
return configToBuild;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shortcut for making GenerativeConfig builder.
|
||||||
|
*/
|
||||||
|
public static GenerationConfig.Builder config(){
|
||||||
|
return new GenerationConfig.Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This may be set before any Model instance classe are instantiated to
|
||||||
|
* set where the model may be found. This may be needed if setting
|
||||||
|
* library search path by standard means is not available.
|
||||||
|
*/
|
||||||
|
public static String LIBRARY_SEARCH_PATH;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generally for debugging purposes only. Will print
|
||||||
|
* the numerical tokens as they are generated instead of the string representations.
|
||||||
|
* Will also print out the processed input tokens as numbers to standard out.
|
||||||
|
*/
|
||||||
|
public static boolean OUTPUT_DEBUG = false;
|
||||||
|
|
||||||
|
protected static LLModelLibrary library;
|
||||||
|
|
||||||
|
protected Pointer model;
|
||||||
|
|
||||||
|
protected String modelName;
|
||||||
|
|
||||||
|
public LLModel(Path modelPath) {
|
||||||
|
|
||||||
|
if(library==null) {
|
||||||
|
|
||||||
|
if (LIBRARY_SEARCH_PATH != null){
|
||||||
|
library = Util.loadSharedLibrary(LIBRARY_SEARCH_PATH);
|
||||||
|
library.llmodel_set_implementation_search_path(LIBRARY_SEARCH_PATH);
|
||||||
|
} else {
|
||||||
|
// Copy system libraries to Temp folder
|
||||||
|
Path tempLibraryDirectory = Util.copySharedLibraries();
|
||||||
|
library = Util.loadSharedLibrary(tempLibraryDirectory.toString());
|
||||||
|
|
||||||
|
library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString() );
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelType = type;
|
||||||
|
modelName = modelPath.getFileName().toString();
|
||||||
|
String modelPathAbs = modelPath.toAbsolutePath().toString();
|
||||||
|
|
||||||
|
LLModelLibrary.LLModelError error = new LLModelLibrary.LLModelError(jnr.ffi.Runtime.getSystemRuntime());
|
||||||
|
|
||||||
|
// Check if model file exists
|
||||||
|
if(!Files.exists(modelPath)){
|
||||||
|
throw new IllegalStateException("Model file does not exist: " + modelPathAbs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Model Struct. Will load dynamically the correct backend based on model type
|
||||||
|
model = library.llmodel_model_create2(modelPathAbs, "auto", error);
|
||||||
|
|
||||||
|
if(model == null) {
|
||||||
|
throw new IllegalStateException("Could not load gpt4all backend :" + error.message);
|
||||||
|
}
|
||||||
|
library.llmodel_loadModel(model, modelPathAbs);
|
||||||
|
|
||||||
|
if(!library.llmodel_isModelLoaded(model)){
|
||||||
|
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setThreadCount(int nThreads) {
|
||||||
|
library.llmodel_setThreadCount(this.model, nThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int threadCount() {
|
||||||
|
return library.llmodel_threadCount(this.model);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate text after the prompt
|
||||||
|
*
|
||||||
|
* @param prompt The text prompt to complete
|
||||||
|
* @param generationConfig What generation settings to use while generating text
|
||||||
|
* @return String The complete generated text
|
||||||
|
*/
|
||||||
|
public String generate(String prompt, GenerationConfig generationConfig) {
|
||||||
|
return generate(prompt, generationConfig, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate text after the prompt
|
||||||
|
*
|
||||||
|
* @param prompt The text prompt to complete
|
||||||
|
* @param generationConfig What generation settings to use while generating text
|
||||||
|
* @param streamToStdOut Should the generation be streamed to standard output. Useful for troubleshooting.
|
||||||
|
* @return String The complete generated text
|
||||||
|
*/
|
||||||
|
public String generate(String prompt, GenerationConfig generationConfig, boolean streamToStdOut) {
|
||||||
|
|
||||||
|
ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
|
||||||
|
ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
|
||||||
|
|
||||||
|
LLModelLibrary.ResponseCallback responseCallback = (int tokenID, Pointer response) -> {
|
||||||
|
|
||||||
|
if(LLModel.OUTPUT_DEBUG)
|
||||||
|
System.out.print("Response token " + tokenID + " " );
|
||||||
|
|
||||||
|
long len = 0;
|
||||||
|
byte nextByte;
|
||||||
|
do{
|
||||||
|
nextByte= response.getByte(len);
|
||||||
|
len++;
|
||||||
|
if(nextByte!=0) {
|
||||||
|
bufferingForWholeGeneration.write(nextByte);
|
||||||
|
if(streamToStdOut){
|
||||||
|
bufferingForStdOutStream.write(nextByte);
|
||||||
|
// Test if Buffer is UTF8 valid string.
|
||||||
|
byte[] currentBytes = bufferingForStdOutStream.toByteArray();
|
||||||
|
String validString = Util.getValidUtf8(currentBytes);
|
||||||
|
if(validString!=null){ // is valid string
|
||||||
|
System.out.print(validString);
|
||||||
|
// reset the buffer for next utf8 sequence to buffer
|
||||||
|
bufferingForStdOutStream.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while(nextByte != 0);
|
||||||
|
|
||||||
|
return true; // continue generating
|
||||||
|
};
|
||||||
|
|
||||||
|
library.llmodel_prompt(this.model,
|
||||||
|
prompt,
|
||||||
|
(int tokenID) -> {
|
||||||
|
if(LLModel.OUTPUT_DEBUG)
|
||||||
|
System.out.println("token " + tokenID);
|
||||||
|
return true; // continue processing
|
||||||
|
},
|
||||||
|
responseCallback,
|
||||||
|
(boolean isRecalculating) -> {
|
||||||
|
if(LLModel.OUTPUT_DEBUG)
|
||||||
|
System.out.println("recalculating");
|
||||||
|
return isRecalculating; // continue generating
|
||||||
|
},
|
||||||
|
generationConfig);
|
||||||
|
|
||||||
|
return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public static class ChatCompletionResponse {
|
||||||
|
public String model;
|
||||||
|
public Usage usage;
|
||||||
|
public List<Map<String, String>> choices;
|
||||||
|
|
||||||
|
// Getters and setters
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Usage {
|
||||||
|
public int promptTokens;
|
||||||
|
public int completionTokens;
|
||||||
|
public int totalTokens;
|
||||||
|
|
||||||
|
// Getters and setters
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
|
||||||
|
GenerationConfig generationConfig) {
|
||||||
|
return chatCompletion(messages, generationConfig, false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* chatCompletion formats the existing chat conversation into a template to be
|
||||||
|
* easier to process for chat UIs. It is not absolutely necessary as generate method
|
||||||
|
* may be directly used to make generations with gpt models.
|
||||||
|
*
|
||||||
|
* @param messages List of Maps "role"->"user", "content"->"...", "role"-> "assistant"->"..."
|
||||||
|
* @param generationConfig How to decode/process the generation.
|
||||||
|
* @param streamToStdOut Send tokens as they are calculated Standard output.
|
||||||
|
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
|
||||||
|
* @return ChatCompletionResponse contains stats and generated Text.
|
||||||
|
*/
|
||||||
|
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
|
||||||
|
GenerationConfig generationConfig, boolean streamToStdOut,
|
||||||
|
boolean outputFullPromptToStdOut) {
|
||||||
|
String fullPrompt = buildPrompt(messages);
|
||||||
|
|
||||||
|
if(outputFullPromptToStdOut)
|
||||||
|
System.out.print(fullPrompt);
|
||||||
|
|
||||||
|
String generatedText = generate(fullPrompt, generationConfig, streamToStdOut);
|
||||||
|
|
||||||
|
ChatCompletionResponse response = new ChatCompletionResponse();
|
||||||
|
response.model = this.modelName;
|
||||||
|
|
||||||
|
Usage usage = new Usage();
|
||||||
|
usage.promptTokens = fullPrompt.length();
|
||||||
|
usage.completionTokens = generatedText.length();
|
||||||
|
usage.totalTokens = fullPrompt.length() + generatedText.length();
|
||||||
|
response.usage = usage;
|
||||||
|
|
||||||
|
Map<String, String> message = new HashMap<>();
|
||||||
|
message.put("role", "assistant");
|
||||||
|
message.put("content", generatedText);
|
||||||
|
|
||||||
|
response.choices = List.of(message);
|
||||||
|
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static String buildPrompt(List<Map<String, String>> messages) {
|
||||||
|
StringBuilder fullPrompt = new StringBuilder();
|
||||||
|
|
||||||
|
for (Map<String, String> message : messages) {
|
||||||
|
if ("system".equals(message.get("role"))) {
|
||||||
|
String systemMessage = message.get("content") + "\n";
|
||||||
|
fullPrompt.append(systemMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fullPrompt.append("### Instruction: \n" +
|
||||||
|
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||||
|
"### Prompt: ");
|
||||||
|
|
||||||
|
for (Map<String, String> message : messages) {
|
||||||
|
if ("user".equals(message.get("role"))) {
|
||||||
|
String userMessage = "\n" + message.get("content");
|
||||||
|
fullPrompt.append(userMessage);
|
||||||
|
}
|
||||||
|
if ("assistant".equals(message.get("role"))) {
|
||||||
|
String assistantMessage = "\n### Response: " + message.get("content");
|
||||||
|
fullPrompt.append(assistantMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fullPrompt.append("\n### Response:");
|
||||||
|
|
||||||
|
return fullPrompt.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws Exception {
|
||||||
|
library.llmodel_model_destroy(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,79 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import jnr.ffi.Pointer;
|
||||||
|
import jnr.ffi.Struct;
|
||||||
|
import jnr.ffi.annotations.Delegate;
|
||||||
|
import jnr.ffi.annotations.Encoding;
|
||||||
|
import jnr.ffi.annotations.In;
|
||||||
|
import jnr.ffi.annotations.Out;
|
||||||
|
import jnr.ffi.types.u_int64_t;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The basic Native library interface the provides all the LLM functions.
|
||||||
|
*/
|
||||||
|
public interface LLModelLibrary {
|
||||||
|
|
||||||
|
interface PromptCallback {
|
||||||
|
@Delegate
|
||||||
|
boolean invoke(int token_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ResponseCallback {
|
||||||
|
@Delegate
|
||||||
|
boolean invoke(int token_id, Pointer response);
|
||||||
|
}
|
||||||
|
|
||||||
|
interface RecalculateCallback {
|
||||||
|
@Delegate
|
||||||
|
boolean invoke(boolean is_recalculating);
|
||||||
|
}
|
||||||
|
|
||||||
|
class LLModelError extends Struct {
|
||||||
|
public final Struct.AsciiStringRef message = new Struct.AsciiStringRef();
|
||||||
|
public final int32_t status = new int32_t();
|
||||||
|
public LLModelError(jnr.ffi.Runtime runtime) {
|
||||||
|
super(runtime);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class LLModelPromptContext extends Struct {
|
||||||
|
public final Pointer logits = new Pointer();
|
||||||
|
public final ssize_t logits_size = new ssize_t();
|
||||||
|
public final Pointer tokens = new Pointer();
|
||||||
|
public final ssize_t tokens_size = new ssize_t();
|
||||||
|
public final int32_t n_past = new int32_t();
|
||||||
|
public final int32_t n_ctx = new int32_t();
|
||||||
|
public final int32_t n_predict = new int32_t();
|
||||||
|
public final int32_t top_k = new int32_t();
|
||||||
|
public final Float top_p = new Float();
|
||||||
|
public final Float temp = new Float();
|
||||||
|
public final int32_t n_batch = new int32_t();
|
||||||
|
public final Float repeat_penalty = new Float();
|
||||||
|
public final int32_t repeat_last_n = new int32_t();
|
||||||
|
public final Float context_erase = new Float();
|
||||||
|
|
||||||
|
public LLModelPromptContext(jnr.ffi.Runtime runtime) {
|
||||||
|
super(runtime);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Pointer llmodel_model_create2(String model_path, String build_variant, @Out LLModelError llmodel_error);
|
||||||
|
void llmodel_model_destroy(Pointer model);
|
||||||
|
boolean llmodel_loadModel(Pointer model, String model_path);
|
||||||
|
boolean llmodel_isModelLoaded(Pointer model);
|
||||||
|
@u_int64_t long llmodel_get_state_size(Pointer model);
|
||||||
|
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
||||||
|
@u_int64_t long llmodel_restore_state_data(Pointer model, Pointer src);
|
||||||
|
|
||||||
|
void llmodel_set_implementation_search_path(String path);
|
||||||
|
|
||||||
|
// ctx was an @Out ... without @Out crash
|
||||||
|
void llmodel_prompt(Pointer model, @Encoding("UTF-8") String prompt,
|
||||||
|
PromptCallback prompt_callback,
|
||||||
|
ResponseCallback response_callback,
|
||||||
|
RecalculateCallback recalculate_callback,
|
||||||
|
@In LLModelPromptContext ctx);
|
||||||
|
void llmodel_setThreadCount(Pointer model, int n_threads);
|
||||||
|
int llmodel_threadCount(Pointer model);
|
||||||
|
}
|
@ -0,0 +1,147 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import jnr.ffi.LibraryLoader;
|
||||||
|
import jnr.ffi.LibraryOption;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.charset.CharacterCodingException;
|
||||||
|
import java.nio.charset.CharsetDecoder;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class Util {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(Util.class);
|
||||||
|
private static final CharsetDecoder cs = StandardCharsets.UTF_8.newDecoder();
|
||||||
|
|
||||||
|
public static LLModelLibrary loadSharedLibrary(String librarySearchPath){
|
||||||
|
String libraryName = "llmodel";
|
||||||
|
Map<LibraryOption, Object> libraryOptions = new HashMap<>();
|
||||||
|
libraryOptions.put(LibraryOption.LoadNow, true); // load immediately instead of lazily (ie on first use)
|
||||||
|
libraryOptions.put(LibraryOption.IgnoreError, false); // calls shouldn't save last errno after call
|
||||||
|
|
||||||
|
if(librarySearchPath!=null) {
|
||||||
|
Map<String, List<String>> searchPaths = new HashMap<>();
|
||||||
|
searchPaths.put(libraryName, List.of(librarySearchPath));
|
||||||
|
|
||||||
|
return LibraryLoader.loadLibrary(LLModelLibrary.class,
|
||||||
|
libraryOptions,
|
||||||
|
searchPaths,
|
||||||
|
libraryName
|
||||||
|
);
|
||||||
|
}else {
|
||||||
|
|
||||||
|
return LibraryLoader.loadLibrary(LLModelLibrary.class,
|
||||||
|
libraryOptions,
|
||||||
|
libraryName
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy over shared library files from resource package to
|
||||||
|
* target Temp directory.
|
||||||
|
* Returns Path to the temp directory holding the shared libraries
|
||||||
|
*/
|
||||||
|
public static Path copySharedLibraries() {
|
||||||
|
try {
|
||||||
|
// Identify the OS and architecture
|
||||||
|
String osName = System.getProperty("os.name").toLowerCase();
|
||||||
|
boolean isWindows = osName.startsWith("windows");
|
||||||
|
boolean isMac = osName.startsWith("mac os x");
|
||||||
|
boolean isLinux = osName.startsWith("linux");
|
||||||
|
if(isWindows) osName = "windows";
|
||||||
|
if(isMac) osName = "macos";
|
||||||
|
if(isLinux) osName = "linux";
|
||||||
|
|
||||||
|
//String osArch = System.getProperty("os.arch");
|
||||||
|
|
||||||
|
// Create a temporary directory
|
||||||
|
Path tempDirectory = Files.createTempDirectory("nativeLibraries");
|
||||||
|
tempDirectory.toFile().deleteOnExit();
|
||||||
|
|
||||||
|
String[] libraryNames = {
|
||||||
|
"gptj-default",
|
||||||
|
"gptj-avxonly",
|
||||||
|
"llmodel",
|
||||||
|
"mpt-default",
|
||||||
|
"llamamodel-230511-default",
|
||||||
|
"llamamodel-230519-default",
|
||||||
|
"llamamodel-mainline-default"
|
||||||
|
};
|
||||||
|
|
||||||
|
for (String libraryName : libraryNames) {
|
||||||
|
|
||||||
|
if(isWindows){
|
||||||
|
libraryName = libraryName + ".dll";
|
||||||
|
} else if(isMac){
|
||||||
|
libraryName = "lib" + libraryName + ".dylib";
|
||||||
|
} else if(isLinux) {
|
||||||
|
libraryName = "lib"+ libraryName + ".so";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the resource path based on the OS and architecture
|
||||||
|
String nativeLibraryPath = "/native/" + osName + "/" + libraryName;
|
||||||
|
|
||||||
|
// Get the library resource as a stream
|
||||||
|
InputStream in = Util.class.getResourceAsStream(nativeLibraryPath);
|
||||||
|
if (in == null) {
|
||||||
|
throw new RuntimeException("Unable to find native library: " + nativeLibraryPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a file in the temporary directory with the original library name
|
||||||
|
Path tempLibraryPath = tempDirectory.resolve(libraryName);
|
||||||
|
|
||||||
|
// Use Files.copy to copy the library to the temporary file
|
||||||
|
Files.copy(in, tempLibraryPath, StandardCopyOption.REPLACE_EXISTING);
|
||||||
|
|
||||||
|
// Close the input stream
|
||||||
|
in.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add shutdown hook to delete tempDir on JVM exit
|
||||||
|
// On Windows deleting dll files that are loaded into memory is not possible.
|
||||||
|
if(!isWindows) {
|
||||||
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
|
try {
|
||||||
|
Files.walk(tempDirectory)
|
||||||
|
.sorted(Comparator.reverseOrder())
|
||||||
|
.map(Path::toFile)
|
||||||
|
.forEach(file -> {
|
||||||
|
try {
|
||||||
|
Files.delete(file.toPath());
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("Deleting temp library file", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("Deleting temp directory for libraries", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempDirectory;
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to load native libraries", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getValidUtf8(byte[] bytes) {
|
||||||
|
try {
|
||||||
|
return cs.decode(ByteBuffer.wrap(bytes)).toString();
|
||||||
|
} catch (CharacterCodingException e) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GPTJ chat completion, multiple messages
|
||||||
|
*/
|
||||||
|
public class Example1 {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
|
||||||
|
// Optionally in case override to location of shared libraries is necessary
|
||||||
|
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||||
|
|
||||||
|
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
|
||||||
|
|
||||||
|
LLModel.GenerationConfig config = LLModel.config()
|
||||||
|
.withNPredict(4096).build();
|
||||||
|
|
||||||
|
gptjModel.chatCompletion(
|
||||||
|
List.of(Map.of("role", "user", "content", "Add 2+2"),
|
||||||
|
Map.of("role", "assistant", "content", "4"),
|
||||||
|
Map.of("role", "user", "content", "Multiply 4 * 5")), config, true, true);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generation with MPT model
|
||||||
|
*/
|
||||||
|
public class Example2 {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
|
||||||
|
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||||
|
|
||||||
|
// Optionally in case override to location of shared libraries is necessary
|
||||||
|
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||||
|
|
||||||
|
try (LLModel mptModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-mpt-7b-instruct.bin"))) {
|
||||||
|
|
||||||
|
LLModel.GenerationConfig config =
|
||||||
|
LLModel.config()
|
||||||
|
.withNPredict(4096)
|
||||||
|
.withRepeatLastN(64)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
mptModel.generate(prompt, config, true);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,33 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import jnr.ffi.LibraryLoader;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GPTJ chat completion with system message
|
||||||
|
*/
|
||||||
|
public class Example3 {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
|
||||||
|
// Optionally in case override to location of shared libraries is necessary
|
||||||
|
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||||
|
|
||||||
|
try ( LLModel gptjModel = new LLModel(Path.of("C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin")) ){
|
||||||
|
|
||||||
|
LLModel.GenerationConfig config = LLModel.config()
|
||||||
|
.withNPredict(4096).build();
|
||||||
|
|
||||||
|
// String result = gptjModel.generate(prompt, config, true);
|
||||||
|
gptjModel.chatCompletion(
|
||||||
|
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||||
|
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||||
|
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
|
public class Example4 {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
|
||||||
|
String prompt = "### Human:\nWhat is the meaning of life\n### Assistant:";
|
||||||
|
// The emoji is poop emoji. The Unicode character is encoded as surrogate pair for Java string.
|
||||||
|
// LLM should correctly identify it as poop emoji in the description
|
||||||
|
//String prompt = "### Human:\nDescribe the meaning of this emoji \uD83D\uDCA9\n### Assistant:";
|
||||||
|
//String prompt = "### Human:\nOutput the unicode character of smiley face emoji\n### Assistant:";
|
||||||
|
|
||||||
|
// Optionally in case override to location of shared libraries is necessary
|
||||||
|
//LLModel.LIBRARY_SEARCH_PATH = "C:\\Users\\felix\\gpt4all\\lib\\";
|
||||||
|
|
||||||
|
String model = "ggml-vicuna-7b-1.1-q4_2.bin";
|
||||||
|
//String model = "ggml-gpt4all-j-v1.3-groovy.bin";
|
||||||
|
//String model = "ggml-mpt-7b-instruct.bin";
|
||||||
|
String basePath = "C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\";
|
||||||
|
//String basePath = "/Users/fzaslavs/Library/Application Support/nomic.ai/GPT4All/";
|
||||||
|
|
||||||
|
try (LLModel mptModel = new LLModel(Path.of(basePath + model))) {
|
||||||
|
|
||||||
|
LLModel.GenerationConfig config =
|
||||||
|
LLModel.config()
|
||||||
|
.withNPredict(4096)
|
||||||
|
.withRepeatLastN(64)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
String result = mptModel.generate(prompt, config, true);
|
||||||
|
|
||||||
|
System.out.println("Code points:");
|
||||||
|
result.codePoints().forEach(System.out::println);
|
||||||
|
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
package com.hexadevlabs.gpt4all;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.AfterAll;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For now disabled until can run with latest library on windows
|
||||||
|
*/
|
||||||
|
@Disabled
|
||||||
|
public class Tests {
|
||||||
|
|
||||||
|
static LLModel model;
|
||||||
|
static String PATH_TO_MODEL="C:\\Users\\felix\\AppData\\Local\\nomic.ai\\GPT4All\\ggml-gpt4all-j-v1.3-groovy.bin";
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void before(){
|
||||||
|
model = new LLModel(Path.of(PATH_TO_MODEL));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void simplePrompt(){
|
||||||
|
LLModel.GenerationConfig config =
|
||||||
|
LLModel.config()
|
||||||
|
.withNPredict(20)
|
||||||
|
.build();
|
||||||
|
String prompt =
|
||||||
|
"### Instruction: \n" +
|
||||||
|
"The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n" +
|
||||||
|
"### Prompt: \n" +
|
||||||
|
"Add 2+2\n" +
|
||||||
|
"### Response: 4\n" +
|
||||||
|
"Multiply 4 * 5\n" +
|
||||||
|
"### Response:";
|
||||||
|
String generatedText = model.generate(prompt, config, true);
|
||||||
|
assertTrue(generatedText.contains("20"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void chatCompletion(){
|
||||||
|
LLModel.GenerationConfig config =
|
||||||
|
LLModel.config()
|
||||||
|
.withNPredict(20)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LLModel.ChatCompletionResponse response= model.chatCompletion(
|
||||||
|
List.of(Map.of("role", "system", "content", "You are a helpful assistant"),
|
||||||
|
Map.of("role", "user", "content", "Add 2+2")), config, true, true);
|
||||||
|
|
||||||
|
assertTrue( response.choices.get(0).get("content").contains("4") );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
public static void after() throws Exception {
|
||||||
|
if(model != null)
|
||||||
|
model.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user