feat: support two types of role prompts (#52)

1. embeded prompt

use __INPUT__ placeholder
will generate one user message when send to gpt
```
- name: shell
  prompt: >
    I want you to act as a linux shell expert.
    Q: How to unzip a file
    A: unzip file.zip
    Q: __INPUT__
    A:
```

2. system prompt
no __INPUT__ placeholder
will generate on system message and one user message when send to gpt
```
- name: shell
  prompt: |
    I want you to act as a linux shell expert.
    I want you to answer only with bash code.
    Do not write explanations.
```
This commit is contained in:
sigoden 2023-03-09 19:10:02 +08:00 committed by GitHub
parent c45d71cdea
commit 9767c07eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 103 deletions

View File

@ -1,93 +1,81 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
use super::role::Role;
use super::MAX_TOKENS;
use crate::utils::count_tokens;
use super::{MAX_TOKENS, MESSAGE_EXTRA_TOKENS};
use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Conversation {
pub tokens: usize,
pub role: Option<Role>,
pub messages: Vec<Message>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Conversation {
pub fn new() -> Self {
pub fn new(role: Option<Role>) -> Self {
let tokens = if let Some(role) = role.as_ref() {
role.consume_tokens()
} else {
0
};
Self {
tokens: 0,
tokens,
role,
messages: vec![],
}
}
pub fn add_chat(&mut self, input: &str, output: &str) -> Result<()> {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_string(),
});
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
let mut need_add_msg = true;
let mut input_tokens = count_tokens(input);
if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() {
self.messages.extend(role.build_emssages(input));
need_add_msg = false;
}
}
if need_add_msg {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_string(),
});
input_tokens += MESSAGE_EXTRA_TOKENS;
}
self.messages.push(Message {
role: MessageRole::Assistant,
content: output.to_string(),
});
self.tokens += count_tokens(input) + count_tokens(output) + 2 * MESSAGE_EXTRA_TOKENS;
self.tokens += input_tokens + count_tokens(output) + MESSAGE_EXTRA_TOKENS;
Ok(())
}
/// Readline prompt
pub fn add_prompt(&mut self, prompt: &str) {
self.messages.push(Message {
role: MessageRole::System,
content: prompt.into(),
});
self.tokens += count_tokens(prompt) + MESSAGE_EXTRA_TOKENS;
}
pub fn echo_messages(&self, content: &str) -> String {
let mut messages = self.messages.to_vec();
messages.push(Message {
role: MessageRole::User,
content: content.into(),
});
let messages = self.build_emssages(content);
serde_yaml::to_string(&messages).unwrap_or("Unable to echo message".into())
}
pub fn build_emssages(&self, content: &str) -> Value {
let mut messages: Vec<Value> = self.messages.iter().map(msg_to_value).collect();
messages.push(msg_to_value(&Message {
role: MessageRole::User,
content: content.into(),
}));
json!(messages)
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
let mut messages = self.messages.to_vec();
let mut need_add_msg = true;
if messages.is_empty() {
if let Some(role) = self.role.as_ref() {
messages = role.build_emssages(content);
need_add_msg = false;
}
};
if need_add_msg {
messages.push(Message {
role: MessageRole::User,
content: content.into(),
});
}
messages
}
pub fn reamind_tokens(&self) -> usize {
MAX_TOKENS.saturating_sub(self.tokens)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum MessageRole {
System,
Assistant,
User,
}
impl MessageRole {
pub fn name(&self) -> &'static str {
match self {
MessageRole::System => "system",
MessageRole::Assistant => "assistant",
MessageRole::User => "user",
}
}
}
fn msg_to_value(msg: &Message) -> Value {
json!({ "role": msg.role.name(), "content": msg.content })
}

34
src/config/message.rs Normal file
View File

@ -0,0 +1,34 @@
use serde::{Deserialize, Serialize};
pub const MESSAGE_EXTRA_TOKENS: usize = 6;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Message {
pub fn new(content: &str) -> Self {
Self {
role: MessageRole::User,
content: content.to_string(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
Assistant,
User,
}
#[test]
fn test_serde() {
assert_eq!(
serde_json::to_string(&Message::new("Hello World")).unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}"
)
}

View File

@ -1,14 +1,17 @@
mod conversation;
mod message;
mod role;
use self::conversation::Conversation;
use self::message::{Message, MESSAGE_EXTRA_TOKENS};
use self::role::Role;
use crate::utils::{count_tokens, now};
use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Text};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use serde::Deserialize;
use std::{
env,
fs::{create_dir_all, read_to_string, File, OpenOptions},
@ -19,12 +22,10 @@ use std::{
};
const MAX_TOKENS: usize = 4096;
const MESSAGE_EXTRA_TOKENS: usize = 6;
const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
const MESSAGE_FILE_NAME: &str = "messages.md";
const TEMP_ROLE_NAME: &str = "";
const SET_COMPLETIONS: [&str; 9] = [
".set api_key",
".set temperature",
@ -126,7 +127,7 @@ impl Config {
format!("# CHAT:[{timestamp}]\n{input}\n--------\n{output}\n--------\n\n",)
}
Some(v) => {
if v.name == TEMP_ROLE_NAME {
if v.is_temp() {
format!(
"# CHAT:[{timestamp}]\n{}\n{input}\n--------\n{output}\n--------\n\n",
v.prompt
@ -166,7 +167,7 @@ impl Config {
}
match self.find_role(name) {
Some(mut role) => {
role.tokens = count_tokens(&role.prompt);
role.tokens = role.consume_tokens();
let output =
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
self.role = Some(role);
@ -178,12 +179,7 @@ impl Config {
pub fn create_temp_role(&mut self, prompt: &str) -> Result<()> {
self.ensure_no_conversation()?;
self.role = Some(Role {
name: TEMP_ROLE_NAME.into(),
prompt: prompt.into(),
temperature: self.temperature,
tokens: count_tokens(prompt),
});
self.role = Some(Role::new(prompt, self.temperature));
Ok(())
}
@ -198,33 +194,32 @@ impl Config {
if let Some(conversation) = self.conversation.as_ref() {
conversation.echo_messages(content)
} else if let Some(role) = self.role.as_ref() {
format!("{}\n{content}", role.prompt)
role.echo_messages(content)
} else {
content.to_string()
}
}
pub fn build_messages(&self, content: &str) -> Result<Value> {
let tokens = count_tokens(content) + MESSAGE_EXTRA_TOKENS;
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
let content_tokens = count_tokens(content);
let check_tokens = |tokens| {
if tokens >= MAX_TOKENS {
bail!("Exceed max tokens limit")
}
Ok(())
};
check_tokens(tokens)?;
let user_message = json!({ "role": "user", "content": content });
let value = if let Some(conversation) = self.conversation.as_ref() {
check_tokens(tokens + conversation.tokens)?;
let messages = if let Some(conversation) = self.conversation.as_ref() {
check_tokens(content_tokens + conversation.tokens)?;
conversation.build_emssages(content)
} else if let Some(role) = self.role.as_ref() {
check_tokens(tokens + role.tokens + MESSAGE_EXTRA_TOKENS)?;
let system_message = json!({ "role": "system", "content": role.prompt });
json!([system_message, user_message])
check_tokens(content_tokens + role.tokens)?;
role.build_emssages(content)
} else {
json!([user_message])
let message = Message::new(content);
check_tokens(content_tokens + MESSAGE_EXTRA_TOKENS)?;
vec![message]
};
Ok(value)
Ok(messages)
}
pub fn info(&self) -> Result<String> {
@ -327,11 +322,7 @@ impl Config {
return Ok(());
}
}
let mut conversation = Conversation::new();
if let Some(role) = self.role.as_ref() {
conversation.add_prompt(&role.prompt);
}
self.conversation = Some(conversation);
self.conversation = Some(Conversation::new(self.role.clone()));
Ok(())
}
@ -341,7 +332,7 @@ impl Config {
pub fn save_conversation(&mut self, input: &str, output: &str) -> Result<()> {
if let Some(conversation) = self.conversation.as_mut() {
conversation.add_chat(input, output)?;
conversation.add_message(input, output)?;
}
Ok(())
}
@ -376,19 +367,6 @@ impl Config {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Role {
/// Role name
pub name: String,
/// Prompt text send to ai for setting up a role
pub prompt: String,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Number of tokens
#[serde(skip_deserializing)]
pub tokens: usize,
}
fn create_config_file(config_path: &Path) -> Result<()> {
let confirm_map_err = |_| anyhow!("Not finish questionnaire, try again later.");
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");

89
src/config/role.rs Normal file
View File

@ -0,0 +1,89 @@
use super::message::{Message, MessageRole, MESSAGE_EXTRA_TOKENS};
use crate::utils::count_tokens;
use serde::{Deserialize, Serialize};
const TEMP_NAME: &str = "";
const INPUT_PLACEHOLDER: &str = "__INPUT__";
const INPUT_PLACEHOLDER_TOKENS: usize = 3;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Role {
/// Role name
pub name: String,
/// Prompt text send to ai for setting up a role.
///
/// If prmopt contains __INPUT___, it's embeded prompt
/// If prmopt don't contain __INPUT___, it's system prompt
pub prompt: String,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Number of tokens
///
/// System prompt consume extra 6 tokens
#[serde(skip_deserializing)]
pub tokens: usize,
}
impl Role {
pub fn new(prompt: &str, temperature: Option<f64>) -> Self {
let mut value = Self {
name: TEMP_NAME.into(),
prompt: prompt.into(),
temperature,
tokens: 0,
};
value.tokens = value.consume_tokens();
value
}
pub fn is_temp(&self) -> bool {
self.name == TEMP_NAME
}
pub fn consume_tokens(&self) -> usize {
if self.embeded() {
count_tokens(&self.prompt) + MESSAGE_EXTRA_TOKENS - INPUT_PLACEHOLDER_TOKENS
} else {
count_tokens(&self.prompt) + 2 * MESSAGE_EXTRA_TOKENS
}
}
pub fn embeded(&self) -> bool {
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn echo_messages(&self, content: &str) -> String {
if self.embeded() {
merge_prompt_content(&self.prompt, content)
} else {
format!("{}{content}", self.prompt)
}
}
pub fn build_emssages(&self, content: &str) -> Vec<Message> {
if self.embeded() {
let content = merge_prompt_content(&self.prompt, content);
vec![Message {
role: MessageRole::User,
content,
}]
} else {
vec![
Message {
role: MessageRole::System,
content: self.prompt.clone(),
},
Message {
role: MessageRole::User,
content: content.to_string(),
},
]
}
}
}
pub fn merge_prompt_content(prompt: &str, content: &str) -> String {
prompt.replace(INPUT_PLACEHOLDER, content)
}