mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-08 07:10:32 +00:00
chat: faster KV shift, continue generating, fix stop sequences (#2781)
* Don't stop generating at end of context * Use llama_kv_cache ops to shift context * Fix and improve reverse prompt detection * Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
parent
90de2d32f8
commit
be66ec8ab5
@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
|
||||
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD 23)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
|
@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
|
||||
{
|
||||
bool atStart = m_tokenize_last_token == -1;
|
||||
bool insertSpace = atStart || (
|
||||
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
|
||||
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
|
||||
);
|
||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||
int32_t fres_len = llama_tokenize_gpt4all(
|
||||
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||
@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
|
||||
return fres;
|
||||
}
|
||||
|
||||
bool LLamaModel::isSpecialToken(Token id) const
|
||||
{
|
||||
return llama_token_get_attr(d_ptr->model, id)
|
||||
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
|
||||
}
|
||||
|
||||
std::string LLamaModel::tokenToString(Token id) const
|
||||
{
|
||||
std::vector<char> result(8, 0);
|
||||
@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
|
||||
return res == 0;
|
||||
}
|
||||
|
||||
void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
{
|
||||
// infinite text generation via context shifting
|
||||
|
||||
// erase up to n_ctx*contextErase tokens
|
||||
int n_keep = shouldAddBOS();
|
||||
int n_past = promptCtx.n_past;
|
||||
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
|
||||
|
||||
assert(n_discard > 0);
|
||||
if (n_discard <= 0)
|
||||
return;
|
||||
|
||||
std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
|
||||
<< ", n_discard = " << n_discard << "\n";
|
||||
|
||||
// erase the first n_discard tokens from the context
|
||||
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
}
|
||||
|
||||
int32_t LLamaModel::contextLength() const
|
||||
{
|
||||
return llama_n_ctx(d_ptr->ctx);
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
#include "llmodel.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -54,9 +53,11 @@ private:
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
|
||||
bool isSpecialToken(Token id) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
void shiftContext(PromptContext &promptCtx) override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override;
|
||||
|
@ -134,7 +134,7 @@ public:
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
|
||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||
};
|
||||
|
||||
using ProgressCallback = std::function<bool(float progress)>;
|
||||
@ -159,7 +159,7 @@ public:
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
@ -213,9 +213,11 @@ protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
|
||||
virtual bool isSpecialToken(Token id) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||
virtual void shiftContext(PromptContext &promptCtx) = 0;
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
@ -232,10 +234,6 @@ protected:
|
||||
return -1;
|
||||
}
|
||||
|
||||
// This is a helper function called from the default implementation of 'prompt' but it can be
|
||||
// shared by all base classes so it isn't virtual
|
||||
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
|
||||
|
||||
const Implementation *m_implementation = nullptr;
|
||||
|
||||
ProgressCallback m_progressCallback;
|
||||
@ -249,11 +247,11 @@ protected:
|
||||
|
||||
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
Token m_tokenize_last_token = -1; // not serialized
|
||||
|
@ -106,7 +106,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
bool allow_context_shift,
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special,
|
||||
const char *fake_reply)
|
||||
@ -135,7 +135,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
||||
wrapper->promptContext, special, fake_reply_p);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
|
@ -74,13 +74,6 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||
*/
|
||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
||||
|
||||
/**
|
||||
* Callback type for recalculation of context.
|
||||
* @param whether the model is recalculating the context.
|
||||
* @return a bool indicating whether the model should keep generating.
|
||||
*/
|
||||
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
||||
|
||||
/**
|
||||
* Embedding cancellation callback for use with llmodel_embed.
|
||||
* @param batch_sizes The number of tokens in each batch that will be embedded.
|
||||
@ -175,7 +168,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
* @param prompt_template A string representing the input prompt template.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
||||
* @param special True if special tokens in the prompt should be processed, false otherwise.
|
||||
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
@ -184,7 +177,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
bool allow_context_shift,
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special,
|
||||
const char *fake_reply);
|
||||
|
@ -11,42 +11,9 @@
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
||||
// FIXME(jared): if recalculate returns false, we leave n_past<tokens.size() and do not tell the caller to stop
|
||||
// FIXME(jared): if we get here during chat name or follow-up generation, bad things will happen when we try to restore
|
||||
// the old prompt context afterwards
|
||||
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
{
|
||||
int n_keep = shouldAddBOS();
|
||||
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
|
||||
|
||||
// Erase the first percentage of context from the tokens
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
|
||||
size_t i = n_keep;
|
||||
promptCtx.n_past = n_keep;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
if (!evalTokens(promptCtx, batch)) {
|
||||
std::cerr << "LLModel ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
if (!recalculate(true))
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
namespace ranges = std::ranges;
|
||||
|
||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
|
||||
{
|
||||
@ -75,7 +42,7 @@ void LLModel::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply)
|
||||
@ -92,12 +59,21 @@ void LLModel::prompt(const std::string &prompt,
|
||||
return;
|
||||
}
|
||||
|
||||
// make sure token cache matches decode offset
|
||||
if (promptCtx.tokens.size() < promptCtx.n_past) {
|
||||
// sanity checks
|
||||
if (promptCtx.n_past > contextLength()) {
|
||||
std::ostringstream ss;
|
||||
ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past;
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
if (promptCtx.n_past > promptCtx.tokens.size()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
|
||||
promptCtx.n_ctx = contextLength();
|
||||
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||
|
||||
if (promptCtx.n_past < promptCtx.tokens.size())
|
||||
promptCtx.tokens.resize(promptCtx.n_past);
|
||||
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
||||
@ -149,15 +125,15 @@ void LLModel::prompt(const std::string &prompt,
|
||||
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||
|
||||
// decode the user prompt
|
||||
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||
return; // error
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (fakeReply == nullptr) {
|
||||
generateResponse(responseCallback, recalculateCallback, promptCtx);
|
||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||
} else {
|
||||
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
||||
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||
return; // error
|
||||
}
|
||||
|
||||
@ -172,19 +148,16 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
if (!asstSuffix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
|
||||
}
|
||||
}
|
||||
|
||||
// returns false on error
|
||||
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp) {
|
||||
// save the context size
|
||||
promptCtx.n_ctx = contextLength();
|
||||
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
||||
@ -192,9 +165,14 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
return false;
|
||||
}
|
||||
|
||||
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||
// FIXME(jared): There are mitigations for this situation, such as making room before
|
||||
// copying the prompt context, or restoring the KV cache when we restore the prompt
|
||||
// context.
|
||||
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
|
||||
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
@ -204,7 +182,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
@ -226,70 +205,170 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* If string s overlaps with the string key such that some prefix of the key is at the end
|
||||
* of the string, return the position in s where the first match starts. Otherwise, return
|
||||
* std::string::npos. Examples:
|
||||
* s = "bfo", key = "foo" -> 1
|
||||
* s = "fooa", key = "foo" -> npos
|
||||
*/
|
||||
static std::string::size_type stringsOverlap(const std::string &s, const std::string &key)
|
||||
{
|
||||
if (s.empty() || key.empty())
|
||||
throw std::invalid_argument("arguments to stringsOverlap must not be empty");
|
||||
|
||||
for (int start = std::max(0, int(s.size()) - int(key.size())); start < s.size(); start++) {
|
||||
if (s.compare(start, s.size(), key, 0, s.size() - start) == 0)
|
||||
return start;
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx) {
|
||||
static const char *stopSequences[] {
|
||||
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
|
||||
};
|
||||
|
||||
// Don't even start if there is no room
|
||||
if (!promptCtx.n_predict)
|
||||
return;
|
||||
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
|
||||
<< "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string cachedResponse;
|
||||
std::vector<Token> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
||||
int n_predicted = 0;
|
||||
|
||||
// predict next tokens
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
// Predict next tokens
|
||||
for (bool stop = false; !stop;) {
|
||||
// Sample next token
|
||||
std::optional<Token> new_tok = sampleToken(promptCtx);
|
||||
std::string new_piece = tokenToString(new_tok.value());
|
||||
cachedTokens.push_back(new_tok.value());
|
||||
cachedResponse += new_piece;
|
||||
|
||||
// sample next token
|
||||
auto id = sampleToken(promptCtx);
|
||||
auto accept = [this, &promptCtx, &cachedTokens, &new_tok, allowContextShift]() -> bool {
|
||||
// Shift context if out of space
|
||||
if (promptCtx.n_past >= promptCtx.n_ctx) {
|
||||
(void)allowContextShift;
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past < promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
// Accept the token
|
||||
Token tok = std::exchange(new_tok, std::nullopt).value();
|
||||
if (!evalTokens(promptCtx, { tok })) {
|
||||
// TODO(jared): raise an exception
|
||||
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!evalTokens(promptCtx, { id })) {
|
||||
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
promptCtx.tokens.push_back(tok);
|
||||
promptCtx.n_past += 1;
|
||||
return true;
|
||||
};
|
||||
|
||||
// display text
|
||||
// Check for EOS
|
||||
auto lengthLimit = std::string::npos;
|
||||
for (const auto token : endTokens()) {
|
||||
if (id == token) return;
|
||||
}
|
||||
|
||||
const std::string str = tokenToString(id);
|
||||
|
||||
// Check if the provided str is part of our reverse prompts
|
||||
bool foundPartialReversePrompt = false;
|
||||
const std::string completed = cachedResponse + std::string(str);
|
||||
if (reversePrompts.find(completed) != reversePrompts.end())
|
||||
return;
|
||||
|
||||
// Check if it partially matches our reverse prompts and if so, cache
|
||||
for (const auto& s : reversePrompts) {
|
||||
if (s.compare(0, completed.size(), completed) == 0) {
|
||||
foundPartialReversePrompt = true;
|
||||
cachedResponse = completed;
|
||||
break;
|
||||
if (new_tok == token) {
|
||||
stop = true;
|
||||
lengthLimit = cachedResponse.size() - new_piece.size();
|
||||
}
|
||||
}
|
||||
|
||||
// Regardless the token gets added to our cache
|
||||
cachedTokens.push_back(id);
|
||||
if (lengthLimit != std::string::npos) {
|
||||
// EOS matched
|
||||
} else if (!isSpecialToken(new_tok.value())) {
|
||||
// Check if the response contains a stop sequence
|
||||
for (const auto &p : stopSequences) {
|
||||
auto match = cachedResponse.find(p);
|
||||
if (match != std::string::npos) stop = true;
|
||||
lengthLimit = std::min(lengthLimit, match);
|
||||
if (match == 0) break;
|
||||
}
|
||||
|
||||
// Continue if we have found a partial match
|
||||
if (foundPartialReversePrompt)
|
||||
continue;
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
promptCtx.tokens.push_back(t);
|
||||
promptCtx.n_past += 1;
|
||||
//TODO: Conversion to std::string can be avoided here...
|
||||
if (!responseCallback(t, std::string(tokenToString(t))))
|
||||
return;
|
||||
// Check if the response matches the start of a stop sequence
|
||||
if (lengthLimit == std::string::npos) {
|
||||
for (const auto &p : stopSequences) {
|
||||
auto match = stringsOverlap(cachedResponse, p);
|
||||
lengthLimit = std::min(lengthLimit, match);
|
||||
if (match == 0) break;
|
||||
}
|
||||
}
|
||||
} else if (ranges::contains(stopSequences, new_piece)) {
|
||||
// Special tokens must exactly match a stop sequence
|
||||
stop = true;
|
||||
lengthLimit = cachedResponse.size() - new_piece.size();
|
||||
}
|
||||
|
||||
// Optionally stop if the context will run out
|
||||
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
|
||||
<< promptCtx.n_ctx << "\n";
|
||||
stop = true;
|
||||
}
|
||||
|
||||
// Empty the cache, up to the length limit
|
||||
std::string::size_type responseLength = 0;
|
||||
while (!cachedTokens.empty()) {
|
||||
Token tok = cachedTokens.front();
|
||||
std::string piece = tokenToString(tok);
|
||||
|
||||
// Stop if the piece (or part of it) does not fit within the length limit
|
||||
if (responseLength + (stop ? 1 : piece.size()) > lengthLimit)
|
||||
break;
|
||||
|
||||
// Remove token from cache
|
||||
assert(cachedResponse.starts_with(piece));
|
||||
cachedTokens.erase(cachedTokens.begin(), cachedTokens.begin() + 1);
|
||||
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
|
||||
|
||||
// Accept the token, if needed (not cached)
|
||||
if (cachedTokens.empty() && new_tok && !accept())
|
||||
return;
|
||||
|
||||
// Send the token
|
||||
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// FIXME(jared): we could avoid printing partial stop sequences if we didn't have to
|
||||
// output token IDs and could cache a partial token for the next prompt call
|
||||
responseLength += piece.size();
|
||||
}
|
||||
assert(cachedTokens.empty() == cachedResponse.empty());
|
||||
|
||||
// Accept the token, if needed (in cache)
|
||||
if (new_tok) {
|
||||
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
|
||||
if (stop) {
|
||||
cachedTokens.pop_back();
|
||||
} else if (!accept()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
|
||||
auto &tokens = promptCtx.tokens;
|
||||
if (tokens.size() < cachedTokens.size()) {
|
||||
/* This is theoretically possible if the longest stop sequence is greater than
|
||||
* n_ctx * contextErase tokens. */
|
||||
throw std::runtime_error("shifted too much context, can't go back");
|
||||
}
|
||||
|
||||
auto discard_start = tokens.end() - cachedTokens.size();
|
||||
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
|
||||
tokens.erase(discard_start, tokens.end());
|
||||
|
||||
promptCtx.n_past -= cachedTokens.size();
|
||||
}
|
||||
|
||||
void LLModel::embed(
|
||||
|
@ -128,7 +128,6 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
|
||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
@ -137,7 +136,7 @@ llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_char_p,
|
||||
PromptCallback,
|
||||
ResponseCallback,
|
||||
RecalculateCallback,
|
||||
ctypes.c_bool,
|
||||
ctypes.POINTER(LLModelPromptContext),
|
||||
ctypes.c_bool,
|
||||
ctypes.c_char_p,
|
||||
@ -513,7 +512,7 @@ class LLModel:
|
||||
ctypes.c_char_p(prompt_template.encode()),
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
True,
|
||||
self.context,
|
||||
special,
|
||||
ctypes.c_char_p(),
|
||||
@ -606,8 +605,3 @@ class LLModel:
|
||||
@staticmethod
|
||||
def _prompt_callback(token_id: int) -> bool:
|
||||
return True
|
||||
|
||||
# Empty recalculate callback
|
||||
@staticmethod
|
||||
def _recalculate_callback(is_recalculating: bool) -> bool:
|
||||
return is_recalculating
|
||||
|
@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD 23)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
if(APPLE)
|
||||
@ -31,7 +31,6 @@ project(gpt4all VERSION ${APP_VERSION_BASE} LANGUAGES CXX C)
|
||||
|
||||
set(CMAKE_AUTOMOC ON)
|
||||
set(CMAKE_AUTORCC ON)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
option(GPT4ALL_TRANSLATIONS OFF "Build with translations")
|
||||
option(GPT4ALL_LOCALHOST OFF "Build installer for localhost repo")
|
||||
|
@ -62,7 +62,7 @@ void Chat::connectLLM()
|
||||
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||
@ -252,9 +252,9 @@ void Chat::serverNewPromptResponsePair(const QString &prompt)
|
||||
m_chatModel->appendResponse("Response: ", prompt);
|
||||
}
|
||||
|
||||
bool Chat::isRecalc() const
|
||||
bool Chat::restoringFromText() const
|
||||
{
|
||||
return m_llmodel->isRecalc();
|
||||
return m_llmodel->restoringFromText();
|
||||
}
|
||||
|
||||
void Chat::unloadAndDeleteLater()
|
||||
@ -320,10 +320,10 @@ void Chat::generatedQuestionFinished(const QString &question)
|
||||
emit generatedQuestionsChanged();
|
||||
}
|
||||
|
||||
void Chat::handleRecalculating()
|
||||
void Chat::handleRestoringFromText()
|
||||
{
|
||||
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
|
||||
emit recalcChanged();
|
||||
emit restoringFromTextChanged();
|
||||
}
|
||||
|
||||
void Chat::handleModelLoadingError(const QString &error)
|
||||
|
@ -27,7 +27,7 @@ class Chat : public QObject
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
|
||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
||||
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
|
||||
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
|
||||
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
||||
@ -88,7 +88,7 @@ public:
|
||||
ResponseState responseState() const;
|
||||
ModelInfo modelInfo() const;
|
||||
void setModelInfo(const ModelInfo &modelInfo);
|
||||
bool isRecalc() const;
|
||||
bool restoringFromText() const;
|
||||
|
||||
Q_INVOKABLE void unloadModel();
|
||||
Q_INVOKABLE void reloadModel();
|
||||
@ -144,7 +144,7 @@ Q_SIGNALS:
|
||||
void processSystemPromptRequested();
|
||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||
void modelInfoChanged();
|
||||
void recalcChanged();
|
||||
void restoringFromTextChanged();
|
||||
void loadDefaultModelRequested();
|
||||
void loadModelRequested(const ModelInfo &modelInfo);
|
||||
void generateNameRequested();
|
||||
@ -167,7 +167,7 @@ private Q_SLOTS:
|
||||
void responseStopped(qint64 promptResponseMs);
|
||||
void generatedNameChanged(const QString &name);
|
||||
void generatedQuestionFinished(const QString &question);
|
||||
void handleRecalculating();
|
||||
void handleRestoringFromText();
|
||||
void handleModelLoadingError(const QString &error);
|
||||
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
||||
|
@ -90,13 +90,13 @@ void ChatAPI::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply) {
|
||||
|
||||
Q_UNUSED(promptCallback);
|
||||
Q_UNUSED(recalculateCallback);
|
||||
Q_UNUSED(allowContextShift);
|
||||
Q_UNUSED(special);
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
|
@ -69,7 +69,7 @@ public:
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &ctx,
|
||||
bool special,
|
||||
std::string *fakeReply) override;
|
||||
@ -97,38 +97,57 @@ protected:
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
|
||||
{
|
||||
(void)ctx;
|
||||
(void)str;
|
||||
(void)special;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
std::string tokenToString(Token id) const override {
|
||||
bool isSpecialToken(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
Token sampleToken(PromptContext &ctx) const override {
|
||||
std::string tokenToString(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
Token sampleToken(PromptContext &ctx) const override
|
||||
{
|
||||
(void)ctx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
|
||||
{
|
||||
(void)ctx;
|
||||
(void)tokens;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
int32_t contextLength() const override {
|
||||
void shiftContext(PromptContext &promptCtx) override
|
||||
{
|
||||
(void)promptCtx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
const std::vector<Token> &endTokens() const override {
|
||||
int32_t contextLength() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool shouldAddBOS() const override {
|
||||
const std::vector<Token> &endTokens() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool shouldAddBOS() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_promptResponseTokens(0)
|
||||
, m_promptTokens(0)
|
||||
, m_isRecalc(false)
|
||||
, m_restoringFromText(false)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
if (m_isRecalc != isRecalc) {
|
||||
m_isRecalc = isRecalc;
|
||||
emit recalcChanged();
|
||||
}
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
|
||||
{
|
||||
if (m_restoreStateFromText) {
|
||||
@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit promptProcessing();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_timer->start();
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@ -904,10 +894,9 @@ void ChatLLM::generateName()
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
|
||||
LLModel::PromptContext ctx = m_ctx;
|
||||
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
|
||||
promptFunc, responseFunc, recalcFunc, ctx);
|
||||
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
|
||||
std::string trimmed = trim_whitespace(m_nameResponse);
|
||||
if (trimmed != m_nameResponse) {
|
||||
m_nameResponse = trimmed;
|
||||
@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
|
||||
return words.size() <= 3;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleQuestionPrompt(int32_t token)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
|
||||
LLModel::PromptContext ctx = m_ctx;
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
|
||||
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
||||
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ false, ctx);
|
||||
elapsed += totalTime.elapsed();
|
||||
emit responseStopped(elapsed);
|
||||
}
|
||||
@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return false;
|
||||
}
|
||||
|
||||
// this function serialized the cached model state to disk.
|
||||
// we want to also serialize n_ctx, and read it at load time.
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt()
|
||||
#endif
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
// use "%1%2" and not "%1" to avoid implicit whitespace
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText()
|
||||
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
|
||||
return;
|
||||
|
||||
m_isRecalc = true;
|
||||
emit recalcChanged();
|
||||
m_restoringFromText = true;
|
||||
emit restoringFromTextChanged();
|
||||
|
||||
m_stopGenerating = false;
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText()
|
||||
auto responseText = response.second.toStdString();
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||
recalcFunc, m_ctx, false, &responseText);
|
||||
/*allowContextShift*/ true, m_ctx, false, &responseText);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_stateFromText.clear();
|
||||
}
|
||||
|
||||
m_isRecalc = false;
|
||||
emit recalcChanged();
|
||||
m_restoringFromText = false;
|
||||
emit restoringFromTextChanged();
|
||||
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ class Chat;
|
||||
class ChatLLM : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
||||
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
||||
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
||||
@ -121,7 +121,7 @@ public:
|
||||
ModelInfo modelInfo() const;
|
||||
void setModelInfo(const ModelInfo &info);
|
||||
|
||||
bool isRecalc() const { return m_isRecalc; }
|
||||
bool restoringFromText() const { return m_restoringFromText; }
|
||||
|
||||
void acquireModel();
|
||||
void resetModel();
|
||||
@ -172,7 +172,7 @@ public Q_SLOTS:
|
||||
void processRestoreStateFromText();
|
||||
|
||||
Q_SIGNALS:
|
||||
void recalcChanged();
|
||||
void restoringFromTextChanged();
|
||||
void loadedModelInfoChanged();
|
||||
void modelLoadingPercentageChanged(float);
|
||||
void modelLoadingError(const QString &error);
|
||||
@ -201,19 +201,14 @@ protected:
|
||||
int32_t repeat_penalty_tokens);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleRecalculate(bool isRecalc);
|
||||
bool handleNamePrompt(int32_t token);
|
||||
bool handleNameResponse(int32_t token, const std::string &response);
|
||||
bool handleNameRecalculate(bool isRecalc);
|
||||
bool handleSystemPrompt(int32_t token);
|
||||
bool handleSystemResponse(int32_t token, const std::string &response);
|
||||
bool handleSystemRecalculate(bool isRecalc);
|
||||
bool handleRestoreStateFromTextPrompt(int32_t token);
|
||||
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
|
||||
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
|
||||
bool handleQuestionPrompt(int32_t token);
|
||||
bool handleQuestionResponse(int32_t token, const std::string &response);
|
||||
bool handleQuestionRecalculate(bool isRecalc);
|
||||
void saveState();
|
||||
void restoreState();
|
||||
|
||||
@ -236,7 +231,7 @@ private:
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
std::atomic<bool> m_isRecalc;
|
||||
std::atomic<bool> m_restoringFromText; // status indication
|
||||
std::atomic<bool> m_forceUnloadModel;
|
||||
std::atomic<bool> m_markedForDeletion;
|
||||
bool m_isServer;
|
||||
|
@ -834,7 +834,7 @@ Rectangle {
|
||||
to: 360
|
||||
duration: 1000
|
||||
loops: Animation.Infinite
|
||||
running: currentResponse && (currentChat.responseInProgress || currentChat.isRecalc)
|
||||
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -867,13 +867,13 @@ Rectangle {
|
||||
color: theme.mutedTextColor
|
||||
}
|
||||
RowLayout {
|
||||
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.isRecalc)
|
||||
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText)
|
||||
Text {
|
||||
color: theme.mutedTextColor
|
||||
font.pixelSize: theme.fontSizeLarger
|
||||
text: {
|
||||
if (currentChat.isRecalc)
|
||||
return qsTr("recalculating context ...");
|
||||
if (currentChat.restoringFromText)
|
||||
return qsTr("restoring from text ...");
|
||||
switch (currentChat.responseState) {
|
||||
case Chat.ResponseStopped: return qsTr("response stopped ...");
|
||||
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
|
||||
@ -1861,7 +1861,7 @@ Rectangle {
|
||||
}
|
||||
}
|
||||
function sendMessage() {
|
||||
if (textInput.text === "" || currentChat.responseInProgress || currentChat.isRecalc)
|
||||
if (textInput.text === "" || currentChat.responseInProgress || currentChat.restoringFromText)
|
||||
return
|
||||
|
||||
currentChat.stopGenerating()
|
||||
|
Loading…
Reference in New Issue
Block a user