1
use std::collections::{HashMap, HashSet};
2
use std::fmt::Display;
3

            
4
use thiserror::Error;
5

            
6
use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
7
use crate::solver::SolverFamily;
8

            
9
#[derive(Debug, Error)]
10
pub enum ResolveRulesError {
11
    RuleSetNotFound,
12
}
13

            
14
impl Display for ResolveRulesError {
15
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16
        match self {
17
            ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
18
        }
19
    }
20
}
21

            
22
/// Helper function to get a rule set by name, or return an error if it doesn't exist.
23
///
24
/// # Arguments
25
/// - `rule_set_name` The name of the rule set to get.
26
///
27
/// # Returns
28
/// - The rule set with the given name or `RuleSetError::RuleSetNotFound` if it doesn't exist.
29
fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
30
    match get_rule_set_by_name(rule_set_name) {
31
        Some(rule_set) => Ok(rule_set),
32
        None => Err(ResolveRulesError::RuleSetNotFound),
33
    }
34
}
35

            
36
/// Resolve a list of rule sets (and dependencies) by their names
37
///
38
/// # Arguments
39
/// - `rule_set_names` The names of the rule sets to resolve.
40
///
41
/// # Returns
42
/// - A list of the given rule sets and all of their dependencies, or error
43
///
44
#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
45
pub fn rule_sets_by_names<'a>(
46
    rule_set_names: &Vec<String>,
47
) -> Result<HashSet<&'a RuleSet<'static>>, ResolveRulesError> {
48
    let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
49

            
50
    for rule_set_name in rule_set_names {
51
        let rule_set = get_rule_set(rule_set_name)?;
52
        let new_dependencies = rule_set.get_dependencies();
53
        rs_set.insert(rule_set);
54
        rs_set.extend(new_dependencies);
55
    }
56

            
57
    Ok(rs_set)
58
}
59

            
60
/// Resolves the final set of rule sets to apply based on target solver and extra rule set names.
61
///
62
/// # Arguments
63
/// - `target_solver` The solver to resolve the rule sets for.
64
/// - `extra_rs_names` The names of the extra rule sets to use
65
///
66
/// # Returns
67
/// - A vector of rule sets to apply.
68
///
69
#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
70
pub fn resolve_rule_sets<'a>(
71
    target_solver: SolverFamily,
72
    extra_rs_names: &Vec<String>,
73
) -> Result<Vec<&'a RuleSet<'static>>, ResolveRulesError> {
74
    let mut ans = HashSet::new();
75

            
76
    for rs in get_rule_sets_for_solver_family(target_solver) {
77
        ans.extend(rs.with_dependencies());
78
    }
79

            
80
    ans.extend(rule_sets_by_names(extra_rs_names)?);
81
    Ok(ans.iter().cloned().collect())
82
}
83

            
84
/// Convert a list of rule sets into a final map of rules to their priorities.
85
///
86
/// # Arguments
87
/// - `rule_sets` The rule sets to get the rules from.
88
/// # Returns
89
/// - A map of rules to their priorities.
90
pub fn get_rule_priorities<'a>(
91
    rule_sets: &Vec<&'a RuleSet<'a>>,
92
) -> Result<HashMap<&'a Rule<'a>, u8>, ResolveRulesError> {
93
    let mut rule_priorities: HashMap<&'a Rule<'a>, (&'a RuleSet<'a>, u8)> = HashMap::new();
94

            
95
    for rs in rule_sets {
96
        for (rule, priority) in rs.get_rules() {
97
            if let Some((old_rs, _)) = rule_priorities.get(rule) {
98
                if rs.order >= old_rs.order {
99
                    rule_priorities.insert(rule, (&rs, *priority));
100
                }
101
            } else {
102
                rule_priorities.insert(rule, (&rs, *priority));
103
            }
104
        }
105
    }
106

            
107
    let mut ans: HashMap<&'a Rule<'a>, u8> = HashMap::new();
108
    for (rule, (_, priority)) in rule_priorities {
109
        ans.insert(rule, priority);
110
    }
111

            
112
    Ok(ans)
113
}
114

            
115
/// Compare two rules by their priorities and names.
116
///
117
/// Takes the rules and a map of rules to their priorities.
118
/// If rules are not in the map, they are assumed to have priority 0.
119
/// If the rules have the same priority, they are compared by their names.
120
///
121
/// # Arguments
122
/// - `a` first rule to compare.
123
/// - `b` second rule to compare.
124
/// - `rule_priorities` The priorities of the rules.
125
///
126
/// # Returns
127
/// - The ordering of the two rules.
128
pub fn rule_cmp<'a>(
129
    a: &Rule<'a>,
130
    b: &Rule<'a>,
131
    rule_priorities: &HashMap<&'a Rule<'a>, u8>,
132
) -> std::cmp::Ordering {
133
    let a_priority = *rule_priorities.get(a).unwrap_or(&0);
134
    let b_priority = *rule_priorities.get(b).unwrap_or(&0);
135

            
136
    if a_priority == b_priority {
137
        return a.name.cmp(b.name);
138
    }
139

            
140
    b_priority.cmp(&a_priority)
141
}
142

            
143
/// Get a final ordering of rules based on their priorities and names.
144
///
145
/// # Arguments
146
/// - `rule_priorities` The priorities of the rules.
147
///
148
/// # Returns
149
/// - A list of rules sorted by their priorities and names.
150
pub fn get_rules_vec<'a>(rule_priorities: &HashMap<&'a Rule<'a>, u8>) -> Vec<&'a Rule<'a>> {
151
    let mut rules: Vec<&'a Rule<'a>> = rule_priorities.keys().copied().collect();
152
    rules.sort_by(|a, b| rule_cmp(a, b, rule_priorities));
153
    rules
154
}