1#![allow(clippy::arc_with_non_send_sync)]
2
3mod expand_ac;
4
5use std::{
6 cell::RefCell,
7 collections::{BTreeSet, HashMap},
8 fmt::Display,
9 rc::Rc,
10 sync::{
11 Arc, Mutex, RwLock,
12 atomic::{AtomicBool, Ordering},
13 },
14};
15
16use expand_ac::add_return_expression_to_generator_model;
17use itertools::Itertools as _;
18use serde::{Deserialize, Serialize};
19use tracing::warn;
20use uniplate::{Biplate, Uniplate};
21
22use crate::{
23 ast::{
24 Atom, DeclarationKind,
25 serde::{HasId as _, ObjId},
26 },
27 bug,
28 context::Context,
29 into_matrix_expr, matrix_expr,
30 metadata::Metadata,
31 rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
32 solver::{Solver, SolverError},
33};
34
35use super::{
36 DeclarationPtr, Domain, Expression, Model, Moo, Name, Range, SubModel, SymbolTable,
37 ac_operators::ACOperatorKind,
38};
39
40pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
46
47pub enum ComprehensionKind {
56 Sum,
57 And,
58 Or,
59}
60#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
62#[biplate(to=SubModel)]
63#[biplate(to=Expression)]
64pub struct Comprehension {
65 return_expression_submodel: SubModel,
66 generator_submodel: SubModel,
67 induction_vars: Vec<Name>,
68}
69
70impl Comprehension {
71 pub fn domain_of(&self) -> Option<Domain> {
72 let return_expr_domain = self
73 .return_expression_submodel
74 .clone()
75 .into_single_expression()
76 .domain_of()?;
77
78 Some(Domain::Matrix(
80 Box::new(return_expr_domain),
81 vec![Domain::Int(vec![Range::UnboundedR(1)])],
82 ))
83 }
84
85 pub fn expand_simple(self, symtab: &mut SymbolTable) -> Result<Vec<Expression>, SolverError> {
95 let minion = Solver::new(crate::solver::adaptors::Minion::new());
96 let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
98
99 model.search_order = Some(self.induction_vars.clone());
101 *model.as_submodel_mut() = self.generator_submodel.clone();
102
103 let extra_rule_sets = &["Base", "Constant", "Bubble"];
107
108 let rule_sets =
109 resolve_rule_sets(crate::solver::SolverFamily::Minion, extra_rule_sets).unwrap();
110
111 let model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
112 rewrite_morph(model, &rule_sets, false)
113 } else {
114 rewrite_naive(&model, &rule_sets, false, false).unwrap()
115 };
116
117 let return_expression_submodel = self.return_expression_submodel.clone();
137 let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
138 *return_expression_model.as_submodel_mut() = return_expression_submodel;
139
140 let return_expression_model =
141 if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
142 rewrite_morph(return_expression_model, &rule_sets, false)
143 } else {
144 rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
145 };
146
147 let minion = minion.load_model(model.clone())?;
148
149 let values = Arc::new(Mutex::new(Vec::new()));
150 let values_ptr = Arc::clone(&values);
151
152 tracing::debug!(model=%model,comprehension=%self,"Minion solving comprehension (simple mode)");
153 minion.solve(Box::new(move |sols| {
154 let values = &mut *values_ptr.lock().unwrap();
156 values.push(sols);
157 true
158 }))?;
159
160 let values = values.lock().unwrap().clone();
161
162 let mut return_expressions = vec![];
163
164 for value in values {
165 let return_expression_submodel = return_expression_model.as_submodel().clone();
168 let child_symtab = return_expression_submodel.symbols().clone();
169 let return_expression = return_expression_submodel.into_single_expression();
170
171 let value: HashMap<_, _> = value
174 .into_iter()
175 .filter(|(n, _)| self.induction_vars.contains(n))
176 .collect();
177
178 let value_ptr = Arc::new(value);
179 let value_ptr_2 = Arc::clone(&value_ptr);
180
181 let return_expression = return_expression.transform_bi(&move |x: Atom| {
183 let Atom::Reference(ref ptr) = x else {
184 return x;
185 };
186
187 let Some(lit) = value_ptr_2.get(&ptr.name()) else {
189 return x;
190 };
191
192 Atom::Literal(lit.clone())
193 });
194
195 let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
205
206 for (name, decl) in child_symtab.into_iter_local() {
208 if value_ptr.get(&name).is_some()
210 && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
211 {
212 continue;
213 }
214
215 let Name::Machine(_) = &name else {
216 bug!(
217 "the symbol table of the return expression of a comprehension should only contain machine names"
218 );
219 };
220
221 let id = decl.id();
222 let new_decl = symtab.gensym(&decl.domain().unwrap());
223
224 machine_name_translations.insert(id, new_decl);
225 }
226
227 #[allow(clippy::arc_with_non_send_sync)]
229 let return_expression = return_expression.transform_bi(&move |atom: Atom| {
230 if let Atom::Reference(ref decl) = atom
231 && let id = decl.id()
232 && let Some(new_decl) = machine_name_translations.get(&id)
233 {
234 Atom::Reference(new_decl.clone())
235 } else {
236 atom
237 }
238 });
239
240 return_expressions.push(return_expression);
241 }
242
243 Ok(return_expressions)
244 }
245
246 pub fn expand_ac(
257 self,
258 symtab: &mut SymbolTable,
259 ac_operator: ACOperatorKind,
260 ) -> Result<Vec<Expression>, SolverError> {
261 let induction_vars_2 = self.induction_vars.clone();
272 let generator_symtab_ptr = Rc::clone(self.generator_submodel.symbols_ptr_unchecked());
273 let return_expression =
274 self.clone()
275 .return_expression()
276 .transform_bi(&move |decl: DeclarationPtr| {
277 if induction_vars_2.contains(&decl.name()) {
279 (*generator_symtab_ptr)
282 .borrow()
283 .lookup_local(&decl.name())
284 .unwrap()
285 } else {
286 decl
287 }
288 });
289
290 let generator_submodel = add_return_expression_to_generator_model(
294 self.generator_submodel.clone(),
295 return_expression,
296 &ac_operator,
297 );
298
299 let mut generator_model = Model::new(Arc::new(RwLock::new(Context::default())));
303
304 *generator_model.as_submodel_mut() = generator_submodel;
305
306 generator_model.search_order = Some(self.induction_vars.clone());
308
309 let extra_rule_sets = &[
310 "Base",
311 "Constant",
312 "Bubble",
313 "Better_AC_Comprehension_Expansion",
314 ];
315
316 let rule_sets =
317 resolve_rule_sets(crate::solver::SolverFamily::Minion, extra_rule_sets).unwrap();
318
319 let generator_model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
320 rewrite_morph(generator_model, &rule_sets, false)
321 } else {
322 rewrite_naive(&generator_model, &rule_sets, false, false).unwrap()
323 };
324
325 let minion = Solver::new(crate::solver::adaptors::Minion::new());
326 let minion = minion.load_model(generator_model.clone());
327
328 let minion = match minion {
329 Err(e) => {
330 warn!(why=%e,model=%generator_model,"Loading generator model failed, failing expand_ac rule");
331 return Err(e);
332 }
333 Ok(minion) => minion,
334 };
335
336 let return_expression_submodel = self.return_expression_submodel.clone();
340 let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
341 *return_expression_model.as_submodel_mut() = return_expression_submodel;
342
343 let return_expression_model =
344 if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
345 rewrite_morph(return_expression_model, &rule_sets, false)
346 } else {
347 rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
348 };
349
350 let values = Arc::new(Mutex::new(Vec::new()));
351 let values_ptr = Arc::clone(&values);
352
353 tracing::debug!(model=%generator_model,comprehension=%self,"Minion solving comprehnesion (ac mode)");
357
358 minion.solve(Box::new(move |sols| {
359 let values = &mut *values_ptr.lock().unwrap();
361 values.push(sols);
362 true
363 }))?;
364
365 let values = values.lock().unwrap().clone();
366
367 let mut return_expressions = vec![];
368
369 for value in values {
370 let return_expression_submodel = return_expression_model.as_submodel().clone();
373 let child_symtab = return_expression_submodel.symbols().clone();
374 let return_expression = return_expression_submodel.into_single_expression();
375
376 let value: HashMap<_, _> = value
379 .into_iter()
380 .filter(|(n, _)| self.induction_vars.contains(n))
381 .collect();
382
383 let value_ptr = Arc::new(value);
384 let value_ptr_2 = Arc::clone(&value_ptr);
385
386 let return_expression = return_expression.transform_bi(&move |x: Atom| {
388 let Atom::Reference(ref ptr) = x else {
389 return x;
390 };
391
392 let Some(lit) = value_ptr_2.get(&ptr.name()) else {
394 return x;
395 };
396
397 Atom::Literal(lit.clone())
398 });
399
400 let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
410
411 for (name, decl) in child_symtab.into_iter_local() {
413 if value_ptr.get(&name).is_some()
415 && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
416 {
417 continue;
418 }
419
420 let Name::Machine(_) = &name else {
421 bug!(
422 "the symbol table of the return expression of a comprehension should only contain machine names"
423 );
424 };
425
426 let id = decl.id();
427 let new_decl = symtab.gensym(&decl.domain().unwrap());
428
429 machine_name_translations.insert(id, new_decl);
430 }
431
432 #[allow(clippy::arc_with_non_send_sync)]
434 let return_expression = return_expression.transform_bi(&move |atom: Atom| {
435 if let Atom::Reference(ref decl) = atom
436 && let id = decl.id()
437 && let Some(new_decl) = machine_name_translations.get(&id)
438 {
439 Atom::Reference(new_decl.clone())
440 } else {
441 atom
442 }
443 });
444
445 return_expressions.push(return_expression);
446 }
447
448 Ok(return_expressions)
449 }
450
451 pub fn return_expression(self) -> Expression {
452 self.return_expression_submodel.into_single_expression()
453 }
454
455 pub fn replace_return_expression(&mut self, new_expr: Expression) {
456 let new_expr = match new_expr {
457 Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
458 Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
459 }
460 expr => Expression::Root(Metadata::new(), vec![expr]),
461 };
462
463 *self.return_expression_submodel.root_mut_unchecked() = new_expr;
464 }
465
466 pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
468 if self.is_induction_guard(&guard) {
469 self.generator_submodel.add_constraint(guard);
470 true
471 } else {
472 false
473 }
474 }
475
476 pub fn is_induction_guard(&self, expr: &Expression) -> bool {
478 is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
479 }
480}
481
482impl Display for Comprehension {
483 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484 let generators: String = self
485 .generator_submodel
486 .symbols()
487 .clone()
488 .into_iter_local()
489 .map(|(name, decl): (Name, DeclarationPtr)| {
490 let domain: Domain = decl.domain().unwrap();
491 (name, domain)
492 })
493 .map(|(name, domain): (Name, Domain)| format!("{name}: {domain}"))
494 .join(",");
495
496 let guards = self
497 .generator_submodel
498 .constraints()
499 .iter()
500 .map(|x| format!("{x}"))
501 .join(",");
502
503 let generators_and_guards = itertools::join([generators, guards], ",");
504
505 let expression = &self.return_expression_submodel;
506 write!(f, "[{expression} | {generators_and_guards}]")
507 }
508}
509
510#[derive(Clone, Debug, PartialEq, Eq)]
512pub struct ComprehensionBuilder {
513 guards: Vec<Expression>,
514 generator_symboltable: Rc<RefCell<SymbolTable>>,
518 return_expr_symboltable: Rc<RefCell<SymbolTable>>,
519 induction_variables: BTreeSet<Name>,
520}
521
522impl ComprehensionBuilder {
523 pub fn new(symbol_table_ptr: Rc<RefCell<SymbolTable>>) -> Self {
524 ComprehensionBuilder {
525 guards: vec![],
526 generator_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
527 symbol_table_ptr.clone(),
528 ))),
529 return_expr_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
530 symbol_table_ptr,
531 ))),
532 induction_variables: BTreeSet::new(),
533 }
534 }
535
536 pub fn generator_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
538 Rc::clone(&self.generator_symboltable)
539 }
540
541 pub fn return_expr_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
543 Rc::clone(&self.return_expr_symboltable)
544 }
545
546 pub fn guard(mut self, guard: Expression) -> Self {
547 self.guards.push(guard);
548 self
549 }
550
551 pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
552 let name = declaration.name().clone();
553 let domain = declaration.domain().unwrap();
554 assert!(!self.induction_variables.contains(&name));
555
556 self.induction_variables.insert(name.clone());
557
558 (*self.generator_symboltable)
560 .borrow_mut()
561 .insert(declaration);
562
563 (*self.return_expr_symboltable)
565 .borrow_mut()
566 .insert(DeclarationPtr::new_given(name, domain));
567
568 self
569 }
570
571 pub fn with_return_value(
576 self,
577 mut expression: Expression,
578 comprehension_kind: Option<ComprehensionKind>,
579 ) -> Comprehension {
580 let parent_symboltable = self
581 .generator_symboltable
582 .as_ref()
583 .borrow_mut()
584 .parent_mut_unchecked()
585 .clone()
586 .unwrap();
587 let mut generator_submodel = SubModel::new(parent_symboltable.clone());
588 let mut return_expression_submodel = SubModel::new(parent_symboltable);
589
590 *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
591 *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
592
593 let induction_variables = self.induction_variables;
596
597 let (mut induction_guards, mut other_guards): (Vec<_>, Vec<_>) = self
599 .guards
600 .into_iter()
601 .partition(|x| is_induction_guard(&induction_variables, x));
602
603 let induction_variables_2 = induction_variables.clone();
604 let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
605
606 induction_guards =
608 Biplate::<DeclarationPtr>::transform_bi(&induction_guards, &move |decl| {
609 if induction_variables_2.contains(&decl.name()) {
610 (*generator_symboltable_ptr)
611 .borrow()
612 .lookup_local(&decl.name())
613 .unwrap()
614 } else {
615 decl
616 }
617 })
618 .into_iter()
619 .collect_vec();
620
621 let induction_variables_2 = induction_variables.clone();
622 let return_expr_symboltable_ptr =
623 return_expression_submodel.symbols_ptr_unchecked().clone();
624
625 other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
627 if induction_variables_2.contains(&decl.name()) {
628 (*return_expr_symboltable_ptr)
629 .borrow()
630 .lookup_local(&decl.name())
631 .unwrap()
632 } else {
633 decl
634 }
635 })
636 .into_iter()
637 .collect_vec();
638
639 if !other_guards.is_empty() {
641 let comprehension_kind = comprehension_kind.expect(
642 "if any guards reference decision variables, a comprehension kind should be given",
643 );
644
645 let guard_expr = match other_guards.as_slice() {
646 [x] => x.clone(),
647 xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
648 };
649
650 expression = match comprehension_kind {
651 ComprehensionKind::And => {
652 Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
653 }
654 ComprehensionKind::Or => Expression::And(
655 Metadata::new(),
656 Moo::new(Expression::And(
657 Metadata::new(),
658 Moo::new(matrix_expr![guard_expr, expression]),
659 )),
660 ),
661
662 ComprehensionKind::Sum => {
663 panic!("guards that reference decision variables not yet implemented for sum");
664 }
665 }
666 }
667
668 generator_submodel.add_constraints(induction_guards);
669
670 return_expression_submodel.add_constraint(expression);
671
672 Comprehension {
673 return_expression_submodel,
674 generator_submodel,
675 induction_vars: induction_variables.into_iter().collect_vec(),
676 }
677 }
678}
679
680fn is_induction_guard(induction_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
682 guard
683 .universe_bi()
684 .iter()
685 .all(|x| induction_variables.contains(x))
686}