conjure_core/ast/
expressions.rs

1use std::collections::VecDeque;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::ast::literals::AbstractLiteral;
9use crate::ast::literals::Literal;
10use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11use crate::ast::symbol_table::SymbolTable;
12use crate::ast::Atom;
13use crate::ast::Name;
14use crate::ast::ReturnType;
15use crate::ast::SetAttr;
16use crate::bug;
17use crate::metadata::Metadata;
18use enum_compatability_macro::document_compatibility;
19use uniplate::derive::Uniplate;
20use uniplate::{Biplate, Uniplate as _};
21
22use super::ac_operators::ACOperatorKind;
23use super::comprehension::Comprehension;
24use super::records::RecordValue;
25use super::{Domain, Range, SubModel, Typeable};
26
27// Ensure that this type doesn't get too big
28//
29// If you triggered this assertion, you either made a variant of this enum that is too big, or you
30// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
31// stuff in boxes.
32//
33// Enums take the size of their largest variant, so an enum with mostly small variants and a few
34// large ones wastes memory... A larger Expression type also slows down Oxide.
35//
36// For more information, and more details on type sizes and how to measure them, see the commit
37// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
38//
39// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
40//
41// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
42
43// expect size of Expression to be 96 bytes
44static_assertions::assert_eq_size!([u8; 96], Expression);
45
46/// Represents different types of expressions used to define rules and constraints in the model.
47///
48/// The `Expression` enum includes operations, constants, and variable references
49/// used to build rules and conditions for the model.
50#[document_compatibility]
51#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
52#[uniplate(walk_into=[Atom,SubModel,AbstractLiteral<Expression>,Comprehension])]
53#[biplate(to=Metadata)]
54#[biplate(to=Atom,walk_into=[Expression,AbstractLiteral<Expression>,Vec<Expression>])]
55#[biplate(to=Name,walk_into=[Expression,Atom,AbstractLiteral<Expression>,Vec<Expression>])]
56#[biplate(to=Vec<Expression>)]
57#[biplate(to=Option<Expression>)]
58#[biplate(to=SubModel,walk_into=[Comprehension])]
59#[biplate(to=Comprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>,walk_into=[Atom])]
62#[biplate(to=RecordValue<Expression>,walk_into=[AbstractLiteral<Expression>])]
63#[biplate(to=RecordValue<Literal>,walk_into=[Atom,Literal,AbstractLiteral<Literal>,AbstractLiteral<Expression>])]
64#[biplate(to=Literal,walk_into=[Atom])]
65pub enum Expression {
66    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
67    /// The top of the model
68    Root(Metadata, Vec<Expression>),
69
70    /// An expression representing "A is valid as long as B is true"
71    /// Turns into a conjunction when it reaches a boolean context
72    Bubble(Metadata, Box<Expression>, Box<Expression>),
73
74    /// A comprehension.
75    ///
76    /// The inside of the comprehension opens a new scope.
77    Comprehension(Metadata, Box<Comprehension>),
78
79    /// Defines dominance ("Solution A is preferred over Solution B")
80    DominanceRelation(Metadata, Box<Expression>),
81    /// `fromSolution(name)` - Used in dominance relation definitions
82    FromSolution(Metadata, Box<Expression>),
83
84    Atomic(Metadata, Atom),
85
86    /// A matrix index.
87    ///
88    /// Defined iff the indices are within their respective index domains.
89    #[compatible(JsonInput)]
90    UnsafeIndex(Metadata, Box<Expression>, Vec<Expression>),
91
92    /// A safe matrix index.
93    ///
94    /// See [`Expression::UnsafeIndex`]
95    SafeIndex(Metadata, Box<Expression>, Vec<Expression>),
96
97    /// A matrix slice: `a[indices]`.
98    ///
99    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
100    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
101    /// `Some(1),None,Some(2)`.
102    ///
103    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
104    ///
105    /// Defined iff the defined indices are within their respective index domains.
106    #[compatible(JsonInput)]
107    UnsafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
108
109    /// A safe matrix slice: `a[indices]`.
110    ///
111    /// See [`Expression::UnsafeSlice`].
112    SafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
113
114    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
115    ///
116    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
117    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
118    /// `SafeSlice` respectively.
119    InDomain(Metadata, Box<Expression>, Domain),
120
121    /// `toInt(b)` casts boolean expression b to an integer.
122    ///
123    /// - If b is false, then `toInt(b) == 0`
124    ///
125    /// - If b is true, then `toInt(b) == 1`
126    ToInt(Metadata, Box<Expression>),
127
128    Scope(Metadata, Box<SubModel>),
129
130    /// `|x|` - absolute value of `x`
131    #[compatible(JsonInput)]
132    Abs(Metadata, Box<Expression>),
133
134    /// `sum(<vec_expr>)`
135    #[compatible(JsonInput)]
136    Sum(Metadata, Box<Expression>),
137
138    /// `a * b * c * ...`
139    #[compatible(JsonInput)]
140    Product(Metadata, Box<Expression>),
141
142    /// `min(<vec_expr>)`
143    #[compatible(JsonInput)]
144    Min(Metadata, Box<Expression>),
145
146    /// `max(<vec_expr>)`
147    #[compatible(JsonInput)]
148    Max(Metadata, Box<Expression>),
149
150    /// `not(a)`
151    #[compatible(JsonInput, SAT)]
152    Not(Metadata, Box<Expression>),
153
154    /// `or(<vec_expr>)`
155    #[compatible(JsonInput, SAT)]
156    Or(Metadata, Box<Expression>),
157
158    /// `and(<vec_expr>)`
159    #[compatible(JsonInput, SAT)]
160    And(Metadata, Box<Expression>),
161
162    /// Ensures that `a->b` (material implication).
163    #[compatible(JsonInput)]
164    Imply(Metadata, Box<Expression>, Box<Expression>),
165
166    /// `iff(a, b)` a <-> b
167    #[compatible(JsonInput)]
168    Iff(Metadata, Box<Expression>, Box<Expression>),
169
170    #[compatible(JsonInput)]
171    Union(Metadata, Box<Expression>, Box<Expression>),
172
173    #[compatible(JsonInput)]
174    In(Metadata, Box<Expression>, Box<Expression>),
175
176    #[compatible(JsonInput)]
177    Intersect(Metadata, Box<Expression>, Box<Expression>),
178
179    #[compatible(JsonInput)]
180    Supset(Metadata, Box<Expression>, Box<Expression>),
181
182    #[compatible(JsonInput)]
183    SupsetEq(Metadata, Box<Expression>, Box<Expression>),
184
185    #[compatible(JsonInput)]
186    Subset(Metadata, Box<Expression>, Box<Expression>),
187
188    #[compatible(JsonInput)]
189    SubsetEq(Metadata, Box<Expression>, Box<Expression>),
190
191    #[compatible(JsonInput)]
192    Eq(Metadata, Box<Expression>, Box<Expression>),
193
194    #[compatible(JsonInput)]
195    Neq(Metadata, Box<Expression>, Box<Expression>),
196
197    #[compatible(JsonInput)]
198    Geq(Metadata, Box<Expression>, Box<Expression>),
199
200    #[compatible(JsonInput)]
201    Leq(Metadata, Box<Expression>, Box<Expression>),
202
203    #[compatible(JsonInput)]
204    Gt(Metadata, Box<Expression>, Box<Expression>),
205
206    #[compatible(JsonInput)]
207    Lt(Metadata, Box<Expression>, Box<Expression>),
208
209    /// Division after preventing division by zero, usually with a bubble
210    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
211
212    /// Division with a possibly undefined value (division by 0)
213    #[compatible(JsonInput)]
214    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
215
216    /// Modulo after preventing mod 0, usually with a bubble
217    SafeMod(Metadata, Box<Expression>, Box<Expression>),
218
219    /// Modulo with a possibly undefined value (mod 0)
220    #[compatible(JsonInput)]
221    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
222
223    /// Negation: `-x`
224    #[compatible(JsonInput)]
225    Neg(Metadata, Box<Expression>),
226
227    /// Unsafe power`x**y` (possibly undefined)
228    ///
229    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
230    #[compatible(JsonInput)]
231    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
232
233    /// `UnsafePow` after preventing undefinedness
234    SafePow(Metadata, Box<Expression>, Box<Expression>),
235
236    /// `allDiff(<vec_expr>)`
237    #[compatible(JsonInput)]
238    AllDiff(Metadata, Box<Expression>),
239
240    /// Binary subtraction operator
241    ///
242    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
243    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
244    /// have already edited minus_to_sum to prevent this from applying to sets
245    #[compatible(JsonInput)]
246    Minus(Metadata, Box<Expression>, Box<Expression>),
247
248    /// Ensures that x=|y| i.e. x is the absolute value of y.
249    ///
250    /// Low-level Minion constraint.
251    ///
252    /// # See also
253    ///
254    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
255    #[compatible(Minion)]
256    FlatAbsEq(Metadata, Box<Atom>, Box<Atom>),
257
258    /// Ensures that `alldiff([a,b,...])`.
259    ///
260    /// Low-level Minion constraint.
261    ///
262    /// # See also
263    ///
264    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
265    #[compatible(Minion)]
266    FlatAllDiff(Metadata, Vec<Atom>),
267
268    /// Ensures that sum(vec) >= x.
269    ///
270    /// Low-level Minion constraint.
271    ///
272    /// # See also
273    ///
274    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
275    #[compatible(Minion)]
276    FlatSumGeq(Metadata, Vec<Atom>, Atom),
277
278    /// Ensures that sum(vec) <= x.
279    ///
280    /// Low-level Minion constraint.
281    ///
282    /// # See also
283    ///
284    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
285    #[compatible(Minion)]
286    FlatSumLeq(Metadata, Vec<Atom>, Atom),
287
288    /// `ineq(x,y,k)` ensures that x <= y + k.
289    ///
290    /// Low-level Minion constraint.
291    ///
292    /// # See also
293    ///
294    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
295    #[compatible(Minion)]
296    FlatIneq(Metadata, Box<Atom>, Box<Atom>, Box<Literal>),
297
298    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
299    ///
300    /// Low-level Minion constraint.
301    ///
302    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
303    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
304    /// watched-and and watched-or.
305    ///
306    /// # See also
307    ///
308    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
309    /// + `rules::minion::boolean_literal_to_wliteral`.
310    #[compatible(Minion)]
311    FlatWatchedLiteral(Metadata, Name, Literal),
312
313    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
314    /// product of cs and xs.
315    ///
316    /// Low-level Minion constraint.
317    ///
318    /// Represents a weighted sum of the form `ax + by + cz + ...`
319    ///
320    /// # See also
321    ///
322    /// + [Minion
323    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
324    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Box<Atom>),
325
326    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
327    /// product of cs and xs.
328    ///
329    /// Low-level Minion constraint.
330    ///
331    /// Represents a weighted sum of the form `ax + by + cz + ...`
332    ///
333    /// # See also
334    ///
335    /// + [Minion
336    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
337    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Box<Atom>),
338
339    /// Ensures that x =-y, where x and y are atoms.
340    ///
341    /// Low-level Minion constraint.
342    ///
343    /// # See also
344    ///
345    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
346    #[compatible(Minion)]
347    FlatMinusEq(Metadata, Box<Atom>, Box<Atom>),
348
349    /// Ensures that x*y=z.
350    ///
351    /// Low-level Minion constraint.
352    ///
353    /// # See also
354    ///
355    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
356    #[compatible(Minion)]
357    FlatProductEq(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
358
359    /// Ensures that floor(x/y)=z. Always true when y=0.
360    ///
361    /// Low-level Minion constraint.
362    ///
363    /// # See also
364    ///
365    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
366    #[compatible(Minion)]
367    MinionDivEqUndefZero(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
368
369    /// Ensures that x%y=z. Always true when y=0.
370    ///
371    /// Low-level Minion constraint.
372    ///
373    /// # See also
374    ///
375    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
376    #[compatible(Minion)]
377    MinionModuloEqUndefZero(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
378
379    /// Ensures that `x**y = z`.
380    ///
381    /// Low-level Minion constraint.
382    ///
383    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
384    /// is odd and z is -1 if y is even).
385    ///
386    /// # See also
387    ///
388    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
389    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
390    MinionPow(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
391
392    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
393    /// variable.
394    ///
395    /// Low-level Minion constraint.
396    ///
397    /// # See also
398    ///
399    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
400    #[compatible(Minion)]
401    MinionReify(Metadata, Box<Expression>, Atom),
402
403    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
404    /// variable.
405    ///
406    /// Low-level Minion constraint.
407    ///
408    /// # See also
409    ///
410    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
411    #[compatible(Minion)]
412    MinionReifyImply(Metadata, Box<Expression>, Atom),
413
414    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
415    /// intervals {a1,…,a2}, {b1,…,b2} etc.
416    ///
417    /// The list of intervals must be given in numerical order.
418    ///
419    /// Low-level Minion constraint.
420    ///
421    /// # See also
422    ///>
423    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
424    #[compatible(Minion)]
425    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
426
427    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
428    ///
429    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
430    ///
431    /// The list of values must be given in numerical order.
432    ///
433    /// Low-level Minion constraint.
434    ///
435    /// # See also
436    ///
437    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
438    #[compatible(Minion)]
439    MinionWInSet(Metadata, Atom, Vec<i32>),
440
441    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
442    /// in the range `[1..len(vec)]`.
443    ///
444    /// Low-level Minion constraint.
445    ///
446    /// # See also
447    ///
448    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
449    #[compatible(Minion)]
450    MinionElementOne(Metadata, Vec<Atom>, Box<Atom>, Box<Atom>),
451
452    /// Declaration of an auxiliary variable.
453    ///
454    /// As with Savile Row, we semantically distinguish this from `Eq`.
455    #[compatible(Minion)]
456    AuxDeclaration(Metadata, Name, Box<Expression>),
457}
458
459// for the given matrix literal, return a bounded domain from the min to max of applying op to each
460// child expression.
461//
462// Op must be monotonic.
463//
464// Returns none if unbounded
465fn bounded_i32_domain_for_matrix_literal_monotonic(
466    e: &Expression,
467    op: fn(i32, i32) -> Option<i32>,
468    symtab: &SymbolTable,
469) -> Option<Domain> {
470    // only care about the elements, not the indices
471    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
472
473    // fold each element's domain into one using op.
474    //
475    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
476    // the ranges [a1,a2], [b1,b2] will be
477    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
478    //
479    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
480    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
481    // memory (on the hakank_eprime_xkcd test)...
482    //
483    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
484    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
485    //
486    // +,-,/,* are all monotone, so this assumption should be fine for now...
487
488    let expr = exprs.pop()?;
489    let Some(Domain::Int(ranges)) = expr.domain_of(symtab) else {
490        return None;
491    };
492
493    let (mut current_min, mut current_max) = range_vec_bounds_i32(&ranges)?;
494
495    for expr in exprs {
496        let Some(Domain::Int(ranges)) = expr.domain_of(symtab) else {
497            return None;
498        };
499
500        let (min, max) = range_vec_bounds_i32(&ranges)?;
501
502        // all the possible new values for current_min / current_max
503        let minmax = op(min, current_max)?;
504        let minmin = op(min, current_min)?;
505        let maxmin = op(max, current_min)?;
506        let maxmax = op(max, current_max)?;
507        let vals = [minmax, minmin, maxmin, maxmax];
508
509        current_min = *vals
510            .iter()
511            .min()
512            .expect("vals iterator should not be empty, and should have a minimum.");
513        current_max = *vals
514            .iter()
515            .max()
516            .expect("vals iterator should not be empty, and should have a maximum.");
517    }
518
519    if current_min == current_max {
520        Some(Domain::Int(vec![Range::Single(current_min)]))
521    } else {
522        Some(Domain::Int(vec![Range::Bounded(current_min, current_max)]))
523    }
524}
525
526// Returns none if unbounded
527fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
528    let mut min = i32::MAX;
529    let mut max = i32::MIN;
530    for r in ranges {
531        match r {
532            Range::Single(i) => {
533                if *i < min {
534                    min = *i;
535                }
536                if *i > max {
537                    max = *i;
538                }
539            }
540            Range::Bounded(i, j) => {
541                if *i < min {
542                    min = *i;
543                }
544                if *j > max {
545                    max = *j;
546                }
547            }
548            Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
549        }
550    }
551    Some((min, max))
552}
553
554impl Expression {
555    /// Returns the possible values of the expression, recursing to leaf expressions
556    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
557        let ret = match self {
558            Expression::Union(_, a, b) => Some(Domain::Set(
559                SetAttr::None,
560                Box::new(a.domain_of(syms)?.union(&b.domain_of(syms)?).ok()?),
561            )),
562            Expression::Intersect(_, a, b) => Some(Domain::Set(
563                SetAttr::None,
564                Box::new(a.domain_of(syms)?.intersect(&b.domain_of(syms)?).ok()?),
565            )),
566            Expression::In(_, _, _) => Some(Domain::Bool),
567            Expression::Supset(_, _, _) => Some(Domain::Bool),
568            Expression::SupsetEq(_, _, _) => Some(Domain::Bool),
569            Expression::Subset(_, _, _) => Some(Domain::Bool),
570            Expression::SubsetEq(_, _, _) => Some(Domain::Bool),
571            Expression::AbstractLiteral(_, _) => None,
572            Expression::DominanceRelation(_, _) => Some(Domain::Bool),
573            Expression::FromSolution(_, expr) => expr.domain_of(syms),
574            Expression::Comprehension(_, comprehension) => comprehension.domain_of(syms),
575            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
576                match matrix.domain_of(syms)? {
577                    Domain::Matrix(elem_domain, _) => Some(*elem_domain),
578                    Domain::Tuple(_) => None,
579                    Domain::Record(_) => None,
580                    _ => {
581                        bug!("subject of an index operation should support indexing")
582                    }
583                }
584            }
585            Expression::UnsafeSlice(_, matrix, indices)
586            | Expression::SafeSlice(_, matrix, indices) => {
587                let sliced_dimension = indices.iter().position(Option::is_none);
588
589                let Domain::Matrix(elem_domain, index_domains) = matrix.domain_of(syms)? else {
590                    bug!("subject of an index operation should be a matrix");
591                };
592
593                match sliced_dimension {
594                    Some(dimension) => Some(Domain::Matrix(
595                        elem_domain,
596                        vec![index_domains[dimension].clone()],
597                    )),
598
599                    // same as index
600                    None => Some(*elem_domain),
601                }
602            }
603            Expression::InDomain(_, _, _) => Some(Domain::Bool),
604            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.resolve_domain(name)?),
605            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
606                Some(Domain::Int(vec![Range::Single(*n)]))
607            }
608            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::Bool),
609            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
610            Expression::Scope(_, _) => Some(Domain::Bool),
611            Expression::Sum(_, e) => {
612                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y), syms)
613            }
614            Expression::Product(_, e) => {
615                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y), syms)
616            }
617            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(
618                e,
619                |x, y| Some(if x < y { x } else { y }),
620                syms,
621            ),
622            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(
623                e,
624                |x, y| Some(if x > y { x } else { y }),
625                syms,
626            ),
627            Expression::UnsafeDiv(_, a, b) => a
628                .domain_of(syms)?
629                .apply_i32(
630                    // rust integer division is truncating; however, we want to always round down,
631                    // including for negative numbers.
632                    |x, y| {
633                        if y != 0 {
634                            Some((x as f32 / y as f32).floor() as i32)
635                        } else {
636                            None
637                        }
638                    },
639                    &b.domain_of(syms)?,
640                )
641                .ok(),
642            Expression::SafeDiv(_, a, b) => {
643                // rust integer division is truncating; however, we want to always round down
644                // including for negative numbers.
645                let domain = a.domain_of(syms)?.apply_i32(
646                    |x, y| {
647                        if y != 0 {
648                            Some((x as f32 / y as f32).floor() as i32)
649                        } else {
650                            None
651                        }
652                    },
653                    &b.domain_of(syms)?,
654                );
655
656                match domain {
657                    Ok(Domain::Int(ranges)) => {
658                        let mut ranges = ranges;
659                        ranges.push(Range::Single(0));
660                        Some(Domain::Int(ranges))
661                    }
662                    Err(_) => todo!(),
663                    _ => unreachable!(),
664                }
665            }
666            Expression::UnsafeMod(_, a, b) => a
667                .domain_of(syms)?
668                .apply_i32(
669                    |x, y| if y != 0 { Some(x % y) } else { None },
670                    &b.domain_of(syms)?,
671                )
672                .ok(),
673            Expression::SafeMod(_, a, b) => {
674                let domain = a.domain_of(syms)?.apply_i32(
675                    |x, y| if y != 0 { Some(x % y) } else { None },
676                    &b.domain_of(syms)?,
677                );
678
679                match domain {
680                    Ok(Domain::Int(ranges)) => {
681                        let mut ranges = ranges;
682                        ranges.push(Range::Single(0));
683                        Some(Domain::Int(ranges))
684                    }
685                    Err(_) => todo!(),
686                    _ => unreachable!(),
687                }
688            }
689            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
690                .domain_of(syms)?
691                .apply_i32(
692                    |x, y| {
693                        if (x != 0 || y != 0) && y >= 0 {
694                            Some(x.pow(y as u32))
695                        } else {
696                            None
697                        }
698                    },
699                    &b.domain_of(syms)?,
700                )
701                .ok(),
702            Expression::Root(_, _) => None,
703            Expression::Bubble(_, _, _) => None,
704            Expression::AuxDeclaration(_, _, _) => Some(Domain::Bool),
705            Expression::And(_, _) => Some(Domain::Bool),
706            Expression::Not(_, _) => Some(Domain::Bool),
707            Expression::Or(_, _) => Some(Domain::Bool),
708            Expression::Imply(_, _, _) => Some(Domain::Bool),
709            Expression::Iff(_, _, _) => Some(Domain::Bool),
710            Expression::Eq(_, _, _) => Some(Domain::Bool),
711            Expression::Neq(_, _, _) => Some(Domain::Bool),
712            Expression::Geq(_, _, _) => Some(Domain::Bool),
713            Expression::Leq(_, _, _) => Some(Domain::Bool),
714            Expression::Gt(_, _, _) => Some(Domain::Bool),
715            Expression::Lt(_, _, _) => Some(Domain::Bool),
716            Expression::FlatAbsEq(_, _, _) => Some(Domain::Bool),
717            Expression::FlatSumGeq(_, _, _) => Some(Domain::Bool),
718            Expression::FlatSumLeq(_, _, _) => Some(Domain::Bool),
719            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::Bool),
720            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::Bool),
721            Expression::FlatIneq(_, _, _, _) => Some(Domain::Bool),
722            Expression::AllDiff(_, _) => Some(Domain::Bool),
723            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::Bool),
724            Expression::MinionReify(_, _, _) => Some(Domain::Bool),
725            Expression::MinionReifyImply(_, _, _) => Some(Domain::Bool),
726            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::Bool),
727            Expression::MinionWInSet(_, _, _) => Some(Domain::Bool),
728            Expression::MinionElementOne(_, _, _, _) => Some(Domain::Bool),
729            Expression::Neg(_, x) => {
730                let Some(Domain::Int(mut ranges)) = x.domain_of(syms) else {
731                    return None;
732                };
733
734                for range in ranges.iter_mut() {
735                    *range = match range {
736                        Range::Single(x) => Range::Single(-*x),
737                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
738                        Range::UnboundedR(i) => Range::UnboundedL(-*i),
739                        Range::UnboundedL(i) => Range::UnboundedR(-*i),
740                    };
741                }
742
743                Some(Domain::Int(ranges))
744            }
745            Expression::Minus(_, a, b) => a
746                .domain_of(syms)?
747                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?)
748                .ok(),
749            Expression::FlatAllDiff(_, _) => Some(Domain::Bool),
750            Expression::FlatMinusEq(_, _, _) => Some(Domain::Bool),
751            Expression::FlatProductEq(_, _, _, _) => Some(Domain::Bool),
752            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::Bool),
753            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::Bool),
754            Expression::Abs(_, a) => a
755                .domain_of(syms)?
756                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?)
757                .ok(),
758            Expression::MinionPow(_, _, _, _) => Some(Domain::Bool),
759            Expression::ToInt(_, _) => Some(Domain::Int(vec![Range::Bounded(0, 1)])),
760        };
761        match ret {
762            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
763            // Once they support a full domain as we define it, we can remove this conversion
764            Some(Domain::Int(ranges)) if ranges.len() > 1 => {
765                let (min, max) = range_vec_bounds_i32(&ranges)?;
766                Some(Domain::Int(vec![Range::Bounded(min, max)]))
767            }
768            _ => ret,
769        }
770    }
771
772    pub fn get_meta(&self) -> Metadata {
773        let metas: VecDeque<Metadata> = self.children_bi();
774        metas[0].clone()
775    }
776
777    pub fn set_meta(&self, meta: Metadata) {
778        self.transform_bi(Arc::new(move |_| meta.clone()));
779    }
780
781    /// Checks whether this expression is safe.
782    ///
783    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
784    ///
785    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
786    /// safe through the use of bubble rules.
787    pub fn is_safe(&self) -> bool {
788        // TODO: memoise in Metadata
789        for expr in self.universe() {
790            match expr {
791                Expression::UnsafeDiv(_, _, _)
792                | Expression::UnsafeMod(_, _, _)
793                | Expression::UnsafePow(_, _, _)
794                | Expression::UnsafeIndex(_, _, _)
795                | Expression::Bubble(_, _, _)
796                | Expression::UnsafeSlice(_, _, _) => {
797                    return false;
798                }
799                _ => {}
800            }
801        }
802        true
803    }
804
805    pub fn is_clean(&self) -> bool {
806        let metadata = self.get_meta();
807        metadata.clean
808    }
809
810    pub fn set_clean(&mut self, bool_value: bool) {
811        let mut metadata = self.get_meta();
812        metadata.clean = bool_value;
813        self.set_meta(metadata);
814    }
815
816    /// True if the expression is an associative and commutative operator
817    pub fn is_associative_commutative_operator(&self) -> bool {
818        TryInto::<ACOperatorKind>::try_into(self).is_ok()
819    }
820
821    /// True if the expression is a matrix literal.
822    ///
823    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
824    /// [`Expression`].
825    pub fn is_matrix_literal(&self) -> bool {
826        matches!(
827            self,
828            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
829                | Expression::Atomic(
830                    _,
831                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
832                )
833        )
834    }
835
836    /// True iff self and other are both atomic and identical.
837    ///
838    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
839    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
840    /// entire subtrees of `self` and `other`.
841    pub fn identical_atom_to(&self, other: &Expression) -> bool {
842        let atom1: Result<&Atom, _> = self.try_into();
843        let atom2: Result<&Atom, _> = other.try_into();
844
845        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
846            atom2 == atom1
847        } else {
848            false
849        }
850    }
851
852    /// If the expression is a list, returns the inner expressions.
853    ///
854    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
855    /// any explicitly specified domain.
856    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
857        match self {
858            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
859                matrix.unwrap_list().cloned()
860            }
861            Expression::Atomic(
862                _,
863                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
864            ) => matrix.unwrap_list().map(|elems| {
865                elems
866                    .clone()
867                    .into_iter()
868                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
869                    .collect_vec()
870            }),
871            _ => None,
872        }
873    }
874
875    /// If the expression is a matrix, gets it elements and index domain.
876    ///
877    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
878    ///
879    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
880    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
881    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
882    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
883        match self {
884            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
885                Some((elems.clone(), *domain))
886            }
887            Expression::Atomic(
888                _,
889                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
890            ) => Some((
891                elems
892                    .clone()
893                    .into_iter()
894                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
895                    .collect_vec(),
896                *domain,
897            )),
898
899            _ => None,
900        }
901    }
902
903    /// For a Root expression, extends the inner vec with the given vec.
904    ///
905    /// # Panics
906    /// Panics if the expression is not Root.
907    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
908        match self {
909            Expression::Root(meta, mut children) => {
910                children.extend(exprs);
911                Expression::Root(meta, children)
912            }
913            _ => panic!("extend_root called on a non-Root expression"),
914        }
915    }
916
917    /// Converts the expression to a literal, if possible.
918    pub fn into_literal(self) -> Option<Literal> {
919        match self {
920            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
921            Expression::AbstractLiteral(_, abslit) => {
922                Some(Literal::AbstractLiteral(abslit.clone().into_literals()?))
923            }
924            Expression::Neg(_, e) => {
925                let Literal::Int(i) = e.into_literal()? else {
926                    bug!("negated literal should be an int");
927                };
928
929                Some(Literal::Int(-i))
930            }
931
932            _ => None,
933        }
934    }
935
936    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
937    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
938        TryFrom::try_from(self).ok()
939    }
940}
941
942impl From<i32> for Expression {
943    fn from(i: i32) -> Self {
944        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
945    }
946}
947
948impl From<bool> for Expression {
949    fn from(b: bool) -> Self {
950        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
951    }
952}
953
954impl From<Atom> for Expression {
955    fn from(value: Atom) -> Self {
956        Expression::Atomic(Metadata::new(), value)
957    }
958}
959
960impl From<Name> for Expression {
961    fn from(name: Name) -> Self {
962        Expression::Atomic(Metadata::new(), Atom::Reference(name))
963    }
964}
965
966impl From<Box<Expression>> for Expression {
967    fn from(val: Box<Expression>) -> Self {
968        val.as_ref().clone()
969    }
970}
971
972impl Display for Expression {
973    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
974        match &self {
975            Expression::Union(_, box1, box2) => {
976                write!(f, "({} union {})", box1.clone(), box2.clone())
977            }
978            Expression::In(_, e1, e2) => {
979                write!(f, "{} in {}", e1, e2)
980            }
981            Expression::Intersect(_, box1, box2) => {
982                write!(f, "({} intersect {})", box1.clone(), box2.clone())
983            }
984            Expression::Supset(_, box1, box2) => {
985                write!(f, "({} supset {})", box1.clone(), box2.clone())
986            }
987            Expression::SupsetEq(_, box1, box2) => {
988                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
989            }
990            Expression::Subset(_, box1, box2) => {
991                write!(f, "({} subset {})", box1.clone(), box2.clone())
992            }
993            Expression::SubsetEq(_, box1, box2) => {
994                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
995            }
996
997            Expression::AbstractLiteral(_, l) => l.fmt(f),
998            Expression::Comprehension(_, c) => c.fmt(f),
999            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1000                write!(f, "{e1}{}", pretty_vec(e2))
1001            }
1002            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1003                let args = es
1004                    .iter()
1005                    .map(|x| match x {
1006                        Some(x) => format!("{}", x),
1007                        None => "..".into(),
1008                    })
1009                    .join(",");
1010
1011                write!(f, "{e1}[{args}]")
1012            }
1013            Expression::InDomain(_, e, domain) => {
1014                write!(f, "__inDomain({e},{domain})")
1015            }
1016            Expression::Root(_, exprs) => {
1017                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1018            }
1019            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({})", expr),
1020            Expression::FromSolution(_, expr) => write!(f, "FromSolution({})", expr),
1021            Expression::Atomic(_, atom) => atom.fmt(f),
1022            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1023            Expression::Abs(_, a) => write!(f, "|{}|", a),
1024            Expression::Sum(_, e) => {
1025                write!(f, "sum({e})")
1026            }
1027            Expression::Product(_, e) => {
1028                write!(f, "product({e})")
1029            }
1030            Expression::Min(_, e) => {
1031                write!(f, "min({e})")
1032            }
1033            Expression::Max(_, e) => {
1034                write!(f, "max({e})")
1035            }
1036            Expression::Not(_, expr_box) => {
1037                write!(f, "!({})", expr_box.clone())
1038            }
1039            Expression::Or(_, e) => {
1040                write!(f, "or({e})")
1041            }
1042            Expression::And(_, e) => {
1043                write!(f, "and({e})")
1044            }
1045            Expression::Imply(_, box1, box2) => {
1046                write!(f, "({}) -> ({})", box1, box2)
1047            }
1048            Expression::Iff(_, box1, box2) => {
1049                write!(f, "({}) <-> ({})", box1, box2)
1050            }
1051            Expression::Eq(_, box1, box2) => {
1052                write!(f, "({} = {})", box1.clone(), box2.clone())
1053            }
1054            Expression::Neq(_, box1, box2) => {
1055                write!(f, "({} != {})", box1.clone(), box2.clone())
1056            }
1057            Expression::Geq(_, box1, box2) => {
1058                write!(f, "({} >= {})", box1.clone(), box2.clone())
1059            }
1060            Expression::Leq(_, box1, box2) => {
1061                write!(f, "({} <= {})", box1.clone(), box2.clone())
1062            }
1063            Expression::Gt(_, box1, box2) => {
1064                write!(f, "({} > {})", box1.clone(), box2.clone())
1065            }
1066            Expression::Lt(_, box1, box2) => {
1067                write!(f, "({} < {})", box1.clone(), box2.clone())
1068            }
1069            Expression::FlatSumGeq(_, box1, box2) => {
1070                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1071            }
1072            Expression::FlatSumLeq(_, box1, box2) => {
1073                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1074            }
1075            Expression::FlatIneq(_, box1, box2, box3) => write!(
1076                f,
1077                "Ineq({}, {}, {})",
1078                box1.clone(),
1079                box2.clone(),
1080                box3.clone()
1081            ),
1082            Expression::AllDiff(_, e) => {
1083                write!(f, "allDiff({e})")
1084            }
1085            Expression::Bubble(_, box1, box2) => {
1086                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1087            }
1088            Expression::SafeDiv(_, box1, box2) => {
1089                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1090            }
1091            Expression::UnsafeDiv(_, box1, box2) => {
1092                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1093            }
1094            Expression::UnsafePow(_, box1, box2) => {
1095                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1096            }
1097            Expression::SafePow(_, box1, box2) => {
1098                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1099            }
1100            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1101                write!(
1102                    f,
1103                    "DivEq({}, {}, {})",
1104                    box1.clone(),
1105                    box2.clone(),
1106                    box3.clone()
1107                )
1108            }
1109            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1110                write!(
1111                    f,
1112                    "ModEq({}, {}, {})",
1113                    box1.clone(),
1114                    box2.clone(),
1115                    box3.clone()
1116                )
1117            }
1118            Expression::FlatWatchedLiteral(_, x, l) => {
1119                write!(f, "WatchedLiteral({},{})", x, l)
1120            }
1121            Expression::MinionReify(_, box1, box2) => {
1122                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1123            }
1124            Expression::MinionReifyImply(_, box1, box2) => {
1125                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1126            }
1127            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1128                let intervals = intervals.iter().join(",");
1129                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1130            }
1131            Expression::MinionWInSet(_, atom, values) => {
1132                let values = values.iter().join(",");
1133                write!(f, "__minion_w_inset({atom},{values})")
1134            }
1135            Expression::AuxDeclaration(_, n, e) => {
1136                write!(f, "{} =aux {}", n, e.clone())
1137            }
1138            Expression::UnsafeMod(_, a, b) => {
1139                write!(f, "{} % {}", a.clone(), b.clone())
1140            }
1141            Expression::SafeMod(_, a, b) => {
1142                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1143            }
1144            Expression::Neg(_, a) => {
1145                write!(f, "-({})", a.clone())
1146            }
1147            Expression::Minus(_, a, b) => {
1148                write!(f, "({} - {})", a.clone(), b.clone())
1149            }
1150            Expression::FlatAllDiff(_, es) => {
1151                write!(f, "__flat_alldiff({})", pretty_vec(es))
1152            }
1153            Expression::FlatAbsEq(_, a, b) => {
1154                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1155            }
1156            Expression::FlatMinusEq(_, a, b) => {
1157                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1158            }
1159            Expression::FlatProductEq(_, a, b, c) => {
1160                write!(
1161                    f,
1162                    "FlatProductEq({},{},{})",
1163                    a.clone(),
1164                    b.clone(),
1165                    c.clone()
1166                )
1167            }
1168            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1169                write!(
1170                    f,
1171                    "FlatWeightedSumLeq({},{},{})",
1172                    pretty_vec(cs),
1173                    pretty_vec(vs),
1174                    total.clone()
1175                )
1176            }
1177            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1178                write!(
1179                    f,
1180                    "FlatWeightedSumGeq({},{},{})",
1181                    pretty_vec(cs),
1182                    pretty_vec(vs),
1183                    total.clone()
1184                )
1185            }
1186            Expression::MinionPow(_, atom, atom1, atom2) => {
1187                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
1188            }
1189            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1190                let atoms = atoms.iter().join(",");
1191                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1192            }
1193
1194            Expression::ToInt(_, expr) => {
1195                write!(f, "toInt({expr})")
1196            }
1197        }
1198    }
1199}
1200
1201impl Typeable for Expression {
1202    fn return_type(&self) -> Option<ReturnType> {
1203        match self {
1204            Expression::Union(_, subject, _) => {
1205                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1206            }
1207            Expression::Intersect(_, subject, _) => {
1208                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1209            }
1210            Expression::In(_, _, _) => Some(ReturnType::Bool),
1211            Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1212            Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1213            Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1214            Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1215            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1216            Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1217                Some(subject.return_type()?)
1218            }
1219            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1220                Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1221            }
1222            Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1223            Expression::Comprehension(_, _) => None,
1224            Expression::Root(_, _) => Some(ReturnType::Bool),
1225            Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1226            Expression::FromSolution(_, expr) => expr.return_type(),
1227            Expression::Atomic(_, Atom::Literal(lit)) => lit.return_type(),
1228            Expression::Atomic(_, Atom::Reference(_)) => None,
1229            Expression::Scope(_, scope) => scope.return_type(),
1230            Expression::Abs(_, _) => Some(ReturnType::Int),
1231            Expression::Sum(_, _) => Some(ReturnType::Int),
1232            Expression::Product(_, _) => Some(ReturnType::Int),
1233            Expression::Min(_, _) => Some(ReturnType::Int),
1234            Expression::Max(_, _) => Some(ReturnType::Int),
1235            Expression::Not(_, _) => Some(ReturnType::Bool),
1236            Expression::Or(_, _) => Some(ReturnType::Bool),
1237            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1238            Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1239            Expression::And(_, _) => Some(ReturnType::Bool),
1240            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1241            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1242            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1243            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1244            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1245            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1246            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1247            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1248            Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1249            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1250            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1251            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1252            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1253            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1254            Expression::Bubble(_, _, _) => None,
1255            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1256            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1257            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1258            Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1259            Expression::MinionWInSet(_, _, _) => Some(ReturnType::Bool),
1260            Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1261            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1262            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1263            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1264            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1265            Expression::Neg(_, _) => Some(ReturnType::Int),
1266            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1267            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1268            Expression::Minus(_, _, _) => Some(ReturnType::Int),
1269            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1270            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1271            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1272            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1273            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1274            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1275            Expression::ToInt(_, _) => Some(ReturnType::Int),
1276        }
1277    }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282    use std::rc::Rc;
1283
1284    use crate::{ast::declaration::Declaration, matrix_expr};
1285
1286    use super::*;
1287
1288    #[test]
1289    fn test_domain_of_constant_sum() {
1290        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1291        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1292        let sum = Expression::Sum(
1293            Metadata::new(),
1294            Box::new(matrix_expr![c1.clone(), c2.clone()]),
1295        );
1296        assert_eq!(
1297            sum.domain_of(&SymbolTable::new()),
1298            Some(Domain::Int(vec![Range::Single(3)]))
1299        );
1300    }
1301
1302    #[test]
1303    fn test_domain_of_constant_invalid_type() {
1304        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1305        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1306        let sum = Expression::Sum(
1307            Metadata::new(),
1308            Box::new(matrix_expr![c1.clone(), c2.clone()]),
1309        );
1310        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1311    }
1312
1313    #[test]
1314    fn test_domain_of_empty_sum() {
1315        let sum = Expression::Sum(Metadata::new(), Box::new(matrix_expr![]));
1316        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1317    }
1318
1319    #[test]
1320    fn test_domain_of_reference() {
1321        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1322        let mut vars = SymbolTable::new();
1323        vars.insert(Rc::new(Declaration::new_var(
1324            Name::Machine(0),
1325            Domain::Int(vec![Range::Single(1)]),
1326        )))
1327        .unwrap();
1328        assert_eq!(
1329            reference.domain_of(&vars),
1330            Some(Domain::Int(vec![Range::Single(1)]))
1331        );
1332    }
1333
1334    #[test]
1335    fn test_domain_of_reference_not_found() {
1336        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1337        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
1338    }
1339
1340    #[test]
1341    fn test_domain_of_reference_sum_single() {
1342        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1343        let mut vars = SymbolTable::new();
1344        vars.insert(Rc::new(Declaration::new_var(
1345            Name::Machine(0),
1346            Domain::Int(vec![Range::Single(1)]),
1347        )))
1348        .unwrap();
1349        let sum = Expression::Sum(
1350            Metadata::new(),
1351            Box::new(matrix_expr![reference.clone(), reference.clone()]),
1352        );
1353        assert_eq!(
1354            sum.domain_of(&vars),
1355            Some(Domain::Int(vec![Range::Single(2)]))
1356        );
1357    }
1358
1359    #[test]
1360    fn test_domain_of_reference_sum_bounded() {
1361        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1362        let mut vars = SymbolTable::new();
1363        vars.insert(Rc::new(Declaration::new_var(
1364            Name::Machine(0),
1365            Domain::Int(vec![Range::Bounded(1, 2)]),
1366        )));
1367        let sum = Expression::Sum(
1368            Metadata::new(),
1369            Box::new(matrix_expr![reference.clone(), reference.clone()]),
1370        );
1371        assert_eq!(
1372            sum.domain_of(&vars),
1373            Some(Domain::Int(vec![Range::Bounded(2, 4)]))
1374        );
1375    }
1376
1377    #[test]
1378    fn biplate_to_names() {
1379        let expr = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(1)));
1380        let expected_expr = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(2)));
1381        let actual_expr = expr.transform_bi(Arc::new(move |x: Name| match x {
1382            Name::Machine(i) => Name::Machine(i + 1),
1383            n => n,
1384        }));
1385        assert_eq!(actual_expr, expected_expr);
1386
1387        let expr = Expression::And(
1388            Metadata::new(),
1389            Box::new(matrix_expr![Expression::AuxDeclaration(
1390                Metadata::new(),
1391                Name::Machine(0),
1392                Box::new(Expression::Atomic(
1393                    Metadata::new(),
1394                    Atom::Reference(Name::Machine(1))
1395                ))
1396            )]),
1397        );
1398        let expected_expr = Expression::And(
1399            Metadata::new(),
1400            Box::new(matrix_expr![Expression::AuxDeclaration(
1401                Metadata::new(),
1402                Name::Machine(1),
1403                Box::new(Expression::Atomic(
1404                    Metadata::new(),
1405                    Atom::Reference(Name::Machine(2))
1406                ))
1407            )]),
1408        );
1409
1410        let actual_expr = expr.transform_bi(Arc::new(move |x: Name| match x {
1411            Name::Machine(i) => Name::Machine(i + 1),
1412            n => n,
1413        }));
1414        assert_eq!(actual_expr, expected_expr);
1415    }
1416}