1
use std::fmt::{Display, Formatter};
2

            
3
use derive_is_enum_variant::is_enum_variant;
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7

            
8
use crate::ast::constants::Constant;
9
use crate::ast::symbol_table::{Name, SymbolTable};
10
use crate::ast::ReturnType;
11
use crate::metadata::Metadata;
12

            
13
use super::{Domain, Range};
14

            
15
#[document_compatibility]
16
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, is_enum_variant)]
17
#[non_exhaustive]
18
pub enum Expression {
19
    /**
20
     * Represents an empty expression
21
     * NB: we only expect this at the top level of a model (if there is no constraints)
22
     */
23
    Nothing,
24

            
25
    /// An expression representing "A is valid as long as B is true"
26
    /// Turns into a conjunction when it reaches a boolean context
27
    Bubble(Metadata, Box<Expression>, Box<Expression>),
28

            
29
    #[compatible(Minion, JsonInput)]
30
    Constant(Metadata, Constant),
31

            
32
    #[compatible(Minion, JsonInput, SAT)]
33
    Reference(Metadata, Name),
34

            
35
    #[compatible(Minion, JsonInput)]
36
    Sum(Metadata, Vec<Expression>),
37

            
38
    // /// Division after preventing division by zero, usually with a top-level constraint
39
    // #[compatible(Minion)]
40
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
41
    // /// Division with a possibly undefined value (division by 0)
42
    // #[compatible(Minion, JsonInput)]
43
    // Div(Metadata, Box<Expression>, Box<Expression>),
44
    #[compatible(JsonInput)]
45
    Min(Metadata, Vec<Expression>),
46

            
47
    #[compatible(JsonInput, SAT)]
48
    Not(Metadata, Box<Expression>),
49

            
50
    #[compatible(JsonInput, SAT)]
51
    Or(Metadata, Vec<Expression>),
52

            
53
    #[compatible(JsonInput, SAT)]
54
    And(Metadata, Vec<Expression>),
55

            
56
    #[compatible(JsonInput)]
57
    Eq(Metadata, Box<Expression>, Box<Expression>),
58

            
59
    #[compatible(JsonInput)]
60
    Neq(Metadata, Box<Expression>, Box<Expression>),
61

            
62
    #[compatible(JsonInput)]
63
    Geq(Metadata, Box<Expression>, Box<Expression>),
64

            
65
    #[compatible(JsonInput)]
66
    Leq(Metadata, Box<Expression>, Box<Expression>),
67

            
68
    #[compatible(JsonInput)]
69
    Gt(Metadata, Box<Expression>, Box<Expression>),
70

            
71
    #[compatible(JsonInput)]
72
    Lt(Metadata, Box<Expression>, Box<Expression>),
73

            
74
    /// Division after preventing division by zero, usually with a bubble
75
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
76

            
77
    /// Division with a possibly undefined value (division by 0)
78
    #[compatible(JsonInput)]
79
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
80

            
81
    /* Flattened SumEq.
82
     *
83
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
84
     * This is NOT a valid expression in either Essence or minion.
85
     *
86
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
87
     */
88
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
89

            
90
    // Flattened Constraints
91
    #[compatible(Minion)]
92
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
93

            
94
    #[compatible(Minion)]
95
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
96

            
97
    #[compatible(Minion)]
98
    DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
99

            
100
    #[compatible(Minion)]
101
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
102

            
103
    #[compatible(Minion)]
104
    AllDiff(Metadata, Vec<Expression>),
105
}
106

            
107
fn expr_vec_to_domain_i32(
108
    exprs: &Vec<Expression>,
109
    op: fn(i32, i32) -> Option<i32>,
110
    vars: &SymbolTable,
111
) -> Option<Domain> {
112
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
113
    domains
114
        .into_iter()
115
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
116
        .flatten()
117
}
118

            
119
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
120
    let mut min = i32::MAX;
121
    let mut max = i32::MIN;
122
    for r in ranges {
123
        match r {
124
            Range::Single(i) => {
125
                if *i < min {
126
                    min = *i;
127
                }
128
                if *i > max {
129
                    max = *i;
130
                }
131
            }
132
            Range::Bounded(i, j) => {
133
                if *i < min {
134
                    min = *i;
135
                }
136
                if *j > max {
137
                    max = *j;
138
                }
139
            }
140
        }
141
    }
142
    (min, max)
143
}
144

            
145
impl Expression {
146
    /// Returns the possible values of the expression, recursing to leaf expressions
147
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
148
        let ret = match self {
149
            Expression::Reference(_, name) => Some(vars.get(name)?.domain.clone()),
150
            Expression::Constant(_, Constant::Int(n)) => {
151
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
152
            }
153
            Expression::Constant(_, Constant::Bool(_)) => Some(Domain::BoolDomain),
154
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
155
            Expression::Min(_, exprs) => {
156
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
157
            }
158
            Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
159
                a.domain_of(vars)?.apply_i32(
160
                    |x, y| if y != 0 { Some(x / y) } else { None },
161
                    &b.domain_of(vars)?,
162
                )
163
            }
164
            _ => todo!("Calculate domain of {:?}", self),
165
            // TODO: (flm8) Add support for calculating the domains of more expression types
166
        };
167
        match ret {
168
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
169
            // Once they support a full domain as we define it, we can remove this conversion
170
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
171
                let (min, max) = range_vec_bounds_i32(&ranges);
172
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
173
            }
174
            _ => ret,
175
        }
176
    }
177

            
178
    pub fn can_be_undefined(&self) -> bool {
179
        // TODO: there will be more false cases but we are being conservative
180
        match self {
181
            Expression::Reference(_, _) => false,
182
            Expression::Constant(_, Constant::Bool(_)) => false,
183
            Expression::Constant(_, Constant::Int(_)) => false,
184
            _ => true,
185
        }
186
    }
187

            
188
    pub fn return_type(&self) -> Option<ReturnType> {
189
        match self {
190
            Expression::Constant(_, Constant::Int(_)) => Some(ReturnType::Int),
191
            Expression::Constant(_, Constant::Bool(_)) => Some(ReturnType::Bool),
192
            Expression::Reference(_, _) => None,
193
            Expression::Sum(_, _) => Some(ReturnType::Int),
194
            Expression::Min(_, _) => Some(ReturnType::Int),
195
            Expression::Not(_, _) => Some(ReturnType::Bool),
196
            Expression::Or(_, _) => Some(ReturnType::Bool),
197
            Expression::And(_, _) => Some(ReturnType::Bool),
198
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
199
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
200
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
201
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
202
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
203
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
204
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
205
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
206
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
207
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
208
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
209
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
210
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
211
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
212
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
213
            Expression::Nothing => None,
214
        }
215
    }
216

            
217
    pub fn is_clean(&self) -> bool {
218
        match self {
219
            Expression::Nothing => true,
220
            Expression::Constant(metadata, _) => metadata.clean,
221
            Expression::Reference(metadata, _) => metadata.clean,
222
            Expression::Sum(metadata, exprs) => metadata.clean,
223
            Expression::Min(metadata, exprs) => metadata.clean,
224
            Expression::Not(metadata, expr) => metadata.clean,
225
            Expression::Or(metadata, exprs) => metadata.clean,
226
            Expression::And(metadata, exprs) => metadata.clean,
227
            Expression::Eq(metadata, box1, box2) => metadata.clean,
228
            Expression::Neq(metadata, box1, box2) => metadata.clean,
229
            Expression::Geq(metadata, box1, box2) => metadata.clean,
230
            Expression::Leq(metadata, box1, box2) => metadata.clean,
231
            Expression::Gt(metadata, box1, box2) => metadata.clean,
232
            Expression::Lt(metadata, box1, box2) => metadata.clean,
233
            Expression::SumGeq(metadata, box1, box2) => metadata.clean,
234
            Expression::SumLeq(metadata, box1, box2) => metadata.clean,
235
            Expression::Ineq(metadata, box1, box2, box3) => metadata.clean,
236
            Expression::AllDiff(metadata, exprs) => metadata.clean,
237
            Expression::SumEq(metadata, exprs, expr) => metadata.clean,
238
            _ => false,
239
        }
240
    }
241

            
242
    pub fn set_clean(&mut self, bool_value: bool) {
243
        match self {
244
            Expression::Nothing => {}
245
            Expression::Constant(metadata, _) => metadata.clean = bool_value,
246
            Expression::Reference(metadata, _) => metadata.clean = bool_value,
247
            Expression::Sum(metadata, _) => {
248
                metadata.clean = bool_value;
249
            }
250
            Expression::Min(metadata, _) => {
251
                metadata.clean = bool_value;
252
            }
253
            Expression::Not(metadata, _) => {
254
                metadata.clean = bool_value;
255
            }
256
            Expression::Or(metadata, _) => {
257
                metadata.clean = bool_value;
258
            }
259
            Expression::And(metadata, _) => {
260
                metadata.clean = bool_value;
261
            }
262
            Expression::Eq(metadata, box1, box2) => {
263
                metadata.clean = bool_value;
264
            }
265
            Expression::Neq(metadata, _box1, _box2) => {
266
                metadata.clean = bool_value;
267
            }
268
            Expression::Geq(metadata, _box1, _box2) => {
269
                metadata.clean = bool_value;
270
            }
271
            Expression::Leq(metadata, _box1, _box2) => {
272
                metadata.clean = bool_value;
273
            }
274
            Expression::Gt(metadata, _box1, _box2) => {
275
                metadata.clean = bool_value;
276
            }
277
            Expression::Lt(metadata, _box1, _box2) => {
278
                metadata.clean = bool_value;
279
            }
280
            Expression::SumGeq(metadata, _box1, _box2) => {
281
                metadata.clean = bool_value;
282
            }
283
            Expression::SumLeq(metadata, _box1, _box2) => {
284
                metadata.clean = bool_value;
285
            }
286
            Expression::Ineq(metadata, _box1, _box2, _box3) => {
287
                metadata.clean = bool_value;
288
            }
289
            Expression::AllDiff(metadata, _exprs) => {
290
                metadata.clean = bool_value;
291
            }
292
            Expression::SumEq(metadata, _exprs, _expr) => {
293
                metadata.clean = bool_value;
294
            }
295
            Expression::Bubble(metadata, box1, box2) => {
296
                metadata.clean = bool_value;
297
            }
298
            Expression::SafeDiv(metadata, box1, box2) => {
299
                metadata.clean = bool_value;
300
            }
301
            Expression::UnsafeDiv(metadata, box1, box2) => {
302
                metadata.clean = bool_value;
303
            }
304
            Expression::DivEq(metadata, box1, box2, box3) => {
305
                metadata.clean = bool_value;
306
            }
307
        }
308
    }
309
}
310

            
311
fn display_expressions(expressions: &[Expression]) -> String {
312
    // if expressions.len() <= 3 {
313
    format!(
314
        "[{}]",
315
        expressions
316
            .iter()
317
            .map(|e| e.to_string())
318
            .collect::<Vec<String>>()
319
            .join(", ")
320
    )
321
    // } else {
322
    //     format!(
323
    //         "[{}..{}]",
324
    //         expressions[0],
325
    //         expressions[expressions.len() - 1]
326
    //     )
327
    // }
328
}
329

            
330
impl From<i32> for Expression {
331
    fn from(i: i32) -> Self {
332
        Expression::Constant(Metadata::new(), Constant::Int(i))
333
    }
334
}
335

            
336
impl From<bool> for Expression {
337
    fn from(b: bool) -> Self {
338
        Expression::Constant(Metadata::new(), Constant::Bool(b))
339
    }
340
}
341

            
342
impl Display for Expression {
343
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
344
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
345
        match &self {
346
            Expression::Constant(_, c) => match c {
347
                Constant::Bool(b) => write!(f, "{}", b),
348
                Constant::Int(i) => write!(f, "{}", i),
349
            },
350
            Expression::Reference(_, name) => match name {
351
                Name::MachineName(n) => write!(f, "_{}", n),
352
                Name::UserName(s) => write!(f, "{}", s),
353
            },
354
            Expression::Nothing => write!(f, "Nothing"),
355
            Expression::Sum(_, expressions) => {
356
                write!(f, "Sum({})", display_expressions(expressions))
357
            }
358
            Expression::Min(_, expressions) => {
359
                write!(f, "Min({})", display_expressions(expressions))
360
            }
361
            Expression::Not(_, expr_box) => {
362
                write!(f, "Not({})", expr_box.clone())
363
            }
364
            Expression::Or(_, expressions) => {
365
                write!(f, "Or({})", display_expressions(expressions))
366
            }
367
            Expression::And(_, expressions) => {
368
                write!(f, "And({})", display_expressions(expressions))
369
            }
370
            Expression::Eq(_, box1, box2) => {
371
                write!(f, "({} = {})", box1.clone(), box2.clone())
372
            }
373
            Expression::Neq(_, box1, box2) => {
374
                write!(f, "({} != {})", box1.clone(), box2.clone())
375
            }
376
            Expression::Geq(_, box1, box2) => {
377
                write!(f, "({} >= {})", box1.clone(), box2.clone())
378
            }
379
            Expression::Leq(_, box1, box2) => {
380
                write!(f, "({} <= {})", box1.clone(), box2.clone())
381
            }
382
            Expression::Gt(_, box1, box2) => {
383
                write!(f, "({} > {})", box1.clone(), box2.clone())
384
            }
385
            Expression::Lt(_, box1, box2) => {
386
                write!(f, "({} < {})", box1.clone(), box2.clone())
387
            }
388
            Expression::SumEq(_, expressions, expr_box) => {
389
                write!(
390
                    f,
391
                    "SumEq({}, {})",
392
                    display_expressions(expressions),
393
                    expr_box.clone()
394
                )
395
            }
396
            Expression::SumGeq(_, box1, box2) => {
397
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
398
            }
399
            Expression::SumLeq(_, box1, box2) => {
400
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
401
            }
402
            Expression::Ineq(_, box1, box2, box3) => write!(
403
                f,
404
                "Ineq({}, {}, {})",
405
                box1.clone(),
406
                box2.clone(),
407
                box3.clone()
408
            ),
409
            Expression::AllDiff(_, expressions) => {
410
                write!(f, "AllDiff({})", display_expressions(expressions))
411
            }
412
            Expression::Bubble(_, box1, box2) => {
413
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
414
            }
415
            Expression::SafeDiv(_, box1, box2) => {
416
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
417
            }
418
            Expression::UnsafeDiv(_, box1, box2) => {
419
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
420
            }
421
            Expression::DivEq(_, box1, box2, box3) => {
422
                write!(
423
                    f,
424
                    "DivEq({}, {}, {})",
425
                    box1.clone(),
426
                    box2.clone(),
427
                    box3.clone()
428
                )
429
            }
430
            #[allow(unreachable_patterns)]
431
            other => todo!("Implement display for {:?}", other),
432
        }
433
    }
434
}
435

            
436
#[cfg(test)]
437
mod tests {
438
    use crate::ast::DecisionVariable;
439

            
440
    use super::*;
441

            
442
    #[test]
443
    fn test_domain_of_constant_sum() {
444
        let c1 = Expression::Constant(Metadata::new(), Constant::Int(1));
445
        let c2 = Expression::Constant(Metadata::new(), Constant::Int(2));
446
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
447
        assert_eq!(
448
            sum.domain_of(&SymbolTable::new()),
449
            Some(Domain::IntDomain(vec![Range::Single(3)]))
450
        );
451
    }
452

            
453
    #[test]
454
    fn test_domain_of_constant_invalid_type() {
455
        let c1 = Expression::Constant(Metadata::new(), Constant::Int(1));
456
        let c2 = Expression::Constant(Metadata::new(), Constant::Bool(true));
457
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
458
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
459
    }
460

            
461
    #[test]
462
    fn test_domain_of_empty_sum() {
463
        let sum = Expression::Sum(Metadata::new(), vec![]);
464
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
465
    }
466

            
467
    #[test]
468
    fn test_domain_of_reference() {
469
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
470
        let mut vars = SymbolTable::new();
471
        vars.insert(
472
            Name::MachineName(0),
473
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
474
        );
475
        assert_eq!(
476
            reference.domain_of(&vars),
477
            Some(Domain::IntDomain(vec![Range::Single(1)]))
478
        );
479
    }
480

            
481
    #[test]
482
    fn test_domain_of_reference_not_found() {
483
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
484
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
485
    }
486

            
487
    #[test]
488
    fn test_domain_of_reference_sum_single() {
489
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
490
        let mut vars = SymbolTable::new();
491
        vars.insert(
492
            Name::MachineName(0),
493
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
494
        );
495
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
496
        assert_eq!(
497
            sum.domain_of(&vars),
498
            Some(Domain::IntDomain(vec![Range::Single(2)]))
499
        );
500
    }
501

            
502
    #[test]
503
    fn test_domain_of_reference_sum_bounded() {
504
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
505
        let mut vars = SymbolTable::new();
506
        vars.insert(
507
            Name::MachineName(0),
508
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
509
        );
510
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
511
        assert_eq!(
512
            sum.domain_of(&vars),
513
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
514
        );
515
    }
516
}