1
use proc_macro::{self, TokenStream};
2

            
3
use proc_macro2::TokenStream as TokenStream2;
4
use quote::{format_ident, quote};
5
use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Ident, Variant};
6

            
7
use crate::utils::generate::{generate_field_clones, generate_field_fills, generate_field_idents};
8

            
9
mod utils;
10

            
11
/// Generate the full match pattern for a variant
12
18
fn generate_match_pattern(variant: &Variant, root_ident: &Ident) -> TokenStream2 {
13
18
    let field_idents = generate_field_idents(&variant.fields);
14
18
    let variant_ident = &variant.ident;
15
18

            
16
18
    if field_idents.is_empty() {
17
        quote! {
18
            #root_ident::#variant_ident
19
        }
20
    } else {
21
18
        quote! {
22
18
            #root_ident::#variant_ident(#(#field_idents,)*)
23
18
        }
24
    }
25
18
}
26

            
27
/// Generate the code to get the children of a variant
28
9
fn generate_variant_children_match_arm(variant: &Variant, root_ident: &Ident) -> TokenStream2 {
29
9
    let field_clones = generate_field_clones(&variant.fields, root_ident);
30
9

            
31
9
    let match_pattern = generate_match_pattern(variant, root_ident);
32

            
33
9
    let clones = if field_clones.is_empty() {
34
1
        quote! {
35
1
            Vec::new()
36
1
        }
37
    } else {
38
8
        quote! {
39
8
            vec![#(#field_clones,)*].iter().flatten().cloned().collect::<Vec<_>>()
40
8
        }
41
    };
42

            
43
9
    let mach_arm = quote! {
44
9
         #match_pattern => {
45
9
            #clones
46
9
        }
47
9
    };
48
9

            
49
9
    mach_arm
50
9
}
51

            
52
/// Generate an implementation of `context` for a variant
53
9
fn generate_variant_context_match_arm(variant: &Variant, root_ident: &Ident) -> TokenStream2 {
54
9
    let variant_ident = &variant.ident;
55
9
    let children_ident = Ident::new("children", variant_ident.span());
56
9
    let field_fills = generate_field_fills(&variant.fields, root_ident, &children_ident);
57
9
    let error_ident = format_ident!("UniplateError{}", root_ident);
58
9
    let match_pattern = generate_match_pattern(variant, root_ident);
59
9

            
60
9
    if field_fills.is_empty() {
61
        quote! {
62
            #match_pattern => {
63
                Box::new(|_| Ok(#root_ident::#variant_ident))
64
            }
65
        }
66
    } else {
67
9
        quote! {
68
9
            #match_pattern => {
69
9
                Box::new(|children| {
70
9
                    if (children.len() != self.children().len()) {
71
9
                        return Err(#error_ident::WrongNumberOfChildren(self.children().len(), children.len()));
72
9
                    }
73
9

            
74
9
                    let mut #children_ident = children.clone();
75
9
                    Ok(#root_ident::#variant_ident(#(#field_fills,)*))
76
9
                })
77
9
            }
78
9
        }
79
    }
80
9
}
81

            
82
/// Derive the `Uniplate` trait for an arbitrary type
83
///
84
/// # WARNING
85
///
86
/// This is alpha code. It is not yet stable and some features are missing.
87
///
88
/// ## What works?
89
///
90
/// - Deriving `Uniplate` for enum types
91
/// - `Box<T>` and `Vec<T>` fields, including nested vectors
92
/// - Tuple fields, including nested tuples - e.g. `(Vec<T>, (Box<T>, i32))`
93
///
94
/// ## What does not work?
95
///
96
/// - Structs
97
/// - Unions
98
/// - Array fields
99
/// - Multiple type arguments - e.g. `MyType<T, R>`
100
/// - Any complex type arguments, e.g. `MyType<T: MyTrait1 + MyTrait2>`
101
/// - Any collection type other than `Vec`
102
/// - Any box type other than `Box`
103
///
104
/// # Usage
105
///
106
/// This macro is intended to replace a hand-coded implementation of the `Uniplate` trait.
107
/// Example:
108
///
109
/// ```rust
110
/// use uniplate_derive::Uniplate;
111
/// use uniplate::uniplate::Uniplate;
112
///
113
/// #[derive(PartialEq, Eq, Debug, Clone, Uniplate)]
114
/// enum MyEnum {
115
///    A(Box<MyEnum>),
116
///    B(Vec<MyEnum>),
117
///    C(i32),
118
/// }
119
///
120
/// let a = MyEnum::A(Box::new(MyEnum::C(42)));
121
/// let (children, context) = a.uniplate();
122
/// assert_eq!(children, vec![MyEnum::C(42)]);
123
/// assert_eq!(context(vec![MyEnum::C(42)]).unwrap(), a);
124
/// ```
125
///
126
#[proc_macro_derive(Uniplate)]
127
1
pub fn derive(macro_input: TokenStream) -> TokenStream {
128
1
    let input = parse_macro_input!(macro_input as DeriveInput);
129
1
    let root_ident = &input.ident;
130
1
    let data = &input.data;
131

            
132
1
    let children_impl: TokenStream2 = match data {
133
        Data::Struct(_) => unimplemented!("Structs currently not supported"), // ToDo support structs
134
        Data::Union(_) => unimplemented!("Unions currently not supported"),   // ToDo support unions
135
1
        Data::Enum(DataEnum { variants, .. }) => {
136
1
            let match_arms: Vec<TokenStream2> = variants
137
1
                .iter()
138
9
                .map(|vt| generate_variant_children_match_arm(vt, root_ident))
139
1
                .collect::<Vec<_>>();
140

            
141
1
            let match_statement = quote! {
142
1
                match self {
143
1
                    #(#match_arms)*
144
1
                }
145
1
            };
146

            
147
1
            match_statement
148
        }
149
    };
150

            
151
1
    let context_impl = match data {
152
        Data::Struct(_) => unimplemented!("Structs currently not supported"),
153
        Data::Union(_) => unimplemented!("Unions currently not supported"),
154
1
        Data::Enum(DataEnum { variants, .. }) => {
155
1
            let match_arms: Vec<TokenStream2> = variants
156
1
                .iter()
157
9
                .map(|vt| generate_variant_context_match_arm(vt, root_ident))
158
1
                .collect::<Vec<_>>();
159

            
160
1
            let match_statement = quote! {
161
1
                match self {
162
1
                    #(#match_arms)*
163
1
                }
164
1
            };
165

            
166
1
            match_statement
167
1
        }
168
1
    };
169
1

            
170
1
    let error_ident = format_ident!("UniplateError{}", root_ident);
171
1

            
172
1
    let output = quote! {
173
1
        use uniplate::uniplate::UniplateError as #error_ident;
174
1

            
175
1
        impl Uniplate for #root_ident {
176
1
            #[allow(unused_variables)]
177
1
            fn uniplate(&self) -> (Vec<#root_ident>, Box<dyn Fn(Vec<#root_ident>) -> Result<#root_ident, #error_ident> + '_>) {
178
1
                let context: Box<dyn Fn(Vec<#root_ident>) -> Result<#root_ident, #error_ident>> = #context_impl;
179
1

            
180
1
                let children: Vec<#root_ident> = #children_impl;
181
1

            
182
1
                (children, context)
183
1
            }
184
1
        }
185
1
    };
186
1

            
187
1
    // println!("Final macro output:\n{}", output.to_string());
188
1

            
189
1
    output.into()
190
1
}