diff --git a/src/config/input.rs b/src/config/input.rs index ee85c2b..1fce6d0 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -379,13 +379,14 @@ async fn load_paths( } } for file_url in remote_urls { - if is_image(&file_url) { - medias.push(file_url) + let (contents, extension) = fetch(&loaders, &file_url, true) + .await + .with_context(|| format!("Failed to load url '{file_url}'"))?; + if extension == MEDIA_URL_EXTENSION { + data_urls.insert(sha256(&contents), file_url); + medias.push(contents) } else { - let (text, _) = fetch(&loaders, &file_url) - .await - .with_context(|| format!("Failed to load url '{file_url}'"))?; - files.push((file_url, text)); + files.push((file_url, contents)); } } Ok((files, medias, data_urls)) @@ -416,22 +417,19 @@ fn resolve_local_path(path: &str) -> Option { } fn is_image(path: &str) -> bool { - path_extension(path) + get_patch_extension(path) .map(|v| IMAGE_EXTS.contains(&v.as_str())) .unwrap_or_default() } -fn read_media_to_data_url>(image_path: P) -> Result { - let image_path = image_path.as_ref(); - let mime_type = match image_path.extension().and_then(|v| v.to_str()) { - Some(extension) => match extension { - "png" => "image/png", - "jpg" | "jpeg" => "image/jpeg", - "webp" => "image/webp", - "gif" => "image/gif", - _ => bail!("Unsupported media type"), - }, - None => bail!("Unknown media type"), +fn read_media_to_data_url(image_path: &str) -> Result { + let extension = get_patch_extension(image_path).unwrap_or_default(); + let mime_type = match extension.as_str() { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "webp" => "image/webp", + "gif" => "image/gif", + _ => bail!("Unexpected media type"), }; let mut file = File::open(image_path)?; let mut buffer = Vec::new(); diff --git a/src/rag/loader.rs b/src/rag/loader.rs index 0b0d935..c764d78 100644 --- a/src/rag/loader.rs +++ b/src/rag/loader.rs @@ -60,7 +60,7 @@ pub async fn load_file( loaders: &HashMap, path: &str, ) -> Result<(String, RagMetadata)> { - let extension = path_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into()); + let extension = get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into()); match loaders.get(&extension) { Some(loader_command) => load_with_command(path, &extension, loader_command), None => load_plain(path, &extension).await, @@ -71,7 +71,7 @@ pub async fn load_url( loaders: &HashMap, path: &str, ) -> Result<(String, RagMetadata)> { - let (contents, extension) = fetch(loaders, path).await?; + let (contents, extension) = fetch(loaders, path, false).await?; let mut metadata: RagMetadata = Default::default(); metadata.insert(PATH_METADATA.into(), path.into()); metadata.insert(EXTENSION_METADATA.into(), extension); diff --git a/src/utils/path.rs b/src/utils/path.rs index 4e66337..83e51c7 100644 --- a/src/utils/path.rs +++ b/src/utils/path.rs @@ -42,7 +42,7 @@ pub async fn expand_glob_paths>(paths: &[T]) -> Result Ok(new_paths) } -pub fn path_extension(path: &str) -> Option { +pub fn get_patch_extension(path: &str) -> Option { Path::new(&path) .extension() .map(|v| v.to_string_lossy().to_lowercase()) diff --git a/src/utils/request.rs b/src/utils/request.rs index 8733b75..039bf75 100644 --- a/src/utils/request.rs +++ b/src/utils/request.rs @@ -8,6 +8,7 @@ use tokio::io::AsyncWriteExt; pub const URL_LOADER: &str = "url"; pub const RECURSIVE_URL_LOADER: &str = "recursive_url"; +pub const MEDIA_URL_EXTENSION: &str = "media_url"; pub const DEFAULT_EXTENSION: &str = "txt"; lazy_static! { @@ -19,7 +20,11 @@ lazy_static! { }; } -pub async fn fetch(loaders: &HashMap, path: &str) -> Result<(String, String)> { +pub async fn fetch( + loaders: &HashMap, + path: &str, + allow_media: bool, +) -> Result<(String, String)> { if let Some(loader_command) = loaders.get(URL_LOADER) { let contents = run_loader_command(path, URL_LOADER, loader_command)?; return Ok((contents, DEFAULT_EXTENSION.into())); @@ -29,6 +34,9 @@ pub async fn fetch(loaders: &HashMap, path: &str) -> Result<(Str Err(ref err) => bail!("{err}"), }; let mut res = client.get(path).send().await?; + if !res.status().is_success() { + bail!("Invalid status: {}", res.status()); + } let content_type = res .headers() .get(CONTENT_TYPE) @@ -37,45 +45,71 @@ pub async fn fetch(loaders: &HashMap, path: &str) -> Result<(Str Some((mime, _)) => mime.trim(), None => v, }) - .unwrap_or_default(); - let extension = match content_type { - "application/pdf" => "pdf", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx", - "application/vnd.openxmlformats-officedocument.presentationml.presentation" => "pptx", - "application/vnd.oasis.opendocument.text" => "odt", - "application/vnd.oasis.opendocument.spreadsheet" => "ods", - "application/vnd.oasis.opendocument.presentation" => "odp", - "application/rtf" => "rtf", - "text/html" => "html", - _ => path + .map(|v| v.to_string()) + .unwrap_or_else(|| { + format!( + "_/{}", + get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into()) + ) + }); + let mut is_media = false; + let extension = match content_type.as_str() { + "application/pdf" => "pdf".into(), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx".into(), + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx".into(), + "application/vnd.openxmlformats-officedocument.presentationml.presentation" => { + "pptx".into() + } + "application/vnd.oasis.opendocument.text" => "odt".into(), + "application/vnd.oasis.opendocument.spreadsheet" => "ods".into(), + "application/vnd.oasis.opendocument.presentation" => "odp".into(), + "application/rtf" => "rtf".into(), + "text/javascript" => "js".into(), + "text/html" => "html".into(), + _ => content_type .rsplit_once('/') - .and_then(|(_, pair)| pair.rsplit_once('.').map(|(_, ext)| ext)) - .unwrap_or(DEFAULT_EXTENSION), + .map(|(first, last)| { + if ["image", "video", "audio"].contains(&first) { + is_media = true; + MEDIA_URL_EXTENSION.into() + } else { + last.to_lowercase() + } + }) + .unwrap_or_else(|| DEFAULT_EXTENSION.into()), }; - let extension = extension.to_lowercase(); - let result = match loaders.get(&extension) { - Some(loader_command) => { - let save_path = temp_file("-download-", &format!(".{extension}")) - .display() - .to_string(); - let mut save_file = tokio::fs::File::create(&save_path).await?; - let mut size = 0; - while let Some(chunk) = res.chunk().await? { - size += chunk.len(); - save_file.write_all(&chunk).await?; - } - let contents = if size == 0 { - println!("{}", warning_text(&format!("No content at '{path}'"))); - String::new() - } else { - run_loader_command(&save_path, &extension, loader_command)? - }; - (contents, DEFAULT_EXTENSION.into()) + let result = if is_media { + if !allow_media { + bail!("Unexpected media type") } - None => { - let contents = res.text().await?; - (contents, extension) + let image_bytes = res.bytes().await?; + let image_base64 = base64_encode(&image_bytes); + let contents = format!("data:{};base64,{}", content_type, image_base64); + (contents, extension) + } else { + match loaders.get(&extension) { + Some(loader_command) => { + let save_path = temp_file("-download-", &format!(".{extension}")) + .display() + .to_string(); + let mut save_file = tokio::fs::File::create(&save_path).await?; + let mut size = 0; + while let Some(chunk) = res.chunk().await? { + size += chunk.len(); + save_file.write_all(&chunk).await?; + } + let contents = if size == 0 { + println!("{}", warning_text(&format!("No content at '{path}'"))); + String::new() + } else { + run_loader_command(&save_path, &extension, loader_command)? + }; + (contents, DEFAULT_EXTENSION.into()) + } + None => { + let contents = res.text().await?; + (contents, extension) + } } }; Ok(result)