Fix the context.

This commit is contained in:
Adam Treat 2023-04-17 14:11:41 -04:00
parent b0ce635338
commit f73fbf28a4
3 changed files with 21 additions and 5 deletions

19
llm.cpp
View File

@ -64,7 +64,7 @@ bool LLMObject::isModelLoaded() const
return m_llmodel->isModelLoaded();
}
void LLMObject::resetResponse()
void LLMObject::regenerateResponse()
{
s_ctx.n_past -= m_responseTokens;
s_ctx.n_past = std::max(0, s_ctx.n_past);
@ -75,9 +75,17 @@ void LLMObject::resetResponse()
emit responseChanged();
}
void LLMObject::resetResponse()
{
m_responseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
}
void LLMObject::resetContext()
{
resetResponse();
regenerateResponse();
s_ctx = LLModel::PromptContext();
}
@ -142,7 +150,6 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
qInfo() << instructPrompt << "\n";
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
@ -167,6 +174,7 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
}
@ -182,6 +190,11 @@ void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch);
}
void LLM::regenerateResponse()
{
emit regenerateResponseRequested(); // blocking queued connection
}
void LLM::resetResponse()
{
emit resetResponseRequested(); // blocking queued connection

5
llm.h
View File

@ -18,6 +18,7 @@ public:
bool loadModel();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
@ -63,8 +64,9 @@ public:
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch);
Q_INVOKABLE void resetContext();
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating();
QString response() const;
@ -80,6 +82,7 @@ Q_SIGNALS:
void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch);
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
void modelNameChanged();

View File

@ -666,7 +666,7 @@ Window {
if (LLM.responseInProgress)
LLM.stopGenerating()
else {
LLM.resetResponse()
LLM.regenerateResponse()
if (chatModel.count) {
var listElement = chatModel.get(chatModel.count - 1)
if (listElement.name === qsTr("Response: ")) {