mirror of
https://github.com/sigoden/aichat
synced 2024-11-18 09:28:27 +00:00
feat: support two types of role prompts (#52)
1. embeded prompt use __INPUT__ placeholder will generate one user message when send to gpt ``` - name: shell prompt: > I want you to act as a linux shell expert. Q: How to unzip a file A: unzip file.zip Q: __INPUT__ A: ``` 2. system prompt no __INPUT__ placeholder will generate on system message and one user message when send to gpt ``` - name: shell prompt: | I want you to act as a linux shell expert. I want you to answer only with bash code. Do not write explanations. ```
This commit is contained in:
parent
c45d71cdea
commit
9767c07eee
@ -1,93 +1,81 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
||||
use super::role::Role;
|
||||
use super::MAX_TOKENS;
|
||||
|
||||
use crate::utils::count_tokens;
|
||||
|
||||
use super::{MAX_TOKENS, MESSAGE_EXTRA_TOKENS};
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Conversation {
|
||||
pub tokens: usize,
|
||||
pub role: Option<Role>,
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(role: Option<Role>) -> Self {
|
||||
let tokens = if let Some(role) = role.as_ref() {
|
||||
role.consume_tokens()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
Self {
|
||||
tokens: 0,
|
||||
tokens,
|
||||
role,
|
||||
messages: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_chat(&mut self, input: &str, output: &str) -> Result<()> {
|
||||
self.messages.push(Message {
|
||||
role: MessageRole::User,
|
||||
content: input.to_string(),
|
||||
});
|
||||
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
|
||||
let mut need_add_msg = true;
|
||||
let mut input_tokens = count_tokens(input);
|
||||
if self.messages.is_empty() {
|
||||
if let Some(role) = self.role.as_ref() {
|
||||
self.messages.extend(role.build_emssages(input));
|
||||
need_add_msg = false;
|
||||
}
|
||||
}
|
||||
if need_add_msg {
|
||||
self.messages.push(Message {
|
||||
role: MessageRole::User,
|
||||
content: input.to_string(),
|
||||
});
|
||||
input_tokens += MESSAGE_EXTRA_TOKENS;
|
||||
}
|
||||
self.messages.push(Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: output.to_string(),
|
||||
});
|
||||
self.tokens += count_tokens(input) + count_tokens(output) + 2 * MESSAGE_EXTRA_TOKENS;
|
||||
self.tokens += input_tokens + count_tokens(output) + MESSAGE_EXTRA_TOKENS;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Readline prompt
|
||||
pub fn add_prompt(&mut self, prompt: &str) {
|
||||
self.messages.push(Message {
|
||||
role: MessageRole::System,
|
||||
content: prompt.into(),
|
||||
});
|
||||
self.tokens += count_tokens(prompt) + MESSAGE_EXTRA_TOKENS;
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, content: &str) -> String {
|
||||
let mut messages = self.messages.to_vec();
|
||||
messages.push(Message {
|
||||
role: MessageRole::User,
|
||||
content: content.into(),
|
||||
});
|
||||
let messages = self.build_emssages(content);
|
||||
serde_yaml::to_string(&messages).unwrap_or("Unable to echo message".into())
|
||||
}
|
||||
|
||||
pub fn build_emssages(&self, content: &str) -> Value {
|
||||
let mut messages: Vec<Value> = self.messages.iter().map(msg_to_value).collect();
|
||||
messages.push(msg_to_value(&Message {
|
||||
role: MessageRole::User,
|
||||
content: content.into(),
|
||||
}));
|
||||
json!(messages)
|
||||
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
|
||||
let mut messages = self.messages.to_vec();
|
||||
let mut need_add_msg = true;
|
||||
if messages.is_empty() {
|
||||
if let Some(role) = self.role.as_ref() {
|
||||
messages = role.build_emssages(content);
|
||||
need_add_msg = false;
|
||||
}
|
||||
};
|
||||
if need_add_msg {
|
||||
messages.push(Message {
|
||||
role: MessageRole::User,
|
||||
content: content.into(),
|
||||
});
|
||||
}
|
||||
messages
|
||||
}
|
||||
|
||||
pub fn reamind_tokens(&self) -> usize {
|
||||
MAX_TOKENS.saturating_sub(self.tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub enum MessageRole {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
}
|
||||
|
||||
impl MessageRole {
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::Assistant => "assistant",
|
||||
MessageRole::User => "user",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn msg_to_value(msg: &Message) -> Value {
|
||||
json!({ "role": msg.role.name(), "content": msg.content })
|
||||
}
|
||||
|
34
src/config/message.rs
Normal file
34
src/config/message.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const MESSAGE_EXTRA_TOKENS: usize = 6;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(content: &str) -> Self {
|
||||
Self {
|
||||
role: MessageRole::User,
|
||||
content: content.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessageRole {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&Message::new("Hello World")).unwrap(),
|
||||
"{\"role\":\"user\",\"content\":\"Hello World\"}"
|
||||
)
|
||||
}
|
@ -1,14 +1,17 @@
|
||||
mod conversation;
|
||||
mod message;
|
||||
mod role;
|
||||
|
||||
use self::conversation::Conversation;
|
||||
use self::message::{Message, MESSAGE_EXTRA_TOKENS};
|
||||
use self::role::Role;
|
||||
|
||||
use crate::utils::{count_tokens, now};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use inquire::{Confirm, Text};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
env,
|
||||
fs::{create_dir_all, read_to_string, File, OpenOptions},
|
||||
@ -19,12 +22,10 @@ use std::{
|
||||
};
|
||||
|
||||
const MAX_TOKENS: usize = 4096;
|
||||
const MESSAGE_EXTRA_TOKENS: usize = 6;
|
||||
const CONFIG_FILE_NAME: &str = "config.yaml";
|
||||
const ROLES_FILE_NAME: &str = "roles.yaml";
|
||||
const HISTORY_FILE_NAME: &str = "history.txt";
|
||||
const MESSAGE_FILE_NAME: &str = "messages.md";
|
||||
const TEMP_ROLE_NAME: &str = "P";
|
||||
const SET_COMPLETIONS: [&str; 9] = [
|
||||
".set api_key",
|
||||
".set temperature",
|
||||
@ -126,7 +127,7 @@ impl Config {
|
||||
format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",)
|
||||
}
|
||||
Some(v) => {
|
||||
if v.name == TEMP_ROLE_NAME {
|
||||
if v.is_temp() {
|
||||
format!(
|
||||
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n",
|
||||
v.prompt
|
||||
@ -166,7 +167,7 @@ impl Config {
|
||||
}
|
||||
match self.find_role(name) {
|
||||
Some(mut role) => {
|
||||
role.tokens = count_tokens(&role.prompt);
|
||||
role.tokens = role.consume_tokens();
|
||||
let output =
|
||||
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
|
||||
self.role = Some(role);
|
||||
@ -178,12 +179,7 @@ impl Config {
|
||||
|
||||
pub fn create_temp_role(&mut self, prompt: &str) -> Result<()> {
|
||||
self.ensure_no_conversation()?;
|
||||
self.role = Some(Role {
|
||||
name: TEMP_ROLE_NAME.into(),
|
||||
prompt: prompt.into(),
|
||||
temperature: self.temperature,
|
||||
tokens: count_tokens(prompt),
|
||||
});
|
||||
self.role = Some(Role::new(prompt, self.temperature));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -198,33 +194,32 @@ impl Config {
|
||||
if let Some(conversation) = self.conversation.as_ref() {
|
||||
conversation.echo_messages(content)
|
||||
} else if let Some(role) = self.role.as_ref() {
|
||||
format!("{}\n{content}", role.prompt)
|
||||
role.echo_messages(content)
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, content: &str) -> Result<Value> {
|
||||
let tokens = count_tokens(content) + MESSAGE_EXTRA_TOKENS;
|
||||
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
|
||||
let content_tokens = count_tokens(content);
|
||||
let check_tokens = |tokens| {
|
||||
if tokens >= MAX_TOKENS {
|
||||
bail!("Exceed max tokens limit")
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
check_tokens(tokens)?;
|
||||
let user_message = json!({ "role": "user", "content": content });
|
||||
let value = if let Some(conversation) = self.conversation.as_ref() {
|
||||
check_tokens(tokens + conversation.tokens)?;
|
||||
let messages = if let Some(conversation) = self.conversation.as_ref() {
|
||||
check_tokens(content_tokens + conversation.tokens)?;
|
||||
conversation.build_emssages(content)
|
||||
} else if let Some(role) = self.role.as_ref() {
|
||||
check_tokens(tokens + role.tokens + MESSAGE_EXTRA_TOKENS)?;
|
||||
let system_message = json!({ "role": "system", "content": role.prompt });
|
||||
json!([system_message, user_message])
|
||||
check_tokens(content_tokens + role.tokens)?;
|
||||
role.build_emssages(content)
|
||||
} else {
|
||||
json!([user_message])
|
||||
let message = Message::new(content);
|
||||
check_tokens(content_tokens + MESSAGE_EXTRA_TOKENS)?;
|
||||
vec![message]
|
||||
};
|
||||
Ok(value)
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn info(&self) -> Result<String> {
|
||||
@ -327,11 +322,7 @@ impl Config {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let mut conversation = Conversation::new();
|
||||
if let Some(role) = self.role.as_ref() {
|
||||
conversation.add_prompt(&role.prompt);
|
||||
}
|
||||
self.conversation = Some(conversation);
|
||||
self.conversation = Some(Conversation::new(self.role.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -341,7 +332,7 @@ impl Config {
|
||||
|
||||
pub fn save_conversation(&mut self, input: &str, output: &str) -> Result<()> {
|
||||
if let Some(conversation) = self.conversation.as_mut() {
|
||||
conversation.add_chat(input, output)?;
|
||||
conversation.add_message(input, output)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -376,19 +367,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Role {
|
||||
/// Role name
|
||||
pub name: String,
|
||||
/// Prompt text send to ai for setting up a role
|
||||
pub prompt: String,
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
pub temperature: Option<f64>,
|
||||
/// Number of tokens
|
||||
#[serde(skip_deserializing)]
|
||||
pub tokens: usize,
|
||||
}
|
||||
|
||||
fn create_config_file(config_path: &Path) -> Result<()> {
|
||||
let confirm_map_err = |_| anyhow!("Not finish questionnaire, try again later.");
|
||||
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
|
||||
|
89
src/config/role.rs
Normal file
89
src/config/role.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
|
||||
|
||||
use crate::utils::count_tokens;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const TEMP_NAME: &str = "P";
|
||||
const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
||||
const INPUT_PLACEHOLDER_TOKENS: usize = 3;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Role {
|
||||
/// Role name
|
||||
pub name: String,
|
||||
/// Prompt text send to ai for setting up a role.
|
||||
///
|
||||
/// If prmopt contains __INPUT___, it's embeded prompt
|
||||
/// If prmopt don't contain __INPUT___, it's system prompt
|
||||
pub prompt: String,
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
pub temperature: Option<f64>,
|
||||
/// Number of tokens
|
||||
///
|
||||
/// System prompt consume extra 6 tokens
|
||||
#[serde(skip_deserializing)]
|
||||
pub tokens: usize,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
|
||||
let mut value = Self {
|
||||
name: TEMP_NAME.into(),
|
||||
prompt: prompt.into(),
|
||||
temperature,
|
||||
tokens: 0,
|
||||
};
|
||||
value.tokens = value.consume_tokens();
|
||||
value
|
||||
}
|
||||
|
||||
pub fn is_temp(&self) -> bool {
|
||||
self.name == TEMP_NAME
|
||||
}
|
||||
|
||||
pub fn consume_tokens(&self) -> usize {
|
||||
if self.embeded() {
|
||||
count_tokens(&self.prompt) + MESSAGE_EXTRA_TOKENS - INPUT_PLACEHOLDER_TOKENS
|
||||
} else {
|
||||
count_tokens(&self.prompt) + 2 * MESSAGE_EXTRA_TOKENS
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embeded(&self) -> bool {
|
||||
self.prompt.contains(INPUT_PLACEHOLDER)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, content: &str) -> String {
|
||||
if self.embeded() {
|
||||
merge_prompt_content(&self.prompt, content)
|
||||
} else {
|
||||
format!("{}{content}", self.prompt)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
|
||||
if self.embeded() {
|
||||
let content = merge_prompt_content(&self.prompt, content);
|
||||
vec![Message {
|
||||
role: MessageRole::User,
|
||||
content,
|
||||
}]
|
||||
} else {
|
||||
vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: self.prompt.clone(),
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: content.to_string(),
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_prompt_content(prompt: &str, content: &str) -> String {
|
||||
prompt.replace(INPUT_PLACEHOLDER, content)
|
||||
}
|
Loading…
Reference in New Issue
Block a user