1
use std::fmt::{self, Display, Formatter};
2
use std::hash::Hash;
3

            
4
use thiserror::Error;
5

            
6
use crate::ast::{Expression, SymbolTable};
7
use crate::metadata::Metadata;
8
use crate::model::Model;
9

            
10
#[derive(Debug, Error)]
11
pub enum ApplicationError {
12
    #[error("Rule is not applicable")]
13
    RuleNotApplicable,
14

            
15
    #[error("Could not calculate the expression domain")]
16
    DomainError,
17
}
18

            
19
/// The result of applying a rule to an expression.
20
///
21
/// Contains an expression to replace the original, a top-level constraint to add to the top of the constraint AST, and an expansion to the model symbol table.
22
#[non_exhaustive]
23
#[derive(Clone, Debug)]
24
pub struct Reduction {
25
    pub new_expression: Expression,
26
    pub new_top: Expression,
27
    pub symbols: SymbolTable,
28
}
29

            
30
/// The result of applying a rule to an expression.
31
/// Contains either a set of reduction instructions or an error.
32
pub type ApplicationResult = Result<Reduction, ApplicationError>;
33

            
34
impl Reduction {
35
    pub fn new(new_expression: Expression, new_top: Expression, symbols: SymbolTable) -> Self {
36
        Self {
37
            new_expression,
38
            new_top,
39
            symbols,
40
        }
41
    }
42

            
43
    /// Represents a reduction with no side effects on the model.
44
    pub fn pure(new_expression: Expression) -> Self {
45
        Self {
46
            new_expression,
47
            new_top: Expression::Nothing,
48
            symbols: SymbolTable::new(),
49
        }
50
    }
51

            
52
    /// Represents a reduction that also modifies the symbol table.
53
    pub fn with_symbols(new_expression: Expression, symbols: SymbolTable) -> Self {
54
        Self {
55
            new_expression,
56
            new_top: Expression::Nothing,
57
            symbols,
58
        }
59
    }
60

            
61
    /// Represents a reduction that also adds a top-level constraint to the model.
62
    pub fn with_top(new_expression: Expression, new_top: Expression) -> Self {
63
        Self {
64
            new_expression,
65
            new_top,
66
            symbols: SymbolTable::new(),
67
        }
68
    }
69

            
70
    // Apply side-effects (e.g. symbol table updates
71
    pub fn apply(self, model: &mut Model) {
72
        model.variables.extend(self.symbols); // Add new assignments to the symbol table
73
        if self.new_top.is_nothing() {
74
            model.constraints = self.new_expression.clone();
75
        } else {
76
            model.constraints = match self.new_expression {
77
                Expression::And(metadata, mut exprs) => {
78
                    // Avoid creating a nested conjunction
79
                    exprs.push(self.new_top.clone());
80
                    Expression::And(metadata.clone_dirty(), exprs)
81
                }
82
                _ => Expression::And(
83
                    Metadata::new(),
84
                    vec![self.new_expression.clone(), self.new_top],
85
                ),
86
            };
87
        }
88
    }
89
}
90

            
91
/**
92
 * A rule with a name, application function, and rule sets.
93
 *
94
 * # Fields
95
 * - `name` The name of the rule.
96
 * - `application` The function to apply the rule.
97
 * - `rule_sets` A list of rule set names and priorities that this rule is a part of. This is used to populate rulesets at runtime.
98
 */
99
#[derive(Clone, Debug)]
100
pub struct Rule<'a> {
101
    pub name: &'a str,
102
    pub application: fn(&Expression, &Model) -> ApplicationResult,
103
    pub rule_sets: &'a [(&'a str, u8)], // (name, priority). At runtime, we add the rule to rulesets
104
}
105

            
106
impl<'a> Rule<'a> {
107
    pub const fn new(
108
        name: &'a str,
109
        application: fn(&Expression, &Model) -> ApplicationResult,
110
        rule_sets: &'a [(&'static str, u8)],
111
    ) -> Self {
112
        Self {
113
            name,
114
            application,
115
            rule_sets,
116
        }
117
    }
118

            
119
    pub fn apply(&self, expr: &Expression, mdl: &Model) -> ApplicationResult {
120
        (self.application)(expr, mdl)
121
    }
122
}
123

            
124
impl<'a> Display for Rule<'a> {
125
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
126
        write!(f, "{}", self.name)
127
    }
128
}
129

            
130
impl<'a> PartialEq for Rule<'a> {
131
    fn eq(&self, other: &Self) -> bool {
132
        self.name == other.name
133
    }
134
}
135

            
136
impl<'a> Eq for Rule<'a> {}
137

            
138
impl<'a> Hash for Rule<'a> {
139
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
140
        self.name.hash(state);
141
    }
142
}