feat: support role args (#69)

* feat: support role args

We can use role args to pass some additional arguments to the prompt.

```
- name: convert:json:yaml
  prompt: convert __ARG1__ below to __ARG2__
```

`:json:yaml` is `role args`. It has two args:

- arg1 `json`, it will replace __ARG1__ in prompt
- arg2 `yaml`, it will replace __ARG2__ in prompt

```
〉.role convert:json:yaml
name: convert:json:yaml
prompt: convert json below to yaml
temperature: null

〉.role convert:yaml:json
name: convert:yaml:json
prompt: convert yaml below to json
temperature: null
```

different role args,  will generate different prompts.

* small updates
This commit is contained in:
sigoden 2023-03-13 10:09:01 +08:00 committed by GitHub
parent 8fa2e2683b
commit 9c04455e36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 8 deletions

View File

@ -120,7 +120,7 @@ impl Config {
}
pub fn find_role(&self, name: &str) -> Option<Role> {
self.roles.iter().find(|v| v.name == name).cloned()
self.roles.iter().find(|v| v.match_name(name)).cloned()
}
pub fn config_dir() -> Result<PathBuf> {
@ -200,7 +200,8 @@ impl Config {
pub fn change_role(&mut self, name: &str) -> Result<String> {
match self.find_role(name) {
Some(role) => {
Some(mut role) => {
role.complete_prompt_args(name);
if let Some(conversation) = self.conversation.as_mut() {
conversation.update_role(&role)?;
}

View File

@ -9,10 +9,7 @@ const INPUT_PLACEHOLDER: &str = "__INPUT__";
pub struct Role {
/// Role name
pub name: String,
/// Prompt text send to ai for setting up a role.
///
/// If prmopt contains __INPUT___, it's embeded prompt
/// If prmopt don't contain __INPUT___, it's system prompt
/// Prompt text
pub prompt: String,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
@ -35,6 +32,21 @@ impl Role {
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn complete_prompt_args(&mut self, name: &str) {
self.name = name.to_string();
self.prompt = complete_prompt_args(&self.prompt, &self.name);
}
pub fn match_name(&self, name: &str) -> bool {
if self.name.contains(':') {
let role_name_parts: Vec<&str> = self.name.split(':').collect();
let name_parts: Vec<&str> = name.split(':').collect();
role_name_parts[0] == name_parts[0] && role_name_parts.len() == name_parts.len()
} else {
self.name == name
}
}
pub fn echo_messages(&self, content: &str) -> String {
if self.embeded() {
merge_prompt_content(&self.prompt, content)
@ -65,6 +77,31 @@ impl Role {
}
}
pub fn merge_prompt_content(prompt: &str, content: &str) -> String {
fn merge_prompt_content(prompt: &str, content: &str) -> String {
prompt.replace(INPUT_PLACEHOLDER, content)
}
fn complete_prompt_args(prompt: &str, name: &str) -> String {
let mut prompt = prompt.to_string();
for (i, arg) in name.split(':').skip(1).enumerate() {
prompt = prompt.replace(&format!("__ARG{}__", i + 1), arg);
}
prompt
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_prompt_name() {
assert_eq!(
complete_prompt_args("convert __ARG1__", "convert:foo"),
"convert foo"
);
assert_eq!(
complete_prompt_args("convert __ARG1__ to __ARG2__", "convert:foo:bar"),
"convert foo to bar"
);
}
}

View File

@ -47,7 +47,8 @@ impl Repl {
fn create_completer(config: SharedConfig, commands: &[String]) -> DefaultCompleter {
let mut completion = commands.to_vec();
completion.extend(config.read().repl_completions());
let mut completer = DefaultCompleter::with_inclusions(&['.', '-', '_']).set_min_word_len(2);
let mut completer =
DefaultCompleter::with_inclusions(&['.', '-', '_', ':']).set_min_word_len(2);
completer.insert(completion.clone());
completer
}