2023-05-01 13:10:05 +00:00
# include "chatllm.h"
2023-05-04 19:31:41 +00:00
# include "chat.h"
2023-06-22 19:44:49 +00:00
# include "chatgpt.h"
# include "modellist.h"
2023-05-01 13:10:05 +00:00
# include "network.h"
2023-06-27 15:54:34 +00:00
# include "mysettings.h"
2023-05-31 21:04:01 +00:00
# include "../gpt4all-backend/llmodel.h"
2023-05-01 13:10:05 +00:00
//#define DEBUG
2023-05-13 23:05:35 +00:00
//#define DEBUG_MODEL_LOADING
2023-05-01 13:10:05 +00:00
2023-05-08 21:23:02 +00:00
# define GPTJ_INTERNAL_STATE_VERSION 0
# define LLAMA_INTERNAL_STATE_VERSION 0
2023-07-09 15:32:51 +00:00
# define BERT_INTERNAL_STATE_VERSION 0
2023-05-08 21:23:02 +00:00
2023-05-13 23:05:35 +00:00
class LLModelStore {
public :
static LLModelStore * globalInstance ( ) ;
LLModelInfo acquireModel ( ) ; // will block until llmodel is ready
void releaseModel ( const LLModelInfo & info ) ; // must be called when you are done
private :
LLModelStore ( )
{
// seed with empty model
m_availableModels . append ( LLModelInfo ( ) ) ;
}
~ LLModelStore ( ) { }
QVector < LLModelInfo > m_availableModels ;
QMutex m_mutex ;
QWaitCondition m_condition ;
friend class MyLLModelStore ;
} ;
class MyLLModelStore : public LLModelStore { } ;
Q_GLOBAL_STATIC ( MyLLModelStore , storeInstance )
LLModelStore * LLModelStore : : globalInstance ( )
{
return storeInstance ( ) ;
}
LLModelInfo LLModelStore : : acquireModel ( )
{
QMutexLocker locker ( & m_mutex ) ;
while ( m_availableModels . isEmpty ( ) )
m_condition . wait ( locker . mutex ( ) ) ;
return m_availableModels . takeFirst ( ) ;
}
void LLModelStore : : releaseModel ( const LLModelInfo & info )
{
QMutexLocker locker ( & m_mutex ) ;
m_availableModels . append ( info ) ;
Q_ASSERT ( m_availableModels . count ( ) < 2 ) ;
m_condition . wakeAll ( ) ;
}
2023-05-13 23:33:19 +00:00
ChatLLM : : ChatLLM ( Chat * parent , bool isServer )
2023-05-01 13:10:05 +00:00
: QObject { nullptr }
, m_promptResponseTokens ( 0 )
2023-05-11 20:46:25 +00:00
, m_promptTokens ( 0 )
2023-05-01 13:10:05 +00:00
, m_isRecalc ( false )
2023-06-19 22:26:04 +00:00
, m_shouldBeLoaded ( true )
, m_stopGenerating ( false )
2023-06-19 18:34:53 +00:00
, m_timer ( nullptr )
2023-05-13 23:33:19 +00:00
, m_isServer ( isServer )
2023-06-27 15:54:34 +00:00
, m_forceMetal ( MySettings : : globalInstance ( ) - > forceMetal ( ) )
, m_reloadingToChangeVariant ( false )
2023-07-01 15:34:21 +00:00
, m_processedSystemPrompt ( false )
2023-10-10 20:43:02 +00:00
, m_restoreStateFromText ( false )
2023-05-01 13:10:05 +00:00
{
moveToThread ( & m_llmThread ) ;
connect ( this , & ChatLLM : : sendStartup , Network : : globalInstance ( ) , & Network : : sendStartup ) ;
connect ( this , & ChatLLM : : sendModelLoaded , Network : : globalInstance ( ) , & Network : : sendModelLoaded ) ;
2023-06-01 18:13:12 +00:00
connect ( this , & ChatLLM : : shouldBeLoadedChanged , this , & ChatLLM : : handleShouldBeLoadedChanged ,
Qt : : QueuedConnection ) ; // explicitly queued
2023-06-19 23:51:28 +00:00
connect ( parent , & Chat : : idChanged , this , & ChatLLM : : handleChatIdChanged ) ;
2023-06-19 18:34:53 +00:00
connect ( & m_llmThread , & QThread : : started , this , & ChatLLM : : handleThreadStarted ) ;
2023-06-27 15:54:34 +00:00
connect ( MySettings : : globalInstance ( ) , & MySettings : : forceMetalChanged , this , & ChatLLM : : handleForceMetalChanged ) ;
2023-09-13 14:32:08 +00:00
connect ( MySettings : : globalInstance ( ) , & MySettings : : deviceChanged , this , & ChatLLM : : handleDeviceChanged ) ;
2023-06-01 18:13:12 +00:00
// The following are blocking operations and will block the llm thread
connect ( this , & ChatLLM : : requestRetrieveFromDB , LocalDocs : : globalInstance ( ) - > database ( ) , & Database : : retrieveFromDB ,
Qt : : BlockingQueuedConnection ) ;
2023-06-19 23:51:28 +00:00
m_llmThread . setObjectName ( parent - > id ( ) ) ;
2023-05-01 13:10:05 +00:00
m_llmThread . start ( ) ;
}
2023-05-12 18:06:03 +00:00
ChatLLM : : ~ ChatLLM ( )
{
2023-07-09 18:42:11 +00:00
m_stopGenerating = true ;
2023-05-12 18:06:03 +00:00
m_llmThread . quit ( ) ;
m_llmThread . wait ( ) ;
2023-05-13 23:05:35 +00:00
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if ( isModelLoaded ( ) ) {
2023-06-22 19:44:49 +00:00
delete m_llModelInfo . model ;
m_llModelInfo . model = nullptr ;
2023-05-13 23:05:35 +00:00
}
2023-05-12 18:06:03 +00:00
}
2023-06-19 18:34:53 +00:00
void ChatLLM : : handleThreadStarted ( )
{
m_timer = new TokenTimer ( this ) ;
connect ( m_timer , & TokenTimer : : report , this , & ChatLLM : : reportSpeed ) ;
emit threadStarted ( ) ;
}
2023-06-27 15:54:34 +00:00
void ChatLLM : : handleForceMetalChanged ( bool forceMetal )
{
# if defined(Q_OS_MAC) && defined(__arm__)
m_forceMetal = forceMetal ;
if ( isModelLoaded ( ) & & m_shouldBeLoaded ) {
m_reloadingToChangeVariant = true ;
unloadModel ( ) ;
reloadModel ( ) ;
m_reloadingToChangeVariant = false ;
}
# endif
}
2023-09-13 14:32:08 +00:00
void ChatLLM : : handleDeviceChanged ( )
{
if ( isModelLoaded ( ) & & m_shouldBeLoaded ) {
m_reloadingToChangeVariant = true ;
unloadModel ( ) ;
reloadModel ( ) ;
m_reloadingToChangeVariant = false ;
}
}
2023-05-04 19:31:41 +00:00
bool ChatLLM : : loadDefaultModel ( )
2023-05-01 13:10:05 +00:00
{
2023-06-22 19:44:49 +00:00
ModelInfo defaultModel = ModelList : : globalInstance ( ) - > defaultModelInfo ( ) ;
2023-07-01 15:34:21 +00:00
if ( defaultModel . filename ( ) . isEmpty ( ) ) {
2023-06-22 19:44:49 +00:00
emit modelLoadingError ( QString ( " Could not find any model to load " ) ) ;
2023-05-01 13:10:05 +00:00
return false ;
}
2023-06-22 19:44:49 +00:00
return loadModel ( defaultModel ) ;
2023-05-01 13:10:05 +00:00
}
2023-06-22 19:44:49 +00:00
bool ChatLLM : : loadModel ( const ModelInfo & modelInfo )
2023-05-01 13:10:05 +00:00
{
2023-05-13 23:05:35 +00:00
// This is a complicated method because N different possible threads are interested in the outcome
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
// possible chat threads all vying for a single resource - the currently loaded model - as the user
// switches back and forth between chats. It is important for our main/gui thread to never block
// but simultaneously always have up2date information with regards to which chat has the model loaded
// and what the type and name of that model is. I've tried to comment extensively in this method
// to provide an overview of what we're doing here.
// We're already loaded with this model
2023-06-22 19:44:49 +00:00
if ( isModelLoaded ( ) & & this - > modelInfo ( ) = = modelInfo )
2023-05-01 13:10:05 +00:00
return true ;
2023-06-22 19:44:49 +00:00
bool isChatGPT = modelInfo . isChatGPT ;
2023-07-01 15:34:21 +00:00
QString filePath = modelInfo . dirpath + modelInfo . filename ( ) ;
2023-05-13 23:05:35 +00:00
QFileInfo fileInfo ( filePath ) ;
// We have a live model, but it isn't the one we want
bool alreadyAcquired = isModelLoaded ( ) ;
if ( alreadyAcquired ) {
2023-05-30 22:17:59 +00:00
resetContext ( ) ;
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " already acquired model deleted " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
2023-06-22 19:44:49 +00:00
delete m_llModelInfo . model ;
m_llModelInfo . model = nullptr ;
2023-06-20 20:14:30 +00:00
emit isModelLoadedChanged ( false ) ;
2023-05-13 23:33:19 +00:00
} else if ( ! m_isServer ) {
2023-05-13 23:05:35 +00:00
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
2023-06-22 19:44:49 +00:00
m_llModelInfo = LLModelStore : : globalInstance ( ) - > acquireModel ( ) ;
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-27 15:54:34 +00:00
qDebug ( ) < < " acquired model from store " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if ( ! m_shouldBeLoaded ) {
2023-05-13 23:33:19 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " no longer need model " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:33:19 +00:00
# endif
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ;
m_llModelInfo = LLModelInfo ( ) ;
2023-06-20 20:14:30 +00:00
emit isModelLoadedChanged ( false ) ;
2023-05-13 23:05:35 +00:00
return false ;
}
// Check if the store just gave us exactly the model we were looking for
2023-06-27 15:54:34 +00:00
if ( m_llModelInfo . model & & m_llModelInfo . fileInfo = = fileInfo & & ! m_reloadingToChangeVariant ) {
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " store had our model " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
restoreState ( ) ;
2023-06-20 20:14:30 +00:00
emit isModelLoadedChanged ( true ) ;
2023-07-09 18:52:08 +00:00
setModelInfo ( modelInfo ) ;
2023-07-01 15:34:21 +00:00
Q_ASSERT ( ! m_modelInfo . filename ( ) . isEmpty ( ) ) ;
if ( m_modelInfo . filename ( ) . isEmpty ( ) )
emit modelLoadingError ( QString ( " Modelinfo is left null for %1 " ) . arg ( modelInfo . filename ( ) ) ) ;
else
processSystemPrompt ( ) ;
2023-05-13 23:05:35 +00:00
return true ;
} else {
// Release the memory since we have to switch to a different model.
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " deleting model " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
2023-06-22 19:44:49 +00:00
delete m_llModelInfo . model ;
m_llModelInfo . model = nullptr ;
2023-05-13 23:05:35 +00:00
}
2023-05-01 13:10:05 +00:00
}
2023-05-13 23:05:35 +00:00
// Guarantee we've released the previous models memory
2023-06-22 19:44:49 +00:00
Q_ASSERT ( ! m_llModelInfo . model ) ;
2023-05-13 23:05:35 +00:00
// Store the file info in the modelInfo in case we have an error loading
2023-06-22 19:44:49 +00:00
m_llModelInfo . fileInfo = fileInfo ;
2023-05-01 13:10:05 +00:00
2023-07-11 16:09:33 +00:00
// Check if we've previously tried to load this file and failed/crashed
if ( MySettings : : globalInstance ( ) - > attemptModelLoad ( ) = = filePath ) {
MySettings : : globalInstance ( ) - > setAttemptModelLoad ( QString ( ) ) ; // clear the flag
if ( ! m_isServer )
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ; // release back into the store
m_llModelInfo = LLModelInfo ( ) ;
2023-11-30 17:37:52 +00:00
emit modelLoadingError ( QString ( " Previous attempt to load model resulted in crash for `%1` most likely due to insufficient memory. You should either remove this model or decrease your system RAM usage by closing other applications. " ) . arg ( modelInfo . filename ( ) ) ) ;
2023-07-11 16:09:33 +00:00
}
2023-05-13 23:05:35 +00:00
if ( fileInfo . exists ( ) ) {
2023-06-20 18:02:46 +00:00
if ( isChatGPT ) {
2023-05-15 00:12:15 +00:00
QString apiKey ;
QString chatGPTModel = fileInfo . completeBaseName ( ) . remove ( 0 , 8 ) ; // remove the chatgpt- prefix
{
QFile file ( filePath ) ;
file . open ( QIODeviceBase : : ReadOnly | QIODeviceBase : : Text ) ;
QTextStream stream ( & file ) ;
apiKey = stream . readAll ( ) ;
file . close ( ) ;
}
2023-06-22 19:44:49 +00:00
m_llModelType = LLModelType : : CHATGPT_ ;
2023-05-15 00:12:15 +00:00
ChatGPT * model = new ChatGPT ( ) ;
model - > setModelName ( chatGPTModel ) ;
model - > setAPIKey ( apiKey ) ;
2023-06-22 19:44:49 +00:00
m_llModelInfo . model = model ;
2023-05-01 13:10:05 +00:00
} else {
2023-06-27 15:54:34 +00:00
2023-12-16 22:58:15 +00:00
// TODO: make configurable in UI
auto n_ctx = MySettings : : globalInstance ( ) - > modelContextLength ( modelInfo ) ;
m_ctx . n_ctx = n_ctx ;
std : : string buildVariant = " auto " ;
2023-06-27 15:54:34 +00:00
# if defined(Q_OS_MAC) && defined(__arm__)
if ( m_forceMetal )
2023-12-16 22:58:15 +00:00
buildVariant = " metal " ;
2023-06-27 15:54:34 +00:00
# endif
2023-12-16 22:58:15 +00:00
m_llModelInfo . model = LLModel : : Implementation : : construct ( filePath . toStdString ( ) , buildVariant , n_ctx ) ;
2023-06-27 15:54:34 +00:00
2023-06-22 19:44:49 +00:00
if ( m_llModelInfo . model ) {
2023-09-13 14:32:08 +00:00
// Update the settings that a model is being loaded and update the device list
2023-07-11 16:09:33 +00:00
MySettings : : globalInstance ( ) - > setAttemptModelLoad ( filePath ) ;
2023-09-13 14:32:08 +00:00
// Pick the best match for the device
2023-09-14 12:25:37 +00:00
QString actualDevice = m_llModelInfo . model - > implementation ( ) . buildVariant ( ) = = " metal " ? " Metal " : " CPU " ;
2023-09-13 14:32:08 +00:00
const QString requestedDevice = MySettings : : globalInstance ( ) - > device ( ) ;
2023-10-02 14:23:11 +00:00
if ( requestedDevice = = " CPU " ) {
emit reportFallbackReason ( " " ) ; // fallback not applicable
} else {
2023-12-16 22:58:15 +00:00
const size_t requiredMemory = m_llModelInfo . model - > requiredMem ( filePath . toStdString ( ) , n_ctx ) ;
2023-09-13 14:32:08 +00:00
std : : vector < LLModel : : GPUDevice > availableDevices = m_llModelInfo . model - > availableGPUDevices ( requiredMemory ) ;
2023-09-29 18:25:37 +00:00
LLModel : : GPUDevice * device = nullptr ;
2023-09-13 23:30:27 +00:00
if ( ! availableDevices . empty ( ) & & requestedDevice = = " Auto " & & availableDevices . front ( ) . type = = 2 /*a discrete gpu*/ ) {
2023-09-29 18:25:37 +00:00
device = & availableDevices . front ( ) ;
2023-09-13 14:32:08 +00:00
} else {
for ( LLModel : : GPUDevice & d : availableDevices ) {
2023-09-13 19:32:42 +00:00
if ( QString : : fromStdString ( d . name ) = = requestedDevice ) {
2023-09-29 18:25:37 +00:00
device = & d ;
2023-09-13 19:32:42 +00:00
break ;
}
2023-09-13 14:32:08 +00:00
}
}
2023-09-29 18:25:37 +00:00
2023-10-04 19:51:46 +00:00
emit reportFallbackReason ( " " ) ; // no fallback yet
std : : string unavail_reason ;
if ( ! device ) {
// GPU not available
} else if ( ! m_llModelInfo . model - > initializeGPUDevice ( * device , & unavail_reason ) ) {
2023-10-06 15:30:55 +00:00
emit reportFallbackReason ( QString : : fromStdString ( " <br> " + unavail_reason ) ) ;
2023-09-29 18:25:37 +00:00
} else {
actualDevice = QString : : fromStdString ( device - > name ) ;
}
2023-09-13 14:32:08 +00:00
}
2023-09-14 12:25:37 +00:00
// Report which device we're actually using
emit reportDevice ( actualDevice ) ;
2023-12-16 22:58:15 +00:00
bool success = m_llModelInfo . model - > loadModel ( filePath . toStdString ( ) , n_ctx ) ;
2023-09-29 18:25:37 +00:00
if ( actualDevice = = " CPU " ) {
// we asked llama.cpp to use the CPU
} else if ( ! success ) {
// llama_init_from_file returned nullptr
2023-09-14 20:52:31 +00:00
emit reportDevice ( " CPU " ) ;
2023-10-06 15:30:55 +00:00
emit reportFallbackReason ( " <br>GPU loading failed (out of VRAM?) " ) ;
2023-12-16 22:58:15 +00:00
success = m_llModelInfo . model - > loadModel ( filePath . toStdString ( ) , n_ctx ) ;
2023-09-29 18:25:37 +00:00
} else if ( ! m_llModelInfo . model - > usingGPUDevice ( ) ) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
// for instance if the quantization method is not supported on Vulkan yet
emit reportDevice ( " CPU " ) ;
2023-10-06 15:30:55 +00:00
emit reportFallbackReason ( " <br>model or quant has no GPU support " ) ;
2023-09-14 20:52:31 +00:00
}
2023-07-11 16:09:33 +00:00
MySettings : : globalInstance ( ) - > setAttemptModelLoad ( QString ( ) ) ;
2023-06-19 23:51:28 +00:00
if ( ! success ) {
2023-09-14 20:52:31 +00:00
delete m_llModelInfo . model ;
m_llModelInfo . model = nullptr ;
2023-06-19 23:51:28 +00:00
if ( ! m_isServer )
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ; // release back into the store
m_llModelInfo = LLModelInfo ( ) ;
2023-07-01 15:34:21 +00:00
emit modelLoadingError ( QString ( " Could not load model due to invalid model file for %1 " ) . arg ( modelInfo . filename ( ) ) ) ;
2023-06-19 23:51:28 +00:00
} else {
2023-07-08 14:04:38 +00:00
switch ( m_llModelInfo . model - > implementation ( ) . modelType ( ) [ 0 ] ) {
2023-06-22 19:44:49 +00:00
case ' L ' : m_llModelType = LLModelType : : LLAMA_ ; break ;
case ' G ' : m_llModelType = LLModelType : : GPTJ_ ; break ;
2023-07-09 15:32:51 +00:00
case ' B ' : m_llModelType = LLModelType : : BERT_ ; break ;
2023-06-19 23:51:28 +00:00
default :
{
2023-09-14 20:52:31 +00:00
delete m_llModelInfo . model ;
m_llModelInfo . model = nullptr ;
2023-06-19 23:51:28 +00:00
if ( ! m_isServer )
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ; // release back into the store
m_llModelInfo = LLModelInfo ( ) ;
2023-07-01 15:34:21 +00:00
emit modelLoadingError ( QString ( " Could not determine model type for %1 " ) . arg ( modelInfo . filename ( ) ) ) ;
2023-06-19 23:51:28 +00:00
}
}
2023-05-31 21:04:01 +00:00
}
2023-06-04 18:55:05 +00:00
} else {
if ( ! m_isServer )
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ; // release back into the store
m_llModelInfo = LLModelInfo ( ) ;
2023-07-01 15:34:21 +00:00
emit modelLoadingError ( QString ( " Could not load model due to invalid format for %1 " ) . arg ( modelInfo . filename ( ) ) ) ;
2023-05-15 00:12:15 +00:00
}
2023-05-01 13:10:05 +00:00
}
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " new model " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
2023-05-07 15:24:07 +00:00
restoreState ( ) ;
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " modelLoadedChanged " < < m_llmThread . objectName ( ) ;
2023-05-13 23:05:35 +00:00
fflush ( stdout ) ;
2023-05-07 15:24:07 +00:00
# endif
2023-06-20 20:14:30 +00:00
emit isModelLoadedChanged ( isModelLoaded ( ) ) ;
2023-05-01 13:10:05 +00:00
2023-05-10 03:43:16 +00:00
static bool isFirstLoad = true ;
2023-05-09 15:46:33 +00:00
if ( isFirstLoad ) {
2023-05-01 13:10:05 +00:00
emit sendStartup ( ) ;
2023-05-09 15:46:33 +00:00
isFirstLoad = false ;
} else
2023-05-01 13:10:05 +00:00
emit sendModelLoaded ( ) ;
2023-05-04 19:31:41 +00:00
} else {
2023-05-13 23:33:19 +00:00
if ( ! m_isServer )
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ; // release back into the store
m_llModelInfo = LLModelInfo ( ) ;
2023-07-01 15:34:21 +00:00
emit modelLoadingError ( QString ( " Could not find file for model %1 " ) . arg ( modelInfo . filename ( ) ) ) ;
2023-05-01 13:10:05 +00:00
}
2023-09-21 16:41:48 +00:00
if ( m_llModelInfo . model ) {
2023-06-22 19:44:49 +00:00
setModelInfo ( modelInfo ) ;
2023-09-21 16:41:48 +00:00
processSystemPrompt ( ) ;
}
2023-06-22 19:44:49 +00:00
return m_llModelInfo . model ;
2023-05-01 13:10:05 +00:00
}
bool ChatLLM : : isModelLoaded ( ) const
{
2023-06-22 19:44:49 +00:00
return m_llModelInfo . model & & m_llModelInfo . model - > isModelLoaded ( ) ;
2023-05-01 13:10:05 +00:00
}
2023-10-28 20:53:42 +00:00
std : : string remove_leading_whitespace ( const std : : string & input ) {
auto first_non_whitespace = std : : find_if ( input . begin ( ) , input . end ( ) , [ ] ( unsigned char c ) {
return ! std : : isspace ( c ) ;
} ) ;
if ( first_non_whitespace = = input . end ( ) )
return std : : string ( ) ;
return std : : string ( first_non_whitespace , input . end ( ) ) ;
}
std : : string trim_whitespace ( const std : : string & input ) {
auto first_non_whitespace = std : : find_if ( input . begin ( ) , input . end ( ) , [ ] ( unsigned char c ) {
return ! std : : isspace ( c ) ;
} ) ;
if ( first_non_whitespace = = input . end ( ) )
return std : : string ( ) ;
auto last_non_whitespace = std : : find_if ( input . rbegin ( ) , input . rend ( ) , [ ] ( unsigned char c ) {
return ! std : : isspace ( c ) ;
} ) . base ( ) ;
return std : : string ( first_non_whitespace , last_non_whitespace ) ;
}
2023-05-01 13:10:05 +00:00
void ChatLLM : : regenerateResponse ( )
{
2023-05-15 18:08:08 +00:00
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
2023-06-22 19:44:49 +00:00
if ( m_llModelType = = LLModelType : : CHATGPT_ )
2023-05-15 18:08:08 +00:00
m_ctx . n_past - = 1 ;
else
m_ctx . n_past - = m_promptResponseTokens ;
2023-05-01 13:10:05 +00:00
m_ctx . n_past = std : : max ( 0 , m_ctx . n_past ) ;
2023-10-03 16:42:31 +00:00
m_ctx . tokens . erase ( m_ctx . tokens . end ( ) - m_promptResponseTokens , m_ctx . tokens . end ( ) ) ;
2023-05-01 13:10:05 +00:00
m_promptResponseTokens = 0 ;
2023-05-11 20:46:25 +00:00
m_promptTokens = 0 ;
2023-05-01 13:10:05 +00:00
m_response = std : : string ( ) ;
2023-06-20 20:14:30 +00:00
emit responseChanged ( QString : : fromStdString ( m_response ) ) ;
2023-05-01 13:10:05 +00:00
}
void ChatLLM : : resetResponse ( )
{
2023-05-11 20:46:25 +00:00
m_promptTokens = 0 ;
2023-05-01 13:10:05 +00:00
m_promptResponseTokens = 0 ;
m_response = std : : string ( ) ;
2023-06-20 20:14:30 +00:00
emit responseChanged ( QString : : fromStdString ( m_response ) ) ;
2023-05-01 13:10:05 +00:00
}
void ChatLLM : : resetContext ( )
{
regenerateResponse ( ) ;
2023-07-01 15:34:21 +00:00
m_processedSystemPrompt = false ;
2023-05-01 13:10:05 +00:00
m_ctx = LLModel : : PromptContext ( ) ;
}
QString ChatLLM : : response ( ) const
{
return QString : : fromStdString ( remove_leading_whitespace ( m_response ) ) ;
}
2023-06-22 19:44:49 +00:00
ModelInfo ChatLLM : : modelInfo ( ) const
2023-05-01 13:10:05 +00:00
{
2023-06-22 19:44:49 +00:00
return m_modelInfo ;
2023-05-01 13:10:05 +00:00
}
2023-06-22 19:44:49 +00:00
void ChatLLM : : setModelInfo ( const ModelInfo & modelInfo )
2023-05-01 13:10:05 +00:00
{
2023-06-22 19:44:49 +00:00
m_modelInfo = modelInfo ;
emit modelInfoChanged ( modelInfo ) ;
2023-05-01 13:10:05 +00:00
}
2023-06-22 19:44:49 +00:00
void ChatLLM : : modelChangeRequested ( const ModelInfo & modelInfo )
2023-05-01 13:10:05 +00:00
{
2023-06-22 19:44:49 +00:00
loadModel ( modelInfo ) ;
2023-05-01 13:10:05 +00:00
}
bool ChatLLM : : handlePrompt ( int32_t token )
{
2023-06-19 22:23:05 +00:00
// m_promptResponseTokens is related to last prompt/response not
2023-05-01 13:10:05 +00:00
// the entire context window which we can reset on regenerate prompt
2023-05-07 15:24:07 +00:00
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " prompt process " < < m_llmThread . objectName ( ) < < token ;
2023-05-07 15:24:07 +00:00
# endif
2023-05-11 20:46:25 +00:00
+ + m_promptTokens ;
2023-05-01 13:10:05 +00:00
+ + m_promptResponseTokens ;
2023-08-30 13:43:56 +00:00
m_timer - > start ( ) ;
2023-05-01 13:10:05 +00:00
return ! m_stopGenerating ;
}
bool ChatLLM : : handleResponse ( int32_t token , const std : : string & response )
{
# if defined(DEBUG)
printf ( " %s " , response . c_str ( ) ) ;
fflush ( stdout ) ;
# endif
// check for error
if ( token < 0 ) {
m_response . append ( response ) ;
2023-10-28 20:53:42 +00:00
emit responseChanged ( QString : : fromStdString ( remove_leading_whitespace ( m_response ) ) ) ;
2023-05-01 13:10:05 +00:00
return false ;
}
2023-06-19 22:23:05 +00:00
// m_promptResponseTokens is related to last prompt/response not
2023-05-01 13:10:05 +00:00
// the entire context window which we can reset on regenerate prompt
+ + m_promptResponseTokens ;
2023-06-19 18:34:53 +00:00
m_timer - > inc ( ) ;
2023-05-01 13:10:05 +00:00
Q_ASSERT ( ! response . empty ( ) ) ;
m_response . append ( response ) ;
2023-10-28 20:53:42 +00:00
emit responseChanged ( QString : : fromStdString ( remove_leading_whitespace ( m_response ) ) ) ;
2023-05-01 13:10:05 +00:00
return ! m_stopGenerating ;
}
bool ChatLLM : : handleRecalculate ( bool isRecalc )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " recalculate " < < m_llmThread . objectName ( ) < < isRecalc ;
# endif
2023-05-01 13:10:05 +00:00
if ( m_isRecalc ! = isRecalc ) {
m_isRecalc = isRecalc ;
emit recalcChanged ( ) ;
}
return ! m_stopGenerating ;
}
2023-07-01 15:34:21 +00:00
bool ChatLLM : : prompt ( const QList < QString > & collectionList , const QString & prompt )
{
2023-10-28 20:39:25 +00:00
if ( m_restoreStateFromText ) {
Q_ASSERT ( m_state . isEmpty ( ) ) ;
processRestoreStateFromText ( ) ;
}
2023-07-01 15:34:21 +00:00
if ( ! m_processedSystemPrompt )
processSystemPrompt ( ) ;
const QString promptTemplate = MySettings : : globalInstance ( ) - > modelPromptTemplate ( m_modelInfo ) ;
const int32_t n_predict = MySettings : : globalInstance ( ) - > modelMaxLength ( m_modelInfo ) ;
const int32_t top_k = MySettings : : globalInstance ( ) - > modelTopK ( m_modelInfo ) ;
const float top_p = MySettings : : globalInstance ( ) - > modelTopP ( m_modelInfo ) ;
const float temp = MySettings : : globalInstance ( ) - > modelTemperature ( m_modelInfo ) ;
const int32_t n_batch = MySettings : : globalInstance ( ) - > modelPromptBatchSize ( m_modelInfo ) ;
const float repeat_penalty = MySettings : : globalInstance ( ) - > modelRepeatPenalty ( m_modelInfo ) ;
const int32_t repeat_penalty_tokens = MySettings : : globalInstance ( ) - > modelRepeatPenaltyTokens ( m_modelInfo ) ;
return promptInternal ( collectionList , prompt , promptTemplate , n_predict , top_k , top_p , temp , n_batch ,
repeat_penalty , repeat_penalty_tokens ) ;
}
2023-05-01 13:10:05 +00:00
2023-07-01 15:34:21 +00:00
bool ChatLLM : : promptInternal ( const QList < QString > & collectionList , const QString & prompt , const QString & promptTemplate ,
int32_t n_predict , int32_t top_k , float top_p , float temp , int32_t n_batch , float repeat_penalty ,
int32_t repeat_penalty_tokens )
2023-05-01 13:10:05 +00:00
{
if ( ! isModelLoaded ( ) )
return false ;
2023-06-19 22:23:54 +00:00
QList < ResultInfo > databaseResults ;
2023-06-29 00:42:40 +00:00
const int retrievalSize = MySettings : : globalInstance ( ) - > localDocsRetrievalSize ( ) ;
2023-12-04 17:58:40 +00:00
if ( ! collectionList . isEmpty ( ) ) {
emit requestRetrieveFromDB ( collectionList , prompt , retrievalSize , & databaseResults ) ; // blocks
emit databaseResultsChanged ( databaseResults ) ;
}
2023-06-01 18:13:12 +00:00
// Augment the prompt template with the results if any
QList < QString > augmentedTemplate ;
2023-06-19 22:23:54 +00:00
if ( ! databaseResults . isEmpty ( ) )
2023-06-01 18:13:12 +00:00
augmentedTemplate . append ( " ### Context: " ) ;
2023-06-19 22:23:54 +00:00
for ( const ResultInfo & info : databaseResults )
2023-06-01 18:13:12 +00:00
augmentedTemplate . append ( info . text ) ;
2023-07-01 15:34:21 +00:00
augmentedTemplate . append ( promptTemplate ) ;
2023-06-01 18:13:12 +00:00
QString instructPrompt = augmentedTemplate . join ( " \n " ) . arg ( prompt ) ;
2023-05-01 13:10:05 +00:00
2023-07-01 15:34:21 +00:00
int n_threads = MySettings : : globalInstance ( ) - > threadCount ( ) ;
2023-05-01 13:10:05 +00:00
m_stopGenerating = false ;
auto promptFunc = std : : bind ( & ChatLLM : : handlePrompt , this , std : : placeholders : : _1 ) ;
auto responseFunc = std : : bind ( & ChatLLM : : handleResponse , this , std : : placeholders : : _1 ,
std : : placeholders : : _2 ) ;
auto recalcFunc = std : : bind ( & ChatLLM : : handleRecalculate , this , std : : placeholders : : _1 ) ;
2023-05-21 00:04:36 +00:00
emit promptProcessing ( ) ;
2023-05-01 13:10:05 +00:00
qint32 logitsBefore = m_ctx . logits . size ( ) ;
m_ctx . n_predict = n_predict ;
m_ctx . top_k = top_k ;
m_ctx . top_p = top_p ;
m_ctx . temp = temp ;
m_ctx . n_batch = n_batch ;
m_ctx . repeat_penalty = repeat_penalty ;
m_ctx . repeat_last_n = repeat_penalty_tokens ;
2023-06-22 19:44:49 +00:00
m_llModelInfo . model - > setThreadCount ( n_threads ) ;
2023-05-21 00:03:59 +00:00
# if defined(DEBUG)
2023-05-01 13:10:05 +00:00
printf ( " %s " , qPrintable ( instructPrompt ) ) ;
fflush ( stdout ) ;
2023-05-21 00:03:59 +00:00
# endif
2023-06-19 18:34:53 +00:00
m_timer - > start ( ) ;
2023-06-22 19:44:49 +00:00
m_llModelInfo . model - > prompt ( instructPrompt . toStdString ( ) , promptFunc , responseFunc , recalcFunc , m_ctx ) ;
2023-05-21 00:03:59 +00:00
# if defined(DEBUG)
2023-05-01 13:10:05 +00:00
printf ( " \n " ) ;
fflush ( stdout ) ;
2023-05-21 00:03:59 +00:00
# endif
2023-06-19 18:34:53 +00:00
m_timer - > stop ( ) ;
2023-05-01 13:10:05 +00:00
std : : string trimmed = trim_whitespace ( m_response ) ;
if ( trimmed ! = m_response ) {
m_response = trimmed ;
2023-06-20 20:14:30 +00:00
emit responseChanged ( QString : : fromStdString ( m_response ) ) ;
2023-05-01 13:10:05 +00:00
}
emit responseStopped ( ) ;
return true ;
}
2023-05-02 00:27:07 +00:00
2023-05-13 23:05:35 +00:00
void ChatLLM : : setShouldBeLoaded ( bool b )
2023-05-02 00:27:07 +00:00
{
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " setShouldBeLoaded " < < m_llmThread . objectName ( ) < < b < < m_llModelInfo . model ;
2023-05-07 15:24:07 +00:00
# endif
2023-05-13 23:05:35 +00:00
m_shouldBeLoaded = b ; // atomic
emit shouldBeLoadedChanged ( ) ;
}
void ChatLLM : : handleShouldBeLoadedChanged ( )
{
if ( m_shouldBeLoaded )
reloadModel ( ) ;
else
unloadModel ( ) ;
}
void ChatLLM : : forceUnloadModel ( )
{
m_shouldBeLoaded = false ; // atomic
unloadModel ( ) ;
}
void ChatLLM : : unloadModel ( )
{
2023-05-15 22:48:24 +00:00
if ( ! isModelLoaded ( ) | | m_isServer )
2023-05-13 23:05:35 +00:00
return ;
2023-05-04 19:31:41 +00:00
saveState ( ) ;
2023-05-13 23:05:35 +00:00
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " unloadModel " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-13 23:05:35 +00:00
# endif
2023-06-22 19:44:49 +00:00
LLModelStore : : globalInstance ( ) - > releaseModel ( m_llModelInfo ) ;
m_llModelInfo = LLModelInfo ( ) ;
2023-06-20 20:14:30 +00:00
emit isModelLoadedChanged ( false ) ;
2023-05-02 00:27:07 +00:00
}
2023-05-13 23:05:35 +00:00
void ChatLLM : : reloadModel ( )
2023-05-02 00:27:07 +00:00
{
2023-05-15 22:48:24 +00:00
if ( isModelLoaded ( ) | | m_isServer )
2023-05-13 23:05:35 +00:00
return ;
# if defined(DEBUG_MODEL_LOADING)
2023-06-22 19:44:49 +00:00
qDebug ( ) < < " reloadModel " < < m_llmThread . objectName ( ) < < m_llModelInfo . model ;
2023-05-07 15:24:07 +00:00
# endif
2023-06-22 19:44:49 +00:00
const ModelInfo m = modelInfo ( ) ;
2023-07-01 15:34:21 +00:00
if ( m . name ( ) . isEmpty ( ) )
2023-05-04 19:31:41 +00:00
loadDefaultModel ( ) ;
2023-06-20 20:14:30 +00:00
else
loadModel ( m ) ;
2023-05-02 00:27:07 +00:00
}
2023-05-02 15:19:17 +00:00
void ChatLLM : : generateName ( )
{
Q_ASSERT ( isModelLoaded ( ) ) ;
if ( ! isModelLoaded ( ) )
return ;
QString instructPrompt ( " ### Instruction: \n "
" Describe response above in three words. \n "
" ### Response: \n " ) ;
auto promptFunc = std : : bind ( & ChatLLM : : handleNamePrompt , this , std : : placeholders : : _1 ) ;
auto responseFunc = std : : bind ( & ChatLLM : : handleNameResponse , this , std : : placeholders : : _1 ,
std : : placeholders : : _2 ) ;
auto recalcFunc = std : : bind ( & ChatLLM : : handleNameRecalculate , this , std : : placeholders : : _1 ) ;
LLModel : : PromptContext ctx = m_ctx ;
# if defined(DEBUG)
printf ( " %s " , qPrintable ( instructPrompt ) ) ;
fflush ( stdout ) ;
# endif
2023-06-22 19:44:49 +00:00
m_llModelInfo . model - > prompt ( instructPrompt . toStdString ( ) , promptFunc , responseFunc , recalcFunc , ctx ) ;
2023-05-02 15:19:17 +00:00
# if defined(DEBUG)
printf ( " \n " ) ;
fflush ( stdout ) ;
# endif
std : : string trimmed = trim_whitespace ( m_nameResponse ) ;
if ( trimmed ! = m_nameResponse ) {
m_nameResponse = trimmed ;
2023-06-20 20:14:30 +00:00
emit generatedNameChanged ( QString : : fromStdString ( m_nameResponse ) ) ;
2023-05-02 15:19:17 +00:00
}
}
2023-06-19 23:51:28 +00:00
void ChatLLM : : handleChatIdChanged ( const QString & id )
2023-05-04 19:31:41 +00:00
{
2023-06-19 23:51:28 +00:00
m_llmThread . setObjectName ( id ) ;
}
2023-05-02 15:19:17 +00:00
bool ChatLLM : : handleNamePrompt ( int32_t token )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " name prompt " < < m_llmThread . objectName ( ) < < token ;
# endif
2023-05-02 15:19:17 +00:00
Q_UNUSED ( token ) ;
qt_noop ( ) ;
2023-07-09 18:42:11 +00:00
return ! m_stopGenerating ;
2023-05-02 15:19:17 +00:00
}
bool ChatLLM : : handleNameResponse ( int32_t token , const std : : string & response )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " name response " < < m_llmThread . objectName ( ) < < token < < response ;
# endif
2023-05-02 15:19:17 +00:00
Q_UNUSED ( token ) ;
2023-05-08 16:02:31 +00:00
2023-05-02 15:19:17 +00:00
m_nameResponse . append ( response ) ;
2023-06-20 20:14:30 +00:00
emit generatedNameChanged ( QString : : fromStdString ( m_nameResponse ) ) ;
2023-05-08 16:02:31 +00:00
QString gen = QString : : fromStdString ( m_nameResponse ) . simplified ( ) ;
QStringList words = gen . split ( ' ' , Qt : : SkipEmptyParts ) ;
return words . size ( ) < = 3 ;
2023-05-02 15:19:17 +00:00
}
bool ChatLLM : : handleNameRecalculate ( bool isRecalc )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " name recalc " < < m_llmThread . objectName ( ) < < isRecalc ;
# endif
2023-05-02 15:19:17 +00:00
Q_UNUSED ( isRecalc ) ;
2023-07-09 15:32:51 +00:00
qt_noop ( ) ;
return true ;
2023-05-02 15:19:17 +00:00
}
2023-05-04 19:31:41 +00:00
2023-07-01 15:34:21 +00:00
bool ChatLLM : : handleSystemPrompt ( int32_t token )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " system prompt " < < m_llmThread . objectName ( ) < < token < < m_stopGenerating ;
# endif
2023-07-01 15:34:21 +00:00
Q_UNUSED ( token ) ;
2023-07-09 18:42:11 +00:00
return ! m_stopGenerating ;
2023-07-01 15:34:21 +00:00
}
bool ChatLLM : : handleSystemResponse ( int32_t token , const std : : string & response )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " system response " < < m_llmThread . objectName ( ) < < token < < response < < m_stopGenerating ;
# endif
2023-07-01 15:34:21 +00:00
Q_UNUSED ( token ) ;
Q_UNUSED ( response ) ;
2023-07-11 16:37:21 +00:00
return false ;
2023-07-01 15:34:21 +00:00
}
bool ChatLLM : : handleSystemRecalculate ( bool isRecalc )
{
2023-07-11 16:37:21 +00:00
# if defined(DEBUG)
qDebug ( ) < < " system recalc " < < m_llmThread . objectName ( ) < < isRecalc ;
# endif
2023-07-01 15:34:21 +00:00
Q_UNUSED ( isRecalc ) ;
2023-07-11 16:37:21 +00:00
return false ;
2023-07-01 15:34:21 +00:00
}
2023-10-10 20:43:02 +00:00
bool ChatLLM : : handleRestoreStateFromTextPrompt ( int32_t token )
{
# if defined(DEBUG)
qDebug ( ) < < " restore state from text prompt " < < m_llmThread . objectName ( ) < < token < < m_stopGenerating ;
# endif
Q_UNUSED ( token ) ;
return ! m_stopGenerating ;
}
bool ChatLLM : : handleRestoreStateFromTextResponse ( int32_t token , const std : : string & response )
{
# if defined(DEBUG)
qDebug ( ) < < " restore state from text response " < < m_llmThread . objectName ( ) < < token < < response < < m_stopGenerating ;
# endif
Q_UNUSED ( token ) ;
Q_UNUSED ( response ) ;
return false ;
}
bool ChatLLM : : handleRestoreStateFromTextRecalculate ( bool isRecalc )
{
# if defined(DEBUG)
qDebug ( ) < < " restore state from text recalc " < < m_llmThread . objectName ( ) < < isRecalc ;
# endif
Q_UNUSED ( isRecalc ) ;
return false ;
}
2023-12-16 22:58:15 +00:00
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
2023-10-10 20:43:02 +00:00
bool ChatLLM : : serialize ( QDataStream & stream , int version , bool serializeKV )
2023-05-04 19:31:41 +00:00
{
2023-05-08 21:23:02 +00:00
if ( version > 1 ) {
2023-06-22 19:44:49 +00:00
stream < < m_llModelType ;
switch ( m_llModelType ) {
2023-05-08 21:23:02 +00:00
case GPTJ_ : stream < < GPTJ_INTERNAL_STATE_VERSION ; break ;
case LLAMA_ : stream < < LLAMA_INTERNAL_STATE_VERSION ; break ;
2023-07-09 15:32:51 +00:00
case BERT_ : stream < < BERT_INTERNAL_STATE_VERSION ; break ;
2023-05-08 21:23:02 +00:00
default : Q_UNREACHABLE ( ) ;
}
}
2023-05-04 19:31:41 +00:00
stream < < response ( ) ;
stream < < generatedName ( ) ;
stream < < m_promptResponseTokens ;
2023-10-10 20:43:02 +00:00
if ( ! serializeKV ) {
# if defined(DEBUG)
qDebug ( ) < < " serialize " < < m_llmThread . objectName ( ) < < m_state . size ( ) ;
# endif
return stream . status ( ) = = QDataStream : : Ok ;
}
2023-06-19 22:23:05 +00:00
if ( version < = 3 ) {
2023-10-11 15:31:34 +00:00
int responseLogits = 0 ;
2023-06-19 22:23:05 +00:00
stream < < responseLogits ;
}
2023-05-04 19:31:41 +00:00
stream < < m_ctx . n_past ;
2024-01-22 15:01:31 +00:00
if ( version > = 7 ) {
2023-12-16 22:58:15 +00:00
stream < < m_ctx . n_ctx ;
}
2023-05-04 19:31:41 +00:00
stream < < quint64 ( m_ctx . logits . size ( ) ) ;
stream . writeRawData ( reinterpret_cast < const char * > ( m_ctx . logits . data ( ) ) , m_ctx . logits . size ( ) * sizeof ( float ) ) ;
stream < < quint64 ( m_ctx . tokens . size ( ) ) ;
stream . writeRawData ( reinterpret_cast < const char * > ( m_ctx . tokens . data ( ) ) , m_ctx . tokens . size ( ) * sizeof ( int ) ) ;
saveState ( ) ;
2023-05-06 00:11:24 +00:00
QByteArray compressed = qCompress ( m_state ) ;
stream < < compressed ;
2023-05-07 15:24:07 +00:00
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " serialize " < < m_llmThread . objectName ( ) < < m_state . size ( ) ;
2023-05-07 15:24:07 +00:00
# endif
2023-05-04 19:31:41 +00:00
return stream . status ( ) = = QDataStream : : Ok ;
}
2023-10-10 20:43:02 +00:00
bool ChatLLM : : deserialize ( QDataStream & stream , int version , bool deserializeKV , bool discardKV )
2023-05-04 19:31:41 +00:00
{
2023-05-08 21:23:02 +00:00
if ( version > 1 ) {
int internalStateVersion ;
2023-06-22 19:44:49 +00:00
stream > > m_llModelType ;
2023-05-08 21:23:02 +00:00
stream > > internalStateVersion ; // for future use
}
2023-05-04 19:31:41 +00:00
QString response ;
stream > > response ;
m_response = response . toStdString ( ) ;
QString nameResponse ;
stream > > nameResponse ;
m_nameResponse = nameResponse . toStdString ( ) ;
stream > > m_promptResponseTokens ;
2023-10-10 20:43:02 +00:00
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = ! deserializeKV | | discardKV ;
if ( ! deserializeKV ) {
# if defined(DEBUG)
qDebug ( ) < < " deserialize " < < m_llmThread . objectName ( ) ;
# endif
return stream . status ( ) = = QDataStream : : Ok ;
}
2023-06-19 22:23:05 +00:00
if ( version < = 3 ) {
int responseLogits ;
stream > > responseLogits ;
}
2023-10-10 20:43:02 +00:00
int32_t n_past ;
stream > > n_past ;
if ( ! discardKV ) m_ctx . n_past = n_past ;
2024-01-22 15:01:31 +00:00
if ( version > = 7 ) {
2023-12-16 22:58:15 +00:00
uint32_t n_ctx ;
stream > > n_ctx ;
if ( ! discardKV ) m_ctx . n_ctx = n_ctx ;
}
2023-05-04 19:31:41 +00:00
quint64 logitsSize ;
stream > > logitsSize ;
2023-10-10 20:43:02 +00:00
if ( ! discardKV ) {
m_ctx . logits . resize ( logitsSize ) ;
stream . readRawData ( reinterpret_cast < char * > ( m_ctx . logits . data ( ) ) , logitsSize * sizeof ( float ) ) ;
} else {
stream . skipRawData ( logitsSize * sizeof ( float ) ) ;
}
2023-05-04 19:31:41 +00:00
quint64 tokensSize ;
stream > > tokensSize ;
2023-10-10 20:43:02 +00:00
if ( ! discardKV ) {
m_ctx . tokens . resize ( tokensSize ) ;
stream . readRawData ( reinterpret_cast < char * > ( m_ctx . tokens . data ( ) ) , tokensSize * sizeof ( int ) ) ;
} else {
stream . skipRawData ( tokensSize * sizeof ( int ) ) ;
}
2023-05-08 09:52:57 +00:00
if ( version > 0 ) {
QByteArray compressed ;
stream > > compressed ;
2023-10-10 20:43:02 +00:00
if ( ! discardKV )
m_state = qUncompress ( compressed ) ;
2023-05-08 09:52:57 +00:00
} else {
2023-12-12 16:45:03 +00:00
if ( ! discardKV ) {
2023-10-10 20:43:02 +00:00
stream > > m_state ;
2023-12-12 16:45:03 +00:00
} else {
2023-10-10 20:43:02 +00:00
QByteArray state ;
2023-12-12 16:45:03 +00:00
stream > > state ;
2023-10-10 20:43:02 +00:00
}
2023-05-08 09:52:57 +00:00
}
2023-10-10 20:43:02 +00:00
2023-05-07 15:24:07 +00:00
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " deserialize " < < m_llmThread . objectName ( ) ;
2023-05-07 15:24:07 +00:00
# endif
2023-05-04 19:31:41 +00:00
return stream . status ( ) = = QDataStream : : Ok ;
}
void ChatLLM : : saveState ( )
{
if ( ! isModelLoaded ( ) )
return ;
2023-06-22 19:44:49 +00:00
if ( m_llModelType = = LLModelType : : CHATGPT_ ) {
2023-05-15 22:36:41 +00:00
m_state . clear ( ) ;
QDataStream stream ( & m_state , QIODeviceBase : : WriteOnly ) ;
stream . setVersion ( QDataStream : : Qt_6_5 ) ;
2023-06-22 19:44:49 +00:00
ChatGPT * chatGPT = static_cast < ChatGPT * > ( m_llModelInfo . model ) ;
2023-05-15 22:36:41 +00:00
stream < < chatGPT - > context ( ) ;
return ;
}
2023-06-22 19:44:49 +00:00
const size_t stateSize = m_llModelInfo . model - > stateSize ( ) ;
2023-05-04 19:31:41 +00:00
m_state . resize ( stateSize ) ;
2023-05-07 15:24:07 +00:00
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " saveState " < < m_llmThread . objectName ( ) < < " size: " < < m_state . size ( ) ;
2023-05-07 15:24:07 +00:00
# endif
2023-06-22 19:44:49 +00:00
m_llModelInfo . model - > saveState ( static_cast < uint8_t * > ( reinterpret_cast < void * > ( m_state . data ( ) ) ) ) ;
2023-05-04 19:31:41 +00:00
}
void ChatLLM : : restoreState ( )
{
2023-10-10 20:43:02 +00:00
if ( ! isModelLoaded ( ) )
2023-05-04 19:31:41 +00:00
return ;
2023-06-22 19:44:49 +00:00
if ( m_llModelType = = LLModelType : : CHATGPT_ ) {
2023-05-15 22:36:41 +00:00
QDataStream stream ( & m_state , QIODeviceBase : : ReadOnly ) ;
stream . setVersion ( QDataStream : : Qt_6_5 ) ;
2023-06-22 19:44:49 +00:00
ChatGPT * chatGPT = static_cast < ChatGPT * > ( m_llModelInfo . model ) ;
2023-05-15 22:36:41 +00:00
QList < QString > context ;
stream > > context ;
chatGPT - > setContext ( context ) ;
m_state . clear ( ) ;
2023-12-12 16:45:03 +00:00
m_state . squeeze ( ) ;
2023-05-15 22:36:41 +00:00
return ;
}
2023-05-07 15:24:07 +00:00
# if defined(DEBUG)
2023-06-19 23:51:28 +00:00
qDebug ( ) < < " restoreState " < < m_llmThread . objectName ( ) < < " size: " < < m_state . size ( ) ;
2023-05-07 15:24:07 +00:00
# endif
2023-10-10 20:43:02 +00:00
if ( m_state . isEmpty ( ) )
return ;
2023-12-12 16:45:03 +00:00
if ( m_llModelInfo . model - > stateSize ( ) = = m_state . size ( ) ) {
m_llModelInfo . model - > restoreState ( static_cast < const uint8_t * > ( reinterpret_cast < void * > ( m_state . data ( ) ) ) ) ;
m_processedSystemPrompt = true ;
} else {
qWarning ( ) < < " restoring state from text because " < < m_llModelInfo . model - > stateSize ( ) < < " != " < < m_state . size ( ) < < " \n " ;
m_restoreStateFromText = true ;
}
2023-05-09 01:05:50 +00:00
m_state . clear ( ) ;
2023-12-12 16:45:03 +00:00
m_state . squeeze ( ) ;
2023-05-04 19:31:41 +00:00
}
2023-07-01 15:34:21 +00:00
void ChatLLM : : processSystemPrompt ( )
{
Q_ASSERT ( isModelLoaded ( ) ) ;
2023-11-21 15:42:12 +00:00
if ( ! isModelLoaded ( ) | | m_processedSystemPrompt | | m_restoreStateFromText | | m_isServer )
2023-07-01 15:34:21 +00:00
return ;
2023-07-12 18:27:48 +00:00
const std : : string systemPrompt = MySettings : : globalInstance ( ) - > modelSystemPrompt ( m_modelInfo ) . toStdString ( ) ;
2023-07-12 18:30:11 +00:00
if ( QString : : fromStdString ( systemPrompt ) . trimmed ( ) . isEmpty ( ) ) {
2023-07-12 18:27:48 +00:00
m_processedSystemPrompt = true ;
return ;
}
2023-10-10 20:43:02 +00:00
// Start with a whole new context
2023-07-09 18:42:11 +00:00
m_stopGenerating = false ;
2023-10-10 20:43:02 +00:00
m_ctx = LLModel : : PromptContext ( ) ;
2023-07-01 15:34:21 +00:00
auto promptFunc = std : : bind ( & ChatLLM : : handleSystemPrompt , this , std : : placeholders : : _1 ) ;
auto responseFunc = std : : bind ( & ChatLLM : : handleSystemResponse , this , std : : placeholders : : _1 ,
std : : placeholders : : _2 ) ;
auto recalcFunc = std : : bind ( & ChatLLM : : handleSystemRecalculate , this , std : : placeholders : : _1 ) ;
const int32_t n_predict = MySettings : : globalInstance ( ) - > modelMaxLength ( m_modelInfo ) ;
const int32_t top_k = MySettings : : globalInstance ( ) - > modelTopK ( m_modelInfo ) ;
const float top_p = MySettings : : globalInstance ( ) - > modelTopP ( m_modelInfo ) ;
const float temp = MySettings : : globalInstance ( ) - > modelTemperature ( m_modelInfo ) ;
const int32_t n_batch = MySettings : : globalInstance ( ) - > modelPromptBatchSize ( m_modelInfo ) ;
const float repeat_penalty = MySettings : : globalInstance ( ) - > modelRepeatPenalty ( m_modelInfo ) ;
const int32_t repeat_penalty_tokens = MySettings : : globalInstance ( ) - > modelRepeatPenaltyTokens ( m_modelInfo ) ;
int n_threads = MySettings : : globalInstance ( ) - > threadCount ( ) ;
m_ctx . n_predict = n_predict ;
m_ctx . top_k = top_k ;
m_ctx . top_p = top_p ;
m_ctx . temp = temp ;
m_ctx . n_batch = n_batch ;
m_ctx . repeat_penalty = repeat_penalty ;
m_ctx . repeat_last_n = repeat_penalty_tokens ;
m_llModelInfo . model - > setThreadCount ( n_threads ) ;
# if defined(DEBUG)
printf ( " %s " , qPrintable ( QString : : fromStdString ( systemPrompt ) ) ) ;
fflush ( stdout ) ;
# endif
m_llModelInfo . model - > prompt ( systemPrompt , promptFunc , responseFunc , recalcFunc , m_ctx ) ;
# if defined(DEBUG)
printf ( " \n " ) ;
fflush ( stdout ) ;
# endif
2023-10-10 20:43:02 +00:00
2023-11-21 15:42:12 +00:00
m_processedSystemPrompt = m_stopGenerating = = false ;
2023-10-10 20:43:02 +00:00
}
void ChatLLM : : processRestoreStateFromText ( )
{
Q_ASSERT ( isModelLoaded ( ) ) ;
if ( ! isModelLoaded ( ) | | ! m_restoreStateFromText | | m_isServer )
return ;
m_isRecalc = true ;
emit recalcChanged ( ) ;
m_stopGenerating = false ;
m_ctx = LLModel : : PromptContext ( ) ;
auto promptFunc = std : : bind ( & ChatLLM : : handleRestoreStateFromTextPrompt , this , std : : placeholders : : _1 ) ;
auto responseFunc = std : : bind ( & ChatLLM : : handleRestoreStateFromTextResponse , this , std : : placeholders : : _1 ,
std : : placeholders : : _2 ) ;
auto recalcFunc = std : : bind ( & ChatLLM : : handleRestoreStateFromTextRecalculate , this , std : : placeholders : : _1 ) ;
const QString promptTemplate = MySettings : : globalInstance ( ) - > modelPromptTemplate ( m_modelInfo ) ;
const int32_t n_predict = MySettings : : globalInstance ( ) - > modelMaxLength ( m_modelInfo ) ;
const int32_t top_k = MySettings : : globalInstance ( ) - > modelTopK ( m_modelInfo ) ;
const float top_p = MySettings : : globalInstance ( ) - > modelTopP ( m_modelInfo ) ;
const float temp = MySettings : : globalInstance ( ) - > modelTemperature ( m_modelInfo ) ;
const int32_t n_batch = MySettings : : globalInstance ( ) - > modelPromptBatchSize ( m_modelInfo ) ;
const float repeat_penalty = MySettings : : globalInstance ( ) - > modelRepeatPenalty ( m_modelInfo ) ;
const int32_t repeat_penalty_tokens = MySettings : : globalInstance ( ) - > modelRepeatPenaltyTokens ( m_modelInfo ) ;
int n_threads = MySettings : : globalInstance ( ) - > threadCount ( ) ;
m_ctx . n_predict = n_predict ;
m_ctx . top_k = top_k ;
m_ctx . top_p = top_p ;
m_ctx . temp = temp ;
m_ctx . n_batch = n_batch ;
m_ctx . repeat_penalty = repeat_penalty ;
m_ctx . repeat_last_n = repeat_penalty_tokens ;
m_llModelInfo . model - > setThreadCount ( n_threads ) ;
for ( auto pair : m_stateFromText ) {
const QString str = pair . first = = " Prompt: " ? promptTemplate . arg ( pair . second ) : pair . second ;
m_llModelInfo . model - > prompt ( str . toStdString ( ) , promptFunc , responseFunc , recalcFunc , m_ctx ) ;
}
if ( ! m_stopGenerating ) {
m_restoreStateFromText = false ;
m_stateFromText . clear ( ) ;
}
m_isRecalc = false ;
emit recalcChanged ( ) ;
2023-09-29 17:53:43 +00:00
}