Fix the context.

pull/520/head
Adam Treat 1 year ago
parent 801107a12c
commit 15ae0a4441

@ -64,7 +64,7 @@ bool LLMObject::isModelLoaded() const
return m_llmodel->isModelLoaded(); return m_llmodel->isModelLoaded();
} }
void LLMObject::resetResponse() void LLMObject::regenerateResponse()
{ {
s_ctx.n_past -= m_responseTokens; s_ctx.n_past -= m_responseTokens;
s_ctx.n_past = std::max(0, s_ctx.n_past); s_ctx.n_past = std::max(0, s_ctx.n_past);
@ -75,9 +75,17 @@ void LLMObject::resetResponse()
emit responseChanged(); emit responseChanged();
} }
void LLMObject::resetResponse()
{
m_responseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void LLMObject::resetContext() void LLMObject::resetContext()
{ {
resetResponse(); regenerateResponse();
s_ctx = LLModel::PromptContext(); s_ctx = LLModel::PromptContext();
} }
@ -142,7 +150,6 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
emit responseStarted(); emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size(); qint32 logitsBefore = s_ctx.logits.size();
qInfo() << instructPrompt << "\n";
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch); m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
m_responseLogits += s_ctx.logits.size() - logitsBefore; m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
@ -167,6 +174,7 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
} }
@ -182,6 +190,11 @@ void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch); emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch);
} }
void LLM::regenerateResponse()
{
emit regenerateResponseRequested(); // blocking queued connection
}
void LLM::resetResponse() void LLM::resetResponse()
{ {
emit resetResponseRequested(); // blocking queued connection emit resetResponseRequested(); // blocking queued connection

@ -18,6 +18,7 @@ public:
bool loadModel(); bool loadModel();
bool isModelLoaded() const; bool isModelLoaded() const;
void regenerateResponse();
void resetResponse(); void resetResponse();
void resetContext(); void resetContext();
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
@ -63,8 +64,9 @@ public:
Q_INVOKABLE bool isModelLoaded() const; Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch);
Q_INVOKABLE void resetContext(); Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void resetResponse(); Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
QString response() const; QString response() const;
@ -80,6 +82,7 @@ Q_SIGNALS:
void responseInProgressChanged(); void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch);
void regenerateResponseRequested();
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
void modelNameChanged(); void modelNameChanged();

@ -666,7 +666,7 @@ Window {
if (LLM.responseInProgress) if (LLM.responseInProgress)
LLM.stopGenerating() LLM.stopGenerating()
else { else {
LLM.resetResponse() LLM.regenerateResponse()
if (chatModel.count) { if (chatModel.count) {
var listElement = chatModel.get(chatModel.count - 1) var listElement = chatModel.get(chatModel.count - 1)
if (listElement.name === qsTr("Response: ")) { if (listElement.name === qsTr("Response: ")) {

Loading…
Cancel
Save