1use crate::ast::declaration::serde::DeclarationPtrAsId;
2use serde_with::serde_as;
3use std::collections::{HashSet, VecDeque};
4use std::fmt::{Display, Formatter};
5use tracing::trace;
6
7use crate::ast::Atom;
8use crate::ast::Moo;
9use crate::ast::Name;
10use crate::ast::ReturnType;
11use crate::ast::SetAttr;
12use crate::ast::literals::AbstractLiteral;
13use crate::ast::literals::Literal;
14use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
15use crate::bug;
16use crate::metadata::Metadata;
17use conjure_cp_enum_compatibility_macro::document_compatibility;
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20
21use uniplate::{Biplate, Uniplate};
22
23use super::ac_operators::ACOperatorKind;
24use super::categories::{Category, CategoryOf};
25use super::comprehension::Comprehension;
26use super::domains::HasDomain as _;
27use super::records::RecordValue;
28use super::{DeclarationPtr, Domain, Range, SubModel, Typeable};
29
30static_assertions::assert_eq_size!([u8; 104], Expression);
53
54#[document_compatibility]
59#[serde_as]
60#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
61#[biplate(to=Metadata)]
62#[biplate(to=Atom)]
63#[biplate(to=DeclarationPtr)]
64#[biplate(to=Name)]
65#[biplate(to=Vec<Expression>)]
66#[biplate(to=Option<Expression>)]
67#[biplate(to=SubModel)]
68#[biplate(to=Comprehension)]
69#[biplate(to=AbstractLiteral<Expression>)]
70#[biplate(to=AbstractLiteral<Literal>)]
71#[biplate(to=RecordValue<Expression>)]
72#[biplate(to=RecordValue<Literal>)]
73#[biplate(to=Literal)]
74pub enum Expression {
75 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
76 Root(Metadata, Vec<Expression>),
78
79 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
82
83 Comprehension(Metadata, Moo<Comprehension>),
87
88 DominanceRelation(Metadata, Moo<Expression>),
90 FromSolution(Metadata, Moo<Expression>),
92
93 Atomic(Metadata, Atom),
94
95 #[compatible(JsonInput)]
99 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
100
101 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
105
106 #[compatible(JsonInput)]
116 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
117
118 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
122
123 InDomain(Metadata, Moo<Expression>, Domain),
129
130 ToInt(Metadata, Moo<Expression>),
136
137 Scope(Metadata, Moo<SubModel>),
138
139 #[compatible(JsonInput)]
141 Abs(Metadata, Moo<Expression>),
142
143 #[compatible(JsonInput)]
145 Sum(Metadata, Moo<Expression>),
146
147 #[compatible(JsonInput)]
149 Product(Metadata, Moo<Expression>),
150
151 #[compatible(JsonInput)]
153 Min(Metadata, Moo<Expression>),
154
155 #[compatible(JsonInput)]
157 Max(Metadata, Moo<Expression>),
158
159 #[compatible(JsonInput, SAT)]
161 Not(Metadata, Moo<Expression>),
162
163 #[compatible(JsonInput, SAT)]
165 Or(Metadata, Moo<Expression>),
166
167 #[compatible(JsonInput, SAT)]
169 And(Metadata, Moo<Expression>),
170
171 #[compatible(JsonInput)]
173 Imply(Metadata, Moo<Expression>, Moo<Expression>),
174
175 #[compatible(JsonInput)]
177 Iff(Metadata, Moo<Expression>, Moo<Expression>),
178
179 #[compatible(JsonInput)]
180 Union(Metadata, Moo<Expression>, Moo<Expression>),
181
182 #[compatible(JsonInput)]
183 In(Metadata, Moo<Expression>, Moo<Expression>),
184
185 #[compatible(JsonInput)]
186 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
187
188 #[compatible(JsonInput)]
189 Supset(Metadata, Moo<Expression>, Moo<Expression>),
190
191 #[compatible(JsonInput)]
192 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
193
194 #[compatible(JsonInput)]
195 Subset(Metadata, Moo<Expression>, Moo<Expression>),
196
197 #[compatible(JsonInput)]
198 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
199
200 #[compatible(JsonInput)]
201 Eq(Metadata, Moo<Expression>, Moo<Expression>),
202
203 #[compatible(JsonInput)]
204 Neq(Metadata, Moo<Expression>, Moo<Expression>),
205
206 #[compatible(JsonInput)]
207 Geq(Metadata, Moo<Expression>, Moo<Expression>),
208
209 #[compatible(JsonInput)]
210 Leq(Metadata, Moo<Expression>, Moo<Expression>),
211
212 #[compatible(JsonInput)]
213 Gt(Metadata, Moo<Expression>, Moo<Expression>),
214
215 #[compatible(JsonInput)]
216 Lt(Metadata, Moo<Expression>, Moo<Expression>),
217
218 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
220
221 #[compatible(JsonInput)]
223 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
224
225 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
227
228 #[compatible(JsonInput)]
230 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
231
232 #[compatible(JsonInput)]
234 Neg(Metadata, Moo<Expression>),
235
236 #[compatible(JsonInput)]
240 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
241
242 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
244
245 #[compatible(JsonInput)]
247 AllDiff(Metadata, Moo<Expression>),
248
249 #[compatible(JsonInput)]
255 Minus(Metadata, Moo<Expression>, Moo<Expression>),
256
257 #[compatible(Minion)]
265 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
266
267 #[compatible(Minion)]
275 FlatAllDiff(Metadata, Vec<Atom>),
276
277 #[compatible(Minion)]
285 FlatSumGeq(Metadata, Vec<Atom>, Atom),
286
287 #[compatible(Minion)]
295 FlatSumLeq(Metadata, Vec<Atom>, Atom),
296
297 #[compatible(Minion)]
305 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
306
307 #[compatible(Minion)]
320 FlatWatchedLiteral(
321 Metadata,
322 #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
323 Literal,
324 ),
325
326 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
338
339 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
351
352 #[compatible(Minion)]
360 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
361
362 #[compatible(Minion)]
370 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
371
372 #[compatible(Minion)]
380 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
381
382 #[compatible(Minion)]
390 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
391
392 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
404
405 #[compatible(Minion)]
414 MinionReify(Metadata, Moo<Expression>, Atom),
415
416 #[compatible(Minion)]
425 MinionReifyImply(Metadata, Moo<Expression>, Atom),
426
427 #[compatible(Minion)]
438 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
439
440 #[compatible(Minion)]
452 MinionWInSet(Metadata, Atom, Vec<i32>),
453
454 #[compatible(Minion)]
463 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
464
465 #[compatible(Minion)]
469 AuxDeclaration(
470 Metadata,
471 #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
472 Moo<Expression>,
473 ),
474}
475
476fn bounded_i32_domain_for_matrix_literal_monotonic(
483 e: &Expression,
484 op: fn(i32, i32) -> Option<i32>,
485) -> Option<Domain> {
486 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
488
489 let expr = exprs.pop()?;
505 let Some(Domain::Int(ranges)) = expr.domain_of() else {
506 return None;
507 };
508
509 let (mut current_min, mut current_max) = range_vec_bounds_i32(&ranges)?;
510
511 for expr in exprs {
512 let Some(Domain::Int(ranges)) = expr.domain_of() else {
513 return None;
514 };
515
516 let (min, max) = range_vec_bounds_i32(&ranges)?;
517
518 let minmax = op(min, current_max)?;
520 let minmin = op(min, current_min)?;
521 let maxmin = op(max, current_min)?;
522 let maxmax = op(max, current_max)?;
523 let vals = [minmax, minmin, maxmin, maxmax];
524
525 current_min = *vals
526 .iter()
527 .min()
528 .expect("vals iterator should not be empty, and should have a minimum.");
529 current_max = *vals
530 .iter()
531 .max()
532 .expect("vals iterator should not be empty, and should have a maximum.");
533 }
534
535 if current_min == current_max {
536 Some(Domain::Int(vec![Range::Single(current_min)]))
537 } else {
538 Some(Domain::Int(vec![Range::Bounded(current_min, current_max)]))
539 }
540}
541
542fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
544 let mut min = i32::MAX;
545 let mut max = i32::MIN;
546 for r in ranges {
547 match r {
548 Range::Single(i) => {
549 if *i < min {
550 min = *i;
551 }
552 if *i > max {
553 max = *i;
554 }
555 }
556 Range::Bounded(i, j) => {
557 if *i < min {
558 min = *i;
559 }
560 if *j > max {
561 max = *j;
562 }
563 }
564 Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
565 }
566 }
567 Some((min, max))
568}
569
570impl Expression {
571 pub fn domain_of(&self) -> Option<Domain> {
573 let ret = match self {
574 Expression::Union(_, a, b) => Some(Domain::Set(
575 SetAttr::None,
576 Box::new(a.domain_of()?.union(&b.domain_of()?).ok()?),
577 )),
578 Expression::Intersect(_, a, b) => Some(Domain::Set(
579 SetAttr::None,
580 Box::new(a.domain_of()?.intersect(&b.domain_of()?).ok()?),
581 )),
582 Expression::In(_, _, _) => Some(Domain::Bool),
583 Expression::Supset(_, _, _) => Some(Domain::Bool),
584 Expression::SupsetEq(_, _, _) => Some(Domain::Bool),
585 Expression::Subset(_, _, _) => Some(Domain::Bool),
586 Expression::SubsetEq(_, _, _) => Some(Domain::Bool),
587 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
588 Expression::DominanceRelation(_, _) => Some(Domain::Bool),
589 Expression::FromSolution(_, expr) => expr.domain_of(),
590 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
591 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
592 match matrix.domain_of()? {
593 Domain::Matrix(elem_domain, _) => Some(*elem_domain),
594 Domain::Tuple(_) => None,
595 Domain::Record(_) => None,
596 _ => {
597 bug!("subject of an index operation should support indexing")
598 }
599 }
600 }
601 Expression::UnsafeSlice(_, matrix, indices)
602 | Expression::SafeSlice(_, matrix, indices) => {
603 let sliced_dimension = indices.iter().position(Option::is_none);
604
605 let Domain::Matrix(elem_domain, index_domains) = matrix.domain_of()? else {
606 bug!("subject of an index operation should be a matrix");
607 };
608
609 match sliced_dimension {
610 Some(dimension) => Some(Domain::Matrix(
611 elem_domain,
612 vec![index_domains[dimension].clone()],
613 )),
614
615 None => Some(*elem_domain),
617 }
618 }
619 Expression::InDomain(_, _, _) => Some(Domain::Bool),
620 Expression::Atomic(_, Atom::Reference(ptr)) => ptr.domain(),
621 Expression::Atomic(_, atom) => Some(atom.domain_of()),
622 Expression::Scope(_, _) => Some(Domain::Bool),
623 Expression::Sum(_, e) => {
624 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
625 }
626 Expression::Product(_, e) => {
627 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
628 }
629 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
630 Some(if x < y { x } else { y })
631 }),
632 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
633 Some(if x > y { x } else { y })
634 }),
635 Expression::UnsafeDiv(_, a, b) => a
636 .domain_of()?
637 .apply_i32(
638 |x, y| {
641 if y != 0 {
642 Some((x as f32 / y as f32).floor() as i32)
643 } else {
644 None
645 }
646 },
647 &b.domain_of()?,
648 )
649 .ok(),
650 Expression::SafeDiv(_, a, b) => {
651 let domain = a.domain_of()?.apply_i32(
654 |x, y| {
655 if y != 0 {
656 Some((x as f32 / y as f32).floor() as i32)
657 } else {
658 None
659 }
660 },
661 &b.domain_of()?,
662 );
663
664 match domain {
665 Ok(Domain::Int(ranges)) => {
666 let mut ranges = ranges;
667 ranges.push(Range::Single(0));
668 Some(Domain::Int(ranges))
669 }
670 Err(_) => todo!(),
671 _ => unreachable!(),
672 }
673 }
674 Expression::UnsafeMod(_, a, b) => a
675 .domain_of()?
676 .apply_i32(
677 |x, y| if y != 0 { Some(x % y) } else { None },
678 &b.domain_of()?,
679 )
680 .ok(),
681 Expression::SafeMod(_, a, b) => {
682 let domain = a.domain_of()?.apply_i32(
683 |x, y| if y != 0 { Some(x % y) } else { None },
684 &b.domain_of()?,
685 );
686
687 match domain {
688 Ok(Domain::Int(ranges)) => {
689 let mut ranges = ranges;
690 ranges.push(Range::Single(0));
691 Some(Domain::Int(ranges))
692 }
693 Err(_) => todo!(),
694 _ => unreachable!(),
695 }
696 }
697 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
698 .domain_of()?
699 .apply_i32(
700 |x, y| {
701 if (x != 0 || y != 0) && y >= 0 {
702 Some(x.pow(y as u32))
703 } else {
704 None
705 }
706 },
707 &b.domain_of()?,
708 )
709 .ok(),
710 Expression::Root(_, _) => None,
711 Expression::Bubble(_, inner, _) => inner.domain_of(),
712 Expression::AuxDeclaration(_, _, _) => Some(Domain::Bool),
713 Expression::And(_, _) => Some(Domain::Bool),
714 Expression::Not(_, _) => Some(Domain::Bool),
715 Expression::Or(_, _) => Some(Domain::Bool),
716 Expression::Imply(_, _, _) => Some(Domain::Bool),
717 Expression::Iff(_, _, _) => Some(Domain::Bool),
718 Expression::Eq(_, _, _) => Some(Domain::Bool),
719 Expression::Neq(_, _, _) => Some(Domain::Bool),
720 Expression::Geq(_, _, _) => Some(Domain::Bool),
721 Expression::Leq(_, _, _) => Some(Domain::Bool),
722 Expression::Gt(_, _, _) => Some(Domain::Bool),
723 Expression::Lt(_, _, _) => Some(Domain::Bool),
724 Expression::FlatAbsEq(_, _, _) => Some(Domain::Bool),
725 Expression::FlatSumGeq(_, _, _) => Some(Domain::Bool),
726 Expression::FlatSumLeq(_, _, _) => Some(Domain::Bool),
727 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::Bool),
728 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::Bool),
729 Expression::FlatIneq(_, _, _, _) => Some(Domain::Bool),
730 Expression::AllDiff(_, _) => Some(Domain::Bool),
731 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::Bool),
732 Expression::MinionReify(_, _, _) => Some(Domain::Bool),
733 Expression::MinionReifyImply(_, _, _) => Some(Domain::Bool),
734 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::Bool),
735 Expression::MinionWInSet(_, _, _) => Some(Domain::Bool),
736 Expression::MinionElementOne(_, _, _, _) => Some(Domain::Bool),
737 Expression::Neg(_, x) => {
738 let Some(Domain::Int(mut ranges)) = x.domain_of() else {
739 return None;
740 };
741
742 for range in ranges.iter_mut() {
743 *range = match range {
744 Range::Single(x) => Range::Single(-*x),
745 Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
746 Range::UnboundedR(i) => Range::UnboundedL(-*i),
747 Range::UnboundedL(i) => Range::UnboundedR(-*i),
748 };
749 }
750
751 Some(Domain::Int(ranges))
752 }
753 Expression::Minus(_, a, b) => a
754 .domain_of()?
755 .apply_i32(|x, y| Some(x - y), &b.domain_of()?)
756 .ok(),
757 Expression::FlatAllDiff(_, _) => Some(Domain::Bool),
758 Expression::FlatMinusEq(_, _, _) => Some(Domain::Bool),
759 Expression::FlatProductEq(_, _, _, _) => Some(Domain::Bool),
760 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::Bool),
761 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::Bool),
762 Expression::Abs(_, a) => a
763 .domain_of()?
764 .apply_i32(|a, _| Some(a.abs()), &a.domain_of()?)
765 .ok(),
766 Expression::MinionPow(_, _, _, _) => Some(Domain::Bool),
767 Expression::ToInt(_, _) => Some(Domain::Int(vec![Range::Bounded(0, 1)])),
768 };
769 match ret {
770 Some(Domain::Int(ranges)) if ranges.len() > 1 => {
773 let (min, max) = range_vec_bounds_i32(&ranges)?;
774 Some(Domain::Int(vec![Range::Bounded(min, max)]))
775 }
776 _ => ret,
777 }
778 }
779
780 pub fn get_meta(&self) -> Metadata {
781 let metas: VecDeque<Metadata> = self.children_bi();
782 metas[0].clone()
783 }
784
785 pub fn set_meta(&self, meta: Metadata) {
786 self.transform_bi(&|_| meta.clone());
787 }
788
789 pub fn is_safe(&self) -> bool {
796 for expr in self.universe() {
798 match expr {
799 Expression::UnsafeDiv(_, _, _)
800 | Expression::UnsafeMod(_, _, _)
801 | Expression::UnsafePow(_, _, _)
802 | Expression::UnsafeIndex(_, _, _)
803 | Expression::Bubble(_, _, _)
804 | Expression::UnsafeSlice(_, _, _) => {
805 return false;
806 }
807 _ => {}
808 }
809 }
810 true
811 }
812
813 pub fn is_clean(&self) -> bool {
814 let metadata = self.get_meta();
815 metadata.clean
816 }
817
818 pub fn set_clean(&mut self, bool_value: bool) {
819 let mut metadata = self.get_meta();
820 metadata.clean = bool_value;
821 self.set_meta(metadata);
822 }
823
824 pub fn is_associative_commutative_operator(&self) -> bool {
826 TryInto::<ACOperatorKind>::try_into(self).is_ok()
827 }
828
829 pub fn is_matrix_literal(&self) -> bool {
834 matches!(
835 self,
836 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
837 | Expression::Atomic(
838 _,
839 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
840 )
841 )
842 }
843
844 pub fn identical_atom_to(&self, other: &Expression) -> bool {
850 let atom1: Result<&Atom, _> = self.try_into();
851 let atom2: Result<&Atom, _> = other.try_into();
852
853 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
854 atom2 == atom1
855 } else {
856 false
857 }
858 }
859
860 pub fn unwrap_list(self) -> Option<Vec<Expression>> {
865 match self {
866 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
867 matrix.unwrap_list().cloned()
868 }
869 Expression::Atomic(
870 _,
871 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
872 ) => matrix.unwrap_list().map(|elems| {
873 elems
874 .clone()
875 .into_iter()
876 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
877 .collect_vec()
878 }),
879 _ => None,
880 }
881 }
882
883 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
891 match self {
892 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
893 Some((elems, *domain))
894 }
895 Expression::Atomic(
896 _,
897 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
898 ) => Some((
899 elems
900 .into_iter()
901 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
902 .collect_vec(),
903 *domain,
904 )),
905
906 _ => None,
907 }
908 }
909
910 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
915 match self {
916 Expression::Root(meta, mut children) => {
917 children.extend(exprs);
918 Expression::Root(meta, children)
919 }
920 _ => panic!("extend_root called on a non-Root expression"),
921 }
922 }
923
924 pub fn into_literal(self) -> Option<Literal> {
926 match self {
927 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
928 Expression::AbstractLiteral(_, abslit) => {
929 Some(Literal::AbstractLiteral(abslit.into_literals()?))
930 }
931 Expression::Neg(_, e) => {
932 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
933 bug!("negated literal should be an int");
934 };
935
936 Some(Literal::Int(-i))
937 }
938
939 _ => None,
940 }
941 }
942
943 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
945 TryFrom::try_from(self).ok()
946 }
947
948 pub fn universe_categories(&self) -> HashSet<Category> {
950 self.universe()
951 .into_iter()
952 .map(|x| x.category_of())
953 .collect()
954 }
955}
956
957impl TryFrom<&Expression> for i32 {
958 type Error = ();
959
960 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
961 let Expression::Atomic(_, atom) = value else {
962 return Err(());
963 };
964
965 let Atom::Literal(lit) = atom else {
966 return Err(());
967 };
968
969 let Literal::Int(i) = lit else {
970 return Err(());
971 };
972
973 Ok(*i)
974 }
975}
976
977impl TryFrom<Expression> for i32 {
978 type Error = ();
979
980 fn try_from(value: Expression) -> Result<Self, Self::Error> {
981 TryFrom::<&Expression>::try_from(&value)
982 }
983}
984impl From<i32> for Expression {
985 fn from(i: i32) -> Self {
986 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
987 }
988}
989
990impl From<bool> for Expression {
991 fn from(b: bool) -> Self {
992 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
993 }
994}
995
996impl From<Atom> for Expression {
997 fn from(value: Atom) -> Self {
998 Expression::Atomic(Metadata::new(), value)
999 }
1000}
1001
1002impl From<Moo<Expression>> for Expression {
1003 fn from(val: Moo<Expression>) -> Self {
1004 val.as_ref().clone()
1005 }
1006}
1007
1008impl CategoryOf for Expression {
1009 fn category_of(&self) -> Category {
1010 let category = self.cata(&move |x,children| {
1012
1013 if let Some(max_category) = children.iter().max() {
1014 *max_category
1017 } else {
1018 let mut max_category = Category::Bottom;
1020
1021 if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1028 return Category::Decision;
1030 }
1031
1032 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1034 && max_atom_category > max_category{
1036 max_category = max_atom_category;
1038 }
1039
1040 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1042 && max_declaration_category > max_category{
1044 max_category = max_declaration_category;
1046 }
1047 max_category
1048
1049 }
1050 });
1051
1052 if cfg!(debug_assertions) {
1053 trace!(
1054 category= %category,
1055 expression= %self,
1056 "Called Expression::category_of()"
1057 );
1058 };
1059 category
1060 }
1061}
1062
1063impl Display for Expression {
1064 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1065 match &self {
1066 Expression::Union(_, box1, box2) => {
1067 write!(f, "({} union {})", box1.clone(), box2.clone())
1068 }
1069 Expression::In(_, e1, e2) => {
1070 write!(f, "{e1} in {e2}")
1071 }
1072 Expression::Intersect(_, box1, box2) => {
1073 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1074 }
1075 Expression::Supset(_, box1, box2) => {
1076 write!(f, "({} supset {})", box1.clone(), box2.clone())
1077 }
1078 Expression::SupsetEq(_, box1, box2) => {
1079 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1080 }
1081 Expression::Subset(_, box1, box2) => {
1082 write!(f, "({} subset {})", box1.clone(), box2.clone())
1083 }
1084 Expression::SubsetEq(_, box1, box2) => {
1085 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1086 }
1087
1088 Expression::AbstractLiteral(_, l) => l.fmt(f),
1089 Expression::Comprehension(_, c) => c.fmt(f),
1090 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1091 write!(f, "{e1}{}", pretty_vec(e2))
1092 }
1093 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1094 let args = es
1095 .iter()
1096 .map(|x| match x {
1097 Some(x) => format!("{x}"),
1098 None => "..".into(),
1099 })
1100 .join(",");
1101
1102 write!(f, "{e1}[{args}]")
1103 }
1104 Expression::InDomain(_, e, domain) => {
1105 write!(f, "__inDomain({e},{domain})")
1106 }
1107 Expression::Root(_, exprs) => {
1108 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1109 }
1110 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1111 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1112 Expression::Atomic(_, atom) => atom.fmt(f),
1113 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1114 Expression::Abs(_, a) => write!(f, "|{a}|"),
1115 Expression::Sum(_, e) => {
1116 write!(f, "sum({e})")
1117 }
1118 Expression::Product(_, e) => {
1119 write!(f, "product({e})")
1120 }
1121 Expression::Min(_, e) => {
1122 write!(f, "min({e})")
1123 }
1124 Expression::Max(_, e) => {
1125 write!(f, "max({e})")
1126 }
1127 Expression::Not(_, expr_box) => {
1128 write!(f, "!({})", expr_box.clone())
1129 }
1130 Expression::Or(_, e) => {
1131 write!(f, "or({e})")
1132 }
1133 Expression::And(_, e) => {
1134 write!(f, "and({e})")
1135 }
1136 Expression::Imply(_, box1, box2) => {
1137 write!(f, "({box1}) -> ({box2})")
1138 }
1139 Expression::Iff(_, box1, box2) => {
1140 write!(f, "({box1}) <-> ({box2})")
1141 }
1142 Expression::Eq(_, box1, box2) => {
1143 write!(f, "({} = {})", box1.clone(), box2.clone())
1144 }
1145 Expression::Neq(_, box1, box2) => {
1146 write!(f, "({} != {})", box1.clone(), box2.clone())
1147 }
1148 Expression::Geq(_, box1, box2) => {
1149 write!(f, "({} >= {})", box1.clone(), box2.clone())
1150 }
1151 Expression::Leq(_, box1, box2) => {
1152 write!(f, "({} <= {})", box1.clone(), box2.clone())
1153 }
1154 Expression::Gt(_, box1, box2) => {
1155 write!(f, "({} > {})", box1.clone(), box2.clone())
1156 }
1157 Expression::Lt(_, box1, box2) => {
1158 write!(f, "({} < {})", box1.clone(), box2.clone())
1159 }
1160 Expression::FlatSumGeq(_, box1, box2) => {
1161 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1162 }
1163 Expression::FlatSumLeq(_, box1, box2) => {
1164 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1165 }
1166 Expression::FlatIneq(_, box1, box2, box3) => write!(
1167 f,
1168 "Ineq({}, {}, {})",
1169 box1.clone(),
1170 box2.clone(),
1171 box3.clone()
1172 ),
1173 Expression::AllDiff(_, e) => {
1174 write!(f, "allDiff({e})")
1175 }
1176 Expression::Bubble(_, box1, box2) => {
1177 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1178 }
1179 Expression::SafeDiv(_, box1, box2) => {
1180 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1181 }
1182 Expression::UnsafeDiv(_, box1, box2) => {
1183 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1184 }
1185 Expression::UnsafePow(_, box1, box2) => {
1186 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1187 }
1188 Expression::SafePow(_, box1, box2) => {
1189 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1190 }
1191 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1192 write!(
1193 f,
1194 "DivEq({}, {}, {})",
1195 box1.clone(),
1196 box2.clone(),
1197 box3.clone()
1198 )
1199 }
1200 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1201 write!(
1202 f,
1203 "ModEq({}, {}, {})",
1204 box1.clone(),
1205 box2.clone(),
1206 box3.clone()
1207 )
1208 }
1209 Expression::FlatWatchedLiteral(_, x, l) => {
1210 write!(f, "WatchedLiteral({x},{l})", x = &x.name() as &Name)
1211 }
1212 Expression::MinionReify(_, box1, box2) => {
1213 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1214 }
1215 Expression::MinionReifyImply(_, box1, box2) => {
1216 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1217 }
1218 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1219 let intervals = intervals.iter().join(",");
1220 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1221 }
1222 Expression::MinionWInSet(_, atom, values) => {
1223 let values = values.iter().join(",");
1224 write!(f, "__minion_w_inset({atom},{values})")
1225 }
1226 Expression::AuxDeclaration(_, decl, e) => {
1227 write!(f, "{} =aux {}", &decl.name() as &Name, e.clone())
1228 }
1229 Expression::UnsafeMod(_, a, b) => {
1230 write!(f, "{} % {}", a.clone(), b.clone())
1231 }
1232 Expression::SafeMod(_, a, b) => {
1233 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1234 }
1235 Expression::Neg(_, a) => {
1236 write!(f, "-({})", a.clone())
1237 }
1238 Expression::Minus(_, a, b) => {
1239 write!(f, "({} - {})", a.clone(), b.clone())
1240 }
1241 Expression::FlatAllDiff(_, es) => {
1242 write!(f, "__flat_alldiff({})", pretty_vec(es))
1243 }
1244 Expression::FlatAbsEq(_, a, b) => {
1245 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1246 }
1247 Expression::FlatMinusEq(_, a, b) => {
1248 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1249 }
1250 Expression::FlatProductEq(_, a, b, c) => {
1251 write!(
1252 f,
1253 "FlatProductEq({},{},{})",
1254 a.clone(),
1255 b.clone(),
1256 c.clone()
1257 )
1258 }
1259 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1260 write!(
1261 f,
1262 "FlatWeightedSumLeq({},{},{})",
1263 pretty_vec(cs),
1264 pretty_vec(vs),
1265 total.clone()
1266 )
1267 }
1268 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1269 write!(
1270 f,
1271 "FlatWeightedSumGeq({},{},{})",
1272 pretty_vec(cs),
1273 pretty_vec(vs),
1274 total.clone()
1275 )
1276 }
1277 Expression::MinionPow(_, atom, atom1, atom2) => {
1278 write!(f, "MinionPow({atom},{atom1},{atom2})")
1279 }
1280 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1281 let atoms = atoms.iter().join(",");
1282 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1283 }
1284
1285 Expression::ToInt(_, expr) => {
1286 write!(f, "toInt({expr})")
1287 }
1288 }
1289 }
1290}
1291
1292impl Typeable for Expression {
1293 fn return_type(&self) -> Option<ReturnType> {
1294 match self {
1295 Expression::Union(_, subject, _) => {
1296 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1297 }
1298 Expression::Intersect(_, subject, _) => {
1299 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1300 }
1301 Expression::In(_, _, _) => Some(ReturnType::Bool),
1302 Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1303 Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1304 Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1305 Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1306 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1307 Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1308 let mut elem_typ = subject.return_type()?;
1309 let ReturnType::Matrix(_) = elem_typ else {
1310 return None;
1311 };
1312
1313 while let ReturnType::Matrix(new_elem_typ) = elem_typ {
1315 elem_typ = *new_elem_typ;
1316 }
1317
1318 Some(elem_typ)
1319 }
1320 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1321 Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1322 }
1323 Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1324 Expression::Comprehension(_, _) => None,
1325 Expression::Root(_, _) => Some(ReturnType::Bool),
1326 Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1327 Expression::FromSolution(_, expr) => expr.return_type(),
1328 Expression::Atomic(_, atom) => atom.return_type(),
1329 Expression::Scope(_, scope) => scope.return_type(),
1330 Expression::Abs(_, _) => Some(ReturnType::Int),
1331 Expression::Sum(_, _) => Some(ReturnType::Int),
1332 Expression::Product(_, _) => Some(ReturnType::Int),
1333 Expression::Min(_, _) => Some(ReturnType::Int),
1334 Expression::Max(_, _) => Some(ReturnType::Int),
1335 Expression::Not(_, _) => Some(ReturnType::Bool),
1336 Expression::Or(_, _) => Some(ReturnType::Bool),
1337 Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1338 Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1339 Expression::And(_, _) => Some(ReturnType::Bool),
1340 Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1341 Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1342 Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1343 Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1344 Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1345 Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1346 Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1347 Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1348 Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1349 Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1350 Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1351 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1352 Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1353 Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1354 Expression::Bubble(_, inner, _) => inner.return_type(),
1355 Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1356 Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1357 Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1358 Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1359 Expression::MinionWInSet(_, _, _) => Some(ReturnType::Bool),
1360 Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1361 Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1362 Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1363 Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1364 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1365 Expression::Neg(_, _) => Some(ReturnType::Int),
1366 Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1367 Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1368 Expression::Minus(_, _, _) => Some(ReturnType::Int),
1369 Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1370 Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1371 Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1372 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1373 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1374 Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1375 Expression::ToInt(_, _) => Some(ReturnType::Int),
1376 }
1377 }
1378}
1379
1380#[cfg(test)]
1381mod tests {
1382
1383 use crate::matrix_expr;
1384
1385 use super::*;
1386
1387 #[test]
1388 fn test_domain_of_constant_sum() {
1389 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1390 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1391 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1392 assert_eq!(sum.domain_of(), Some(Domain::Int(vec![Range::Single(3)])));
1393 }
1394
1395 #[test]
1396 fn test_domain_of_constant_invalid_type() {
1397 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1398 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1399 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1400 assert_eq!(sum.domain_of(), None);
1401 }
1402
1403 #[test]
1404 fn test_domain_of_empty_sum() {
1405 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1406 assert_eq!(sum.domain_of(), None);
1407 }
1408}