Make localdocs work with server mode.

recalcuatecontext_nonvirtual
Adam Treat 1 year ago committed by AT
parent 8e89ceb54b
commit f62e439a2d

@ -45,7 +45,6 @@ void Chat::connectLLM()
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(LocalDocs::globalInstance(), &LocalDocs::receivedResult, this, &Chat::handleLocalDocsRetrieved, Qt::DirectConnection);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
@ -101,52 +100,17 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
Q_ASSERT(m_results.isEmpty());
m_results.clear(); // just in case, but the assert above is important
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_queuedPrompt.prompt = prompt;
m_queuedPrompt.prompt_template = prompt_template;
m_queuedPrompt.n_predict = n_predict;
m_queuedPrompt.top_k = top_k;
m_queuedPrompt.temp = temp;
m_queuedPrompt.n_batch = n_batch;
m_queuedPrompt.repeat_penalty = repeat_penalty;
m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens;
LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt);
}
void Chat::handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results)
{
// If the uid doesn't match, then these are not our results
if (uid != m_id)
return;
// Store our results locally
m_results = results;
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
if (!m_results.isEmpty())
augmentedTemplate.append("### Context:");
for (const ResultInfo &info : m_results)
augmentedTemplate.append(info.text);
augmentedTemplate.append(m_queuedPrompt.prompt_template);
emit promptRequested(
m_queuedPrompt.prompt,
augmentedTemplate.join("\n"),
m_queuedPrompt.n_predict,
m_queuedPrompt.top_k,
m_queuedPrompt.top_p,
m_queuedPrompt.temp,
m_queuedPrompt.n_batch,
m_queuedPrompt.repeat_penalty,
m_queuedPrompt.repeat_penalty_tokens,
prompt,
prompt_template,
n_predict,
top_k,
top_p,
temp,
n_batch,
repeat_penalty,
repeat_penalty_tokens,
LLM::globalInstance()->threadCount());
m_queuedPrompt = Prompt();
}
void Chat::regenerateResponse()
@ -195,9 +159,14 @@ void Chat::handleModelLoadedChanged()
deleteLater();
}
QList<ResultInfo> Chat::results() const
{
return m_llmodel->results();
}
void Chat::promptProcessing()
{
m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
m_responseState = !results().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
}
@ -207,7 +176,7 @@ void Chat::responseStopped()
QList<QString> references;
QList<QString> referencesContext;
int validReferenceNumber = 1;
for (const ResultInfo &info : m_results) {
for (const ResultInfo &info : results()) {
if (info.file.isEmpty())
continue;
if (validReferenceNumber == 1)
@ -241,7 +210,6 @@ void Chat::responseStopped()
m_chatModel->updateReferences(index, references.join("\n"), referencesContext);
emit responseChanged();
m_results.clear();
m_responseInProgress = false;
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
@ -266,6 +234,10 @@ void Chat::setModelName(const QString &modelName)
void Chat::newPromptResponsePair(const QString &prompt)
{
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
@ -274,7 +246,11 @@ void Chat::newPromptResponsePair(const QString &prompt)
void Chat::serverNewPromptResponsePair(const QString &prompt)
{
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
}

@ -60,6 +60,8 @@ public:
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QList<ResultInfo> results() const;
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
QString responseState() const;
@ -115,7 +117,6 @@ Q_SIGNALS:
void collectionListChanged();
private Q_SLOTS:
void handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results);
void handleResponseChanged();
void handleModelLoadedChanged();
void promptProcessing();
@ -125,24 +126,11 @@ private Q_SLOTS:
void handleModelNameChanged();
private:
struct Prompt {
QString prompt;
QString prompt_template;
int32_t n_predict;
int32_t top_k;
float top_p;
float temp;
int32_t n_batch;
float repeat_penalty;
int32_t repeat_penalty_tokens;
};
QString m_id;
QString m_name;
QString m_userName;
QString m_savedModelName;
QList<QString> m_collections;
QList<ResultInfo> m_results;
ChatModel *m_chatModel;
bool m_responseInProgress;
ResponseState m_responseState;
@ -150,7 +138,6 @@ private:
ChatLLM *m_llmodel;
bool m_isServer;
bool m_shouldDeleteLater;
Prompt m_queuedPrompt;
};
#endif // CHAT_H

@ -91,9 +91,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(m_chat->id());
m_llmThread.start();
}
@ -386,7 +392,19 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
if (!isModelLoaded())
return false;
QString instructPrompt = prompt_template.arg(prompt);
m_results.clear();
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_results); // blocks
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
if (!m_results.isEmpty())
augmentedTemplate.append("### Context:");
for (const ResultInfo &info : m_results)
augmentedTemplate.append(info.text);
augmentedTemplate.append(prompt_template);
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
m_stopGenerating = false;
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);

@ -5,6 +5,7 @@
#include <QThread>
#include <QFileInfo>
#include "localdocs.h"
#include "../gpt4all-backend/llmodel.h"
enum LLModelType {
@ -39,6 +40,7 @@ public:
void regenerateResponse();
void resetResponse();
void resetContext();
QList<ResultInfo> results() const { return m_results; }
void stopGenerating() { m_stopGenerating = true; }
@ -85,6 +87,8 @@ Q_SIGNALS:
void stateChanged();
void threadStarted();
void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
protected:
bool handlePrompt(int32_t token);
@ -111,6 +115,7 @@ protected:
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
QList<ResultInfo> m_results;
bool m_isRecalc;
bool m_isServer;
bool m_isChatGPT;

@ -892,7 +892,8 @@ bool Database::removeFolderFromWatch(const QString &path)
return m_watcher->removePath(path);
}
void Database::retrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize)
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
QList<ResultInfo> *results)
{
#if defined(DEBUG)
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
@ -904,7 +905,6 @@ void Database::retrieveFromDB(const QString &uid, const QList<QString> &collecti
return;
}
QList<ResultInfo> results;
while (q.next()) {
const int rowid = q.value(0).toInt();
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
@ -924,13 +924,12 @@ void Database::retrieveFromDB(const QString &uid, const QList<QString> &collecti
info.page = page;
info.from = from;
info.to = to;
results.append(info);
results->append(info);
#if defined(DEBUG)
qDebug() << "retrieve rowid:" << rowid
<< "chunk_text:" << chunk_text;
#endif
}
emit retrieveResult(uid, results);
}
void Database::cleanDB()

@ -43,13 +43,12 @@ public Q_SLOTS:
void scanDocuments(int folder_id, const QString &folder_path);
void addFolder(const QString &collection, const QString &path);
void removeFolder(const QString &collection, const QString &path);
void retrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize);
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void cleanDB();
void changeChunkSize(int chunkSize);
Q_SIGNALS:
void docsToScanChanged();
void retrieveResult(const QString &uid, const QList<ResultInfo> &result);
void collectionListUpdated(const QList<CollectionItem> &collectionList);
private Q_SLOTS:

@ -24,12 +24,8 @@ LocalDocs::LocalDocs()
&Database::addFolder, Qt::QueuedConnection);
connect(this, &LocalDocs::requestRemoveFolder, m_database,
&Database::removeFolder, Qt::QueuedConnection);
connect(this, &LocalDocs::requestRetrieveFromDB, m_database,
&Database::retrieveFromDB, Qt::QueuedConnection);
connect(this, &LocalDocs::requestChunkSizeChange, m_database,
&Database::changeChunkSize, Qt::QueuedConnection);
connect(m_database, &Database::retrieveResult, this,
&LocalDocs::receivedResult, Qt::QueuedConnection);
connect(m_database, &Database::collectionListUpdated,
m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection);
}
@ -49,11 +45,6 @@ void LocalDocs::removeFolder(const QString &collection, const QString &path)
emit requestRemoveFolder(collection, path);
}
void LocalDocs::requestRetrieve(const QString &uid, const QList<QString> &collections, const QString &text)
{
emit requestRetrieveFromDB(uid, collections, text, m_retrievalSize);
}
int LocalDocs::chunkSize() const
{
return m_chunkSize;

@ -20,7 +20,8 @@ public:
Q_INVOKABLE void addFolder(const QString &collection, const QString &path);
Q_INVOKABLE void removeFolder(const QString &collection, const QString &path);
void requestRetrieve(const QString &uid, const QList<QString> &collections, const QString &text);
Database *database() const { return m_database; }
int chunkSize() const;
void setChunkSize(int chunkSize);
@ -31,9 +32,7 @@ public:
Q_SIGNALS:
void requestAddFolder(const QString &collection, const QString &path);
void requestRemoveFolder(const QString &collection, const QString &path);
void requestRetrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize);
void requestChunkSizeChange(int chunkSize);
void receivedResult(const QString &uid, const QList<ResultInfo> &result);
void localDocsModelChanged();
void chunkSizeChanged();
void retrievalSizeChanged();

@ -51,6 +51,20 @@ static inline QJsonObject modelToJson(const ModelInfo &info)
return model;
}
static inline QJsonObject resultToJson(const ResultInfo &info)
{
QJsonObject result;
result.insert("file", info.file);
result.insert("title", info.title);
result.insert("author", info.author);
result.insert("date", info.date);
result.insert("text", info.text);
result.insert("page", info.page);
result.insert("from", info.from);
result.insert("to", info.to);
return result;
}
Server::Server(Chat *chat)
: ChatLLM(chat, true /*isServer*/)
, m_chat(chat)
@ -298,7 +312,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
int promptTokens = 0;
int responseTokens = 0;
QList<QString> responses;
QList<QPair<QString, QList<ResultInfo>>> responses;
for (int i = 0; i < n; ++i) {
if (!prompt(actualPrompt,
promptTemplate,
@ -317,7 +331,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
QString echoedPrompt = actualPrompt;
if (!echoedPrompt.endsWith("\n"))
echoedPrompt += "\n";
responses.append((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response());
responses.append(qMakePair((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response(), m_results));
if (!promptTokens)
promptTokens += m_promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens;
@ -335,24 +349,36 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
if (isChat) {
int index = 0;
for (QString r : responses) {
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("index", index++);
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
QJsonObject message;
message.insert("role", "assistant");
message.insert("content", r);
message.insert("content", result);
choice.insert("message", message);
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
choices.append(choice);
}
} else {
int index = 0;
for (QString r : responses) {
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("text", r);
choice.insert("text", result);
choice.insert("index", index++);
choice.insert("logprobs", QJsonValue::Null); // We don't support
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
choices.append(choice);
}
}

Loading…
Cancel
Save