Don't repeat the prompt in the response.

This commit is contained in:
Adam Treat 2023-04-09 01:11:52 -04:00
parent 0903da3afa
commit 02e13737f3
2 changed files with 9 additions and 5 deletions

View File

@ -700,6 +700,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size()); n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd; std::vector<gpt_vocab::id> embd;
std::vector<gpt_vocab::id> resp;
// determine the required inference memory per token: // determine the required inference memory per token:
size_t mem_per_token = 0; size_t mem_per_token = 0;
@ -720,6 +721,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
n_past += embd.size(); n_past += embd.size();
embd.clear(); embd.clear();
resp.clear();
if (i >= embd_inp.size()) { if (i >= embd_inp.size()) {
// sample next token // sample next token
@ -738,6 +740,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
// add it to the context // add it to the context
embd.push_back(id); embd.push_back(id);
resp.push_back(id);
} else { } else {
// if here, it means we are still processing the input prompt // if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) { for (int k = i; k < embd_inp.size(); k++) {
@ -750,7 +753,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
} }
// display text // display text
for (auto id : embd) { for (auto id : resp) {
if (!response(d_ptr->vocab.id_to_token[id])) if (!response(d_ptr->vocab.id_to_token[id]))
goto stop_generating; goto stop_generating;
} }
@ -762,7 +765,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
} }
stop_generating: stop_generating:
#if 1 #if 0
// report timing // report timing
{ {
const int64_t t_main_end_us = ggml_time_us(); const int64_t t_main_end_us = ggml_time_us();

View File

@ -80,7 +80,7 @@ Window {
model: chatModel model: chatModel
delegate: TextArea { delegate: TextArea {
text: currentResponse ? LLM.response : value text: currentResponse ? LLM.response : value
width: parent.width width: listView.width
color: "#d1d5db" color: "#d1d5db"
wrapMode: Text.WordWrap wrapMode: Text.WordWrap
focus: false focus: false
@ -204,11 +204,12 @@ Window {
listElement.currentResponse = false listElement.currentResponse = false
listElement.value = LLM.response listElement.value = LLM.response
} }
var prompt = textInput.text + "\n"
chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false, "value": textInput.text}) chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false, "value": textInput.text})
chatModel.append({"name": qsTr("Response: "), "currentResponse": true, "value": "", "prompt": textInput.text}) chatModel.append({"name": qsTr("Response: "), "currentResponse": true, "value": "", "prompt": prompt})
LLM.resetResponse() LLM.resetResponse()
LLM.prompt(textInput.text) LLM.prompt(prompt)
textInput.text = "" textInput.text = ""
} }