mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
Implement configurable context length (#1749)
This commit is contained in:
parent
7aa0f779de
commit
d1c56b8b28
@ -714,8 +714,9 @@ Bert::~Bert() {
|
||||
bert_free(d_ptr->ctx);
|
||||
}
|
||||
|
||||
bool Bert::loadModel(const std::string &modelPath)
|
||||
bool Bert::loadModel(const std::string &modelPath, int n_ctx)
|
||||
{
|
||||
(void)n_ctx;
|
||||
d_ptr->ctx = bert_load_from_file(modelPath.c_str());
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = d_ptr->ctx != nullptr;
|
||||
@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
size_t Bert::requiredMem(const std::string &/*modelPath*/)
|
||||
size_t Bert::requiredMem(const std::string &modelPath, int n_ctx)
|
||||
{
|
||||
(void)modelPath;
|
||||
(void)n_ctx;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -18,9 +18,9 @@ public:
|
||||
|
||||
bool supportsEmbedding() const override { return true; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -676,7 +676,8 @@ GPTJ::GPTJ()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
size_t GPTJ::requiredMem(const std::string &modelPath) {
|
||||
size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
|
||||
(void)n_ctx;
|
||||
gptj_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
bool GPTJ::loadModel(const std::string &modelPath) {
|
||||
bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
|
||||
(void)n_ctx;
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
|
||||
|
@ -17,9 +17,9 @@ public:
|
||||
|
||||
bool supportsEmbedding() const override { return false; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -120,7 +120,8 @@ struct llama_file_hparams {
|
||||
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
|
||||
};
|
||||
|
||||
size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
||||
size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
|
||||
// TODO(cebtenzzre): update to GGUF
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
fin.seekg(0, std::ios_base::end);
|
||||
size_t filesize = fin.tellg();
|
||||
@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
|
||||
const size_t n_ctx = 2048;
|
||||
const size_t kvcache_element_size = 2; // fp16
|
||||
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
|
||||
return filesize + est_kvcache_size;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
|
||||
{
|
||||
gpt_params params;
|
||||
|
||||
// load the model
|
||||
if (n_ctx < 8) {
|
||||
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
|
||||
n_ctx = 8;
|
||||
}
|
||||
|
||||
// -- load the model --
|
||||
|
||||
d_ptr->model_params = llama_model_default_params();
|
||||
|
||||
d_ptr->model_params.use_mmap = params.use_mmap;
|
||||
d_ptr->model_params.use_mmap = params.use_mmap;
|
||||
#if defined (__APPLE__)
|
||||
d_ptr->model_params.use_mlock = true;
|
||||
d_ptr->model_params.use_mlock = true;
|
||||
#else
|
||||
d_ptr->model_params.use_mlock = params.use_mlock;
|
||||
d_ptr->model_params.use_mlock = params.use_mlock;
|
||||
#endif
|
||||
|
||||
d_ptr->ctx_params = llama_context_default_params();
|
||||
|
||||
d_ptr->ctx_params.n_ctx = 2048;
|
||||
d_ptr->ctx_params.seed = params.seed;
|
||||
d_ptr->ctx_params.f16_kv = params.memory_f16;
|
||||
|
||||
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
|
||||
// that we want this many logits so the state serializes consistently.
|
||||
d_ptr->ctx_params.logits_all = true;
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (llama_verbose()) {
|
||||
std::cerr << "llama.cpp: using Metal" << std::endl;
|
||||
@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||
if (n_ctx > n_ctx_train) {
|
||||
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
||||
<< n_ctx << " specified)\n";
|
||||
}
|
||||
|
||||
// -- initialize the context --
|
||||
|
||||
d_ptr->ctx_params = llama_context_default_params();
|
||||
|
||||
d_ptr->ctx_params.n_ctx = n_ctx;
|
||||
d_ptr->ctx_params.seed = params.seed;
|
||||
d_ptr->ctx_params.f16_kv = params.memory_f16;
|
||||
|
||||
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
|
||||
// that we want this many logits so the state serializes consistently.
|
||||
d_ptr->ctx_params.logits_all = true;
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||
|
||||
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
||||
if (!d_ptr->ctx) {
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
|
@ -17,9 +17,9 @@ public:
|
||||
|
||||
bool supportsEmbedding() const override { return false; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) {
|
||||
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) {
|
||||
if (!has_at_least_minimal_hardware()) {
|
||||
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
|
||||
return nullptr;
|
||||
@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
|
||||
if(impl) {
|
||||
LLModel* metalimpl = impl->m_construct();
|
||||
metalimpl->m_implementation = impl;
|
||||
size_t req_mem = metalimpl->requiredMem(modelPath);
|
||||
/* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at
|
||||
* load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in
|
||||
* most (all?) places where this is called, causing underestimation of required
|
||||
* memory. */
|
||||
size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx);
|
||||
float req_to_total = (float) req_mem / (float) total_mem;
|
||||
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
|
||||
if (req_to_total >= 0.53) {
|
||||
@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
(void)n_ctx;
|
||||
#endif
|
||||
|
||||
if (!impl) {
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
static bool isImplementation(const Dlhandle&);
|
||||
static const std::vector<Implementation>& implementationList();
|
||||
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
|
||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto");
|
||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
|
||||
static std::vector<GPUDevice> availableGPUDevices();
|
||||
static void setImplementationsSearchPath(const std::string& path);
|
||||
static const std::string& implementationsSearchPath();
|
||||
@ -74,9 +74,9 @@ public:
|
||||
|
||||
virtual bool supportsEmbedding() const = 0;
|
||||
virtual bool supportsCompletion() const = 0;
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath) = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||
|
@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) {
|
||||
delete reinterpret_cast<LLModelWrapper*>(model);
|
||||
}
|
||||
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path)
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->requiredMem(model_path);
|
||||
return wrapper->llModel->requiredMem(model_path, n_ctx);
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->loadModel(model_path);
|
||||
return wrapper->llModel->loadModel(model_path, n_ctx);
|
||||
}
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
|
@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model);
|
||||
* Estimate RAM requirement for a model file
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @param n_ctx Maximum size of context window
|
||||
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
|
||||
*/
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx);
|
||||
|
||||
/**
|
||||
* Load a model from a file.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @param n_ctx Maximum size of context window
|
||||
* @return true if the model was loaded successfully, false otherwise.
|
||||
*/
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path);
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx);
|
||||
|
||||
/**
|
||||
* Check if a model is loaded.
|
||||
|
@ -188,7 +188,7 @@ public class LLModel : ILLModel
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool Load(string modelPath)
|
||||
{
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
|
||||
}
|
||||
|
||||
protected void Destroy()
|
||||
|
@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_loadModel(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
|
||||
[NativeTypeName("int32_t")] int n_ctx);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
|
||||
|
@ -39,7 +39,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
||||
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
|
||||
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
||||
_logger.LogInformation("Model loading started");
|
||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
|
||||
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
||||
if (!loadedSuccessfully)
|
||||
{
|
||||
|
@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
|
||||
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
|
||||
return nullptr;
|
||||
}
|
||||
if (!llmodel_loadModel(model, fname)) {
|
||||
if (!llmodel_loadModel(model, fname, 2048)) {
|
||||
llmodel_model_destroy(model);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ public class LLModel implements AutoCloseable {
|
||||
if(model == null) {
|
||||
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
|
||||
}
|
||||
library.llmodel_loadModel(model, modelPathAbs);
|
||||
library.llmodel_loadModel(model, modelPathAbs, 2048);
|
||||
|
||||
if(!library.llmodel_isModelLoaded(model)){
|
||||
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
||||
|
@ -61,7 +61,7 @@ public interface LLModelLibrary {
|
||||
|
||||
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
|
||||
void llmodel_model_destroy(Pointer model);
|
||||
boolean llmodel_loadModel(Pointer model, String model_path);
|
||||
boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
|
||||
boolean llmodel_isModelLoaded(Pointer model);
|
||||
@u_int64_t long llmodel_get_state_size(Pointer model);
|
||||
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
||||
|
@ -1,2 +1,2 @@
|
||||
from .gpt4all import Embed4All, GPT4All # noqa
|
||||
from .pyllmodel import LLModel # noqa
|
||||
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
|
||||
from .pyllmodel import LLModel as LLModel
|
||||
|
@ -69,6 +69,7 @@ class GPT4All:
|
||||
allow_download: bool = True,
|
||||
n_threads: Optional[int] = None,
|
||||
device: Optional[str] = "cpu",
|
||||
n_ctx: int = 2048,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -90,15 +91,16 @@ class GPT4All:
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.model = pyllmodel.LLModel()
|
||||
# Retrieve model and download if allowed
|
||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||
if device is not None:
|
||||
if device != "cpu":
|
||||
self.model.init_gpu(model_path=self.config["path"], device=device)
|
||||
self.model.load_model(self.config["path"])
|
||||
if device is not None and device != "cpu":
|
||||
self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
|
||||
self.model.load_model(self.config["path"], n_ctx)
|
||||
# Set n_threads
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import importlib.resources
|
||||
import logging
|
||||
@ -7,6 +9,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
@ -72,9 +75,9 @@ llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_model_destroy.restype = None
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
@ -114,7 +117,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8"))
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
|
||||
|
||||
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
|
||||
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
|
||||
@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p:
|
||||
err = ctypes.c_char_p()
|
||||
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
|
||||
if model is None:
|
||||
raise ValueError(f"Unable to instantiate model: {err.decode()}")
|
||||
s = err.value
|
||||
raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||
return model
|
||||
|
||||
|
||||
# Symbol to terminate from generator
|
||||
class Sentinel(Enum):
|
||||
TERMINATING_SYMBOL = 0
|
||||
|
||||
|
||||
class LLModel:
|
||||
"""
|
||||
Base class and universal wrapper for GPT4All language models
|
||||
@ -173,12 +182,16 @@ class LLModel:
|
||||
if self.model is not None:
|
||||
self.llmodel_lib.llmodel_model_destroy(self.model)
|
||||
|
||||
def memory_needed(self, model_path: str) -> int:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
self.model = _create_model(model_path_enc)
|
||||
return llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
def memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||
self.model = None
|
||||
return self._memory_needed(model_path, n_ctx)
|
||||
|
||||
def list_gpu(self, model_path: str) -> list:
|
||||
def _memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||
if self.model is None:
|
||||
self.model = _create_model(model_path.encode())
|
||||
return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx)
|
||||
|
||||
def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]:
|
||||
"""
|
||||
Lists available GPU devices that satisfy the model's memory requirements.
|
||||
|
||||
@ -186,45 +199,41 @@ class LLModel:
|
||||
----------
|
||||
model_path : str
|
||||
Path to the model.
|
||||
n_ctx : int
|
||||
Maximum size of context window
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of LLModelGPUDevice structures representing available GPU devices.
|
||||
"""
|
||||
if self.model is not None:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
else:
|
||||
mem_required = self.memory_needed(model_path)
|
||||
mem_required = self._memory_needed(model_path, n_ctx)
|
||||
return self._list_gpu(mem_required)
|
||||
|
||||
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
||||
num_devices = ctypes.c_int32(0)
|
||||
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
||||
if not devices_ptr:
|
||||
raise ValueError("Unable to retrieve available GPU devices")
|
||||
devices = [devices_ptr[i] for i in range(num_devices.value)]
|
||||
return devices
|
||||
return devices_ptr[:num_devices.value]
|
||||
|
||||
def init_gpu(self, model_path: str, device: str):
|
||||
if self.model is not None:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
else:
|
||||
mem_required = self.memory_needed(model_path)
|
||||
device_enc = device.encode("utf-8")
|
||||
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc)
|
||||
def init_gpu(self, model_path: str, device: str, n_ctx: int):
|
||||
mem_required = self._memory_needed(model_path, n_ctx)
|
||||
|
||||
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode())
|
||||
if not success:
|
||||
# Retrieve all GPUs without considering memory requirements.
|
||||
num_devices = ctypes.c_int32(0)
|
||||
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
|
||||
if not all_devices_ptr:
|
||||
raise ValueError("Unable to retrieve list of all GPU devices")
|
||||
all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)]
|
||||
all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]]
|
||||
|
||||
# Retrieve GPUs that meet the memory requirements using list_gpu
|
||||
available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)]
|
||||
available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)]
|
||||
|
||||
# Identify GPUs that are unavailable due to insufficient memory or features
|
||||
unavailable_gpus = set(all_gpus) - set(available_gpus)
|
||||
unavailable_gpus = set(all_gpus).difference(available_gpus)
|
||||
|
||||
# Formulate the error message
|
||||
error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
|
||||
@ -232,7 +241,7 @@ class LLModel:
|
||||
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def load_model(self, model_path: str) -> bool:
|
||||
def load_model(self, model_path: str, n_ctx: int) -> bool:
|
||||
"""
|
||||
Load model from a file.
|
||||
|
||||
@ -240,15 +249,16 @@ class LLModel:
|
||||
----------
|
||||
model_path : str
|
||||
Model filepath
|
||||
n_ctx : int
|
||||
Maximum size of context window
|
||||
|
||||
Returns
|
||||
-------
|
||||
True if model loaded successfully, False otherwise
|
||||
"""
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
self.model = _create_model(model_path_enc)
|
||||
self.model = _create_model(model_path.encode())
|
||||
|
||||
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
||||
llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx)
|
||||
|
||||
filename = os.path.basename(model_path)
|
||||
self.model_name = os.path.splitext(filename)[0]
|
||||
@ -312,7 +322,7 @@ class LLModel:
|
||||
raise ValueError("Text must not be None or empty")
|
||||
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
c_text = ctypes.c_char_p(text.encode())
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
|
||||
llmodel.llmodel_free_embedding(embedding_ptr)
|
||||
@ -357,7 +367,7 @@ class LLModel:
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode("utf-8")
|
||||
prompt_bytes = prompt.encode()
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
@ -385,10 +395,7 @@ class LLModel:
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue: Queue = Queue()
|
||||
output_queue: Queue[str | Sentinel] = Queue()
|
||||
|
||||
# Put response tokens into an output queue
|
||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||
@ -405,7 +412,7 @@ class LLModel:
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
output_queue.put(TERMINATING_SYMBOL)
|
||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
@ -419,7 +426,7 @@ class LLModel:
|
||||
# Generator
|
||||
while True:
|
||||
response = output_queue.get()
|
||||
if response is TERMINATING_SYMBOL:
|
||||
if isinstance(response, Sentinel):
|
||||
break
|
||||
yield response
|
||||
|
||||
@ -442,7 +449,7 @@ class LLModel:
|
||||
else:
|
||||
# beginning of a byte sequence
|
||||
if len(self.buffer) > 0:
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
decoded.append(self.buffer.decode(errors='replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
|
||||
@ -451,7 +458,7 @@ class LLModel:
|
||||
|
||||
if self.buff_expecting_cont_bytes <= 0:
|
||||
# received the whole sequence or an out of place continuation byte
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
decoded.append(self.buffer.decode(errors='replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
@ -117,7 +117,7 @@ def test_empty_embedding():
|
||||
def test_download_model(tmp_path: Path):
|
||||
import gpt4all.gpt4all
|
||||
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
|
||||
try:
|
||||
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
||||
model_path = tmp_path / model.config['filename']
|
||||
|
@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
||||
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
{
|
||||
auto env = info.Env();
|
||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str()) ));
|
||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) ));
|
||||
|
||||
}
|
||||
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
||||
@ -161,7 +161,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
}
|
||||
}
|
||||
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str());
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048);
|
||||
if(!success) {
|
||||
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
||||
return;
|
||||
|
@ -20,15 +20,17 @@ ChatGPT::ChatGPT()
|
||||
{
|
||||
}
|
||||
|
||||
size_t ChatGPT::requiredMem(const std::string &modelPath)
|
||||
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx)
|
||||
{
|
||||
Q_UNUSED(modelPath);
|
||||
Q_UNUSED(n_ctx);
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool ChatGPT::loadModel(const std::string &modelPath)
|
||||
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx)
|
||||
{
|
||||
Q_UNUSED(modelPath);
|
||||
Q_UNUSED(n_ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -48,9 +48,9 @@ public:
|
||||
|
||||
bool supportsEmbedding() const override { return false; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <QDataStream>
|
||||
|
||||
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
||||
#define CHAT_FORMAT_VERSION 6
|
||||
#define CHAT_FORMAT_VERSION 7
|
||||
|
||||
class MyChatListModel: public ChatListModel { };
|
||||
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
|
||||
|
@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
m_llModelInfo.model = model;
|
||||
} else {
|
||||
|
||||
// TODO: make configurable in UI
|
||||
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
||||
m_ctx.n_ctx = n_ctx;
|
||||
|
||||
std::string buildVariant = "auto";
|
||||
#if defined(Q_OS_MAC) && defined(__arm__)
|
||||
if (m_forceMetal)
|
||||
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal");
|
||||
else
|
||||
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto");
|
||||
#else
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
|
||||
buildVariant = "metal";
|
||||
#endif
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
// Update the settings that a model is being loaded and update the device list
|
||||
@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
if (requestedDevice == "CPU") {
|
||||
emit reportFallbackReason(""); // fallback not applicable
|
||||
} else {
|
||||
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
|
||||
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx);
|
||||
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
||||
LLModel::GPUDevice *device = nullptr;
|
||||
|
||||
@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
// Report which device we're actually using
|
||||
emit reportDevice(actualDevice);
|
||||
|
||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
|
||||
if (actualDevice == "CPU") {
|
||||
// we asked llama.cpp to use the CPU
|
||||
} else if (!success) {
|
||||
// llama_init_from_file returned nullptr
|
||||
emit reportDevice("CPU");
|
||||
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
|
||||
success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
||||
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
|
||||
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
||||
// ggml_vk_init was not called in llama.cpp
|
||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
||||
@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
return false;
|
||||
}
|
||||
|
||||
// this function serialized the cached model state to disk.
|
||||
// we want to also serialize n_ctx, and read it at load time.
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
{
|
||||
if (version > 1) {
|
||||
@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
stream << responseLogits;
|
||||
}
|
||||
stream << m_ctx.n_past;
|
||||
if (version >= 6) {
|
||||
stream << m_ctx.n_ctx;
|
||||
}
|
||||
stream << quint64(m_ctx.logits.size());
|
||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
|
||||
stream << quint64(m_ctx.tokens.size());
|
||||
@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
stream >> n_past;
|
||||
if (!discardKV) m_ctx.n_past = n_past;
|
||||
|
||||
if (version >= 6) {
|
||||
uint32_t n_ctx;
|
||||
stream >> n_ctx;
|
||||
if (!discardKV) m_ctx.n_ctx = n_ctx;
|
||||
}
|
||||
|
||||
quint64 logitsSize;
|
||||
stream >> logitsSize;
|
||||
if (!discardKV) {
|
||||
|
@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel()
|
||||
return false;
|
||||
}
|
||||
|
||||
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
|
||||
bool success = m_model->loadModel(filePath.toStdString());
|
||||
m_model = LLModel::Implementation::construct(filePath.toStdString());
|
||||
bool success = m_model->loadModel(filePath.toStdString(), 2048);
|
||||
if (!success) {
|
||||
qWarning() << "WARNING: Could not load sbert";
|
||||
delete m_model;
|
||||
|
@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s)
|
||||
m_promptBatchSize = s;
|
||||
}
|
||||
|
||||
int ModelInfo::contextLength() const
|
||||
{
|
||||
return MySettings::globalInstance()->modelContextLength(*this);
|
||||
}
|
||||
|
||||
void ModelInfo::setContextLength(int l)
|
||||
{
|
||||
if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/);
|
||||
m_contextLength = l;
|
||||
}
|
||||
|
||||
double ModelInfo::repeatPenalty() const
|
||||
{
|
||||
return MySettings::globalInstance()->modelRepeatPenalty(*this);
|
||||
@ -274,6 +285,7 @@ ModelList::ModelList()
|
||||
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
|
||||
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
|
||||
@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
|
||||
return info->maxLength();
|
||||
case PromptBatchSizeRole:
|
||||
return info->promptBatchSize();
|
||||
case ContextLengthRole:
|
||||
return info->contextLength();
|
||||
case RepeatPenaltyRole:
|
||||
return info->repeatPenalty();
|
||||
case RepeatPenaltyTokensRole:
|
||||
@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model)
|
||||
updateData(id, ModelList::TopKRole, model.topK());
|
||||
updateData(id, ModelList::MaxLengthRole, model.maxLength());
|
||||
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
|
||||
updateData(id, ModelList::ContextLengthRole, model.contextLength());
|
||||
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
|
||||
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
|
||||
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
|
||||
@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
||||
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
|
||||
if (obj.contains("promptBatchSize"))
|
||||
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
|
||||
if (obj.contains("contextLength"))
|
||||
updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt());
|
||||
if (obj.contains("repeatPenalty"))
|
||||
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
|
||||
if (obj.contains("repeatPenaltyTokens"))
|
||||
@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings()
|
||||
const int maxLength = settings.value(g + "/maxLength").toInt();
|
||||
Q_ASSERT(settings.contains(g + "/promptBatchSize"));
|
||||
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
|
||||
Q_ASSERT(settings.contains(g + "/contextLength"));
|
||||
const int contextLength = settings.value(g + "/contextLength").toInt();
|
||||
Q_ASSERT(settings.contains(g + "/repeatPenalty"));
|
||||
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
|
||||
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
|
||||
@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings()
|
||||
updateData(id, ModelList::TopKRole, topK);
|
||||
updateData(id, ModelList::MaxLengthRole, maxLength);
|
||||
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
|
||||
updateData(id, ModelList::ContextLengthRole, contextLength);
|
||||
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
|
||||
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
|
||||
updateData(id, ModelList::PromptTemplateRole, promptTemplate);
|
||||
|
@ -39,6 +39,7 @@ struct ModelInfo {
|
||||
Q_PROPERTY(int topK READ topK WRITE setTopK)
|
||||
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
|
||||
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
|
||||
Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength)
|
||||
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
|
||||
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
|
||||
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
|
||||
@ -94,6 +95,8 @@ public:
|
||||
void setMaxLength(int l);
|
||||
int promptBatchSize() const;
|
||||
void setPromptBatchSize(int s);
|
||||
int contextLength() const;
|
||||
void setContextLength(int l);
|
||||
double repeatPenalty() const;
|
||||
void setRepeatPenalty(double p);
|
||||
int repeatPenaltyTokens() const;
|
||||
@ -112,6 +115,7 @@ private:
|
||||
int m_topK = 40;
|
||||
int m_maxLength = 4096;
|
||||
int m_promptBatchSize = 128;
|
||||
int m_contextLength = 2048;
|
||||
double m_repeatPenalty = 1.18;
|
||||
int m_repeatPenaltyTokens = 64;
|
||||
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
|
||||
@ -227,6 +231,7 @@ public:
|
||||
TopKRole,
|
||||
MaxLengthRole,
|
||||
PromptBatchSizeRole,
|
||||
ContextLengthRole,
|
||||
RepeatPenaltyRole,
|
||||
RepeatPenaltyTokensRole,
|
||||
PromptTemplateRole,
|
||||
@ -269,6 +274,7 @@ public:
|
||||
roles[TopKRole] = "topK";
|
||||
roles[MaxLengthRole] = "maxLength";
|
||||
roles[PromptBatchSizeRole] = "promptBatchSize";
|
||||
roles[ContextLengthRole] = "contextLength";
|
||||
roles[RepeatPenaltyRole] = "repeatPenalty";
|
||||
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
|
||||
roles[PromptTemplateRole] = "promptTemplate";
|
||||
|
@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
|
||||
setModelTopK(model, model.m_topK);;
|
||||
setModelMaxLength(model, model.m_maxLength);
|
||||
setModelPromptBatchSize(model, model.m_promptBatchSize);
|
||||
setModelContextLength(model, model.m_contextLength);
|
||||
setModelRepeatPenalty(model, model.m_repeatPenalty);
|
||||
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
|
||||
setModelPromptTemplate(model, model.m_promptTemplate);
|
||||
@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force)
|
||||
emit promptBatchSizeChanged(m);
|
||||
}
|
||||
|
||||
int MySettings::modelContextLength(const ModelInfo &m) const
|
||||
{
|
||||
QSettings setting;
|
||||
setting.sync();
|
||||
return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt();
|
||||
}
|
||||
|
||||
void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force)
|
||||
{
|
||||
if (modelContextLength(m) == l && !force)
|
||||
return;
|
||||
|
||||
QSettings setting;
|
||||
if (m.m_contextLength == l && !m.isClone)
|
||||
setting.remove(QString("model-%1").arg(m.id()) + "/contextLength");
|
||||
else
|
||||
setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l);
|
||||
setting.sync();
|
||||
if (!force)
|
||||
emit contextLengthChanged(m);
|
||||
}
|
||||
|
||||
double MySettings::modelRepeatPenalty(const ModelInfo &m) const
|
||||
{
|
||||
QSettings setting;
|
||||
|
@ -1,6 +1,8 @@
|
||||
#ifndef MYSETTINGS_H
|
||||
#define MYSETTINGS_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <QObject>
|
||||
#include <QMutex>
|
||||
|
||||
@ -59,6 +61,8 @@ public:
|
||||
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
|
||||
QString modelSystemPrompt(const ModelInfo &m) const;
|
||||
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
|
||||
int modelContextLength(const ModelInfo &m) const;
|
||||
Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false);
|
||||
|
||||
// Application settings
|
||||
int threadCount() const;
|
||||
@ -79,6 +83,8 @@ public:
|
||||
void setForceMetal(bool b);
|
||||
QString device() const;
|
||||
void setDevice(const QString &u);
|
||||
int32_t contextLength() const;
|
||||
void setContextLength(int32_t value);
|
||||
|
||||
// Release/Download settings
|
||||
QString lastVersionStarted() const;
|
||||
@ -114,6 +120,7 @@ Q_SIGNALS:
|
||||
void topKChanged(const ModelInfo &model);
|
||||
void maxLengthChanged(const ModelInfo &model);
|
||||
void promptBatchSizeChanged(const ModelInfo &model);
|
||||
void contextLengthChanged(const ModelInfo &model);
|
||||
void repeatPenaltyChanged(const ModelInfo &model);
|
||||
void repeatPenaltyTokensChanged(const ModelInfo &model);
|
||||
void promptTemplateChanged(const ModelInfo &model);
|
||||
|
@ -349,13 +349,61 @@ MySettingsTab {
|
||||
rowSpacing: 10
|
||||
columnSpacing: 10
|
||||
|
||||
Label {
|
||||
id: contextLengthLabel
|
||||
visible: !root.currentModelInfo.isChatGPT
|
||||
text: qsTr("Context Length:")
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
color: theme.textColor
|
||||
Layout.row: 0
|
||||
Layout.column: 0
|
||||
}
|
||||
MyTextField {
|
||||
id: contextLengthField
|
||||
visible: !root.currentModelInfo.isChatGPT
|
||||
text: root.currentModelInfo.contextLength
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 0
|
||||
Layout.column: 1
|
||||
validator: IntValidator {
|
||||
bottom: 1
|
||||
}
|
||||
Connections {
|
||||
target: MySettings
|
||||
function onContextLengthChanged() {
|
||||
contextLengthField.text = root.currentModelInfo.contextLength;
|
||||
}
|
||||
}
|
||||
Connections {
|
||||
target: root
|
||||
function onCurrentModelInfoChanged() {
|
||||
contextLengthField.text = root.currentModelInfo.contextLength;
|
||||
}
|
||||
}
|
||||
onEditingFinished: {
|
||||
var val = parseInt(text)
|
||||
if (!isNaN(val)) {
|
||||
MySettings.setModelContextLength(root.currentModelInfo, val)
|
||||
focus = false
|
||||
} else {
|
||||
text = root.currentModelInfo.contextLength
|
||||
}
|
||||
}
|
||||
Accessible.role: Accessible.EditableText
|
||||
Accessible.name: contextLengthLabel.text
|
||||
Accessible.description: ToolTip.text
|
||||
}
|
||||
|
||||
Label {
|
||||
id: tempLabel
|
||||
text: qsTr("Temperature:")
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 0
|
||||
Layout.column: 0
|
||||
Layout.row: 1
|
||||
Layout.column: 2
|
||||
}
|
||||
|
||||
MyTextField {
|
||||
@ -365,8 +413,8 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 0
|
||||
Layout.column: 1
|
||||
Layout.row: 1
|
||||
Layout.column: 3
|
||||
validator: DoubleValidator {
|
||||
locale: "C"
|
||||
}
|
||||
@ -400,8 +448,8 @@ MySettingsTab {
|
||||
text: qsTr("Top P:")
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 0
|
||||
Layout.column: 2
|
||||
Layout.row: 2
|
||||
Layout.column: 0
|
||||
}
|
||||
MyTextField {
|
||||
id: topPField
|
||||
@ -410,8 +458,8 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 0
|
||||
Layout.column: 3
|
||||
Layout.row: 2
|
||||
Layout.column: 1
|
||||
validator: DoubleValidator {
|
||||
locale: "C"
|
||||
}
|
||||
@ -446,8 +494,8 @@ MySettingsTab {
|
||||
text: qsTr("Top K:")
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 1
|
||||
Layout.column: 0
|
||||
Layout.row: 2
|
||||
Layout.column: 2
|
||||
}
|
||||
MyTextField {
|
||||
id: topKField
|
||||
@ -457,8 +505,8 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 1
|
||||
Layout.column: 1
|
||||
Layout.row: 2
|
||||
Layout.column: 3
|
||||
validator: IntValidator {
|
||||
bottom: 1
|
||||
}
|
||||
@ -493,7 +541,7 @@ MySettingsTab {
|
||||
text: qsTr("Max Length:")
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 1
|
||||
Layout.row: 0
|
||||
Layout.column: 2
|
||||
}
|
||||
MyTextField {
|
||||
@ -504,7 +552,7 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Maximum length of response in tokens")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 1
|
||||
Layout.row: 0
|
||||
Layout.column: 3
|
||||
validator: IntValidator {
|
||||
bottom: 1
|
||||
@ -541,7 +589,7 @@ MySettingsTab {
|
||||
text: qsTr("Prompt Batch Size:")
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
color: theme.textColor
|
||||
Layout.row: 2
|
||||
Layout.row: 1
|
||||
Layout.column: 0
|
||||
}
|
||||
MyTextField {
|
||||
@ -552,7 +600,7 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 2
|
||||
Layout.row: 1
|
||||
Layout.column: 1
|
||||
validator: IntValidator {
|
||||
bottom: 1
|
||||
@ -588,8 +636,8 @@ MySettingsTab {
|
||||
text: qsTr("Repeat Penalty:")
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 2
|
||||
Layout.column: 2
|
||||
Layout.row: 3
|
||||
Layout.column: 0
|
||||
}
|
||||
MyTextField {
|
||||
id: repeatPenaltyField
|
||||
@ -599,8 +647,8 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 2
|
||||
Layout.column: 3
|
||||
Layout.row: 3
|
||||
Layout.column: 1
|
||||
validator: DoubleValidator {
|
||||
locale: "C"
|
||||
}
|
||||
@ -636,7 +684,7 @@ MySettingsTab {
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
Layout.row: 3
|
||||
Layout.column: 0
|
||||
Layout.column: 2
|
||||
}
|
||||
MyTextField {
|
||||
id: repeatPenaltyTokenField
|
||||
@ -647,7 +695,7 @@ MySettingsTab {
|
||||
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 3
|
||||
Layout.column: 1
|
||||
Layout.column: 3
|
||||
validator: IntValidator {
|
||||
bottom: 1
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user