diff --git a/src/lib.rs b/src/lib.rs index 0c199e5..940a2f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,32 +6,41 @@ use syn::{Attribute, DeriveInput, Expr, Ident, Lit, Meta, parse_macro_input}; use std::fs; use std::path::{Path, PathBuf}; -#[proc_macro_derive(Embed, attributes(dir))] +#[derive(Debug, Clone, Copy)] +enum EmbedMode { + Bytes, + Str, +} + +#[proc_macro_derive(Embed, attributes(dir,mode))] pub fn embed(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let struct_name = &input.ident; - let attr = input + let dir_attr = input .attrs .iter() .find(|e| e.path().is_ident("dir")) .expect("No #[dir = \"...\"] attribute found"); - let base_path = PathBuf::from(extract_dir_path(attr)); + let mode_attr = input.attrs.iter().find(|e| e.path().is_ident("mode")); + + let mode = mode_attr.map(extract_mode).unwrap_or(EmbedMode::Bytes); + + let base_path = PathBuf::from(extract_dir_path(dir_attr)); let source_file = PathBuf::from(input.span().unwrap().file()); let source_dir = if let Some(parent) = source_file.parent() { parent } else { // HACK: when running in rust-analyzer i can't seem to get the parent dir. - return TokenStream::from(generate_impl(struct_name, Vec::new(), Vec::new())); + return TokenStream::from(generate_byte_impl(struct_name, Vec::new())); }; let absolue_path = source_dir.join(&base_path); let mut match_arms = Vec::new(); - let mut entries = Vec::new(); for entry in collect_files(&absolue_path) { let rel_path = entry @@ -42,18 +51,20 @@ pub fn embed(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .replace("\\", "/"); let include_path = base_path.join(&rel_path); - let include_string = include_path.to_str(); + let include_string = include_path.to_str().unwrap(); - match_arms.push(quote! { - #rel_path => Some(include_bytes!(#include_string) as &'static [u8]), - }); + let arm = match mode { + EmbedMode::Bytes => generate_byte_arm(&rel_path, include_string), + EmbedMode::Str => generate_str_arm(&rel_path, include_string), + }; - entries.push(quote! { - (#rel_path, include_bytes!(#include_string) as &'static [u8]) - }); + match_arms.push(arm); } - let expanded = generate_impl(struct_name, match_arms, entries); + let expanded = match mode { + EmbedMode::Bytes => generate_byte_impl(struct_name, match_arms), + EmbedMode::Str => generate_str_impl(struct_name, match_arms), + }; proc_macro::TokenStream::from(expanded) } @@ -88,10 +99,36 @@ fn extract_dir_path(attr: &Attribute) -> String { } } -fn generate_impl( +fn extract_mode(attr: &Attribute) -> EmbedMode { + let meta = match &attr.meta { + Meta::NameValue(meta) => meta, + _ => panic!("Expected #[mode = \"bytes\"|\"str\"] as a name-value attribute."), + }; + + let expr_lit = match &meta.value { + Expr::Lit(expr_lit) => expr_lit, + _ => panic!("Expected #[mode = \"bytes\"|\"str\"] with a string literal."), + }; + + match &expr_lit.lit { + Lit::Str(str) => match str.value().as_str() { + "bytes" => EmbedMode::Bytes, + "str" => EmbedMode::Str, + other => panic!("Unknown mode: {other}. Use `bytes` or `str`."), + }, + _ => panic!("Expected #[mode = \"bytes\"|\"str\"] to be a string."), + } +} + +fn generate_byte_arm(rel: &str, include: &str) -> proc_macro2::TokenStream { + quote! { + #rel => Some(include_bytes!(#include)), + } +} + +fn generate_byte_impl( struct_name: &Ident, match_arms: Vec, - entries: Vec, ) -> proc_macro2::TokenStream { quote! { impl #struct_name { @@ -101,9 +138,27 @@ fn generate_impl( _ => None, } } + } + } +} - pub fn iter() -> impl Iterator { - [#(#entries),*].into_iter() +fn generate_str_arm(rel: &str, include: &str) -> proc_macro2::TokenStream { + quote! { + #rel => Some(include_str!(#include)), + } +} + +fn generate_str_impl( + struct_name: &Ident, + match_arms: Vec, +) -> proc_macro2::TokenStream { + quote! { + impl #struct_name { + pub fn get(name: &str) -> Option<&'static str> { + match name { + #(#match_arms)* + _ => None, + } } } } diff --git a/tests/basic.rs b/tests/byte.rs similarity index 81% rename from tests/basic.rs rename to tests/byte.rs index d890650..2c4cfd7 100644 --- a/tests/basic.rs +++ b/tests/byte.rs @@ -2,27 +2,28 @@ use dir_embed::Embed; #[derive(Embed)] #[dir = "./../testdata/"] +#[mode = "bytes"] pub struct Assets; #[test] -fn get() { +fn byte_get() { assert!(Assets::get("file1.txt").is_some()); } #[test] -fn get_missing() { +fn byte_get_missing() { assert!(Assets::get("missing.txt").is_none()); } #[test] -fn read_content() { +fn byte_read_content() { let content_should = "file1".as_bytes(); let content_is = Assets::get("file1.txt").unwrap(); assert_eq!(*content_is, *content_should); } #[test] -fn parse_string() { +fn byte_parse_string() { let file: &[u8] = Assets::get("file1.txt").expect("Can't find file"); let string = str::from_utf8(file).expect("Failed to parse file"); @@ -31,14 +32,13 @@ fn parse_string() { } #[test] -fn sub_directories_get() { +fn byte_sub_directories_get() { assert!(Assets::get("sub/file2.txt").is_some()); } #[test] -fn sub_directories_content() { +fn byte_sub_directories_content() { let content_should = "file2".as_bytes(); let content_is = Assets::get("sub/file2.txt").unwrap(); assert_eq!(*content_is, *content_should); } -