conjure_cp_core/ast/
literals.rs

1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use std::hash::Hasher;
6use ustr::Ustr;
7
8use uniplate::{Biplate, Tree, Uniplate};
9
10use crate::ast::pretty::pretty_vec;
11use crate::metadata::Metadata;
12
13use super::domains::HasDomain;
14use super::{Atom, Domain, Expression, Range, records::RecordValue};
15use super::{Moo, ReturnType, SetAttr, Typeable};
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash)]
18#[uniplate(walk_into=[AbstractLiteral<Literal>])]
19#[biplate(to=Atom)]
20#[biplate(to=AbstractLiteral<Literal>)]
21#[biplate(to=AbstractLiteral<Expression>)]
22#[biplate(to=RecordValue<Literal>)]
23#[biplate(to=RecordValue<Expression>)]
24#[biplate(to=Expression)]
25/// A literal value, equivalent to constants in Conjure.
26pub enum Literal {
27    Int(i32),
28    Bool(bool),
29    //abstract literal variant ends in Literal, but that's ok
30    #[allow(clippy::enum_variant_names)]
31    AbstractLiteral(AbstractLiteral<Literal>),
32}
33
34impl HasDomain for Literal {
35    fn domain_of(&self) -> Domain {
36        match self {
37            Literal::Int(i) => Domain::Int(vec![Range::Single(*i)]),
38            Literal::Bool(_) => Domain::Bool,
39            Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
40        }
41    }
42}
43
44// make possible values of an AbstractLiteral a closed world to make the trait bounds more sane (particularly in Uniplate instances!!)
45pub trait AbstractLiteralValue:
46    Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
47{
48}
49impl AbstractLiteralValue for Expression {}
50impl AbstractLiteralValue for Literal {}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub enum AbstractLiteral<T: AbstractLiteralValue> {
54    Set(Vec<T>),
55
56    /// A 1 dimensional matrix slice with an index domain.
57    Matrix(Vec<T>, Box<Domain>),
58
59    // a tuple of literals
60    Tuple(Vec<T>),
61
62    Record(Vec<RecordValue<T>>),
63}
64
65// TODO: use HasDomain instead once Expression::domain_of returns Domain not Option<Domain>
66impl AbstractLiteral<Expression> {
67    pub fn domain_of(&self) -> Option<Domain> {
68        match self {
69            AbstractLiteral::Set(items) => {
70                // ensure that all items have a domain, or return None
71                let item_domains: Vec<Domain> = items
72                    .iter()
73                    .map(|x| x.domain_of())
74                    .collect::<Option<Vec<Domain>>>()?;
75
76                // union all item domains together
77                let mut item_domain_iter = item_domains.iter().cloned();
78                let first_item = item_domain_iter.next()?;
79                let item_domain = item_domains
80                    .iter()
81                    .try_fold(first_item, |x: Domain, y| x.union(y))
82                    .expect("taking the union of all item domains of a set literal should succeed");
83
84                Some(Domain::Set(SetAttr::None, Box::new(item_domain)))
85            }
86
87            AbstractLiteral::Matrix(items, _) => {
88                // ensure that all items have a domain, or return None
89                let item_domains: Vec<Domain> = items
90                    .iter()
91                    .map(|x| x.domain_of())
92                    .collect::<Option<Vec<Domain>>>()?;
93
94                // union all item domains together
95                let mut item_domain_iter = item_domains.iter().cloned();
96
97                let first_item = item_domain_iter.next()?;
98
99                let item_domain = item_domains
100                    .iter()
101                    .try_fold(first_item, |x: Domain, y| x.union(y))
102                    .expect(
103                        "taking the union of all item domains of a matrix literal should succeed",
104                    );
105
106                let mut new_index_domain = vec![];
107
108                // flatten index domains of n-d matrix into list
109                let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
110                while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
111                    assert!(
112                        !matches!(idx.as_ref(), Domain::Matrix(_, _)),
113                        "n-dimensional matrix literals should be represented as a matrix inside a matrix"
114                    );
115                    new_index_domain.push(idx.as_ref().clone());
116                    e = elems[0].clone();
117                }
118                Some(Domain::Matrix(Box::new(item_domain), new_index_domain))
119            }
120            AbstractLiteral::Tuple(_) => None,
121            AbstractLiteral::Record(_) => None,
122        }
123    }
124}
125
126impl HasDomain for AbstractLiteral<Literal> {
127    fn domain_of(&self) -> Domain {
128        Domain::from_literal_vec(vec![Literal::AbstractLiteral(self.clone())])
129            .expect("abstract literals should be correctly typed")
130    }
131}
132
133impl Typeable for AbstractLiteral<Expression> {
134    fn return_type(&self) -> Option<ReturnType> {
135        match self {
136            AbstractLiteral::Set(items) if items.is_empty() => {
137                Some(ReturnType::Set(Box::new(ReturnType::Unknown)))
138            }
139            AbstractLiteral::Set(items) => {
140                let item_type = items[0].return_type()?;
141
142                // if any items do not have a type, return none.
143                let item_types: Option<Vec<ReturnType>> =
144                    items.iter().map(|x| x.return_type()).collect();
145
146                let item_types = item_types?;
147
148                assert!(
149                    item_types.iter().all(|x| x == &item_type),
150                    "all items in a set should have the same type"
151                );
152
153                Some(ReturnType::Set(Box::new(item_type)))
154            }
155            AbstractLiteral::Matrix(items, _) if items.is_empty() => {
156                Some(ReturnType::Matrix(Box::new(ReturnType::Unknown)))
157            }
158            AbstractLiteral::Matrix(items, _) => {
159                let item_type = items[0].return_type()?;
160
161                // if any items do not have a type, return none.
162                let item_types: Option<Vec<ReturnType>> =
163                    items.iter().map(|x| x.return_type()).collect();
164
165                let item_types = item_types?;
166
167                assert!(
168                    item_types.iter().all(|x| x == &item_type),
169                    "all items in a matrix should have the same type. items: {items} types: {types:#?}",
170                    items = pretty_vec(items),
171                    types = items
172                        .iter()
173                        .map(|x| x.return_type())
174                        .collect::<Vec<Option<ReturnType>>>()
175                );
176
177                Some(ReturnType::Matrix(Box::new(item_type)))
178            }
179            AbstractLiteral::Tuple(items) => {
180                let mut item_types = vec![];
181                for item in items {
182                    item_types.push(item.return_type()?);
183                }
184                Some(ReturnType::Tuple(item_types))
185            }
186            AbstractLiteral::Record(items) => {
187                let mut item_types = vec![];
188                for item in items {
189                    item_types.push(item.value.return_type()?);
190                }
191                Some(ReturnType::Record(item_types))
192            }
193        }
194    }
195}
196
197impl<T> AbstractLiteral<T>
198where
199    T: AbstractLiteralValue,
200{
201    /// Creates a matrix with elements `elems`, with domain `int(1..)`.
202    ///
203    /// This acts as a variable sized list.
204    pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
205        AbstractLiteral::Matrix(elems, Box::new(Domain::Int(vec![Range::UnboundedR(1)])))
206    }
207
208    /// If the AbstractLiteral is a list, returns its elements.
209    ///
210    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
211    /// any explicitly specified domain.
212    pub fn unwrap_list(&self) -> Option<&Vec<T>> {
213        let AbstractLiteral::Matrix(elems, domain) = self else {
214            return None;
215        };
216
217        let Domain::Int(ranges) = domain.as_ref() else {
218            return None;
219        };
220
221        let [Range::UnboundedR(1)] = ranges[..] else {
222            return None;
223        };
224
225        Some(elems)
226    }
227}
228
229impl<T> Display for AbstractLiteral<T>
230where
231    T: AbstractLiteralValue,
232{
233    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234        match self {
235            AbstractLiteral::Set(elems) => {
236                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
237                write!(f, "{{{elems_str}}}")
238            }
239            AbstractLiteral::Matrix(elems, index_domain) => {
240                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
241                write!(f, "[{elems_str};{index_domain}]")
242            }
243            AbstractLiteral::Tuple(elems) => {
244                let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
245                write!(f, "({elems_str})")
246            }
247            AbstractLiteral::Record(entries) => {
248                let entries_str: String = entries
249                    .iter()
250                    .map(|entry| format!("{}: {}", entry.name, entry.value))
251                    .join(",");
252                write!(f, "{{{entries_str}}}")
253            }
254        }
255    }
256}
257
258impl Hash for AbstractLiteral<Literal> {
259    fn hash<H: Hasher>(&self, state: &mut H) {
260        match self {
261            AbstractLiteral::Set(vec) => {
262                0.hash(state);
263                vec.hash(state);
264            }
265            AbstractLiteral::Matrix(elems, index_domain) => {
266                1.hash(state);
267                elems.hash(state);
268                index_domain.hash(state);
269            }
270            AbstractLiteral::Tuple(elems) => {
271                2.hash(state);
272                elems.hash(state);
273            }
274            AbstractLiteral::Record(entries) => {
275                3.hash(state);
276                entries.hash(state);
277            }
278        }
279    }
280}
281
282impl<T> Uniplate for AbstractLiteral<T>
283where
284    T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
285{
286    fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
287        // walking into T
288        match self {
289            AbstractLiteral::Set(vec) => {
290                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
291                (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
292            }
293            AbstractLiteral::Matrix(elems, index_domain) => {
294                let index_domain = index_domain.clone();
295                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
296                (
297                    f1_tree,
298                    Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
299                )
300            }
301            AbstractLiteral::Tuple(elems) => {
302                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
303                (
304                    f1_tree,
305                    Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
306                )
307            }
308            AbstractLiteral::Record(entries) => {
309                let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
310                (
311                    f1_tree,
312                    Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
313                )
314            }
315        }
316    }
317}
318
319impl<U, To> Biplate<To> for AbstractLiteral<U>
320where
321    To: Uniplate,
322    U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
323    RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
324{
325    fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
326        if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
327            // To ==From => return One(self)
328
329            unsafe {
330                // SAFETY: asserted the type equality above
331                let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
332                let tree = Tree::One(self_to);
333                let ctx = Box::new(move |x| {
334                    let Tree::One(x) = x else {
335                        panic!();
336                    };
337
338                    std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
339                });
340
341                (tree, ctx)
342            }
343        } else {
344            // walking into T
345            match self {
346                AbstractLiteral::Set(vec) => {
347                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
348                    (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
349                }
350                AbstractLiteral::Matrix(elems, index_domain) => {
351                    let index_domain = index_domain.clone();
352                    let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
353                    (
354                        f1_tree,
355                        Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
356                    )
357                }
358                AbstractLiteral::Tuple(elems) => {
359                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
360                    (
361                        f1_tree,
362                        Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
363                    )
364                }
365                AbstractLiteral::Record(entries) => {
366                    let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
367                    (
368                        f1_tree,
369                        Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
370                    )
371                }
372            }
373        }
374    }
375}
376
377impl TryFrom<Literal> for i32 {
378    type Error = &'static str;
379
380    fn try_from(value: Literal) -> Result<Self, Self::Error> {
381        match value {
382            Literal::Int(i) => Ok(i),
383            _ => Err("Cannot convert non-i32 literal to i32"),
384        }
385    }
386}
387
388impl TryFrom<Box<Literal>> for i32 {
389    type Error = &'static str;
390
391    fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
392        (*value).try_into()
393    }
394}
395
396impl TryFrom<&Box<Literal>> for i32 {
397    type Error = &'static str;
398
399    fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
400        TryFrom::<&Literal>::try_from(value.as_ref())
401    }
402}
403
404impl TryFrom<&Moo<Literal>> for i32 {
405    type Error = &'static str;
406
407    fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
408        TryFrom::<&Literal>::try_from(value.as_ref())
409    }
410}
411
412impl TryFrom<&Literal> for i32 {
413    type Error = &'static str;
414
415    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
416        match value {
417            Literal::Int(i) => Ok(*i),
418            _ => Err("Cannot convert non-i32 literal to i32"),
419        }
420    }
421}
422
423impl TryFrom<Literal> for bool {
424    type Error = &'static str;
425
426    fn try_from(value: Literal) -> Result<Self, Self::Error> {
427        match value {
428            Literal::Bool(b) => Ok(b),
429            _ => Err("Cannot convert non-bool literal to bool"),
430        }
431    }
432}
433
434impl TryFrom<&Literal> for bool {
435    type Error = &'static str;
436
437    fn try_from(value: &Literal) -> Result<Self, Self::Error> {
438        match value {
439            Literal::Bool(b) => Ok(*b),
440            _ => Err("Cannot convert non-bool literal to bool"),
441        }
442    }
443}
444
445impl From<i32> for Literal {
446    fn from(i: i32) -> Self {
447        Literal::Int(i)
448    }
449}
450
451impl From<bool> for Literal {
452    fn from(b: bool) -> Self {
453        Literal::Bool(b)
454    }
455}
456
457impl From<Literal> for Ustr {
458    fn from(value: Literal) -> Self {
459        // TODO: avoid the temporary-allocation of a string by format! here?
460        Ustr::from(&format!("{value}"))
461    }
462}
463
464impl AbstractLiteral<Expression> {
465    /// If all the elements are literals, returns this as an AbstractLiteral<Literal>.
466    /// Otherwise, returns `None`.
467    pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
468        match self {
469            AbstractLiteral::Set(_) => todo!(),
470            AbstractLiteral::Matrix(items, domain) => {
471                let mut literals = vec![];
472                for item in items {
473                    let literal = match item {
474                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
475                        Expression::AbstractLiteral(_, abslit) => {
476                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
477                        }
478                        _ => None,
479                    }?;
480                    literals.push(literal);
481                }
482
483                Some(AbstractLiteral::Matrix(literals, domain))
484            }
485            AbstractLiteral::Tuple(items) => {
486                let mut literals = vec![];
487                for item in items {
488                    let literal = match item {
489                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
490                        Expression::AbstractLiteral(_, abslit) => {
491                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
492                        }
493                        _ => None,
494                    }?;
495                    literals.push(literal);
496                }
497
498                Some(AbstractLiteral::Tuple(literals))
499            }
500            AbstractLiteral::Record(entries) => {
501                let mut literals = vec![];
502                for entry in entries {
503                    let literal = match entry.value {
504                        Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
505                        Expression::AbstractLiteral(_, abslit) => {
506                            Some(Literal::AbstractLiteral(abslit.into_literals()?))
507                        }
508                        _ => None,
509                    }?;
510
511                    literals.push((entry.name, literal));
512                }
513                Some(AbstractLiteral::Record(
514                    literals
515                        .into_iter()
516                        .map(|(name, literal)| RecordValue {
517                            name,
518                            value: literal,
519                        })
520                        .collect(),
521                ))
522            }
523        }
524    }
525}
526
527// need display implementations for other types as well
528impl Display for Literal {
529    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
530        match &self {
531            Literal::Int(i) => write!(f, "{i}"),
532            Literal::Bool(b) => write!(f, "{b}"),
533            Literal::AbstractLiteral(l) => write!(f, "{l:?}"),
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540
541    use super::*;
542    use crate::{into_matrix, matrix};
543    use uniplate::Uniplate;
544
545    #[test]
546    fn matrix_uniplate_universe() {
547        // Can we traverse through matrices with uniplate?
548        let my_matrix: AbstractLiteral<Literal> = into_matrix![
549            vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Domain::Bool]); 5];
550            Domain::Bool
551        ];
552
553        let expected_index_domains = vec![Domain::Bool; 6];
554        let actual_index_domains: Vec<Domain> = my_matrix.cata(&move |elem, children| {
555            let mut res = vec![];
556            res.extend(children.into_iter().flatten());
557            if let AbstractLiteral::Matrix(_, index_domain) = elem {
558                res.push(*index_domain);
559            }
560
561            res
562        });
563
564        assert_eq!(actual_index_domains, expected_index_domains);
565    }
566}