chore: improve code by getting prompt from config (#39)

This commit is contained in:
sigoden 2023-03-08 17:13:11 +08:00 committed by GitHub
parent ee81275431
commit 2539e24fe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 31 deletions

View File

@ -28,9 +28,9 @@ impl ChatGptClient {
Ok(s)
}
pub fn send_message(&self, input: &str, prompt: Option<String>) -> Result<String> {
pub fn send_message(&self, input: &str) -> Result<String> {
self.runtime.block_on(async {
self.send_message_inner(input, prompt)
self.send_message_inner(input)
.await
.with_context(|| "Failed to send message")
})
@ -39,7 +39,6 @@ impl ChatGptClient {
pub fn send_message_streaming(
&self,
input: &str,
prompt: Option<String>,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
@ -53,7 +52,7 @@ impl ChatGptClient {
let abort = handler.get_abort();
self.runtime.block_on(async {
tokio::select! {
ret = self.send_message_streaming_inner(input, prompt, handler) => {
ret = self.send_message_streaming_inner(input, handler) => {
handler.done()?;
ret.with_context(|| "Failed to send message streaming")
}
@ -69,11 +68,11 @@ impl ChatGptClient {
})
}
async fn send_message_inner(&self, content: &str, prompt: Option<String>) -> Result<String> {
async fn send_message_inner(&self, content: &str) -> Result<String> {
if self.config.borrow().dry_run {
return Ok(combine(content, prompt));
return Ok(self.config.borrow().merge_prompt(content));
}
let builder = self.request_builder(content, prompt, false)?;
let builder = self.request_builder(content, false)?;
let data: Value = builder.send().await?.json().await?;
@ -87,14 +86,13 @@ impl ChatGptClient {
async fn send_message_streaming_inner(
&self,
content: &str,
prompt: Option<String>,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
if self.config.borrow().dry_run {
handler.text(&combine(content, prompt))?;
handler.text(&self.config.borrow().merge_prompt(content))?;
return Ok(());
}
let builder = self.request_builder(content, prompt, true)?;
let builder = self.request_builder(content, true)?;
let mut stream = builder.send().await?.bytes_stream().eventsource();
let mut virgin = true;
while let Some(part) = stream.next().await {
@ -134,14 +132,9 @@ impl ChatGptClient {
Ok(client)
}
fn request_builder(
&self,
content: &str,
prompt: Option<String>,
stream: bool,
) -> Result<RequestBuilder> {
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let user_message = json!({ "role": "user", "content": content });
let messages = match prompt {
let messages = match self.config.borrow().get_prompt() {
Some(prompt) => {
let system_message = json!({ "role": "system", "content": prompt.trim() });
json!([system_message, user_message])
@ -175,13 +168,6 @@ impl ChatGptClient {
}
}
fn combine(content: &str, prompt: Option<String>) -> String {
match prompt {
Some(prompt) => format!("{}\n{content}", prompt.trim()),
None => content.to_string(),
}
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()

View File

@ -178,6 +178,13 @@ impl Config {
})
}
pub fn merge_prompt(&self, content: &str) -> String {
match self.get_prompt() {
Some(prompt) => format!("{}\n{content}", prompt.trim()),
None => content.to_string(),
}
}
pub fn info(&self) -> Result<String> {
let file_info = |path: &Path| {
let state = if path.exists() { "" } else { " ⚠️" };

View File

@ -71,10 +71,9 @@ fn start_directive(
input: &str,
no_stream: bool,
) -> Result<()> {
let prompt = config.borrow().get_prompt();
let highlight = config.borrow().highlight && stdout().is_terminal();
let output = if no_stream {
let output = client.send_message(input, prompt)?;
let output = client.send_message(input)?;
if highlight {
let mut markdown_render = MarkdownRender::new();
println!("{}\n", markdown_render.render(&output).trim_end());
@ -90,7 +89,7 @@ fn start_directive(
abort_clone.set_ctrlc();
})
.expect("Error setting Ctrl-C handler");
let output = render_stream(input, None, &client, highlight, false, abort, wg.clone())?;
let output = render_stream(input, &client, highlight, false, abort, wg.clone())?;
wg.wait();
output
};

View File

@ -16,7 +16,6 @@ use std::thread::spawn;
pub fn render_stream(
input: &str,
prompt: Option<String>,
client: &ChatGptClient,
highlight: bool,
repl: bool,
@ -43,7 +42,7 @@ pub fn render_stream(
drop(wg);
ReplyStreamHandler::new(None, abort)
};
client.send_message_streaming(input, prompt, &mut stream_handler)?;
client.send_message_streaming(input, &mut stream_handler)?;
let buffer = stream_handler.get_buffer();
Ok(buffer.to_string())
}

View File

@ -49,11 +49,9 @@ impl ReplCmdHandler {
return Ok(());
}
let highlight = self.config.borrow().highlight;
let prompt = self.config.borrow().get_prompt();
let wg = WaitGroup::new();
let ret = render_stream(
&input,
prompt,
&self.client,
highlight,
true,