refactor: smart document splitter (#662)

This commit is contained in:
sigoden 2024-06-27 14:55:25 +08:00 committed by GitHub
parent f82524fd15
commit ec83167de6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 36 deletions

View File

@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt;
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
pub const URL_LOADER: &str = "url";
pub const EXTENSION_METADATA: &str = "__extension__";
lazy_static! {
static ref CLIENT: Result<reqwest::Client> = {
@ -22,35 +23,35 @@ lazy_static! {
pub async fn load(
loaders: &HashMap<String, String>,
path: &str,
loader_name: &str,
extension: &str,
) -> Result<Vec<RagDocument>> {
if loader_name == RECURSIVE_URL_LOADER {
if extension == RECURSIVE_URL_LOADER {
let loader_command = loaders
.get(loader_name)
.with_context(|| format!("RAG document loader '{loader_name}' not configured"))?;
let contents = run_loader_command(path, loader_name, loader_command)?;
.get(extension)
.with_context(|| format!("RAG document loader '{extension}' not configured"))?;
let contents = run_loader_command(path, extension, loader_command)?;
let output = match parse_json_documents(&contents) {
Some(v) => v,
None => vec![RagDocument::new(contents)],
};
Ok(output)
} else {
match loaders.get(loader_name) {
Some(loader_command) => load_with_command(path, loader_name, loader_command),
match loaders.get(extension) {
Some(loader_command) => load_with_command(path, extension, loader_command),
None => {
if loader_name == URL_LOADER {
if extension == URL_LOADER {
load_url(loaders, path).await
} else {
load_plain(path, loader_name).await
load_plain(path, extension).await
}
}
}
}
}
async fn load_plain(path: &str, loader_name: &str) -> Result<Vec<RagDocument>> {
async fn load_plain(path: &str, extension: &str) -> Result<Vec<RagDocument>> {
let contents = tokio::fs::read_to_string(path).await?;
if loader_name == "json" {
if extension == "json" {
if let Some(documents) = parse_json_documents(&contents) {
return Ok(documents);
}
@ -66,43 +67,55 @@ async fn load_url(loaders: &HashMap<String, String>, path: &str) -> Result<Vec<R
Err(ref err) => bail!("{err}"),
};
let mut res = client.get(path).send().await?;
let loader_name = path
let mut metadata: RagMetadata = Default::default();
metadata.insert("path".into(), path.to_string());
let extension = path
.rsplit_once('/')
.and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext))
.unwrap_or("txt");
let contents = match loaders.get(loader_name) {
let extension = extension.to_lowercase();
let document = match loaders.get(&extension) {
Some(loader_command) => {
let save_path = env::temp_dir()
.join(format!("aichat-download-{}.{loader_name}", sha256(path)))
.join(format!("aichat-download-{}.{extension}", sha256(path)))
.display()
.to_string();
let mut save_file = tokio::fs::File::create(&save_path).await?;
while let Some(chunk) = res.chunk().await? {
save_file.write_all(&chunk).await?;
}
run_loader_command(&save_path, loader_name, loader_command)?
let contents = run_loader_command(&save_path, &extension, loader_command)?;
metadata.insert(EXTENSION_METADATA.into(), "txt".to_string());
RagDocument::new(contents).with_metadata(metadata)
}
None => {
let contents = res.text().await?;
metadata.insert(EXTENSION_METADATA.into(), extension);
RagDocument::new(contents).with_metadata(metadata)
}
None => res.text().await?,
};
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
Ok(vec![document])
}
fn load_with_command(
path: &str,
loader_name: &str,
extension: &str,
loader_command: &str,
) -> Result<Vec<RagDocument>> {
let contents = run_loader_command(path, loader_name, loader_command)?;
let contents = run_loader_command(path, extension, loader_command)?;
let mut document = RagDocument::new(contents);
document.metadata.insert("path".into(), path.to_string());
document
.metadata
.insert(EXTENSION_METADATA.into(), "txt".to_string());
Ok(vec![document])
}
fn run_loader_command(path: &str, loader_name: &str, loader_command: &str) -> Result<String> {
fn run_loader_command(path: &str, extension: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command).with_context(|| {
anyhow!("Invalid rag document loader '{loader_name}': `{loader_command}`")
anyhow!("Invalid rag document loader '{extension}': `{loader_command}`")
})?;
let mut use_stdout = true;
let outpath = env::temp_dir()
@ -184,7 +197,6 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
"html",
"markdown",
"text",
"data",
]
.into_iter()
.map(|v| v.to_string())
@ -196,7 +208,7 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
if let Some(page_content) = obj.get(&key).and_then(|v| v.as_str()) {
let page_content = page_content.to_string();
obj.remove(&key);
let metadata: IndexMap<_, _> = obj
let mut metadata: IndexMap<_, _> = obj
.into_iter()
.map(|(k, v)| {
if let Value::String(v) = v {
@ -206,6 +218,11 @@ fn parse_json_documents(data: &str) -> Option<Vec<RagDocument>> {
}
})
.collect();
if key == "markdown" {
metadata.insert(EXTENSION_METADATA.into(), "md".into());
} else if key == "html" {
metadata.insert(EXTENSION_METADATA.into(), "html".into());
}
return Some(RagDocument {
page_content,
metadata,
@ -337,7 +354,7 @@ mod tests {
let mut metadata = IndexMap::new();
metadata.insert("k1".into(), "1".into());
let data = r#"[{"k1": 1, "data": "foo" }]"#;
let data = r#"[{"k1": 1, "text": "foo" }]"#;
assert_eq!(
parse_json_documents(data).unwrap(),
vec![RagDocument::new("foo").with_metadata(metadata.clone())]

View File

@ -272,20 +272,24 @@ impl Rag {
if let Some(spinner) = &spinner {
let _ = spinner.set_message(String::new());
}
for (index, (path, loader_name)) in new_paths.into_iter().enumerate() {
for (index, (path, extension)) in new_paths.into_iter().enumerate() {
println!("Loading {path} [{}/{new_paths_len}]", index + 1);
let documents = load(&loaders, &path, &loader_name)
let documents = load(&loaders, &path, &extension)
.await
.with_context(|| format!("Failed to load '{path}'"))?;
let separator = get_separators(&loader_name);
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&separator,
);
let splitted_documents: Vec<_> = documents
.into_iter()
.flat_map(|document| {
.flat_map(|mut document| {
let extension = document
.metadata
.swap_remove(EXTENSION_METADATA)
.unwrap_or_else(|| extension.clone());
let separator = get_separators(&extension);
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&separator,
);
let metadata = document
.metadata
.iter()
@ -299,7 +303,7 @@ impl Rag {
splitter.split_documents(&[document], &split_options)
})
.collect();
let display_path = if loader_name == RECURSIVE_URL_LOADER {
let display_path = if extension == RECURSIVE_URL_LOADER {
format!("{path}**")
} else {
path
@ -557,7 +561,6 @@ impl RagDocument {
}
}
#[allow(unused)]
pub fn with_metadata(mut self, metadata: RagMetadata) -> Self {
self.metadata = metadata;
self