1use std::collections::VecDeque;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::ast::literals::AbstractLiteral;
9use crate::ast::literals::Literal;
10use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11use crate::ast::symbol_table::SymbolTable;
12use crate::ast::Atom;
13use crate::ast::Name;
14use crate::ast::ReturnType;
15use crate::ast::SetAttr;
16use crate::bug;
17use crate::metadata::Metadata;
18use enum_compatability_macro::document_compatibility;
19use uniplate::derive::Uniplate;
20use uniplate::{Biplate, Uniplate as _};
21
22use super::ac_operators::ACOperatorKind;
23use super::comprehension::Comprehension;
24use super::records::RecordValue;
25use super::{Domain, Range, SubModel, Typeable};
26
27static_assertions::assert_eq_size!([u8; 96], Expression);
45
46#[document_compatibility]
51#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
52#[uniplate(walk_into=[Atom,SubModel,AbstractLiteral<Expression>,Comprehension])]
53#[biplate(to=Metadata)]
54#[biplate(to=Atom,walk_into=[Expression,AbstractLiteral<Expression>,Vec<Expression>])]
55#[biplate(to=Name,walk_into=[Expression,Atom,AbstractLiteral<Expression>,Vec<Expression>])]
56#[biplate(to=Vec<Expression>)]
57#[biplate(to=Option<Expression>)]
58#[biplate(to=SubModel,walk_into=[Comprehension])]
59#[biplate(to=Comprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>,walk_into=[Atom])]
62#[biplate(to=RecordValue<Expression>,walk_into=[AbstractLiteral<Expression>])]
63#[biplate(to=RecordValue<Literal>,walk_into=[Atom,Literal,AbstractLiteral<Literal>,AbstractLiteral<Expression>])]
64#[biplate(to=Literal,walk_into=[Atom])]
65pub enum Expression {
66 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
67 Root(Metadata, Vec<Expression>),
69
70 Bubble(Metadata, Box<Expression>, Box<Expression>),
73
74 Comprehension(Metadata, Box<Comprehension>),
78
79 DominanceRelation(Metadata, Box<Expression>),
81 FromSolution(Metadata, Box<Expression>),
83
84 Atomic(Metadata, Atom),
85
86 #[compatible(JsonInput)]
90 UnsafeIndex(Metadata, Box<Expression>, Vec<Expression>),
91
92 SafeIndex(Metadata, Box<Expression>, Vec<Expression>),
96
97 #[compatible(JsonInput)]
107 UnsafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
108
109 SafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
113
114 InDomain(Metadata, Box<Expression>, Domain),
120
121 ToInt(Metadata, Box<Expression>),
127
128 Scope(Metadata, Box<SubModel>),
129
130 #[compatible(JsonInput)]
132 Abs(Metadata, Box<Expression>),
133
134 #[compatible(JsonInput)]
136 Sum(Metadata, Box<Expression>),
137
138 #[compatible(JsonInput)]
140 Product(Metadata, Box<Expression>),
141
142 #[compatible(JsonInput)]
144 Min(Metadata, Box<Expression>),
145
146 #[compatible(JsonInput)]
148 Max(Metadata, Box<Expression>),
149
150 #[compatible(JsonInput, SAT)]
152 Not(Metadata, Box<Expression>),
153
154 #[compatible(JsonInput, SAT)]
156 Or(Metadata, Box<Expression>),
157
158 #[compatible(JsonInput, SAT)]
160 And(Metadata, Box<Expression>),
161
162 #[compatible(JsonInput)]
164 Imply(Metadata, Box<Expression>, Box<Expression>),
165
166 #[compatible(JsonInput)]
168 Iff(Metadata, Box<Expression>, Box<Expression>),
169
170 #[compatible(JsonInput)]
171 Union(Metadata, Box<Expression>, Box<Expression>),
172
173 #[compatible(JsonInput)]
174 In(Metadata, Box<Expression>, Box<Expression>),
175
176 #[compatible(JsonInput)]
177 Intersect(Metadata, Box<Expression>, Box<Expression>),
178
179 #[compatible(JsonInput)]
180 Supset(Metadata, Box<Expression>, Box<Expression>),
181
182 #[compatible(JsonInput)]
183 SupsetEq(Metadata, Box<Expression>, Box<Expression>),
184
185 #[compatible(JsonInput)]
186 Subset(Metadata, Box<Expression>, Box<Expression>),
187
188 #[compatible(JsonInput)]
189 SubsetEq(Metadata, Box<Expression>, Box<Expression>),
190
191 #[compatible(JsonInput)]
192 Eq(Metadata, Box<Expression>, Box<Expression>),
193
194 #[compatible(JsonInput)]
195 Neq(Metadata, Box<Expression>, Box<Expression>),
196
197 #[compatible(JsonInput)]
198 Geq(Metadata, Box<Expression>, Box<Expression>),
199
200 #[compatible(JsonInput)]
201 Leq(Metadata, Box<Expression>, Box<Expression>),
202
203 #[compatible(JsonInput)]
204 Gt(Metadata, Box<Expression>, Box<Expression>),
205
206 #[compatible(JsonInput)]
207 Lt(Metadata, Box<Expression>, Box<Expression>),
208
209 SafeDiv(Metadata, Box<Expression>, Box<Expression>),
211
212 #[compatible(JsonInput)]
214 UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
215
216 SafeMod(Metadata, Box<Expression>, Box<Expression>),
218
219 #[compatible(JsonInput)]
221 UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
222
223 #[compatible(JsonInput)]
225 Neg(Metadata, Box<Expression>),
226
227 #[compatible(JsonInput)]
231 UnsafePow(Metadata, Box<Expression>, Box<Expression>),
232
233 SafePow(Metadata, Box<Expression>, Box<Expression>),
235
236 #[compatible(JsonInput)]
238 AllDiff(Metadata, Box<Expression>),
239
240 #[compatible(JsonInput)]
246 Minus(Metadata, Box<Expression>, Box<Expression>),
247
248 #[compatible(Minion)]
256 FlatAbsEq(Metadata, Box<Atom>, Box<Atom>),
257
258 #[compatible(Minion)]
266 FlatAllDiff(Metadata, Vec<Atom>),
267
268 #[compatible(Minion)]
276 FlatSumGeq(Metadata, Vec<Atom>, Atom),
277
278 #[compatible(Minion)]
286 FlatSumLeq(Metadata, Vec<Atom>, Atom),
287
288 #[compatible(Minion)]
296 FlatIneq(Metadata, Box<Atom>, Box<Atom>, Box<Literal>),
297
298 #[compatible(Minion)]
311 FlatWatchedLiteral(Metadata, Name, Literal),
312
313 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Box<Atom>),
325
326 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Box<Atom>),
338
339 #[compatible(Minion)]
347 FlatMinusEq(Metadata, Box<Atom>, Box<Atom>),
348
349 #[compatible(Minion)]
357 FlatProductEq(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
358
359 #[compatible(Minion)]
367 MinionDivEqUndefZero(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
368
369 #[compatible(Minion)]
377 MinionModuloEqUndefZero(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
378
379 MinionPow(Metadata, Box<Atom>, Box<Atom>, Box<Atom>),
391
392 #[compatible(Minion)]
401 MinionReify(Metadata, Box<Expression>, Atom),
402
403 #[compatible(Minion)]
412 MinionReifyImply(Metadata, Box<Expression>, Atom),
413
414 #[compatible(Minion)]
425 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
426
427 #[compatible(Minion)]
439 MinionWInSet(Metadata, Atom, Vec<i32>),
440
441 #[compatible(Minion)]
450 MinionElementOne(Metadata, Vec<Atom>, Box<Atom>, Box<Atom>),
451
452 #[compatible(Minion)]
456 AuxDeclaration(Metadata, Name, Box<Expression>),
457}
458
459fn bounded_i32_domain_for_matrix_literal_monotonic(
466 e: &Expression,
467 op: fn(i32, i32) -> Option<i32>,
468 symtab: &SymbolTable,
469) -> Option<Domain> {
470 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
472
473 let expr = exprs.pop()?;
489 let Some(Domain::Int(ranges)) = expr.domain_of(symtab) else {
490 return None;
491 };
492
493 let (mut current_min, mut current_max) = range_vec_bounds_i32(&ranges)?;
494
495 for expr in exprs {
496 let Some(Domain::Int(ranges)) = expr.domain_of(symtab) else {
497 return None;
498 };
499
500 let (min, max) = range_vec_bounds_i32(&ranges)?;
501
502 let minmax = op(min, current_max)?;
504 let minmin = op(min, current_min)?;
505 let maxmin = op(max, current_min)?;
506 let maxmax = op(max, current_max)?;
507 let vals = [minmax, minmin, maxmin, maxmax];
508
509 current_min = *vals
510 .iter()
511 .min()
512 .expect("vals iterator should not be empty, and should have a minimum.");
513 current_max = *vals
514 .iter()
515 .max()
516 .expect("vals iterator should not be empty, and should have a maximum.");
517 }
518
519 if current_min == current_max {
520 Some(Domain::Int(vec![Range::Single(current_min)]))
521 } else {
522 Some(Domain::Int(vec![Range::Bounded(current_min, current_max)]))
523 }
524}
525
526fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
528 let mut min = i32::MAX;
529 let mut max = i32::MIN;
530 for r in ranges {
531 match r {
532 Range::Single(i) => {
533 if *i < min {
534 min = *i;
535 }
536 if *i > max {
537 max = *i;
538 }
539 }
540 Range::Bounded(i, j) => {
541 if *i < min {
542 min = *i;
543 }
544 if *j > max {
545 max = *j;
546 }
547 }
548 Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
549 }
550 }
551 Some((min, max))
552}
553
554impl Expression {
555 pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
557 let ret = match self {
558 Expression::Union(_, a, b) => Some(Domain::Set(
559 SetAttr::None,
560 Box::new(a.domain_of(syms)?.union(&b.domain_of(syms)?).ok()?),
561 )),
562 Expression::Intersect(_, a, b) => Some(Domain::Set(
563 SetAttr::None,
564 Box::new(a.domain_of(syms)?.intersect(&b.domain_of(syms)?).ok()?),
565 )),
566 Expression::In(_, _, _) => Some(Domain::Bool),
567 Expression::Supset(_, _, _) => Some(Domain::Bool),
568 Expression::SupsetEq(_, _, _) => Some(Domain::Bool),
569 Expression::Subset(_, _, _) => Some(Domain::Bool),
570 Expression::SubsetEq(_, _, _) => Some(Domain::Bool),
571 Expression::AbstractLiteral(_, _) => None,
572 Expression::DominanceRelation(_, _) => Some(Domain::Bool),
573 Expression::FromSolution(_, expr) => expr.domain_of(syms),
574 Expression::Comprehension(_, comprehension) => comprehension.domain_of(syms),
575 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
576 match matrix.domain_of(syms)? {
577 Domain::Matrix(elem_domain, _) => Some(*elem_domain),
578 Domain::Tuple(_) => None,
579 Domain::Record(_) => None,
580 _ => {
581 bug!("subject of an index operation should support indexing")
582 }
583 }
584 }
585 Expression::UnsafeSlice(_, matrix, indices)
586 | Expression::SafeSlice(_, matrix, indices) => {
587 let sliced_dimension = indices.iter().position(Option::is_none);
588
589 let Domain::Matrix(elem_domain, index_domains) = matrix.domain_of(syms)? else {
590 bug!("subject of an index operation should be a matrix");
591 };
592
593 match sliced_dimension {
594 Some(dimension) => Some(Domain::Matrix(
595 elem_domain,
596 vec![index_domains[dimension].clone()],
597 )),
598
599 None => Some(*elem_domain),
601 }
602 }
603 Expression::InDomain(_, _, _) => Some(Domain::Bool),
604 Expression::Atomic(_, Atom::Reference(name)) => Some(syms.resolve_domain(name)?),
605 Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
606 Some(Domain::Int(vec![Range::Single(*n)]))
607 }
608 Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::Bool),
609 Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
610 Expression::Scope(_, _) => Some(Domain::Bool),
611 Expression::Sum(_, e) => {
612 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y), syms)
613 }
614 Expression::Product(_, e) => {
615 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y), syms)
616 }
617 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(
618 e,
619 |x, y| Some(if x < y { x } else { y }),
620 syms,
621 ),
622 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(
623 e,
624 |x, y| Some(if x > y { x } else { y }),
625 syms,
626 ),
627 Expression::UnsafeDiv(_, a, b) => a
628 .domain_of(syms)?
629 .apply_i32(
630 |x, y| {
633 if y != 0 {
634 Some((x as f32 / y as f32).floor() as i32)
635 } else {
636 None
637 }
638 },
639 &b.domain_of(syms)?,
640 )
641 .ok(),
642 Expression::SafeDiv(_, a, b) => {
643 let domain = a.domain_of(syms)?.apply_i32(
646 |x, y| {
647 if y != 0 {
648 Some((x as f32 / y as f32).floor() as i32)
649 } else {
650 None
651 }
652 },
653 &b.domain_of(syms)?,
654 );
655
656 match domain {
657 Ok(Domain::Int(ranges)) => {
658 let mut ranges = ranges;
659 ranges.push(Range::Single(0));
660 Some(Domain::Int(ranges))
661 }
662 Err(_) => todo!(),
663 _ => unreachable!(),
664 }
665 }
666 Expression::UnsafeMod(_, a, b) => a
667 .domain_of(syms)?
668 .apply_i32(
669 |x, y| if y != 0 { Some(x % y) } else { None },
670 &b.domain_of(syms)?,
671 )
672 .ok(),
673 Expression::SafeMod(_, a, b) => {
674 let domain = a.domain_of(syms)?.apply_i32(
675 |x, y| if y != 0 { Some(x % y) } else { None },
676 &b.domain_of(syms)?,
677 );
678
679 match domain {
680 Ok(Domain::Int(ranges)) => {
681 let mut ranges = ranges;
682 ranges.push(Range::Single(0));
683 Some(Domain::Int(ranges))
684 }
685 Err(_) => todo!(),
686 _ => unreachable!(),
687 }
688 }
689 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
690 .domain_of(syms)?
691 .apply_i32(
692 |x, y| {
693 if (x != 0 || y != 0) && y >= 0 {
694 Some(x.pow(y as u32))
695 } else {
696 None
697 }
698 },
699 &b.domain_of(syms)?,
700 )
701 .ok(),
702 Expression::Root(_, _) => None,
703 Expression::Bubble(_, _, _) => None,
704 Expression::AuxDeclaration(_, _, _) => Some(Domain::Bool),
705 Expression::And(_, _) => Some(Domain::Bool),
706 Expression::Not(_, _) => Some(Domain::Bool),
707 Expression::Or(_, _) => Some(Domain::Bool),
708 Expression::Imply(_, _, _) => Some(Domain::Bool),
709 Expression::Iff(_, _, _) => Some(Domain::Bool),
710 Expression::Eq(_, _, _) => Some(Domain::Bool),
711 Expression::Neq(_, _, _) => Some(Domain::Bool),
712 Expression::Geq(_, _, _) => Some(Domain::Bool),
713 Expression::Leq(_, _, _) => Some(Domain::Bool),
714 Expression::Gt(_, _, _) => Some(Domain::Bool),
715 Expression::Lt(_, _, _) => Some(Domain::Bool),
716 Expression::FlatAbsEq(_, _, _) => Some(Domain::Bool),
717 Expression::FlatSumGeq(_, _, _) => Some(Domain::Bool),
718 Expression::FlatSumLeq(_, _, _) => Some(Domain::Bool),
719 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::Bool),
720 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::Bool),
721 Expression::FlatIneq(_, _, _, _) => Some(Domain::Bool),
722 Expression::AllDiff(_, _) => Some(Domain::Bool),
723 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::Bool),
724 Expression::MinionReify(_, _, _) => Some(Domain::Bool),
725 Expression::MinionReifyImply(_, _, _) => Some(Domain::Bool),
726 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::Bool),
727 Expression::MinionWInSet(_, _, _) => Some(Domain::Bool),
728 Expression::MinionElementOne(_, _, _, _) => Some(Domain::Bool),
729 Expression::Neg(_, x) => {
730 let Some(Domain::Int(mut ranges)) = x.domain_of(syms) else {
731 return None;
732 };
733
734 for range in ranges.iter_mut() {
735 *range = match range {
736 Range::Single(x) => Range::Single(-*x),
737 Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
738 Range::UnboundedR(i) => Range::UnboundedL(-*i),
739 Range::UnboundedL(i) => Range::UnboundedR(-*i),
740 };
741 }
742
743 Some(Domain::Int(ranges))
744 }
745 Expression::Minus(_, a, b) => a
746 .domain_of(syms)?
747 .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?)
748 .ok(),
749 Expression::FlatAllDiff(_, _) => Some(Domain::Bool),
750 Expression::FlatMinusEq(_, _, _) => Some(Domain::Bool),
751 Expression::FlatProductEq(_, _, _, _) => Some(Domain::Bool),
752 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::Bool),
753 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::Bool),
754 Expression::Abs(_, a) => a
755 .domain_of(syms)?
756 .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?)
757 .ok(),
758 Expression::MinionPow(_, _, _, _) => Some(Domain::Bool),
759 Expression::ToInt(_, _) => Some(Domain::Int(vec![Range::Bounded(0, 1)])),
760 };
761 match ret {
762 Some(Domain::Int(ranges)) if ranges.len() > 1 => {
765 let (min, max) = range_vec_bounds_i32(&ranges)?;
766 Some(Domain::Int(vec![Range::Bounded(min, max)]))
767 }
768 _ => ret,
769 }
770 }
771
772 pub fn get_meta(&self) -> Metadata {
773 let metas: VecDeque<Metadata> = self.children_bi();
774 metas[0].clone()
775 }
776
777 pub fn set_meta(&self, meta: Metadata) {
778 self.transform_bi(Arc::new(move |_| meta.clone()));
779 }
780
781 pub fn is_safe(&self) -> bool {
788 for expr in self.universe() {
790 match expr {
791 Expression::UnsafeDiv(_, _, _)
792 | Expression::UnsafeMod(_, _, _)
793 | Expression::UnsafePow(_, _, _)
794 | Expression::UnsafeIndex(_, _, _)
795 | Expression::Bubble(_, _, _)
796 | Expression::UnsafeSlice(_, _, _) => {
797 return false;
798 }
799 _ => {}
800 }
801 }
802 true
803 }
804
805 pub fn is_clean(&self) -> bool {
806 let metadata = self.get_meta();
807 metadata.clean
808 }
809
810 pub fn set_clean(&mut self, bool_value: bool) {
811 let mut metadata = self.get_meta();
812 metadata.clean = bool_value;
813 self.set_meta(metadata);
814 }
815
816 pub fn is_associative_commutative_operator(&self) -> bool {
818 TryInto::<ACOperatorKind>::try_into(self).is_ok()
819 }
820
821 pub fn is_matrix_literal(&self) -> bool {
826 matches!(
827 self,
828 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
829 | Expression::Atomic(
830 _,
831 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
832 )
833 )
834 }
835
836 pub fn identical_atom_to(&self, other: &Expression) -> bool {
842 let atom1: Result<&Atom, _> = self.try_into();
843 let atom2: Result<&Atom, _> = other.try_into();
844
845 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
846 atom2 == atom1
847 } else {
848 false
849 }
850 }
851
852 pub fn unwrap_list(self) -> Option<Vec<Expression>> {
857 match self {
858 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
859 matrix.unwrap_list().cloned()
860 }
861 Expression::Atomic(
862 _,
863 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
864 ) => matrix.unwrap_list().map(|elems| {
865 elems
866 .clone()
867 .into_iter()
868 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
869 .collect_vec()
870 }),
871 _ => None,
872 }
873 }
874
875 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
883 match self {
884 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
885 Some((elems.clone(), *domain))
886 }
887 Expression::Atomic(
888 _,
889 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
890 ) => Some((
891 elems
892 .clone()
893 .into_iter()
894 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
895 .collect_vec(),
896 *domain,
897 )),
898
899 _ => None,
900 }
901 }
902
903 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
908 match self {
909 Expression::Root(meta, mut children) => {
910 children.extend(exprs);
911 Expression::Root(meta, children)
912 }
913 _ => panic!("extend_root called on a non-Root expression"),
914 }
915 }
916
917 pub fn into_literal(self) -> Option<Literal> {
919 match self {
920 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
921 Expression::AbstractLiteral(_, abslit) => {
922 Some(Literal::AbstractLiteral(abslit.clone().into_literals()?))
923 }
924 Expression::Neg(_, e) => {
925 let Literal::Int(i) = e.into_literal()? else {
926 bug!("negated literal should be an int");
927 };
928
929 Some(Literal::Int(-i))
930 }
931
932 _ => None,
933 }
934 }
935
936 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
938 TryFrom::try_from(self).ok()
939 }
940}
941
942impl From<i32> for Expression {
943 fn from(i: i32) -> Self {
944 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
945 }
946}
947
948impl From<bool> for Expression {
949 fn from(b: bool) -> Self {
950 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
951 }
952}
953
954impl From<Atom> for Expression {
955 fn from(value: Atom) -> Self {
956 Expression::Atomic(Metadata::new(), value)
957 }
958}
959
960impl From<Name> for Expression {
961 fn from(name: Name) -> Self {
962 Expression::Atomic(Metadata::new(), Atom::Reference(name))
963 }
964}
965
966impl From<Box<Expression>> for Expression {
967 fn from(val: Box<Expression>) -> Self {
968 val.as_ref().clone()
969 }
970}
971
972impl Display for Expression {
973 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
974 match &self {
975 Expression::Union(_, box1, box2) => {
976 write!(f, "({} union {})", box1.clone(), box2.clone())
977 }
978 Expression::In(_, e1, e2) => {
979 write!(f, "{} in {}", e1, e2)
980 }
981 Expression::Intersect(_, box1, box2) => {
982 write!(f, "({} intersect {})", box1.clone(), box2.clone())
983 }
984 Expression::Supset(_, box1, box2) => {
985 write!(f, "({} supset {})", box1.clone(), box2.clone())
986 }
987 Expression::SupsetEq(_, box1, box2) => {
988 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
989 }
990 Expression::Subset(_, box1, box2) => {
991 write!(f, "({} subset {})", box1.clone(), box2.clone())
992 }
993 Expression::SubsetEq(_, box1, box2) => {
994 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
995 }
996
997 Expression::AbstractLiteral(_, l) => l.fmt(f),
998 Expression::Comprehension(_, c) => c.fmt(f),
999 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1000 write!(f, "{e1}{}", pretty_vec(e2))
1001 }
1002 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1003 let args = es
1004 .iter()
1005 .map(|x| match x {
1006 Some(x) => format!("{}", x),
1007 None => "..".into(),
1008 })
1009 .join(",");
1010
1011 write!(f, "{e1}[{args}]")
1012 }
1013 Expression::InDomain(_, e, domain) => {
1014 write!(f, "__inDomain({e},{domain})")
1015 }
1016 Expression::Root(_, exprs) => {
1017 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1018 }
1019 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({})", expr),
1020 Expression::FromSolution(_, expr) => write!(f, "FromSolution({})", expr),
1021 Expression::Atomic(_, atom) => atom.fmt(f),
1022 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1023 Expression::Abs(_, a) => write!(f, "|{}|", a),
1024 Expression::Sum(_, e) => {
1025 write!(f, "sum({e})")
1026 }
1027 Expression::Product(_, e) => {
1028 write!(f, "product({e})")
1029 }
1030 Expression::Min(_, e) => {
1031 write!(f, "min({e})")
1032 }
1033 Expression::Max(_, e) => {
1034 write!(f, "max({e})")
1035 }
1036 Expression::Not(_, expr_box) => {
1037 write!(f, "!({})", expr_box.clone())
1038 }
1039 Expression::Or(_, e) => {
1040 write!(f, "or({e})")
1041 }
1042 Expression::And(_, e) => {
1043 write!(f, "and({e})")
1044 }
1045 Expression::Imply(_, box1, box2) => {
1046 write!(f, "({}) -> ({})", box1, box2)
1047 }
1048 Expression::Iff(_, box1, box2) => {
1049 write!(f, "({}) <-> ({})", box1, box2)
1050 }
1051 Expression::Eq(_, box1, box2) => {
1052 write!(f, "({} = {})", box1.clone(), box2.clone())
1053 }
1054 Expression::Neq(_, box1, box2) => {
1055 write!(f, "({} != {})", box1.clone(), box2.clone())
1056 }
1057 Expression::Geq(_, box1, box2) => {
1058 write!(f, "({} >= {})", box1.clone(), box2.clone())
1059 }
1060 Expression::Leq(_, box1, box2) => {
1061 write!(f, "({} <= {})", box1.clone(), box2.clone())
1062 }
1063 Expression::Gt(_, box1, box2) => {
1064 write!(f, "({} > {})", box1.clone(), box2.clone())
1065 }
1066 Expression::Lt(_, box1, box2) => {
1067 write!(f, "({} < {})", box1.clone(), box2.clone())
1068 }
1069 Expression::FlatSumGeq(_, box1, box2) => {
1070 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1071 }
1072 Expression::FlatSumLeq(_, box1, box2) => {
1073 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1074 }
1075 Expression::FlatIneq(_, box1, box2, box3) => write!(
1076 f,
1077 "Ineq({}, {}, {})",
1078 box1.clone(),
1079 box2.clone(),
1080 box3.clone()
1081 ),
1082 Expression::AllDiff(_, e) => {
1083 write!(f, "allDiff({e})")
1084 }
1085 Expression::Bubble(_, box1, box2) => {
1086 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1087 }
1088 Expression::SafeDiv(_, box1, box2) => {
1089 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1090 }
1091 Expression::UnsafeDiv(_, box1, box2) => {
1092 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1093 }
1094 Expression::UnsafePow(_, box1, box2) => {
1095 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1096 }
1097 Expression::SafePow(_, box1, box2) => {
1098 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1099 }
1100 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1101 write!(
1102 f,
1103 "DivEq({}, {}, {})",
1104 box1.clone(),
1105 box2.clone(),
1106 box3.clone()
1107 )
1108 }
1109 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1110 write!(
1111 f,
1112 "ModEq({}, {}, {})",
1113 box1.clone(),
1114 box2.clone(),
1115 box3.clone()
1116 )
1117 }
1118 Expression::FlatWatchedLiteral(_, x, l) => {
1119 write!(f, "WatchedLiteral({},{})", x, l)
1120 }
1121 Expression::MinionReify(_, box1, box2) => {
1122 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1123 }
1124 Expression::MinionReifyImply(_, box1, box2) => {
1125 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1126 }
1127 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1128 let intervals = intervals.iter().join(",");
1129 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1130 }
1131 Expression::MinionWInSet(_, atom, values) => {
1132 let values = values.iter().join(",");
1133 write!(f, "__minion_w_inset({atom},{values})")
1134 }
1135 Expression::AuxDeclaration(_, n, e) => {
1136 write!(f, "{} =aux {}", n, e.clone())
1137 }
1138 Expression::UnsafeMod(_, a, b) => {
1139 write!(f, "{} % {}", a.clone(), b.clone())
1140 }
1141 Expression::SafeMod(_, a, b) => {
1142 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1143 }
1144 Expression::Neg(_, a) => {
1145 write!(f, "-({})", a.clone())
1146 }
1147 Expression::Minus(_, a, b) => {
1148 write!(f, "({} - {})", a.clone(), b.clone())
1149 }
1150 Expression::FlatAllDiff(_, es) => {
1151 write!(f, "__flat_alldiff({})", pretty_vec(es))
1152 }
1153 Expression::FlatAbsEq(_, a, b) => {
1154 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1155 }
1156 Expression::FlatMinusEq(_, a, b) => {
1157 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1158 }
1159 Expression::FlatProductEq(_, a, b, c) => {
1160 write!(
1161 f,
1162 "FlatProductEq({},{},{})",
1163 a.clone(),
1164 b.clone(),
1165 c.clone()
1166 )
1167 }
1168 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1169 write!(
1170 f,
1171 "FlatWeightedSumLeq({},{},{})",
1172 pretty_vec(cs),
1173 pretty_vec(vs),
1174 total.clone()
1175 )
1176 }
1177 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1178 write!(
1179 f,
1180 "FlatWeightedSumGeq({},{},{})",
1181 pretty_vec(cs),
1182 pretty_vec(vs),
1183 total.clone()
1184 )
1185 }
1186 Expression::MinionPow(_, atom, atom1, atom2) => {
1187 write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
1188 }
1189 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1190 let atoms = atoms.iter().join(",");
1191 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1192 }
1193
1194 Expression::ToInt(_, expr) => {
1195 write!(f, "toInt({expr})")
1196 }
1197 }
1198 }
1199}
1200
1201impl Typeable for Expression {
1202 fn return_type(&self) -> Option<ReturnType> {
1203 match self {
1204 Expression::Union(_, subject, _) => {
1205 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1206 }
1207 Expression::Intersect(_, subject, _) => {
1208 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1209 }
1210 Expression::In(_, _, _) => Some(ReturnType::Bool),
1211 Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1212 Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1213 Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1214 Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1215 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1216 Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1217 Some(subject.return_type()?)
1218 }
1219 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1220 Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1221 }
1222 Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1223 Expression::Comprehension(_, _) => None,
1224 Expression::Root(_, _) => Some(ReturnType::Bool),
1225 Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1226 Expression::FromSolution(_, expr) => expr.return_type(),
1227 Expression::Atomic(_, Atom::Literal(lit)) => lit.return_type(),
1228 Expression::Atomic(_, Atom::Reference(_)) => None,
1229 Expression::Scope(_, scope) => scope.return_type(),
1230 Expression::Abs(_, _) => Some(ReturnType::Int),
1231 Expression::Sum(_, _) => Some(ReturnType::Int),
1232 Expression::Product(_, _) => Some(ReturnType::Int),
1233 Expression::Min(_, _) => Some(ReturnType::Int),
1234 Expression::Max(_, _) => Some(ReturnType::Int),
1235 Expression::Not(_, _) => Some(ReturnType::Bool),
1236 Expression::Or(_, _) => Some(ReturnType::Bool),
1237 Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1238 Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1239 Expression::And(_, _) => Some(ReturnType::Bool),
1240 Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1241 Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1242 Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1243 Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1244 Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1245 Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1246 Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1247 Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1248 Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1249 Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1250 Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1251 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1252 Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1253 Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1254 Expression::Bubble(_, _, _) => None,
1255 Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1256 Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1257 Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1258 Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1259 Expression::MinionWInSet(_, _, _) => Some(ReturnType::Bool),
1260 Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1261 Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1262 Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1263 Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1264 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1265 Expression::Neg(_, _) => Some(ReturnType::Int),
1266 Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1267 Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1268 Expression::Minus(_, _, _) => Some(ReturnType::Int),
1269 Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1270 Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1271 Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1272 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1273 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1274 Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1275 Expression::ToInt(_, _) => Some(ReturnType::Int),
1276 }
1277 }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282 use std::rc::Rc;
1283
1284 use crate::{ast::declaration::Declaration, matrix_expr};
1285
1286 use super::*;
1287
1288 #[test]
1289 fn test_domain_of_constant_sum() {
1290 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1291 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1292 let sum = Expression::Sum(
1293 Metadata::new(),
1294 Box::new(matrix_expr![c1.clone(), c2.clone()]),
1295 );
1296 assert_eq!(
1297 sum.domain_of(&SymbolTable::new()),
1298 Some(Domain::Int(vec![Range::Single(3)]))
1299 );
1300 }
1301
1302 #[test]
1303 fn test_domain_of_constant_invalid_type() {
1304 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1305 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1306 let sum = Expression::Sum(
1307 Metadata::new(),
1308 Box::new(matrix_expr![c1.clone(), c2.clone()]),
1309 );
1310 assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1311 }
1312
1313 #[test]
1314 fn test_domain_of_empty_sum() {
1315 let sum = Expression::Sum(Metadata::new(), Box::new(matrix_expr![]));
1316 assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1317 }
1318
1319 #[test]
1320 fn test_domain_of_reference() {
1321 let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1322 let mut vars = SymbolTable::new();
1323 vars.insert(Rc::new(Declaration::new_var(
1324 Name::Machine(0),
1325 Domain::Int(vec![Range::Single(1)]),
1326 )))
1327 .unwrap();
1328 assert_eq!(
1329 reference.domain_of(&vars),
1330 Some(Domain::Int(vec![Range::Single(1)]))
1331 );
1332 }
1333
1334 #[test]
1335 fn test_domain_of_reference_not_found() {
1336 let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1337 assert_eq!(reference.domain_of(&SymbolTable::new()), None);
1338 }
1339
1340 #[test]
1341 fn test_domain_of_reference_sum_single() {
1342 let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1343 let mut vars = SymbolTable::new();
1344 vars.insert(Rc::new(Declaration::new_var(
1345 Name::Machine(0),
1346 Domain::Int(vec![Range::Single(1)]),
1347 )))
1348 .unwrap();
1349 let sum = Expression::Sum(
1350 Metadata::new(),
1351 Box::new(matrix_expr![reference.clone(), reference.clone()]),
1352 );
1353 assert_eq!(
1354 sum.domain_of(&vars),
1355 Some(Domain::Int(vec![Range::Single(2)]))
1356 );
1357 }
1358
1359 #[test]
1360 fn test_domain_of_reference_sum_bounded() {
1361 let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(0)));
1362 let mut vars = SymbolTable::new();
1363 vars.insert(Rc::new(Declaration::new_var(
1364 Name::Machine(0),
1365 Domain::Int(vec![Range::Bounded(1, 2)]),
1366 )));
1367 let sum = Expression::Sum(
1368 Metadata::new(),
1369 Box::new(matrix_expr![reference.clone(), reference.clone()]),
1370 );
1371 assert_eq!(
1372 sum.domain_of(&vars),
1373 Some(Domain::Int(vec![Range::Bounded(2, 4)]))
1374 );
1375 }
1376
1377 #[test]
1378 fn biplate_to_names() {
1379 let expr = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(1)));
1380 let expected_expr = Expression::Atomic(Metadata::new(), Atom::Reference(Name::Machine(2)));
1381 let actual_expr = expr.transform_bi(Arc::new(move |x: Name| match x {
1382 Name::Machine(i) => Name::Machine(i + 1),
1383 n => n,
1384 }));
1385 assert_eq!(actual_expr, expected_expr);
1386
1387 let expr = Expression::And(
1388 Metadata::new(),
1389 Box::new(matrix_expr![Expression::AuxDeclaration(
1390 Metadata::new(),
1391 Name::Machine(0),
1392 Box::new(Expression::Atomic(
1393 Metadata::new(),
1394 Atom::Reference(Name::Machine(1))
1395 ))
1396 )]),
1397 );
1398 let expected_expr = Expression::And(
1399 Metadata::new(),
1400 Box::new(matrix_expr![Expression::AuxDeclaration(
1401 Metadata::new(),
1402 Name::Machine(1),
1403 Box::new(Expression::Atomic(
1404 Metadata::new(),
1405 Atom::Reference(Name::Machine(2))
1406 ))
1407 )]),
1408 );
1409
1410 let actual_expr = expr.transform_bi(Arc::new(move |x: Name| match x {
1411 Name::Machine(i) => Name::Machine(i + 1),
1412 n => n,
1413 }));
1414 assert_eq!(actual_expr, expected_expr);
1415 }
1416}