1use itertools::Itertools;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Display, Formatter};
4use std::hash::Hash;
5use std::hash::Hasher;
6use ustr::Ustr;
7
8use uniplate::{Biplate, Tree, Uniplate};
9
10use crate::ast::pretty::pretty_vec;
11use crate::metadata::Metadata;
12
13use super::domains::HasDomain;
14use super::{Atom, Domain, Expression, Range, records::RecordValue};
15use super::{Moo, ReturnType, SetAttr, Typeable};
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate, Hash)]
18#[uniplate(walk_into=[AbstractLiteral<Literal>])]
19#[biplate(to=Atom)]
20#[biplate(to=AbstractLiteral<Literal>)]
21#[biplate(to=AbstractLiteral<Expression>)]
22#[biplate(to=RecordValue<Literal>)]
23#[biplate(to=RecordValue<Expression>)]
24#[biplate(to=Expression)]
25pub enum Literal {
27 Int(i32),
28 Bool(bool),
29 #[allow(clippy::enum_variant_names)]
31 AbstractLiteral(AbstractLiteral<Literal>),
32}
33
34impl HasDomain for Literal {
35 fn domain_of(&self) -> Domain {
36 match self {
37 Literal::Int(i) => Domain::Int(vec![Range::Single(*i)]),
38 Literal::Bool(_) => Domain::Bool,
39 Literal::AbstractLiteral(abstract_literal) => abstract_literal.domain_of(),
40 }
41 }
42}
43
44pub trait AbstractLiteralValue:
46 Clone + Eq + PartialEq + Display + Uniplate + Biplate<RecordValue<Self>> + 'static
47{
48}
49impl AbstractLiteralValue for Expression {}
50impl AbstractLiteralValue for Literal {}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub enum AbstractLiteral<T: AbstractLiteralValue> {
54 Set(Vec<T>),
55
56 Matrix(Vec<T>, Box<Domain>),
58
59 Tuple(Vec<T>),
61
62 Record(Vec<RecordValue<T>>),
63}
64
65impl AbstractLiteral<Expression> {
67 pub fn domain_of(&self) -> Option<Domain> {
68 match self {
69 AbstractLiteral::Set(items) => {
70 let item_domains: Vec<Domain> = items
72 .iter()
73 .map(|x| x.domain_of())
74 .collect::<Option<Vec<Domain>>>()?;
75
76 let mut item_domain_iter = item_domains.iter().cloned();
78 let first_item = item_domain_iter.next()?;
79 let item_domain = item_domains
80 .iter()
81 .try_fold(first_item, |x: Domain, y| x.union(y))
82 .expect("taking the union of all item domains of a set literal should succeed");
83
84 Some(Domain::Set(SetAttr::None, Box::new(item_domain)))
85 }
86
87 AbstractLiteral::Matrix(items, _) => {
88 let item_domains: Vec<Domain> = items
90 .iter()
91 .map(|x| x.domain_of())
92 .collect::<Option<Vec<Domain>>>()?;
93
94 let mut item_domain_iter = item_domains.iter().cloned();
96
97 let first_item = item_domain_iter.next()?;
98
99 let item_domain = item_domains
100 .iter()
101 .try_fold(first_item, |x: Domain, y| x.union(y))
102 .expect(
103 "taking the union of all item domains of a matrix literal should succeed",
104 );
105
106 let mut new_index_domain = vec![];
107
108 let mut e = Expression::AbstractLiteral(Metadata::new(), self.clone());
110 while let Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, idx)) = e {
111 assert!(
112 !matches!(idx.as_ref(), Domain::Matrix(_, _)),
113 "n-dimensional matrix literals should be represented as a matrix inside a matrix"
114 );
115 new_index_domain.push(idx.as_ref().clone());
116 e = elems[0].clone();
117 }
118 Some(Domain::Matrix(Box::new(item_domain), new_index_domain))
119 }
120 AbstractLiteral::Tuple(_) => None,
121 AbstractLiteral::Record(_) => None,
122 }
123 }
124}
125
126impl HasDomain for AbstractLiteral<Literal> {
127 fn domain_of(&self) -> Domain {
128 Domain::from_literal_vec(vec![Literal::AbstractLiteral(self.clone())])
129 .expect("abstract literals should be correctly typed")
130 }
131}
132
133impl Typeable for AbstractLiteral<Expression> {
134 fn return_type(&self) -> Option<ReturnType> {
135 match self {
136 AbstractLiteral::Set(items) if items.is_empty() => {
137 Some(ReturnType::Set(Box::new(ReturnType::Unknown)))
138 }
139 AbstractLiteral::Set(items) => {
140 let item_type = items[0].return_type()?;
141
142 let item_types: Option<Vec<ReturnType>> =
144 items.iter().map(|x| x.return_type()).collect();
145
146 let item_types = item_types?;
147
148 assert!(
149 item_types.iter().all(|x| x == &item_type),
150 "all items in a set should have the same type"
151 );
152
153 Some(ReturnType::Set(Box::new(item_type)))
154 }
155 AbstractLiteral::Matrix(items, _) if items.is_empty() => {
156 Some(ReturnType::Matrix(Box::new(ReturnType::Unknown)))
157 }
158 AbstractLiteral::Matrix(items, _) => {
159 let item_type = items[0].return_type()?;
160
161 let item_types: Option<Vec<ReturnType>> =
163 items.iter().map(|x| x.return_type()).collect();
164
165 let item_types = item_types?;
166
167 assert!(
168 item_types.iter().all(|x| x == &item_type),
169 "all items in a matrix should have the same type. items: {items} types: {types:#?}",
170 items = pretty_vec(items),
171 types = items
172 .iter()
173 .map(|x| x.return_type())
174 .collect::<Vec<Option<ReturnType>>>()
175 );
176
177 Some(ReturnType::Matrix(Box::new(item_type)))
178 }
179 AbstractLiteral::Tuple(items) => {
180 let mut item_types = vec![];
181 for item in items {
182 item_types.push(item.return_type()?);
183 }
184 Some(ReturnType::Tuple(item_types))
185 }
186 AbstractLiteral::Record(items) => {
187 let mut item_types = vec![];
188 for item in items {
189 item_types.push(item.value.return_type()?);
190 }
191 Some(ReturnType::Record(item_types))
192 }
193 }
194 }
195}
196
197impl<T> AbstractLiteral<T>
198where
199 T: AbstractLiteralValue,
200{
201 pub fn matrix_implied_indices(elems: Vec<T>) -> Self {
205 AbstractLiteral::Matrix(elems, Box::new(Domain::Int(vec![Range::UnboundedR(1)])))
206 }
207
208 pub fn unwrap_list(&self) -> Option<&Vec<T>> {
213 let AbstractLiteral::Matrix(elems, domain) = self else {
214 return None;
215 };
216
217 let Domain::Int(ranges) = domain.as_ref() else {
218 return None;
219 };
220
221 let [Range::UnboundedR(1)] = ranges[..] else {
222 return None;
223 };
224
225 Some(elems)
226 }
227}
228
229impl<T> Display for AbstractLiteral<T>
230where
231 T: AbstractLiteralValue,
232{
233 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234 match self {
235 AbstractLiteral::Set(elems) => {
236 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
237 write!(f, "{{{elems_str}}}")
238 }
239 AbstractLiteral::Matrix(elems, index_domain) => {
240 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
241 write!(f, "[{elems_str};{index_domain}]")
242 }
243 AbstractLiteral::Tuple(elems) => {
244 let elems_str: String = elems.iter().map(|x| format!("{x}")).join(",");
245 write!(f, "({elems_str})")
246 }
247 AbstractLiteral::Record(entries) => {
248 let entries_str: String = entries
249 .iter()
250 .map(|entry| format!("{}: {}", entry.name, entry.value))
251 .join(",");
252 write!(f, "{{{entries_str}}}")
253 }
254 }
255 }
256}
257
258impl Hash for AbstractLiteral<Literal> {
259 fn hash<H: Hasher>(&self, state: &mut H) {
260 match self {
261 AbstractLiteral::Set(vec) => {
262 0.hash(state);
263 vec.hash(state);
264 }
265 AbstractLiteral::Matrix(elems, index_domain) => {
266 1.hash(state);
267 elems.hash(state);
268 index_domain.hash(state);
269 }
270 AbstractLiteral::Tuple(elems) => {
271 2.hash(state);
272 elems.hash(state);
273 }
274 AbstractLiteral::Record(entries) => {
275 3.hash(state);
276 entries.hash(state);
277 }
278 }
279 }
280}
281
282impl<T> Uniplate for AbstractLiteral<T>
283where
284 T: AbstractLiteralValue + Biplate<AbstractLiteral<T>>,
285{
286 fn uniplate(&self) -> (Tree<Self>, Box<dyn Fn(Tree<Self>) -> Self>) {
287 match self {
289 AbstractLiteral::Set(vec) => {
290 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(vec);
291 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
292 }
293 AbstractLiteral::Matrix(elems, index_domain) => {
294 let index_domain = index_domain.clone();
295 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
296 (
297 f1_tree,
298 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
299 )
300 }
301 AbstractLiteral::Tuple(elems) => {
302 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(elems);
303 (
304 f1_tree,
305 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
306 )
307 }
308 AbstractLiteral::Record(entries) => {
309 let (f1_tree, f1_ctx) = <_ as Biplate<AbstractLiteral<T>>>::biplate(entries);
310 (
311 f1_tree,
312 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
313 )
314 }
315 }
316 }
317}
318
319impl<U, To> Biplate<To> for AbstractLiteral<U>
320where
321 To: Uniplate,
322 U: AbstractLiteralValue + Biplate<AbstractLiteral<U>> + Biplate<To>,
323 RecordValue<U>: Biplate<AbstractLiteral<U>> + Biplate<To>,
324{
325 fn biplate(&self) -> (Tree<To>, Box<dyn Fn(Tree<To>) -> Self>) {
326 if std::any::TypeId::of::<To>() == std::any::TypeId::of::<AbstractLiteral<U>>() {
327 unsafe {
330 let self_to = std::mem::transmute::<&AbstractLiteral<U>, &To>(self).clone();
332 let tree = Tree::One(self_to);
333 let ctx = Box::new(move |x| {
334 let Tree::One(x) = x else {
335 panic!();
336 };
337
338 std::mem::transmute::<&To, &AbstractLiteral<U>>(&x).clone()
339 });
340
341 (tree, ctx)
342 }
343 } else {
344 match self {
346 AbstractLiteral::Set(vec) => {
347 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(vec);
348 (f1_tree, Box::new(move |x| AbstractLiteral::Set(f1_ctx(x))))
349 }
350 AbstractLiteral::Matrix(elems, index_domain) => {
351 let index_domain = index_domain.clone();
352 let (f1_tree, f1_ctx) = <Vec<U> as Biplate<To>>::biplate(elems);
353 (
354 f1_tree,
355 Box::new(move |x| AbstractLiteral::Matrix(f1_ctx(x), index_domain.clone())),
356 )
357 }
358 AbstractLiteral::Tuple(elems) => {
359 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(elems);
360 (
361 f1_tree,
362 Box::new(move |x| AbstractLiteral::Tuple(f1_ctx(x))),
363 )
364 }
365 AbstractLiteral::Record(entries) => {
366 let (f1_tree, f1_ctx) = <_ as Biplate<To>>::biplate(entries);
367 (
368 f1_tree,
369 Box::new(move |x| AbstractLiteral::Record(f1_ctx(x))),
370 )
371 }
372 }
373 }
374 }
375}
376
377impl TryFrom<Literal> for i32 {
378 type Error = &'static str;
379
380 fn try_from(value: Literal) -> Result<Self, Self::Error> {
381 match value {
382 Literal::Int(i) => Ok(i),
383 _ => Err("Cannot convert non-i32 literal to i32"),
384 }
385 }
386}
387
388impl TryFrom<Box<Literal>> for i32 {
389 type Error = &'static str;
390
391 fn try_from(value: Box<Literal>) -> Result<Self, Self::Error> {
392 (*value).try_into()
393 }
394}
395
396impl TryFrom<&Box<Literal>> for i32 {
397 type Error = &'static str;
398
399 fn try_from(value: &Box<Literal>) -> Result<Self, Self::Error> {
400 TryFrom::<&Literal>::try_from(value.as_ref())
401 }
402}
403
404impl TryFrom<&Moo<Literal>> for i32 {
405 type Error = &'static str;
406
407 fn try_from(value: &Moo<Literal>) -> Result<Self, Self::Error> {
408 TryFrom::<&Literal>::try_from(value.as_ref())
409 }
410}
411
412impl TryFrom<&Literal> for i32 {
413 type Error = &'static str;
414
415 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
416 match value {
417 Literal::Int(i) => Ok(*i),
418 _ => Err("Cannot convert non-i32 literal to i32"),
419 }
420 }
421}
422
423impl TryFrom<Literal> for bool {
424 type Error = &'static str;
425
426 fn try_from(value: Literal) -> Result<Self, Self::Error> {
427 match value {
428 Literal::Bool(b) => Ok(b),
429 _ => Err("Cannot convert non-bool literal to bool"),
430 }
431 }
432}
433
434impl TryFrom<&Literal> for bool {
435 type Error = &'static str;
436
437 fn try_from(value: &Literal) -> Result<Self, Self::Error> {
438 match value {
439 Literal::Bool(b) => Ok(*b),
440 _ => Err("Cannot convert non-bool literal to bool"),
441 }
442 }
443}
444
445impl From<i32> for Literal {
446 fn from(i: i32) -> Self {
447 Literal::Int(i)
448 }
449}
450
451impl From<bool> for Literal {
452 fn from(b: bool) -> Self {
453 Literal::Bool(b)
454 }
455}
456
457impl From<Literal> for Ustr {
458 fn from(value: Literal) -> Self {
459 Ustr::from(&format!("{value}"))
461 }
462}
463
464impl AbstractLiteral<Expression> {
465 pub fn into_literals(self) -> Option<AbstractLiteral<Literal>> {
468 match self {
469 AbstractLiteral::Set(_) => todo!(),
470 AbstractLiteral::Matrix(items, domain) => {
471 let mut literals = vec![];
472 for item in items {
473 let literal = match item {
474 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
475 Expression::AbstractLiteral(_, abslit) => {
476 Some(Literal::AbstractLiteral(abslit.into_literals()?))
477 }
478 _ => None,
479 }?;
480 literals.push(literal);
481 }
482
483 Some(AbstractLiteral::Matrix(literals, domain))
484 }
485 AbstractLiteral::Tuple(items) => {
486 let mut literals = vec![];
487 for item in items {
488 let literal = match item {
489 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
490 Expression::AbstractLiteral(_, abslit) => {
491 Some(Literal::AbstractLiteral(abslit.into_literals()?))
492 }
493 _ => None,
494 }?;
495 literals.push(literal);
496 }
497
498 Some(AbstractLiteral::Tuple(literals))
499 }
500 AbstractLiteral::Record(entries) => {
501 let mut literals = vec![];
502 for entry in entries {
503 let literal = match entry.value {
504 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
505 Expression::AbstractLiteral(_, abslit) => {
506 Some(Literal::AbstractLiteral(abslit.into_literals()?))
507 }
508 _ => None,
509 }?;
510
511 literals.push((entry.name, literal));
512 }
513 Some(AbstractLiteral::Record(
514 literals
515 .into_iter()
516 .map(|(name, literal)| RecordValue {
517 name,
518 value: literal,
519 })
520 .collect(),
521 ))
522 }
523 }
524 }
525}
526
527impl Display for Literal {
529 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
530 match &self {
531 Literal::Int(i) => write!(f, "{i}"),
532 Literal::Bool(b) => write!(f, "{b}"),
533 Literal::AbstractLiteral(l) => write!(f, "{l:?}"),
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540
541 use super::*;
542 use crate::{into_matrix, matrix};
543 use uniplate::Uniplate;
544
545 #[test]
546 fn matrix_uniplate_universe() {
547 let my_matrix: AbstractLiteral<Literal> = into_matrix![
549 vec![Literal::AbstractLiteral(matrix![Literal::Bool(true);Domain::Bool]); 5];
550 Domain::Bool
551 ];
552
553 let expected_index_domains = vec![Domain::Bool; 6];
554 let actual_index_domains: Vec<Domain> = my_matrix.cata(&move |elem, children| {
555 let mut res = vec![];
556 res.extend(children.into_iter().flatten());
557 if let AbstractLiteral::Matrix(_, index_domain) = elem {
558 res.push(*index_domain);
559 }
560
561 res
562 });
563
564 assert_eq!(actual_index_domains, expected_index_domains);
565 }
566}