Implement configurable context length (#1749)

This commit is contained in:
Jared Van Bortel 2023-12-16 17:58:15 -05:00 committed by GitHub
parent 7aa0f779de
commit d1c56b8b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 291 additions and 135 deletions

View File

@ -714,8 +714,9 @@ Bert::~Bert() {
bert_free(d_ptr->ctx);
}
bool Bert::loadModel(const std::string &modelPath)
bool Bert::loadModel(const std::string &modelPath, int n_ctx)
{
(void)n_ctx;
d_ptr->ctx = bert_load_from_file(modelPath.c_str());
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = d_ptr->ctx != nullptr;
@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const
return d_ptr->modelLoaded;
}
size_t Bert::requiredMem(const std::string &/*modelPath*/)
size_t Bert::requiredMem(const std::string &modelPath, int n_ctx)
{
(void)modelPath;
(void)n_ctx;
return 0;
}

View File

@ -18,9 +18,9 @@ public:
bool supportsEmbedding() const override { return true; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -676,7 +676,8 @@ GPTJ::GPTJ()
d_ptr->modelLoaded = false;
}
size_t GPTJ::requiredMem(const std::string &modelPath) {
size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
gptj_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
return mem_req;
}
bool GPTJ::loadModel(const std::string &modelPath) {
bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;

View File

@ -17,9 +17,9 @@ public:
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -120,7 +120,8 @@ struct llama_file_hparams {
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
};
size_t LLamaModel::requiredMem(const std::string &modelPath) {
size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
// TODO(cebtenzzre): update to GGUF
auto fin = std::ifstream(modelPath, std::ios::binary);
fin.seekg(0, std::ios_base::end);
size_t filesize = fin.tellg();
@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
const size_t n_ctx = 2048;
const size_t kvcache_element_size = 2; // fp16
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
return filesize + est_kvcache_size;
}
bool LLamaModel::loadModel(const std::string &modelPath)
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
{
gpt_params params;
// load the model
if (n_ctx < 8) {
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
n_ctx = 8;
}
// -- load the model --
d_ptr->model_params = llama_model_default_params();
d_ptr->model_params.use_mmap = params.use_mmap;
d_ptr->model_params.use_mmap = params.use_mmap;
#if defined (__APPLE__)
d_ptr->model_params.use_mlock = true;
d_ptr->model_params.use_mlock = true;
#else
d_ptr->model_params.use_mlock = params.use_mlock;
d_ptr->model_params.use_mlock = params.use_mlock;
#endif
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = 2048;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
#ifdef GGML_USE_METAL
if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl;
@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
return false;
}
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}
// -- initialize the context --
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
if (!d_ptr->ctx) {
#ifdef GGML_USE_KOMPUTE

View File

@ -17,9 +17,9 @@ public:
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha
return nullptr;
}
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) {
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) {
if (!has_at_least_minimal_hardware()) {
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
return nullptr;
@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
if(impl) {
LLModel* metalimpl = impl->m_construct();
metalimpl->m_implementation = impl;
size_t req_mem = metalimpl->requiredMem(modelPath);
/* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at
* load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in
* most (all?) places where this is called, causing underestimation of required
* memory. */
size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx);
float req_to_total = (float) req_mem / (float) total_mem;
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
if (req_to_total >= 0.53) {
@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
}
}
}
#else
(void)n_ctx;
#endif
if (!impl) {

View File

@ -37,7 +37,7 @@ public:
static bool isImplementation(const Dlhandle&);
static const std::vector<Implementation>& implementationList();
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto");
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
static std::vector<GPUDevice> availableGPUDevices();
static void setImplementationsSearchPath(const std::string& path);
static const std::string& implementationsSearchPath();
@ -74,9 +74,9 @@ public:
virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath) = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }

View File

@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) {
delete reinterpret_cast<LLModelWrapper*>(model);
}
size_t llmodel_required_mem(llmodel_model model, const char *model_path)
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->requiredMem(model_path);
return wrapper->llModel->requiredMem(model_path, n_ctx);
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->loadModel(model_path);
return wrapper->llModel->loadModel(model_path, n_ctx);
}
bool llmodel_isModelLoaded(llmodel_model model)

View File

@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model);
* Estimate RAM requirement for a model file
* @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
*/
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx);
/**
* Load a model from a file.
* @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return true if the model was loaded successfully, false otherwise.
*/
bool llmodel_loadModel(llmodel_model model, const char *model_path);
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx);
/**
* Check if a model is loaded.

View File

@ -188,7 +188,7 @@ public class LLModel : ILLModel
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath)
{
return NativeMethods.llmodel_loadModel(_handle, modelPath);
return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
}
protected void Destroy()

View File

@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
[NativeTypeName("int32_t")] int n_ctx);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]

View File

@ -39,7 +39,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started");
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
if (!loadedSuccessfully)
{

View File

@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
return nullptr;
}
if (!llmodel_loadModel(model, fname)) {
if (!llmodel_loadModel(model, fname, 2048)) {
llmodel_model_destroy(model);
return nullptr;
}

View File

@ -195,7 +195,7 @@ public class LLModel implements AutoCloseable {
if(model == null) {
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
}
library.llmodel_loadModel(model, modelPathAbs);
library.llmodel_loadModel(model, modelPathAbs, 2048);
if(!library.llmodel_isModelLoaded(model)){
throw new IllegalStateException("The model " + modelName + " could not be loaded");

View File

@ -61,7 +61,7 @@ public interface LLModelLibrary {
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
void llmodel_model_destroy(Pointer model);
boolean llmodel_loadModel(Pointer model, String model_path);
boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
boolean llmodel_isModelLoaded(Pointer model);
@u_int64_t long llmodel_get_state_size(Pointer model);
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);

View File

@ -1,2 +1,2 @@
from .gpt4all import Embed4All, GPT4All # noqa
from .pyllmodel import LLModel # noqa
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
from .pyllmodel import LLModel as LLModel

View File

@ -69,6 +69,7 @@ class GPT4All:
allow_download: bool = True,
n_threads: Optional[int] = None,
device: Optional[str] = "cpu",
n_ctx: int = 2048,
verbose: bool = False,
):
"""
@ -90,15 +91,16 @@ class GPT4All:
Default is "cpu".
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
n_ctx: Maximum size of context window
verbose: If True, print debug messages.
"""
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
if device is not None:
if device != "cpu":
self.model.init_gpu(model_path=self.config["path"], device=device)
self.model.load_model(self.config["path"])
if device is not None and device != "cpu":
self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
self.model.load_model(self.config["path"], n_ctx)
# Set n_threads
if n_threads is not None:
self.model.set_thread_count(n_threads)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import ctypes
import importlib.resources
import logging
@ -7,6 +9,7 @@ import re
import subprocess
import sys
import threading
from enum import Enum
from queue import Queue
from typing import Callable, Iterable, List
@ -72,9 +75,9 @@ llmodel.llmodel_model_create2.restype = ctypes.c_void_p
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_model_destroy.restype = None
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
@ -114,7 +117,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8"))
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p:
err = ctypes.c_char_p()
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
if model is None:
raise ValueError(f"Unable to instantiate model: {err.decode()}")
s = err.value
raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}")
return model
# Symbol to terminate from generator
class Sentinel(Enum):
TERMINATING_SYMBOL = 0
class LLModel:
"""
Base class and universal wrapper for GPT4All language models
@ -173,12 +182,16 @@ class LLModel:
if self.model is not None:
self.llmodel_lib.llmodel_model_destroy(self.model)
def memory_needed(self, model_path: str) -> int:
model_path_enc = model_path.encode("utf-8")
self.model = _create_model(model_path_enc)
return llmodel.llmodel_required_mem(self.model, model_path_enc)
def memory_needed(self, model_path: str, n_ctx: int) -> int:
self.model = None
return self._memory_needed(model_path, n_ctx)
def list_gpu(self, model_path: str) -> list:
def _memory_needed(self, model_path: str, n_ctx: int) -> int:
if self.model is None:
self.model = _create_model(model_path.encode())
return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx)
def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]:
"""
Lists available GPU devices that satisfy the model's memory requirements.
@ -186,45 +199,41 @@ class LLModel:
----------
model_path : str
Path to the model.
n_ctx : int
Maximum size of context window
Returns
-------
list
A list of LLModelGPUDevice structures representing available GPU devices.
"""
if self.model is not None:
model_path_enc = model_path.encode("utf-8")
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
else:
mem_required = self.memory_needed(model_path)
mem_required = self._memory_needed(model_path, n_ctx)
return self._list_gpu(mem_required)
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
num_devices = ctypes.c_int32(0)
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
if not devices_ptr:
raise ValueError("Unable to retrieve available GPU devices")
devices = [devices_ptr[i] for i in range(num_devices.value)]
return devices
return devices_ptr[:num_devices.value]
def init_gpu(self, model_path: str, device: str):
if self.model is not None:
model_path_enc = model_path.encode("utf-8")
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
else:
mem_required = self.memory_needed(model_path)
device_enc = device.encode("utf-8")
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc)
def init_gpu(self, model_path: str, device: str, n_ctx: int):
mem_required = self._memory_needed(model_path, n_ctx)
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode())
if not success:
# Retrieve all GPUs without considering memory requirements.
num_devices = ctypes.c_int32(0)
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
if not all_devices_ptr:
raise ValueError("Unable to retrieve list of all GPU devices")
all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)]
all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]]
# Retrieve GPUs that meet the memory requirements using list_gpu
available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)]
available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)]
# Identify GPUs that are unavailable due to insufficient memory or features
unavailable_gpus = set(all_gpus) - set(available_gpus)
unavailable_gpus = set(all_gpus).difference(available_gpus)
# Formulate the error message
error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
@ -232,7 +241,7 @@ class LLModel:
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
raise ValueError(error_msg)
def load_model(self, model_path: str) -> bool:
def load_model(self, model_path: str, n_ctx: int) -> bool:
"""
Load model from a file.
@ -240,15 +249,16 @@ class LLModel:
----------
model_path : str
Model filepath
n_ctx : int
Maximum size of context window
Returns
-------
True if model loaded successfully, False otherwise
"""
model_path_enc = model_path.encode("utf-8")
self.model = _create_model(model_path_enc)
self.model = _create_model(model_path.encode())
llmodel.llmodel_loadModel(self.model, model_path_enc)
llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx)
filename = os.path.basename(model_path)
self.model_name = os.path.splitext(filename)[0]
@ -312,7 +322,7 @@ class LLModel:
raise ValueError("Text must not be None or empty")
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8'))
c_text = ctypes.c_char_p(text.encode())
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
llmodel.llmodel_free_embedding(embedding_ptr)
@ -357,7 +367,7 @@ class LLModel:
prompt,
)
prompt_bytes = prompt.encode("utf-8")
prompt_bytes = prompt.encode()
prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context(
@ -385,10 +395,7 @@ class LLModel:
def prompt_model_streaming(
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
# Symbol to terminate from generator
TERMINATING_SYMBOL = object()
output_queue: Queue = Queue()
output_queue: Queue[str | Sentinel] = Queue()
# Put response tokens into an output queue
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
@ -405,7 +412,7 @@ class LLModel:
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
output_queue.put(TERMINATING_SYMBOL)
output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
@ -419,7 +426,7 @@ class LLModel:
# Generator
while True:
response = output_queue.get()
if response is TERMINATING_SYMBOL:
if isinstance(response, Sentinel):
break
yield response
@ -442,7 +449,7 @@ class LLModel:
else:
# beginning of a byte sequence
if len(self.buffer) > 0:
decoded.append(self.buffer.decode('utf-8', 'replace'))
decoded.append(self.buffer.decode(errors='replace'))
self.buffer.clear()
@ -451,7 +458,7 @@ class LLModel:
if self.buff_expecting_cont_bytes <= 0:
# received the whole sequence or an out of place continuation byte
decoded.append(self.buffer.decode('utf-8', 'replace'))
decoded.append(self.buffer.decode(errors='replace'))
self.buffer.clear()
self.buff_expecting_cont_bytes = 0

View File

@ -117,7 +117,7 @@ def test_empty_embedding():
def test_download_model(tmp_path: Path):
import gpt4all.gpt4all
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
try:
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
model_path = tmp_path / model.config['filename']

View File

@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
{
auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str()) ));
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) ));
}
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
@ -161,7 +161,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
}
}
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str());
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048);
if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return;

View File

@ -20,15 +20,17 @@ ChatGPT::ChatGPT()
{
}
size_t ChatGPT::requiredMem(const std::string &modelPath)
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return 0;
}
bool ChatGPT::loadModel(const std::string &modelPath)
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return true;
}

View File

@ -48,9 +48,9 @@ public:
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -5,7 +5,7 @@
#include <QDataStream>
#define CHAT_FORMAT_MAGIC 0xF5D553CC
#define CHAT_FORMAT_VERSION 6
#define CHAT_FORMAT_VERSION 7
class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)

View File

@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = model;
} else {
// TODO: make configurable in UI
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
std::string buildVariant = "auto";
#if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal)
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal");
else
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto");
#else
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
buildVariant = "metal";
#endif
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) {
// Update the settings that a model is being loaded and update the device list
@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (requestedDevice == "CPU") {
emit reportFallbackReason(""); // fallback not applicable
} else {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx);
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
LLModel::GPUDevice *device = nullptr;
@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Report which device we're actually using
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU
} else if (!success) {
// llama_init_from_file returned nullptr
emit reportDevice("CPU");
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
success = m_llModelInfo.model->loadModel(filePath.toStdString());
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 6) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;
if (version >= 6) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
}
quint64 logitsSize;
stream >> logitsSize;
if (!discardKV) {

View File

@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel()
return false;
}
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
bool success = m_model->loadModel(filePath.toStdString());
m_model = LLModel::Implementation::construct(filePath.toStdString());
bool success = m_model->loadModel(filePath.toStdString(), 2048);
if (!success) {
qWarning() << "WARNING: Could not load sbert";
delete m_model;

View File

@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s)
m_promptBatchSize = s;
}
int ModelInfo::contextLength() const
{
return MySettings::globalInstance()->modelContextLength(*this);
}
void ModelInfo::setContextLength(int l)
{
if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/);
m_contextLength = l;
}
double ModelInfo::repeatPenalty() const
{
return MySettings::globalInstance()->modelRepeatPenalty(*this);
@ -274,6 +285,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->maxLength();
case PromptBatchSizeRole:
return info->promptBatchSize();
case ContextLengthRole:
return info->contextLength();
case RepeatPenaltyRole:
return info->repeatPenalty();
case RepeatPenaltyTokensRole:
@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::TopKRole, model.topK());
updateData(id, ModelList::MaxLengthRole, model.maxLength());
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
updateData(id, ModelList::ContextLengthRole, model.contextLength());
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
if (obj.contains("promptBatchSize"))
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
if (obj.contains("contextLength"))
updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt());
if (obj.contains("repeatPenalty"))
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
if (obj.contains("repeatPenaltyTokens"))
@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings()
const int maxLength = settings.value(g + "/maxLength").toInt();
Q_ASSERT(settings.contains(g + "/promptBatchSize"));
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
Q_ASSERT(settings.contains(g + "/contextLength"));
const int contextLength = settings.value(g + "/contextLength").toInt();
Q_ASSERT(settings.contains(g + "/repeatPenalty"));
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings()
updateData(id, ModelList::TopKRole, topK);
updateData(id, ModelList::MaxLengthRole, maxLength);
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
updateData(id, ModelList::ContextLengthRole, contextLength);
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
updateData(id, ModelList::PromptTemplateRole, promptTemplate);

View File

@ -39,6 +39,7 @@ struct ModelInfo {
Q_PROPERTY(int topK READ topK WRITE setTopK)
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength)
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
@ -94,6 +95,8 @@ public:
void setMaxLength(int l);
int promptBatchSize() const;
void setPromptBatchSize(int s);
int contextLength() const;
void setContextLength(int l);
double repeatPenalty() const;
void setRepeatPenalty(double p);
int repeatPenaltyTokens() const;
@ -112,6 +115,7 @@ private:
int m_topK = 40;
int m_maxLength = 4096;
int m_promptBatchSize = 128;
int m_contextLength = 2048;
double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
@ -227,6 +231,7 @@ public:
TopKRole,
MaxLengthRole,
PromptBatchSizeRole,
ContextLengthRole,
RepeatPenaltyRole,
RepeatPenaltyTokensRole,
PromptTemplateRole,
@ -269,6 +274,7 @@ public:
roles[TopKRole] = "topK";
roles[MaxLengthRole] = "maxLength";
roles[PromptBatchSizeRole] = "promptBatchSize";
roles[ContextLengthRole] = "contextLength";
roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate";

View File

@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
setModelTopK(model, model.m_topK);;
setModelMaxLength(model, model.m_maxLength);
setModelPromptBatchSize(model, model.m_promptBatchSize);
setModelContextLength(model, model.m_contextLength);
setModelRepeatPenalty(model, model.m_repeatPenalty);
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
setModelPromptTemplate(model, model.m_promptTemplate);
@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force)
emit promptBatchSizeChanged(m);
}
int MySettings::modelContextLength(const ModelInfo &m) const
{
QSettings setting;
setting.sync();
return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt();
}
void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force)
{
if (modelContextLength(m) == l && !force)
return;
QSettings setting;
if (m.m_contextLength == l && !m.isClone)
setting.remove(QString("model-%1").arg(m.id()) + "/contextLength");
else
setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l);
setting.sync();
if (!force)
emit contextLengthChanged(m);
}
double MySettings::modelRepeatPenalty(const ModelInfo &m) const
{
QSettings setting;

View File

@ -1,6 +1,8 @@
#ifndef MYSETTINGS_H
#define MYSETTINGS_H
#include <cstdint>
#include <QObject>
#include <QMutex>
@ -59,6 +61,8 @@ public:
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
QString modelSystemPrompt(const ModelInfo &m) const;
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
int modelContextLength(const ModelInfo &m) const;
Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false);
// Application settings
int threadCount() const;
@ -79,6 +83,8 @@ public:
void setForceMetal(bool b);
QString device() const;
void setDevice(const QString &u);
int32_t contextLength() const;
void setContextLength(int32_t value);
// Release/Download settings
QString lastVersionStarted() const;
@ -114,6 +120,7 @@ Q_SIGNALS:
void topKChanged(const ModelInfo &model);
void maxLengthChanged(const ModelInfo &model);
void promptBatchSizeChanged(const ModelInfo &model);
void contextLengthChanged(const ModelInfo &model);
void repeatPenaltyChanged(const ModelInfo &model);
void repeatPenaltyTokensChanged(const ModelInfo &model);
void promptTemplateChanged(const ModelInfo &model);

View File

@ -349,13 +349,61 @@ MySettingsTab {
rowSpacing: 10
columnSpacing: 10
Label {
id: contextLengthLabel
visible: !root.currentModelInfo.isChatGPT
text: qsTr("Context Length:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
Layout.row: 0
Layout.column: 0
}
MyTextField {
id: contextLengthField
visible: !root.currentModelInfo.isChatGPT
text: root.currentModelInfo.contextLength
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 1
validator: IntValidator {
bottom: 1
}
Connections {
target: MySettings
function onContextLengthChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
onEditingFinished: {
var val = parseInt(text)
if (!isNaN(val)) {
MySettings.setModelContextLength(root.currentModelInfo, val)
focus = false
} else {
text = root.currentModelInfo.contextLength
}
}
Accessible.role: Accessible.EditableText
Accessible.name: contextLengthLabel.text
Accessible.description: ToolTip.text
}
Label {
id: tempLabel
text: qsTr("Temperature:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 0
Layout.column: 0
Layout.row: 1
Layout.column: 2
}
MyTextField {
@ -365,8 +413,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 1
Layout.row: 1
Layout.column: 3
validator: DoubleValidator {
locale: "C"
}
@ -400,8 +448,8 @@ MySettingsTab {
text: qsTr("Top P:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 0
Layout.column: 2
Layout.row: 2
Layout.column: 0
}
MyTextField {
id: topPField
@ -410,8 +458,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 3
Layout.row: 2
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
@ -446,8 +494,8 @@ MySettingsTab {
text: qsTr("Top K:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 1
Layout.column: 0
Layout.row: 2
Layout.column: 2
}
MyTextField {
id: topKField
@ -457,8 +505,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
ToolTip.visible: hovered
Layout.row: 1
Layout.column: 1
Layout.row: 2
Layout.column: 3
validator: IntValidator {
bottom: 1
}
@ -493,7 +541,7 @@ MySettingsTab {
text: qsTr("Max Length:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 1
Layout.row: 0
Layout.column: 2
}
MyTextField {
@ -504,7 +552,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum length of response in tokens")
ToolTip.visible: hovered
Layout.row: 1
Layout.row: 0
Layout.column: 3
validator: IntValidator {
bottom: 1
@ -541,7 +589,7 @@ MySettingsTab {
text: qsTr("Prompt Batch Size:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
Layout.row: 2
Layout.row: 1
Layout.column: 0
}
MyTextField {
@ -552,7 +600,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
ToolTip.visible: hovered
Layout.row: 2
Layout.row: 1
Layout.column: 1
validator: IntValidator {
bottom: 1
@ -588,8 +636,8 @@ MySettingsTab {
text: qsTr("Repeat Penalty:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 2
Layout.column: 2
Layout.row: 3
Layout.column: 0
}
MyTextField {
id: repeatPenaltyField
@ -599,8 +647,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
ToolTip.visible: hovered
Layout.row: 2
Layout.column: 3
Layout.row: 3
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
@ -636,7 +684,7 @@ MySettingsTab {
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 3
Layout.column: 0
Layout.column: 2
}
MyTextField {
id: repeatPenaltyTokenField
@ -647,7 +695,7 @@ MySettingsTab {
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
ToolTip.visible: hovered
Layout.row: 3
Layout.column: 1
Layout.column: 3
validator: IntValidator {
bottom: 1
}