conjure_cp_core/ast/
domains.rs

1#![warn(clippy::missing_errors_doc)]
2
3use conjure_cp_core::ast::SymbolTable;
4use itertools::{Itertools, izip};
5use serde::{Deserialize, Serialize};
6use std::{collections::BTreeSet, fmt::Display};
7use thiserror::Error;
8
9use crate::{ast::pretty::pretty_vec, domain_int, range};
10use uniplate::Uniplate;
11
12use super::{AbstractLiteral, Literal, Name, ReturnType, records::RecordEntry, types::Typeable};
13
14#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum Range<A>
16where
17    A: Ord,
18{
19    Single(A),
20    Bounded(A, A),
21
22    /// int(i..)
23    UnboundedR(A),
24
25    /// int(..i)
26    UnboundedL(A),
27}
28
29impl<A: Ord> Range<A> {
30    pub fn contains(&self, val: &A) -> bool {
31        match self {
32            Range::Single(x) => x == val,
33            Range::Bounded(x, y) => x <= val && val <= y,
34            Range::UnboundedR(x) => x >= val,
35            Range::UnboundedL(x) => x <= val,
36        }
37    }
38
39    /// Returns the lower bound of the range, if it has one
40    pub fn lower_bound(&self) -> Option<&A> {
41        match self {
42            Range::Single(a) => Some(a),
43            Range::Bounded(a, _) => Some(a),
44            Range::UnboundedR(a) => Some(a),
45            Range::UnboundedL(_) => None,
46        }
47    }
48}
49
50impl<A: Ord + Display> Display for Range<A> {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Range::Single(i) => write!(f, "{i}"),
54            Range::Bounded(i, j) => write!(f, "{i}..{j}"),
55            Range::UnboundedR(i) => write!(f, "{i}.."),
56            Range::UnboundedL(i) => write!(f, "..{i}"),
57        }
58    }
59}
60
61#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Uniplate, Hash)]
62#[uniplate()]
63pub enum Domain {
64    Bool,
65
66    /// An integer domain.
67    ///
68    /// + If multiple ranges are inside the domain, the values in the domain are the union of these
69    ///   ranges.
70    ///
71    /// + If no ranges are given, the int domain is considered unconstrained, and can take any
72    ///   integer value.
73    Int(Vec<Range<i32>>),
74
75    /// An empty domain of the given type.
76    Empty(ReturnType),
77    Reference(Name),
78    Set(SetAttr, Box<Domain>),
79    /// A n-dimensional matrix with a value domain and n-index domains
80    Matrix(Box<Domain>, Vec<Domain>),
81    // A tuple of n domains (e.g. (int, bool))
82    Tuple(Vec<Domain>),
83
84    Record(Vec<RecordEntry>),
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub enum SetAttr {
89    None,
90    Size(i32),
91    MinSize(i32),
92    MaxSize(i32),
93    MinMaxSize(i32, i32),
94}
95impl Domain {
96    /// Returns true if `lit` is a member of the domain.
97    ///
98    /// # Errors
99    ///
100    /// - [`DomainOpError::InputContainsReference`] if the input domain is a reference or contains
101    ///   a reference, meaning that its members cannot be determined.
102    pub fn contains(&self, lit: &Literal) -> Result<bool, DomainOpError> {
103        // not adding a generic wildcard condition for all domains, so that this gives a compile
104        // error when a domain is added.
105        match (self, lit) {
106            (Domain::Empty(_), _) => Ok(false),
107            (Domain::Int(ranges), Literal::Int(x)) => {
108                // unconstrained int domain
109                if ranges.is_empty() {
110                    return Ok(true);
111                };
112
113                Ok(ranges.iter().any(|range| range.contains(x)))
114            }
115            (Domain::Int(_), _) => Ok(false),
116            (Domain::Bool, Literal::Bool(_)) => Ok(true),
117            (Domain::Bool, _) => Ok(false),
118            (Domain::Reference(_), _) => Err(DomainOpError::InputContainsReference),
119            (
120                Domain::Matrix(elem_domain, index_domains),
121                Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
122            ) => {
123                let mut index_domains = index_domains.clone();
124                if index_domains
125                    .pop()
126                    .expect("a matrix should have atleast one index domain")
127                    != **idx_domain
128                {
129                    return Ok(false);
130                };
131
132                // matrix literals are represented as nested 1d matrices, so the elements of
133                // the matrix literal will be the inner dimensions of the matrix.
134                let next_elem_domain = if index_domains.is_empty() {
135                    elem_domain.as_ref().clone()
136                } else {
137                    Domain::Matrix(elem_domain.clone(), index_domains)
138                };
139
140                for elem in elems {
141                    if !next_elem_domain.contains(elem)? {
142                        return Ok(false);
143                    }
144                }
145
146                Ok(true)
147            }
148            (
149                Domain::Tuple(elem_domains),
150                Literal::AbstractLiteral(AbstractLiteral::Tuple(literal_elems)),
151            ) => {
152                // for every element in the tuple literal, check if it is in the corresponding domain
153                for (elem_domain, elem) in itertools::izip!(elem_domains, literal_elems) {
154                    if !elem_domain.contains(elem)? {
155                        return Ok(false);
156                    }
157                }
158
159                Ok(true)
160            }
161            (
162                Domain::Set(_, domain),
163                Literal::AbstractLiteral(AbstractLiteral::Set(literal_elems)),
164            ) => {
165                for elem in literal_elems {
166                    if !domain.contains(elem)? {
167                        return Ok(false);
168                    }
169                }
170                Ok(true)
171            }
172            (
173                Domain::Record(entries),
174                Literal::AbstractLiteral(AbstractLiteral::Record(lit_entries)),
175            ) => {
176                for (entry, lit_entry) in itertools::izip!(entries, lit_entries) {
177                    if entry.name != lit_entry.name || !(entry.domain.contains(&lit_entry.value)?) {
178                        return Ok(false);
179                    }
180                }
181                Ok(true)
182            }
183
184            (Domain::Record(_), _) => Ok(false),
185
186            (Domain::Matrix(_, _), _) => Ok(false),
187
188            (Domain::Set(_, _), _) => Ok(false),
189
190            (Domain::Tuple(_), _) => Ok(false),
191        }
192    }
193
194    /// Returns a list of all possible values in the domain.
195    ///
196    /// # Errors
197    ///
198    /// - [`DomainOpError::InputNotInteger`] if the domain is not an integer domain.
199    /// - [`DomainOpError::InputUnbounded`] if the domain is unbounded.
200    pub fn values_i32(&self) -> Result<Vec<i32>, DomainOpError> {
201        if let Domain::Empty(ReturnType::Int) = self {
202            return Ok(vec![]);
203        }
204        let Domain::Int(ranges) = self else {
205            return Err(DomainOpError::InputNotInteger(self.return_type().unwrap()));
206        };
207
208        if ranges.is_empty() {
209            return Err(DomainOpError::InputUnbounded);
210        }
211
212        let mut values = vec![];
213        for range in ranges {
214            match range {
215                Range::Single(i) => {
216                    values.push(*i);
217                }
218                Range::Bounded(i, j) => {
219                    values.extend(*i..=*j);
220                }
221                Range::UnboundedR(_) | Range::UnboundedL(_) => {
222                    return Err(DomainOpError::InputUnbounded);
223                }
224            }
225        }
226
227        Ok(values)
228    }
229
230    /// Creates an [`Domain::Int`] containing the given integers.
231    ///
232    /// [`Domain::from_set_i32`] should be used instead where possible, as it is cheaper (it does
233    /// not need to sort its input).
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use conjure_cp_core::ast::{Domain,Range};
239    /// use conjure_cp_core::{domain_int,range};
240    ///
241    /// let elements = vec![1,2,3,4,5];
242    ///
243    /// let domain = Domain::from_slice_i32(&elements);
244    ///
245    /// assert_eq!(domain,domain_int!(1..5));
246    /// ```
247    ///
248    /// ```
249    /// use conjure_cp_core::ast::{Domain,Range};
250    /// use conjure_cp_core::{domain_int,range};
251    ///
252    /// let elements = vec![1,2,4,5,7,8,9,10];
253    ///
254    /// let domain = Domain::from_slice_i32(&elements);
255    ///
256    ///
257    /// assert_eq!(domain,domain_int!(1..2,4..5,7..10));
258    /// ```
259    ///
260    /// ```
261    /// use conjure_cp_core::ast::{Domain,Range,ReturnType};
262    ///
263    /// let elements = vec![];
264    ///
265    /// let domain = Domain::from_slice_i32(&elements);
266    ///
267    /// assert!(matches!(domain,Domain::Empty(ReturnType::Int)))
268    /// ```
269    pub fn from_slice_i32(elements: &[i32]) -> Domain {
270        if elements.is_empty() {
271            return Domain::Empty(ReturnType::Int);
272        }
273
274        let set = BTreeSet::from_iter(elements.iter().copied());
275
276        Domain::from_set_i32(&set)
277    }
278
279    /// Creates an [`Domain::Int`] containing the given integers.
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// use conjure_cp_core::ast::{Domain,Range};
285    /// use conjure_cp_core::{domain_int,range};
286    /// use std::collections::BTreeSet;
287    ///
288    /// let elements = BTreeSet::from([1,2,3,4,5]);
289    ///
290    /// let domain = Domain::from_set_i32(&elements);
291    ///
292    /// assert_eq!(domain,domain_int!(1..5));
293    /// ```
294    ///
295    /// ```
296    /// use conjure_cp_core::ast::{Domain,Range};
297    /// use conjure_cp_core::{domain_int,range};
298    /// use std::collections::BTreeSet;
299    ///
300    /// let elements = BTreeSet::from([1,2,4,5,7,8,9,10]);
301    ///
302    /// let domain = Domain::from_set_i32(&elements);
303    ///
304    /// assert_eq!(domain,domain_int!(1..2,4..5,7..10));
305    /// ```
306    ///
307    /// ```
308    /// use conjure_cp_core::ast::{Domain,Range,ReturnType};
309    /// use std::collections::BTreeSet;
310    ///
311    /// let elements = BTreeSet::from([]);
312    ///
313    /// let domain = Domain::from_set_i32(&elements);
314    ///
315    /// assert!(matches!(domain,Domain::Empty(ReturnType::Int)))
316    /// ```
317    pub fn from_set_i32(elements: &BTreeSet<i32>) -> Domain {
318        if elements.is_empty() {
319            return Domain::Empty(ReturnType::Int);
320        }
321        if elements.len() == 1 {
322            return domain_int!(*elements.first().unwrap());
323        }
324
325        let mut elems_iter = elements.iter().copied();
326
327        let mut ranges: Vec<Range<i32>> = vec![];
328
329        // Loop over the elements in ascending order, turning all sequential runs of
330        // numbers into ranges.
331
332        // the bounds of the current run of numbers.
333        let mut lower = elems_iter
334            .next()
335            .expect("if we get here, elements should have => 2 elements");
336        let mut upper = lower;
337
338        for current in elems_iter {
339            // As elements is a BTreeSet, current is always strictly larger than lower.
340
341            if current == upper + 1 {
342                // current is part of the current run - we now have the run lower..current
343                //
344                upper = current;
345            } else {
346                // the run lower..upper has ended.
347                //
348                // Add the run lower..upper to the domain, and start a new run.
349
350                if lower == upper {
351                    ranges.push(range!(lower));
352                } else {
353                    ranges.push(range!(lower..upper));
354                }
355
356                lower = current;
357                upper = current;
358            }
359        }
360
361        // add the final run to the domain
362        if lower == upper {
363            ranges.push(range!(lower));
364        } else {
365            ranges.push(range!(lower..upper));
366        }
367
368        Domain::Int(ranges)
369    }
370
371    /// For a vector of literals, creates a domain that contains all the elements.
372    ///
373    /// The literals must all be of the same type.
374    ///
375    /// For abstract literals, this method merges the element domains of the literals, but not the
376    /// index domains. Thus, for fixed-sized abstract literals (matrices, tuples, records, etc.),
377    /// all literals in the vector must also have the same size / index domain:
378    ///
379    /// + Matrices: all literals must have the same index domain.
380    /// + Tuples: all literals must have the same number of elements.
381    /// + Records: all literals must have the same fields.
382    ///
383    /// # Errors
384    ///
385    /// - [DomainOpError::InputWrongType] if the input literals are of a different type to
386    ///   each-other, as described above.
387    ///
388    /// # Examples
389    ///
390    /// ```
391    /// use conjure_cp_core::ast::{Domain,Range,Literal,ReturnType};
392    ///
393    /// let domain = Domain::from_literal_vec(vec![]);
394    /// assert_eq!(domain,Ok(Domain::Empty(ReturnType::Unknown)));
395    /// ```
396    ///
397    /// ```
398    /// use conjure_cp_core::ast::{Domain,Range,Literal, AbstractLiteral};
399    /// use conjure_cp_core::{domain_int, range, matrix};
400    ///
401    /// // `[1,2;int(2..3)], [4,5; int(2..3)]` has domain
402    /// // `matrix indexed by [int(2..3)] of int(1..2,4..5)`
403    ///
404    /// let matrix_1 = Literal::AbstractLiteral(matrix![Literal::Int(1),Literal::Int(2);domain_int!(2..3)]);
405    /// let matrix_2 = Literal::AbstractLiteral(matrix![Literal::Int(4),Literal::Int(5);domain_int!(2..3)]);
406    ///
407    /// let domain = Domain::from_literal_vec(vec![matrix_1,matrix_2]);
408    ///
409    /// let expected_domain = Ok(Domain::Matrix(
410    ///     Box::new(domain_int!(1..2,4..5)),vec![domain_int!(2..3)]));
411    ///
412    /// assert_eq!(domain,expected_domain);
413    /// ```
414    ///
415    /// ```
416    /// use conjure_cp_core::ast::{Domain,Range,Literal, AbstractLiteral,DomainOpError};
417    /// use conjure_cp_core::{domain_int, range, matrix};
418    ///
419    /// // `[1,2;int(2..3)], [4,5; int(1..2)]` cannot be combined
420    /// // `matrix indexed by [int(2..3)] of int(1..2,4..5)`
421    ///
422    /// let matrix_1 = Literal::AbstractLiteral(matrix![Literal::Int(1),Literal::Int(2);domain_int!(2..3)]);
423    /// let matrix_2 = Literal::AbstractLiteral(matrix![Literal::Int(4),Literal::Int(5);domain_int!(1..2)]);
424    ///
425    /// let domain = Domain::from_literal_vec(vec![matrix_1,matrix_2]);
426    ///
427    /// assert_eq!(domain,Err(DomainOpError::InputWrongType));
428    /// ```
429    ///
430    /// ```
431    /// use conjure_cp_core::ast::{Domain,Range,Literal, AbstractLiteral};
432    /// use conjure_cp_core::{domain_int,range, matrix};
433    ///
434    /// // `[[1,2; int(1..2)];int(2)], [[4,5; int(1..2)]; int(2)]` has domain
435    /// // `matrix indexed by [int(2),int(1..2)] of int(1..2,4..5)`
436    ///
437    ///
438    /// let matrix_1 = Literal::AbstractLiteral(matrix![Literal::AbstractLiteral(matrix![Literal::Int(1),Literal::Int(2);domain_int!(1..2)]); domain_int!(2)]);
439    /// let matrix_2 = Literal::AbstractLiteral(matrix![Literal::AbstractLiteral(matrix![Literal::Int(4),Literal::Int(5);domain_int!(1..2)]); domain_int!(2)]);
440    ///
441    /// let domain = Domain::from_literal_vec(vec![matrix_1,matrix_2]);
442    ///
443    /// let expected_domain = Ok(Domain::Matrix(
444    ///     Box::new(domain_int!(1..2,4..5)),
445    ///     vec![domain_int!(2),domain_int!(1..2)]));
446    ///
447    /// assert_eq!(domain,expected_domain);
448    /// ```
449    ///
450    ///
451    pub fn from_literal_vec(literals: Vec<Literal>) -> Result<Domain, DomainOpError> {
452        // TODO: use proptest to test this better?
453
454        if literals.is_empty() {
455            return Ok(Domain::Empty(ReturnType::Unknown));
456        }
457
458        let first_literal = literals.first().unwrap();
459
460        match first_literal {
461            Literal::Int(_) => {
462                // check all literals are ints, then pass this to Domain::from_set_i32.
463                let mut ints = BTreeSet::new();
464                for lit in literals {
465                    let Literal::Int(i) = lit else {
466                        return Err(DomainOpError::InputWrongType);
467                    };
468
469                    ints.insert(i);
470                }
471
472                Ok(Domain::from_set_i32(&ints))
473            }
474            Literal::Bool(_) => {
475                // check all literals are bools
476                if literals.iter().any(|x| !matches!(x, Literal::Bool(_))) {
477                    Err(DomainOpError::InputWrongType)
478                } else {
479                    Ok(Domain::Bool)
480                }
481            }
482            Literal::AbstractLiteral(AbstractLiteral::Set(_)) => {
483                let mut all_elems = vec![];
484
485                for lit in literals {
486                    let Literal::AbstractLiteral(AbstractLiteral::Set(elems)) = lit else {
487                        return Err(DomainOpError::InputWrongType);
488                    };
489
490                    all_elems.extend(elems);
491                }
492                let elem_domain = Domain::from_literal_vec(all_elems)?;
493
494                Ok(Domain::Set(SetAttr::None, Box::new(elem_domain)))
495            }
496
497            l @ Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _)) => {
498                let mut first_index_domain = vec![];
499                // flatten index domains of n-d matrix into list
500                let mut l = l.clone();
501                while let Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx)) = l {
502                    assert!(
503                        !matches!(idx.as_ref(), Domain::Matrix(_, _)),
504                        "n-dimensional matrix literals should be represented as a matrix inside a matrix"
505                    );
506                    first_index_domain.push(idx.as_ref().clone());
507                    l = elems[0].clone();
508                }
509
510                let mut all_elems: Vec<Literal> = vec![];
511
512                // check types and index domains
513                for lit in &literals {
514                    let Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx)) = lit else {
515                        return Err(DomainOpError::InputContainsReference);
516                    };
517
518                    all_elems.extend(elems.clone());
519
520                    let mut index_domain = vec![idx.as_ref().clone()];
521                    let mut l = elems[0].clone();
522                    while let Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx)) = l {
523                        assert!(
524                            !matches!(idx.as_ref(), Domain::Matrix(_, _)),
525                            "n-dimensional matrix literals should be represented as a matrix inside a matrix"
526                        );
527                        index_domain.push(idx.as_ref().clone());
528                        l = elems[0].clone();
529                    }
530
531                    if index_domain != first_index_domain {
532                        return Err(DomainOpError::InputWrongType);
533                    }
534                }
535
536                // extract all the terminal elements (those that are not nested matrix literals) from the matrix literal.
537                let mut terminal_elements: Vec<Literal> = vec![];
538                while let Some(elem) = all_elems.pop() {
539                    if let Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) = elem {
540                        all_elems.extend(elems);
541                    } else {
542                        terminal_elements.push(elem);
543                    }
544                }
545
546                let element_domain = Domain::from_literal_vec(terminal_elements)?;
547
548                Ok(Domain::Matrix(Box::new(element_domain), first_index_domain))
549            }
550
551            Literal::AbstractLiteral(AbstractLiteral::Tuple(first_elems)) => {
552                let n_fields = first_elems.len();
553
554                // for each field, calculate the element domain and add it to this list
555                let mut elem_domains = vec![];
556
557                for i in 0..n_fields {
558                    let mut all_elems = vec![];
559                    for lit in &literals {
560                        let Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) = lit else {
561                            return Err(DomainOpError::InputContainsReference);
562                        };
563
564                        if elems.len() != n_fields {
565                            return Err(DomainOpError::InputContainsReference);
566                        }
567
568                        all_elems.push(elems[i].clone());
569                    }
570
571                    elem_domains.push(Domain::from_literal_vec(all_elems)?);
572                }
573
574                Ok(Domain::Tuple(elem_domains))
575            }
576
577            Literal::AbstractLiteral(AbstractLiteral::Record(first_elems)) => {
578                let n_fields = first_elems.len();
579                let field_names = first_elems.iter().map(|x| x.name.clone()).collect_vec();
580
581                // for each field, calculate the element domain and add it to this list
582                let mut elem_domains = vec![];
583
584                for i in 0..n_fields {
585                    let mut all_elems = vec![];
586                    for lit in &literals {
587                        let Literal::AbstractLiteral(AbstractLiteral::Record(elems)) = lit else {
588                            return Err(DomainOpError::InputContainsReference);
589                        };
590
591                        if elems.len() != n_fields {
592                            return Err(DomainOpError::InputContainsReference);
593                        }
594
595                        let elem = elems[i].clone();
596                        if elem.name != field_names[i] {
597                            return Err(DomainOpError::InputContainsReference);
598                        }
599
600                        all_elems.push(elem.value);
601                    }
602
603                    elem_domains.push(Domain::from_literal_vec(all_elems)?);
604                }
605
606                Ok(Domain::Record(
607                    izip!(field_names, elem_domains)
608                        .map(|(name, domain)| RecordEntry { name, domain })
609                        .collect(),
610                ))
611            }
612        }
613    }
614
615    /// Gets all the [`Literal`] values inside this domain.
616    ///
617    /// # Errors
618    ///
619    /// - [`DomainOpError::InputNotInteger`] if the domain is not an integer domain.
620    /// - [`DomainOpError::InputContainsReference`] if the domain is a reference or contains a
621    ///   reference, meaning that its values cannot be determined.
622    pub fn values(&self) -> Result<Vec<Literal>, DomainOpError> {
623        match self {
624            Domain::Empty(_) => Ok(vec![]),
625            Domain::Bool => Ok(vec![false.into(), true.into()]),
626            Domain::Int(_) => self
627                .values_i32()
628                .map(|xs| xs.iter().map(|x| Literal::Int(*x)).collect_vec()),
629
630            // ~niklasdewally: don't know how to define this for collections, so leaving it for
631            // now... However, it definitely can be done, as matrices can be indexed by matrices.
632            Domain::Set(_, _) => todo!(),
633            Domain::Matrix(_, _) => todo!(),
634            Domain::Reference(_) => Err(DomainOpError::InputContainsReference),
635            Domain::Tuple(_) => todo!(), // TODO: Can this be done?
636            Domain::Record(_) => todo!(),
637        }
638    }
639
640    /// Gets the length of this domain.
641    ///
642    /// # Errors
643    ///
644    /// - [`DomainOpError::InputUnbounded`] if the input domain is of infinite size.
645    /// - [`DomainOpError::InputContainsReference`] if the input domain is or contains a
646    ///   domain reference, meaning that its size cannot be determined.
647    pub fn length(&self) -> Result<usize, DomainOpError> {
648        self.values().map(|x| x.len())
649    }
650
651    /// Returns the domain that is the result of applying a binary operation to two integer domains.
652    ///
653    /// The given operator may return `None` if the operation is not defined for its arguments.
654    /// Undefined values will not be included in the resulting domain.
655    ///
656    /// # Errors
657    ///
658    /// - [`DomainOpError::InputUnbounded`] if either of the input domains are unbounded.
659    /// - [`DomainOpError::InputNotInteger`] if either of the input domains are not integers.
660    pub fn apply_i32(
661        &self,
662        op: fn(i32, i32) -> Option<i32>,
663        other: &Domain,
664    ) -> Result<Domain, DomainOpError> {
665        let vs1 = self.values_i32()?;
666        let vs2 = other.values_i32()?;
667
668        let mut set = BTreeSet::new();
669        for (v1, v2) in itertools::iproduct!(vs1, vs2) {
670            if let Some(v) = op(v1, v2) {
671                set.insert(v);
672            }
673        }
674
675        Ok(Domain::from_set_i32(&set))
676    }
677    /// Returns true if the domain is finite.
678    ///
679    /// # Errors
680    ///
681    /// - [`DomainOpError::InputContainsReference`] if the input domain is or contains a
682    ///   domain reference, meaning that its size cannot be determined.
683    pub fn is_finite(&self) -> Result<bool, DomainOpError> {
684        for domain in self.universe() {
685            if let Domain::Int(ranges) = domain {
686                if ranges.is_empty() {
687                    return Ok(false);
688                }
689
690                if ranges
691                    .iter()
692                    .any(|range| matches!(range, Range::UnboundedL(_) | Range::UnboundedR(_)))
693                {
694                    return Ok(false);
695                }
696            } else if let Domain::Reference(_) = domain {
697                return Err(DomainOpError::InputContainsReference);
698            }
699        }
700        Ok(true)
701    }
702
703    /// Resolves this domain to a ground domain, using the symbol table provided to resolve
704    /// references.
705    ///
706    /// A domain is ground iff it is not a domain reference, nor contains any domain references.
707    ///
708    /// See also: [`SymbolTable::resolve_domain`](crate::ast::SymbolTable::resolve_domain).
709    ///
710    /// # Panics
711    ///
712    /// + If a reference domain in `self` does not exist in the given symbol table.
713    pub fn resolve(mut self, symbols: &SymbolTable) -> Domain {
714        // FIXME: cannot use Uniplate::transform here due to reference lifetime shenanigans...
715        // dont see any reason why Uniplate::transform requires a closure that only uses borrows
716        // with a 'static lifetime... ~niklasdewally
717        // ..
718        // Also, still want to make the Uniplate variant which uses FnOnce not Fn with methods that
719        // take self instead of &self -- that would come in handy here!
720
721        let mut done_something = true;
722        while done_something {
723            done_something = false;
724            for (domain, ctx) in self.clone().contexts() {
725                if let Domain::Reference(name) = domain {
726                    self = ctx(symbols
727                        .resolve_domain(&name)
728                        .expect("domain reference should exist in the symbol table")
729                        .resolve(symbols));
730                    done_something = true;
731                }
732            }
733        }
734        self
735    }
736
737    /// Calculates the intersection of two domains.
738    ///
739    /// # Errors
740    ///
741    ///  - [`DomainOpError::InputUnbounded`] if either of the input domains are unbounded.
742    ///  - [`DomainOpError::InputWrongType`] if the input domains are different types, or are not
743    ///    integer or set domains.
744    pub fn intersect(&self, other: &Domain) -> Result<Domain, DomainOpError> {
745        // TODO: does not consider unbounded domains yet
746        // needs to be tested once comprehension rules are written
747
748        match (self, other) {
749            // one or more arguments is an empty int domain
750            (d @ Domain::Empty(ReturnType::Int), Domain::Int(_)) => Ok(d.clone()),
751            (Domain::Int(_), d @ Domain::Empty(ReturnType::Int)) => Ok(d.clone()),
752            (Domain::Empty(ReturnType::Int), d @ Domain::Empty(ReturnType::Int)) => Ok(d.clone()),
753
754            // one or more arguments is an empty set(int) domain
755            (Domain::Set(_, inner1), d @ Domain::Empty(ReturnType::Set(inner2)))
756                if matches!(**inner1, Domain::Int(_) | Domain::Empty(ReturnType::Int))
757                    && matches!(**inner2, ReturnType::Int) =>
758            {
759                Ok(d.clone())
760            }
761            (d @ Domain::Empty(ReturnType::Set(inner1)), Domain::Set(_, inner2))
762                if matches!(**inner1, ReturnType::Int)
763                    && matches!(**inner2, Domain::Int(_) | Domain::Empty(ReturnType::Int)) =>
764            {
765                Ok(d.clone())
766            }
767            (
768                d @ Domain::Empty(ReturnType::Set(inner1)),
769                Domain::Empty(ReturnType::Set(inner2)),
770            ) if matches!(**inner1, ReturnType::Int) && matches!(**inner2, ReturnType::Int) => {
771                Ok(d.clone())
772            }
773
774            // both arguments are non-empy
775            (Domain::Set(_, x), Domain::Set(_, y)) => {
776                Ok(Domain::Set(SetAttr::None, Box::new((*x).intersect(y)?)))
777            }
778
779            (Domain::Int(_), Domain::Int(_)) => {
780                let mut v: BTreeSet<i32> = BTreeSet::new();
781
782                let v1 = self.values_i32()?;
783                let v2 = other.values_i32()?;
784                for value1 in v1.iter() {
785                    if v2.contains(value1) && !v.contains(value1) {
786                        v.insert(*value1);
787                    }
788                }
789                Ok(Domain::from_set_i32(&v))
790            }
791            _ => Err(DomainOpError::InputWrongType),
792        }
793    }
794
795    /// Calculates the union of two domains.
796    ///
797    /// # Errors
798    ///
799    ///  - [`DomainOpError::InputUnbounded`] if either of the input domains are unbounded.
800    ///  - [`DomainOpError::InputWrongType`] if the input domains are different types, or are not
801    ///    integer set, or matrix domains. This is also thrown if the matrix domains that have
802    ///    different index domains.
803    ///    
804    pub fn union(&self, other: &Domain) -> Result<Domain, DomainOpError> {
805        // TODO: does not consider unbounded domains yet
806        // needs to be tested once comprehension rules are written
807        match (self, other) {
808            // one or more arguments is an empty integer domain
809            (Domain::Empty(ReturnType::Int), d @ Domain::Int(_)) => Ok(d.clone()),
810            (d @ Domain::Int(_), Domain::Empty(ReturnType::Int)) => Ok(d.clone()),
811            (Domain::Empty(ReturnType::Int), d @ Domain::Empty(ReturnType::Int)) => Ok(d.clone()),
812
813            // one or more arguments is an empty set(int) domain
814            (d @ Domain::Set(_, inner1), Domain::Empty(ReturnType::Set(inner2)))
815                if matches!(**inner1, Domain::Int(_) | Domain::Empty(ReturnType::Int))
816                    && matches!(**inner2, ReturnType::Int) =>
817            {
818                Ok(d.clone())
819            }
820            (Domain::Empty(ReturnType::Set(inner1)), d @ Domain::Set(_, inner2))
821                if matches!(**inner1, ReturnType::Int)
822                    && matches!(**inner2, Domain::Int(_) | Domain::Empty(ReturnType::Int)) =>
823            {
824                Ok(d.clone())
825            }
826            (
827                d @ Domain::Empty(ReturnType::Set(inner1)),
828                Domain::Empty(ReturnType::Set(inner2)),
829            ) if matches!(**inner1, ReturnType::Int) && matches!(**inner2, ReturnType::Int) => {
830                Ok(d.clone())
831            }
832
833            // both arguments are non empty
834            (Domain::Set(_, x), Domain::Set(_, y)) => {
835                Ok(Domain::Set(SetAttr::None, Box::new((*x).union(y)?)))
836            }
837
838            // union matrices only if they have the same index domain.
839            (Domain::Matrix(d1, idx1), Domain::Matrix(d2, idx2)) if idx1 == idx2 => Ok(
840                Domain::Matrix(Box::new(d1.union(d2.as_ref())?), idx1.clone()),
841            ),
842
843            (Domain::Int(_), Domain::Int(_)) => {
844                let mut v: BTreeSet<i32> = BTreeSet::new();
845                let v1 = self.values_i32()?;
846                let v2 = other.values_i32()?;
847
848                for value1 in v1.iter() {
849                    v.insert(*value1);
850                }
851
852                for value2 in v2.iter() {
853                    v.insert(*value2);
854                }
855
856                Ok(Domain::from_set_i32(&v))
857            }
858            _ => Err(DomainOpError::InputWrongType),
859        }
860    }
861}
862
863impl Display for Domain {
864    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
865        match self {
866            Domain::Bool => {
867                write!(f, "bool")
868            }
869            Domain::Int(vec) => {
870                let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
871
872                if domain_ranges.is_empty() {
873                    write!(f, "int")
874                } else {
875                    write!(f, "int({domain_ranges})")
876                }
877            }
878            Domain::Reference(name) => write!(f, "{name}"),
879            Domain::Set(_, domain) => {
880                write!(f, "set of ({domain})")
881            }
882            Domain::Matrix(value_domain, index_domains) => {
883                write!(
884                    f,
885                    "matrix indexed by [{}] of {value_domain}",
886                    pretty_vec(&index_domains.iter().collect_vec())
887                )
888            }
889            Domain::Tuple(domains) => {
890                write!(
891                    f,
892                    "tuple of ({})",
893                    pretty_vec(&domains.iter().collect_vec())
894                )
895            }
896            Domain::Record(entries) => {
897                write!(
898                    f,
899                    "record of ({})",
900                    pretty_vec(
901                        &entries
902                            .iter()
903                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
904                            .collect_vec()
905                    )
906                )
907            }
908            Domain::Empty(return_type) => write!(f, "empty({return_type:?}"),
909        }
910    }
911}
912
913impl Typeable for Domain {
914    fn return_type(&self) -> Option<ReturnType> {
915        match self {
916            Domain::Bool => Some(ReturnType::Bool),
917            Domain::Int(_) => Some(ReturnType::Int),
918            Domain::Empty(return_type) => Some(return_type.clone()),
919            Domain::Set(_, domain) => Some(ReturnType::Set(Box::new(domain.return_type()?))),
920            Domain::Reference(_) => None, // todo!("add ReturnType for Domain::Reference"),
921            Domain::Matrix(item_domain, index_domains) => {
922                assert!(
923                    !index_domains.is_empty(),
924                    "a matrix should have atleast one dimension"
925                );
926                let mut return_type = ReturnType::Matrix(Box::new(item_domain.return_type()?));
927
928                for _ in 0..(index_domains.len() - 1) {
929                    return_type = ReturnType::Matrix(Box::new(return_type));
930                }
931
932                Some(return_type)
933            }
934            Domain::Tuple(items) => {
935                let mut item_types = vec![];
936                for item in items {
937                    item_types.push(item.return_type()?);
938                }
939                Some(ReturnType::Tuple(item_types))
940            }
941            Domain::Record(items) => {
942                let mut item_types = vec![];
943                for item in items {
944                    item_types.push(item.domain.return_type()?);
945                }
946                Some(ReturnType::Record(item_types))
947            }
948        }
949    }
950}
951
952/// An error thrown by an operation on domains.
953#[non_exhaustive]
954#[derive(Clone, Debug, PartialEq, Eq, Error)]
955#[allow(clippy::enum_variant_names)] // all variant names start with Input at the moment, but that is ok.
956pub enum DomainOpError {
957    /// The operation only supports bounded / finite domains, but was given an unbounded input domain.
958    #[error(
959        "The operation only supports bounded / finite domains, but was given an unbounded input domain."
960    )]
961    InputUnbounded,
962
963    /// The operation only supports integer input domains, but was given an input domain of a
964    /// different type.
965    #[error("The operation only supports integer input domains, but got a {0:?} input domain.")]
966    InputNotInteger(ReturnType),
967
968    /// The operation was given an input domain of the wrong type.
969    #[error("The operation was given input domains of the wrong type.")]
970    InputWrongType,
971
972    /// The operation failed as the input domain contained a reference.
973    #[error("The operation failed as the input domain contained a reference")]
974    InputContainsReference,
975}
976
977/// Types that have a [`Domain`].
978pub trait HasDomain {
979    /// Gets the [`Domain`] of `self`.
980    fn domain_of(&self) -> Domain;
981
982    /// Gets the [`Domain`] of `self`, replacing any references with their domains stored in from the symbol table.
983    ///
984    /// # Panics
985    ///
986    /// - If a symbol referenced in `self` does not exist in the symbol table.
987    fn resolved_domain_of(&self, symbol_table: &SymbolTable) -> Domain {
988        self.domain_of().resolve(symbol_table)
989    }
990}
991
992impl<T: HasDomain> Typeable for T {
993    fn return_type(&self) -> Option<ReturnType> {
994        self.domain_of().return_type()
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001
1002    #[test]
1003    fn test_negative_product() {
1004        let d1 = Domain::Int(vec![Range::Bounded(-2, 1)]);
1005        let d2 = Domain::Int(vec![Range::Bounded(-2, 1)]);
1006        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
1007
1008        assert!(matches!(res, Domain::Int(_)));
1009        if let Domain::Int(ranges) = res {
1010            assert!(!ranges.contains(&Range::Bounded(-4, 4)));
1011        }
1012    }
1013
1014    #[test]
1015    fn test_negative_div() {
1016        let d1 = Domain::Int(vec![Range::Bounded(-2, 1)]);
1017        let d2 = Domain::Int(vec![Range::Bounded(-2, 1)]);
1018        let res = d1
1019            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
1020            .unwrap();
1021
1022        assert!(matches!(res, Domain::Int(_)));
1023        if let Domain::Int(ranges) = res {
1024            assert!(!ranges.contains(&Range::Bounded(-4, 4)));
1025        }
1026    }
1027}