mirror of
https://github.com/sigoden/aichat
synced 2024-11-08 13:10:28 +00:00
refactor: smart document splitter (#662)
This commit is contained in:
parent
f82524fd15
commit
ec83167de6
@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt;
|
||||
|
||||
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
|
||||
pub const URL_LOADER: &str = "url";
|
||||
pub const EXTENSION_METADATA: &str = "__extension__";
|
||||
|
||||
lazy_static! {
|
||||
static ref CLIENT: Result<reqwest::Client> = {
|
||||
@ -22,35 +23,35 @@ lazy_static! {
|
||||
pub async fn load(
|
||||
loaders: &HashMap<String, String>,
|
||||
path: &str,
|
||||
loader_name: &str,
|
||||
extension: &str,
|
||||
) -> Result<Vec<RagDocument>> {
|
||||
if loader_name == RECURSIVE_URL_LOADER {
|
||||
if extension == RECURSIVE_URL_LOADER {
|
||||
let loader_command = loaders
|
||||
.get(loader_name)
|
||||
.with_context(|| format!("RAG document loader '{loader_name}' not configured"))?;
|
||||
let contents = run_loader_command(path, loader_name, loader_command)?;
|
||||
.get(extension)
|
||||
.with_context(|| format!("RAG document loader '{extension}' not configured"))?;
|
||||
let contents = run_loader_command(path, extension, loader_command)?;
|
||||
let output = match parse_json_documents(&contents) {
|
||||
Some(v) => v,
|
||||
None => vec![RagDocument::new(contents)],
|
||||
};
|
||||
Ok(output)
|
||||
} else {
|
||||
match loaders.get(loader_name) {
|
||||
Some(loader_command) => load_with_command(path, loader_name, loader_command),
|
||||
match loaders.get(extension) {
|
||||
Some(loader_command) => load_with_command(path, extension, loader_command),
|
||||
None => {
|
||||
if loader_name == URL_LOADER {
|
||||
if extension == URL_LOADER {
|
||||
load_url(loaders, path).await
|
||||
} else {
|
||||
load_plain(path, loader_name).await
|
||||
load_plain(path, extension).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_plain(path: &str, loader_name: &str) -> Result<Vec<RagDocument>> {
|
||||
async fn load_plain(path: &str, extension: &str) -> Result<Vec<RagDocument>> {
|
||||
let contents = tokio::fs::read_to_string(path).await?;
|
||||
if loader_name == "json" {
|
||||
if extension == "json" {
|
||||
if let Some(documents) = parse_json_documents(&contents) {
|
||||
return Ok(documents);
|
||||
}
|
||||
@ -66,43 +67,55 @@ async fn load_url(loaders: &HashMap<String, String>, path: &str) -> Result<Vec<R
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let mut res = client.get(path).send().await?;
|
||||
let loader_name = path
|
||||
|
||||
let mut metadata: RagMetadata = Default::default();
|
||||
metadata.insert("path".into(), path.to_string());
|
||||
|
||||
let extension = path
|
||||
.rsplit_once('/')
|
||||
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
|
||||
.unwrap_or("txt");
|
||||
let contents = match loaders.get(loader_name) {
|
||||
let extension = extension.to_lowercase();
|
||||
let document = match loaders.get(&extension) {
|
||||
Some(loader_command) => {
|
||||
let save_path = env::temp_dir()
|
||||
.join(format!("aichat-download-{}.{loader_name}", sha256(path)))
|
||||
.join(format!("aichat-download-{}.{extension}", sha256(path)))
|
||||
.display()
|
||||
.to_string();
|
||||
let mut save_file = tokio::fs::File::create(&save_path).await?;
|
||||
while let Some(chunk) = res.chunk().await? {
|
||||
save_file.write_all(&chunk).await?;
|
||||
}
|
||||
run_loader_command(&save_path, loader_name, loader_command)?
|
||||
let contents = run_loader_command(&save_path, &extension, loader_command)?;
|
||||
metadata.insert(EXTENSION_METADATA.into(), "txt".to_string());
|
||||
RagDocument::new(contents).with_metadata(metadata)
|
||||
}
|
||||
None => {
|
||||
let contents = res.text().await?;
|
||||
metadata.insert(EXTENSION_METADATA.into(), extension);
|
||||
RagDocument::new(contents).with_metadata(metadata)
|
||||
}
|
||||
None => res.text().await?,
|
||||
};
|
||||
let mut document = RagDocument::new(contents);
|
||||
document.metadata.insert("path".into(), path.to_string());
|
||||
Ok(vec![document])
|
||||
}
|
||||
|
||||
fn load_with_command(
|
||||
path: &str,
|
||||
loader_name: &str,
|
||||
extension: &str,
|
||||
loader_command: &str,
|
||||
) -> Result<Vec<RagDocument>> {
|
||||
let contents = run_loader_command(path, loader_name, loader_command)?;
|
||||
let contents = run_loader_command(path, extension, loader_command)?;
|
||||
let mut document = RagDocument::new(contents);
|
||||
document.metadata.insert("path".into(), path.to_string());
|
||||
document
|
||||
.metadata
|
||||
.insert(EXTENSION_METADATA.into(), "txt".to_string());
|
||||
Ok(vec![document])
|
||||
}
|
||||
|
||||
fn run_loader_command(path: &str, loader_name: &str, loader_command: &str) -> Result<String> {
|
||||
fn run_loader_command(path: &str, extension: &str, loader_command: &str) -> Result<String> {
|
||||
let cmd_args = shell_words::split(loader_command).with_context(|| {
|
||||
anyhow!("Invalid rag document loader '{loader_name}': `{loader_command}`")
|
||||
anyhow!("Invalid rag document loader '{extension}': `{loader_command}`")
|
||||
})?;
|
||||
let mut use_stdout = true;
|
||||
let outpath = env::temp_dir()
|
||||
@ -184,7 +197,6 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
|
||||
"html",
|
||||
"markdown",
|
||||
"text",
|
||||
"data",
|
||||
]
|
||||
.into_iter()
|
||||
.map(|v| v.to_string())
|
||||
@ -196,7 +208,7 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
|
||||
if let Some(page_content) = obj.get(&key).and_then(|v| v.as_str()) {
|
||||
let page_content = page_content.to_string();
|
||||
obj.remove(&key);
|
||||
let metadata: IndexMap<_, _> = obj
|
||||
let mut metadata: IndexMap<_, _> = obj
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
if let Value::String(v) = v {
|
||||
@ -206,6 +218,11 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if key == "markdown" {
|
||||
metadata.insert(EXTENSION_METADATA.into(), "md".into());
|
||||
} else if key == "html" {
|
||||
metadata.insert(EXTENSION_METADATA.into(), "html".into());
|
||||
}
|
||||
return Some(RagDocument {
|
||||
page_content,
|
||||
metadata,
|
||||
@ -337,7 +354,7 @@ mod tests {
|
||||
|
||||
let mut metadata = IndexMap::new();
|
||||
metadata.insert("k1".into(), "1".into());
|
||||
let data = r#"[{"k1": 1, "data": "foo" }]"#;
|
||||
let data = r#"[{"k1": 1, "text": "foo" }]"#;
|
||||
assert_eq!(
|
||||
parse_json_documents(data).unwrap(),
|
||||
vec![RagDocument::new("foo").with_metadata(metadata.clone())]
|
||||
|
@ -272,20 +272,24 @@ impl Rag {
|
||||
if let Some(spinner) = &spinner {
|
||||
let _ = spinner.set_message(String::new());
|
||||
}
|
||||
for (index, (path, loader_name)) in new_paths.into_iter().enumerate() {
|
||||
for (index, (path, extension)) in new_paths.into_iter().enumerate() {
|
||||
println!("Loading {path} [{}/{new_paths_len}]", index + 1);
|
||||
let documents = load(&loaders, &path, &loader_name)
|
||||
let documents = load(&loaders, &path, &extension)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load '{path}'"))?;
|
||||
let separator = get_separators(&loader_name);
|
||||
let splitter = RecursiveCharacterTextSplitter::new(
|
||||
self.data.chunk_size,
|
||||
self.data.chunk_overlap,
|
||||
&separator,
|
||||
);
|
||||
let splitted_documents: Vec<_> = documents
|
||||
.into_iter()
|
||||
.flat_map(|document| {
|
||||
.flat_map(|mut document| {
|
||||
let extension = document
|
||||
.metadata
|
||||
.swap_remove(EXTENSION_METADATA)
|
||||
.unwrap_or_else(|| extension.clone());
|
||||
let separator = get_separators(&extension);
|
||||
let splitter = RecursiveCharacterTextSplitter::new(
|
||||
self.data.chunk_size,
|
||||
self.data.chunk_overlap,
|
||||
&separator,
|
||||
);
|
||||
let metadata = document
|
||||
.metadata
|
||||
.iter()
|
||||
@ -299,7 +303,7 @@ impl Rag {
|
||||
splitter.split_documents(&[document], &split_options)
|
||||
})
|
||||
.collect();
|
||||
let display_path = if loader_name == RECURSIVE_URL_LOADER {
|
||||
let display_path = if extension == RECURSIVE_URL_LOADER {
|
||||
format!("{path}**")
|
||||
} else {
|
||||
path
|
||||
@ -557,7 +561,6 @@ impl RagDocument {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn with_metadata(mut self, metadata: RagMetadata) -> Self {
|
||||
self.metadata = metadata;
|
||||
self
|
||||
|
Loading…
Reference in New Issue
Block a user