mirror of https://github.com/nomic-ai/gpt4all
Preliminary support for chatgpt models.
parent
da3828af89
commit
dd27c10f54
@ -0,0 +1,206 @@
|
||||
#include "chatgpt.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include <QThread>
|
||||
#include <QEventLoop>
|
||||
#include <QJsonDocument>
|
||||
#include <QJsonObject>
|
||||
#include <QJsonArray>
|
||||
|
||||
//#define DEBUG
|
||||
|
||||
ChatGPT::ChatGPT()
|
||||
: QObject(nullptr)
|
||||
, m_modelName("gpt-3.5-turbo")
|
||||
, m_ctx(nullptr)
|
||||
, m_responseCallback(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
bool ChatGPT::loadModel(const std::string &modelPath)
|
||||
{
|
||||
Q_UNUSED(modelPath);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatGPT::setThreadCount(int32_t n_threads)
|
||||
{
|
||||
Q_UNUSED(n_threads);
|
||||
qt_noop();
|
||||
}
|
||||
|
||||
int32_t ChatGPT::threadCount()
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
ChatGPT::~ChatGPT()
|
||||
{
|
||||
}
|
||||
|
||||
bool ChatGPT::isModelLoaded() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t ChatGPT::stateSize() const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t ChatGPT::saveState(uint8_t *dest) const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t ChatGPT::restoreState(const uint8_t *src)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ChatGPT::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
Q_UNUSED(promptCallback);
|
||||
Q_UNUSED(recalculateCallback);
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
|
||||
return;
|
||||
}
|
||||
|
||||
m_ctx = &promptCtx;
|
||||
m_responseCallback = responseCallback;
|
||||
|
||||
QJsonObject root;
|
||||
root.insert("model", m_modelName);
|
||||
root.insert("stream", true);
|
||||
root.insert("temperature", promptCtx.temp);
|
||||
root.insert("top_p", promptCtx.top_p);
|
||||
root.insert("max_tokens", 200);
|
||||
|
||||
QJsonArray messages;
|
||||
for (int i = 0; i < m_context.count() && i < promptCtx.n_past; ++i) {
|
||||
QJsonObject message;
|
||||
message.insert("role", i % 2 == 0 ? "assistant" : "user");
|
||||
message.insert("content", m_context.at(i));
|
||||
messages.append(message);
|
||||
}
|
||||
|
||||
QJsonObject promptObject;
|
||||
promptObject.insert("role", "user");
|
||||
promptObject.insert("content", QString::fromStdString(prompt));
|
||||
messages.append(promptObject);
|
||||
root.insert("messages", messages);
|
||||
|
||||
QJsonDocument doc(root);
|
||||
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
|
||||
#endif
|
||||
|
||||
QEventLoop loop;
|
||||
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
|
||||
const QString authorization = QString("Bearer %1").arg(m_apiKey);
|
||||
QNetworkRequest request(openaiUrl);
|
||||
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||
request.setRawHeader("Authorization", authorization.toUtf8());
|
||||
QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact));
|
||||
connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
|
||||
connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished);
|
||||
connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead);
|
||||
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred);
|
||||
loop.exec();
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "ChatGPT::prompt end network request";
|
||||
#endif
|
||||
|
||||
m_ctx->n_past += 1;
|
||||
m_context.append(QString::fromStdString(prompt));
|
||||
m_context.append(m_currentResponse);
|
||||
|
||||
m_ctx = nullptr;
|
||||
m_responseCallback = nullptr;
|
||||
m_currentResponse = QString();
|
||||
}
|
||||
|
||||
void ChatGPT::handleFinished()
|
||||
{
|
||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||
if (!reply)
|
||||
return;
|
||||
|
||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
||||
Q_ASSERT(response.isValid());
|
||||
bool ok;
|
||||
int code = response.toInt(&ok);
|
||||
if (!ok || code != 200) {
|
||||
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
|
||||
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
|
||||
}
|
||||
reply->deleteLater();
|
||||
}
|
||||
|
||||
void ChatGPT::handleReadyRead()
|
||||
{
|
||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||
if (!reply)
|
||||
return;
|
||||
|
||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
||||
Q_ASSERT(response.isValid());
|
||||
bool ok;
|
||||
int code = response.toInt(&ok);
|
||||
if (!ok || code != 200) {
|
||||
m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
|
||||
.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
|
||||
return;
|
||||
}
|
||||
|
||||
while (reply->canReadLine()) {
|
||||
QString jsonData = reply->readLine().trimmed();
|
||||
if (jsonData.startsWith("data:"))
|
||||
jsonData.remove(0, 5);
|
||||
jsonData = jsonData.trimmed();
|
||||
if (jsonData.isEmpty())
|
||||
continue;
|
||||
if (jsonData == "[DONE]")
|
||||
continue;
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "line" << qPrintable(jsonData);
|
||||
#endif
|
||||
QJsonParseError err;
|
||||
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
|
||||
if (err.error != QJsonParseError::NoError) {
|
||||
m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
|
||||
.arg(err.errorString()).toStdString());
|
||||
continue;
|
||||
}
|
||||
|
||||
const QJsonObject root = document.object();
|
||||
const QJsonArray choices = root.value("choices").toArray();
|
||||
const QJsonObject choice = choices.first().toObject();
|
||||
const QJsonObject delta = choice.value("delta").toObject();
|
||||
const QString content = delta.value("content").toString();
|
||||
Q_ASSERT(m_ctx);
|
||||
Q_ASSERT(m_responseCallback);
|
||||
m_responseCallback(0, content.toStdString());
|
||||
m_currentResponse += content;
|
||||
}
|
||||
}
|
||||
|
||||
void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
|
||||
{
|
||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||
if (!reply)
|
||||
return;
|
||||
|
||||
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
|
||||
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
#ifndef CHATGPT_H
|
||||
#define CHATGPT_H
|
||||
|
||||
#include <QObject>
|
||||
#include <QNetworkReply>
|
||||
#include <QNetworkRequest>
|
||||
#include <QNetworkAccessManager>
|
||||
#include "../gpt4all-backend/llmodel.h"
|
||||
|
||||
class ChatGPTPrivate;
|
||||
class ChatGPT : public QObject, public LLModel {
|
||||
Q_OBJECT
|
||||
public:
|
||||
ChatGPT();
|
||||
virtual ~ChatGPT();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
void setModelName(const QString &modelName) { m_modelName = modelName; }
|
||||
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override {}
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleFinished();
|
||||
void handleReadyRead();
|
||||
void handleErrorOccurred(QNetworkReply::NetworkError code);
|
||||
|
||||
private:
|
||||
PromptContext *m_ctx;
|
||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||
QString m_modelName;
|
||||
QString m_apiKey;
|
||||
QList<QString> m_context;
|
||||
QString m_currentResponse;
|
||||
QNetworkAccessManager m_networkManager;
|
||||
};
|
||||
|
||||
#endif // CHATGPT_H
|
Loading…
Reference in New Issue