1
use thiserror::Error;
2

            
3
#[derive(Debug, PartialEq, Eq, Error)]
4
pub enum UniplateError {
5
    #[error("Could not reconstruct node because wrong number of children was provided. Expected {0} children, got {1}.")]
6
    WrongNumberOfChildren(usize, usize),
7
}
8

            
9
pub trait Uniplate
10
where
11
    Self: Sized + Clone + Eq,
12
{
13
    /// The `uniplate` function. Takes a node and produces a tuple of `(children, context)`, where:
14
    /// - children is a list of the node's direct descendants of the same type
15
    /// - context is a function to reconstruct the original node with a new list of children
16
    ///
17
    /// ## Warning
18
    ///
19
    /// The number of children passed to context must be the same as the number of children in
20
    /// the original node.
21
    /// If the number of children given is different, context returns `UniplateError::NotEnoughChildren`
22
    #[allow(clippy::type_complexity)]
23
    fn uniplate(
24
        &self,
25
    ) -> (
26
        Vec<Self>,
27
        Box<dyn Fn(Vec<Self>) -> Result<Self, UniplateError> + '_>,
28
    );
29

            
30
    /// Get all children of a node, including itself and all children.
31
    fn universe(&self) -> Vec<Self> {
32
        let mut results = vec![self.clone()];
33
        for child in self.children() {
34
            results.append(&mut child.universe());
35
        }
36
        results
37
    }
38

            
39
    /// Get the DIRECT children of a node.
40
    fn children(&self) -> Vec<Self> {
41
        self.uniplate().0
42
    }
43

            
44
    /// Reconstruct this node with the given children
45
    ///
46
    /// ## Arguments
47
    /// - children - a vector of the same type and same size as self.children()
48
    fn with_children(&self, children: Vec<Self>) -> Result<Self, UniplateError> {
49
        let context = self.uniplate().1;
50
        context(children)
51
    }
52

            
53
    /// Apply the given rule to all nodes bottom up.
54
    fn transform(&self, f: fn(Self) -> Self) -> Result<Self, UniplateError> {
55
        let (children, context) = self.uniplate();
56

            
57
        let mut new_children: Vec<Self> = Vec::new();
58
        for ch in children {
59
            let new_ch = ch.transform(f)?;
60
            new_children.push(new_ch);
61
        }
62

            
63
        let transformed = context(new_children)?;
64
        Ok(f(transformed))
65
    }
66

            
67
    /// Rewrite by applying a rule everywhere you can.
68
    fn rewrite(&self, f: fn(Self) -> Option<Self>) -> Result<Self, UniplateError> {
69
        let (children, context) = self.uniplate();
70

            
71
        let mut new_children: Vec<Self> = Vec::new();
72
        for ch in children {
73
            let new_ch = ch.rewrite(f)?;
74
            new_children.push(new_ch);
75
        }
76

            
77
        let node: Self = context(new_children)?;
78
        Ok(f(node.clone()).unwrap_or(node))
79
    }
80

            
81
    /// Perform a transformation on all the immediate children, then combine them back.
82
    /// This operation allows additional information to be passed downwards, and can be used to provide a top-down transformation.
83
    fn descend(&self, f: fn(Self) -> Self) -> Result<Self, UniplateError> {
84
        let (children, context) = self.uniplate();
85
        let children: Vec<Self> = children.into_iter().map(f).collect();
86

            
87
        context(children)
88
    }
89

            
90
    /// Perform a fold-like computation on each value.
91
    ///
92
    /// Working from the bottom up, this applies the given callback function to each nested
93
    /// component.
94
    ///
95
    /// Unlike [`transform`](Uniplate::transform), this returns an arbitrary type, and is not
96
    /// limited to T -> T transformations. In other words, it can transform a type into a new
97
    /// one.
98
    ///
99
    /// The meaning of the callback function is the following:
100
    ///
101
    ///   f(element_to_fold, folded_children) -> folded_element
102
    ///
103
    fn fold<T>(&self, op: fn(Self, Vec<T>) -> T) -> T {
104
        op(
105
            self.clone(),
106
            self.children().into_iter().map(|c| c.fold(op)).collect(),
107
        )
108
    }
109

            
110
    /// Get the nth one holed context.
111
    ///
112
    /// A uniplate context for type T has holes where all the nested T's should be.
113
    /// This is encoded as a function Vec<T> -> T.
114
    ///
115
    /// On the other hand, the nth one-holed context has only one hole where the nth nested
116
    /// instance of T would be.
117
    ///
118
    /// Eg. for some type:
119
    /// ```ignore
120
    /// enum Expr {
121
    ///     F(A,Expr,A,Expr,A),
122
    ///     G(Expr,A,A)
123
    /// }
124
    /// ```
125
    ///
126
    /// The 1st one-holed context of `F` (using 0-indexing) would be:
127
    /// ```ignore
128
    /// |HOLE| F(a,b,c,HOLE,e)
129
    /// ```
130
    ///
131
    /// Used primarily in the implementation of Zippers.
132
    fn one_holed_context(&self, n: usize) -> Option<Box<dyn Fn(Self) -> Self + '_>> {
133
        let (children, context) = self.uniplate();
134
        let number_of_elems = children.len();
135

            
136
        if n >= number_of_elems {
137
            return None;
138
        }
139

            
140
        Some(Box::new(move |x| {
141
            let mut children = children.clone();
142
            children[n] = x;
143
            #[allow(clippy::unwrap_used)]
144
            // We are directly replacing a child so there can't be an error
145
            context(children).unwrap()
146
        }))
147
    }
148
}