1
use std::collections::{HashMap, HashSet};
2
use std::fmt::{Display, Formatter};
3
use std::hash::Hash;
4
use std::sync::OnceLock;
5

            
6
use log::warn;
7

            
8
use crate::rule_engine::{get_rule_set_by_name, get_rules, Rule};
9
use crate::solver::SolverFamily;
10

            
11
/// A set of rules with a name, priority, and dependencies.
12
#[derive(Clone, Debug)]
13
pub struct RuleSet<'a> {
14
    /// The name of the rule set.
15
    pub name: &'a str,
16
    /// Order of the RuleSet. Used to establish a consistent order of operations when resolving rules.
17
    /// If two RuleSets overlap (contain the same rule but with different priorities), the RuleSet with the higher order will be used as the source of truth.
18
    pub order: u8,
19
    /// A map of rules to their priorities. This will be lazily initialized at runtime.
20
    rules: OnceLock<HashMap<&'a Rule<'a>, u8>>,
21
    /// The names of the rule sets that this rule set depends on.
22
    dependency_rs_names: &'a [&'a str],
23
    dependencies: OnceLock<HashSet<&'a RuleSet<'a>>>,
24
    /// The solver families that this rule set applies to.
25
    pub solver_families: &'a [SolverFamily],
26
}
27

            
28
impl<'a> RuleSet<'a> {
29
    pub const fn new(
30
        name: &'a str,
31
        order: u8,
32
        dependencies: &'a [&'a str],
33
        solver_families: &'a [SolverFamily],
34
    ) -> Self {
35
        Self {
36
            name,
37
            order,
38
            dependency_rs_names: dependencies,
39
            solver_families,
40
            rules: OnceLock::new(),
41
            dependencies: OnceLock::new(),
42
        }
43
    }
44

            
45
    /// Get the rules of this rule set, evaluating them lazily if necessary
46
    /// Returns a `&HashMap<&Rule, u8>` where the key is the rule and the value is the priority of the rule.
47
    pub fn get_rules(&self) -> &HashMap<&'a Rule<'a>, u8> {
48
        match self.rules.get() {
49
            None => {
50
                let rules = self.resolve_rules();
51
                let _ = self.rules.set(rules); // Try to set the rules, but ignore if it fails.
52

            
53
                // At this point, the rules cell is guaranteed to be set, so we can unwrap safely.
54
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
55
                #[allow(clippy::unwrap_used)]
56
                self.rules.get().unwrap()
57
            }
58
            Some(rules) => rules,
59
        }
60
    }
61

            
62
    /// Get the dependencies of this rule set, evaluating them lazily if necessary
63
    /// Returns a `&HashSet<&RuleSet>` of the rule sets that this rule set depends on.
64
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
65
    pub fn get_dependencies(&self) -> &HashSet<&'static RuleSet> {
66
        match self.dependencies.get() {
67
            None => {
68
                let dependencies = self.resolve_dependencies();
69
                let _ = self.dependencies.set(dependencies); // Try to set the dependencies, but ignore if it fails.
70

            
71
                // At this point, the dependencies cell is guaranteed to be set, so we can unwrap safely.
72
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
73
                #[allow(clippy::unwrap_used)]
74
                self.dependencies.get().unwrap()
75
            }
76
            Some(dependencies) => dependencies,
77
        }
78
    }
79

            
80
    /// Get the dependencies of this rule set, including itself
81
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
82
    pub fn with_dependencies(&self) -> HashSet<&'static RuleSet> {
83
        let mut deps = self.get_dependencies().clone();
84
        deps.insert(self);
85
        deps
86
    }
87

            
88
    /// Resolve the rules of this rule set ("reverse the arrows")
89
    fn resolve_rules(&self) -> HashMap<&'a Rule<'a>, u8> {
90
        let mut rules = HashMap::new();
91

            
92
        for rule in get_rules() {
93
            let mut found = false;
94
            let mut priority: u8 = 0;
95

            
96
            for (name, p) in rule.rule_sets {
97
                if *name == self.name {
98
                    found = true;
99
                    priority = *p;
100
                    break;
101
                }
102
            }
103

            
104
            if found {
105
                rules.insert(rule, priority);
106
            }
107
        }
108

            
109
        rules
110
    }
111

            
112
    /// Recursively resolve the dependencies of this rule set.
113
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
114
    fn resolve_dependencies(&self) -> HashSet<&'static RuleSet> {
115
        let mut dependencies = HashSet::new();
116

            
117
        for dep in self.dependency_rs_names {
118
            match get_rule_set_by_name(dep) {
119
                None => {
120
                    warn!(
121
                        "Rule set {} depends on non-existent rule set {}",
122
                        &self.name, dep
123
                    );
124
                }
125
                Some(rule_set) => {
126
                    if !dependencies.contains(rule_set) {
127
                        // Prevent cycles
128
                        dependencies.insert(rule_set);
129
                        dependencies.extend(rule_set.resolve_dependencies());
130
                    }
131
                }
132
            }
133
        }
134

            
135
        dependencies
136
    }
137
}
138

            
139
impl<'a> PartialEq for RuleSet<'a> {
140
    fn eq(&self, other: &Self) -> bool {
141
        self.name == other.name
142
    }
143
}
144

            
145
impl<'a> Eq for RuleSet<'a> {}
146

            
147
impl<'a> Hash for RuleSet<'a> {
148
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
149
        self.name.hash(state);
150
    }
151
}
152

            
153
impl<'a> Display for RuleSet<'a> {
154
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
155
        let n_rules = self.get_rules().len();
156
        let solver_families = self
157
            .solver_families
158
            .iter()
159
            .map(|f| f.to_string())
160
            .collect::<Vec<String>>();
161

            
162
        write!(
163
            f,
164
            "RuleSet {{\n\
165
            \tname: {}\n\
166
            \torder: {}\n\
167
            \trules: {}\n\
168
            \tsolver_families: {:?}\n\
169
        }}",
170
            self.name, self.order, n_rules, solver_families
171
        )
172
    }
173
}